前面记录了Detr及其改进Deformable Detr。这一篇记录一下用Detr训练自己的数据集。先看下Detr附录中给出的大体源码,整体非常清晰。
接下来记录大体实现过程
一、数据准备
借助labelme对数据进行标注
然后将标注数据转换成COCO格式,得到以下几个文件
其中JPEGImages
存放所有图片,Visualization
存放可视化结果,annotations.json
保存所有图片的标注信息
二、模型训练
2.1 编写DataLoader
在detr/datasets目录下创建一个custom_data.py
文件用于处理自己的数据。创建一个类,主要包含__getitem__
和__len__
方法。
在新建一个build
方法用于detr构建数据。
再到当前目录下的__init__.py
文件中添加新的数据类型
def build_dataset(image_set, args):if args.dataset_file == 'coco':return build_coco(image_set, args)if args.dataset_file == 'coco_panoptic':# to avoid making panopticapi required for cocofrom .coco_panoptic import build as build_coco_panopticreturn build_coco_panoptic(image_set, args)if args.dataset_file == 'tooth':from .custom_data import build as build_tooth return build_tooth(image_set, args)
2.2 训练
修改配置参数
在mian.py
中新增数据路径参数
修改类别数量,在models/detr.py
中修改类别数,类别数要设置为实际类型+1,加1是添加背景类。
num_classes = 2 if args.dataset_file != 'coco' else 91
加载预训练模型
if args.resume:if args.resume.startswith('https'):checkpoint = torch.hub.load_state_dict_from_url(args.resume, map_location='cpu', check_hash=True)else:checkpoint = torch.load(args.resume, map_location='cpu')# ==============================================================# 这一段是修改了的,去除多余的参数,并将load_state_dict设置为strict=False,这样它便会只加载模型结构相同部分的预训练参数del checkpoint["model"]["class_embed.weight"]del checkpoint["model"]["class_embed.bias"]del checkpoint["model"]["query_embed.weight"]model_without_ddp.load_state_dict(checkpoint['model'], strict=False)
开始训练
python -m torch.distributed.launch --nproc_per_node=1 --use_env main.py --tooth_path /home/jinhai_zhou/data/2D_seg/ --dataset_file tooth --output_dir ./output/path/box_model --resume "./models/detr-r50-e632da11.pth"
我这里检测训练了500次左右开始收敛,分割训练了大概200多次开始接近收敛
如果训练分割模型,建议分两步,先训练检测模型,然后再训练分割头。
三、测试
新增一个predict.py
文件,用于测试
里面主要包含检测和画图两部分内容
- 检测
def detect(im, model, transform, threshold=0.7):# mean-std normalize the input image (batch-size: 1)img = transform(im).unsqueeze(0)print("image.shape:", img.shape)# demo model only support by default images with aspect ratio between 0.5 and 2# if you want to use images with an aspect ratio outside this range# rescale your image so that the maximum size is at most 1333 for best results# assert img.shape[-2] <= 1600 and img.shape[-1] <= 1600, 'demo model only supports images up to 1600 pixels on each side'# propagate through the modeloutputs = model(img)# keep only predictions with 0.7+ confidenceprobas = outputs['pred_logits'].softmax(-1)[0, :, :-1]keep = probas.max(-1).values > threshold# convert boxes from [0; 1] to image scalesbboxes_scaled = rescale_bboxes(outputs['pred_boxes'][0, keep], im.size)return probas[keep], bboxes_scaled
- 绘制结果
def plot_results(pil_img, prob, boxes, output):CLASSES = ['N/A', 'teeth']# colors for visualizationCOLORS = [[0.000, 0.447, 0.741], [0.850, 0.325, 0.098], [0.929, 0.694, 0.125],[0.494, 0.184, 0.556], [0.466, 0.674, 0.188], [0.301, 0.745, 0.933]]plt.figure(figsize=(16,10))plt.imshow(pil_img)ax = plt.gca()for p, (xmin, ymin, xmax, ymax), c in zip(prob, boxes.tolist(), COLORS * 100):ax.add_patch(plt.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin,fill=False, color=c, linewidth=3))cl = p.argmax()text = f'{CLASSES[cl]}: {p[cl]:0.2f}'ax.text(xmin, ymin, text, fontsize=15,bbox=dict(facecolor='yellow', alpha=0.5))plt.axis('off')plt.savefig(output)plt.close()# plt.show()
测试
device = torch.device(args.device)# fix the seed for reproducibilityseed = args.seed + utils.get_rank()torch.manual_seed(seed)np.random.seed(seed)random.seed(seed)model, criterion, postprocessors = build_model(args)model.to(device)n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)print('number of params:', n_parameters)output_dir = Path(args.output_dir)if args.resume:checkpoint = torch.load(args.resume, map_location='cpu') model.load_state_dict(checkpoint['model'], strict=args.strict)print("load model {} is success!".format(args.resume))else:print("Don't load model!")return# standard PyTorch mean-std input image normalizationtransform = T.Compose([T.Resize(800),T.ToTensor(),T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])if args.img_path is not None:assert Path(args.img_path).is_file(), "{} not an image path".format(args.img_path)im = Image.open(img_path)scores, boxes = detect(im, model, transform=transform)print("scores: ", scores)print("boxes: ", boxes)if args.img_dirs is not None:assert Path(args.img_dirs).is_dir(), "{} not a dir path".format(args.img_dirs)img_paths = Path(args.img_dirs).glob("*.jpg")# print("loads {} images".format(len(list(img_paths))))for idx, img_path in enumerate(img_paths):print(img_path)im = Image.open(img_path)scores, boxes = detect(im, model, transform=transform)print(" scores: ", scores)print("boxes: ", boxes)out_path = Path(output_dir) / img_path.nameprint("out_path: ", out_path)plot_results(im, scores, boxes, out_path)