用Detr训练自定义数据

前面记录了Detr及其改进Deformable Detr。这一篇记录一下用Detr训练自己的数据集。先看下Detr附录中给出的大体源码,整体非常清晰。

接下来记录大体实现过程

一、数据准备

借助labelme对数据进行标注

然后将标注数据转换成COCO格式,得到以下几个文件

其中JPEGImages存放所有图片,Visualization存放可视化结果,annotations.json保存所有图片的标注信息

二、模型训练

2.1 编写DataLoader

在detr/datasets目录下创建一个custom_data.py文件用于处理自己的数据。创建一个类,主要包含__getitem____len__方法。

在新建一个build方法用于detr构建数据。

再到当前目录下的__init__.py文件中添加新的数据类型

def build_dataset(image_set, args):if args.dataset_file == 'coco':return build_coco(image_set, args)if args.dataset_file == 'coco_panoptic':# to avoid making panopticapi required for cocofrom .coco_panoptic import build as build_coco_panopticreturn build_coco_panoptic(image_set, args)if args.dataset_file == 'tooth':from .custom_data import build as build_tooth  return build_tooth(image_set, args)

2.2 训练

修改配置参数
mian.py中新增数据路径参数

修改类别数量,在models/detr.py中修改类别数,类别数要设置为实际类型+1,加1是添加背景类。

num_classes = 2 if args.dataset_file != 'coco' else 91 

加载预训练模型

if args.resume:if args.resume.startswith('https'):checkpoint = torch.hub.load_state_dict_from_url(args.resume, map_location='cpu', check_hash=True)else:checkpoint = torch.load(args.resume, map_location='cpu')# ==============================================================# 这一段是修改了的,去除多余的参数,并将load_state_dict设置为strict=False,这样它便会只加载模型结构相同部分的预训练参数del checkpoint["model"]["class_embed.weight"]del checkpoint["model"]["class_embed.bias"]del checkpoint["model"]["query_embed.weight"]model_without_ddp.load_state_dict(checkpoint['model'], strict=False)

开始训练

python -m torch.distributed.launch --nproc_per_node=1 --use_env main.py --tooth_path /home/jinhai_zhou/data/2D_seg/ --dataset_file tooth --output_dir ./output/path/box_model --resume "./models/detr-r50-e632da11.pth" 

我这里检测训练了500次左右开始收敛,分割训练了大概200多次开始接近收敛

如果训练分割模型,建议分两步,先训练检测模型,然后再训练分割头。

三、测试

新增一个predict.py文件,用于测试
里面主要包含检测和画图两部分内容

  • 检测
def detect(im, model, transform, threshold=0.7):# mean-std normalize the input image (batch-size: 1)img = transform(im).unsqueeze(0)print("image.shape:", img.shape)# demo model only support by default images with aspect ratio between 0.5 and 2# if you want to use images with an aspect ratio outside this range# rescale your image so that the maximum size is at most 1333 for best results# assert img.shape[-2] <= 1600 and img.shape[-1] <= 1600, 'demo model only supports images up to 1600 pixels on each side'# propagate through the modeloutputs = model(img)# keep only predictions with 0.7+ confidenceprobas = outputs['pred_logits'].softmax(-1)[0, :, :-1]keep = probas.max(-1).values > threshold# convert boxes from [0; 1] to image scalesbboxes_scaled = rescale_bboxes(outputs['pred_boxes'][0, keep], im.size)return probas[keep], bboxes_scaled
  • 绘制结果
def plot_results(pil_img, prob, boxes, output):CLASSES = ['N/A', 'teeth']# colors for visualizationCOLORS = [[0.000, 0.447, 0.741], [0.850, 0.325, 0.098], [0.929, 0.694, 0.125],[0.494, 0.184, 0.556], [0.466, 0.674, 0.188], [0.301, 0.745, 0.933]]plt.figure(figsize=(16,10))plt.imshow(pil_img)ax = plt.gca()for p, (xmin, ymin, xmax, ymax), c in zip(prob, boxes.tolist(), COLORS * 100):ax.add_patch(plt.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin,fill=False, color=c, linewidth=3))cl = p.argmax()text = f'{CLASSES[cl]}: {p[cl]:0.2f}'ax.text(xmin, ymin, text, fontsize=15,bbox=dict(facecolor='yellow', alpha=0.5))plt.axis('off')plt.savefig(output)plt.close()# plt.show()

测试

    device = torch.device(args.device)# fix the seed for reproducibilityseed = args.seed + utils.get_rank()torch.manual_seed(seed)np.random.seed(seed)random.seed(seed)model, criterion, postprocessors = build_model(args)model.to(device)n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)print('number of params:', n_parameters)output_dir = Path(args.output_dir)if args.resume:checkpoint = torch.load(args.resume, map_location='cpu') model.load_state_dict(checkpoint['model'], strict=args.strict)print("load model {} is success!".format(args.resume))else:print("Don't load model!")return# standard PyTorch mean-std input image normalizationtransform = T.Compose([T.Resize(800),T.ToTensor(),T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])if args.img_path is not None:assert Path(args.img_path).is_file(), "{} not an image path".format(args.img_path)im = Image.open(img_path)scores, boxes = detect(im, model, transform=transform)print("scores: ", scores)print("boxes: ", boxes)if args.img_dirs is not None:assert Path(args.img_dirs).is_dir(), "{} not a dir path".format(args.img_dirs)img_paths = Path(args.img_dirs).glob("*.jpg")# print("loads {} images".format(len(list(img_paths))))for idx, img_path in enumerate(img_paths):print(img_path)im = Image.open(img_path)scores, boxes = detect(im, model, transform=transform)print(" scores: ", scores)print("boxes: ", boxes)out_path = Path(output_dir) / img_path.nameprint("out_path: ", out_path)plot_results(im, scores, boxes, out_path) 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/859487.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

8086汇编(16位汇编)学习笔记05.asm基础语法和串操作

https://bpsend.net/thread-121-1-2.htmlasm基础语法 1. 环境配置xp环境配置 1.拷贝masm615到指定目录 2.将masm615目录添加进环境变量 3.在cmd中输入ml,可以识别即配置成功dosbox环境配置 1.拷贝masm611到指定目录 2.将masm611所在目录添挂载进dosbox 3.将masm611目录在dosbo…

WinNTSetup 系统安装利器 v5.4.0 单文件版

软件介绍 WinNTSetup,系统安装利器,目前最好用的系统安装器,Windows系统安装部署工具。支持所有Windows平台,支持多系统安装、完全格式化C盘、支持创建VHD虚拟硬盘、在Windows及PE系统下运行,允许在安装前对系统进行预优化设置、集成驱动程序、启用第三方主题支持、加入无…

解决 Cannot GET /favicon.ico

一、报错 二、定位(项目所在文件夹) 三、改名(添加图片,重命名)

Java编程规范-DO / BO / DTO / VO / AO的使用

Java 开发 DO / BO / DTO / VO / AO 的作用 Java 开发中,DO(Data Object)、BO(Business Object)、DTO(Data Transfer Object)、VO(View Object) 和 AO(Application Object) 是常用的对象类型,每种类型都在特定的层次和场景中发挥不同的作用。以下是它们的定义和使用…

硬件开发笔记(三十二):TPS54331电源设计(五):原理图BOM表导出、元器件封装核对

前言一个12V转5V、3.3V和4V的电源电路设计好了,下一步导出BOM表,二次核对元器件型号封装,这是可以生产前的最后一步了。 导出BOM表步骤一:打开原理图打开项目,双击点开原理图:   步骤二:报告-元器件列表列宽一点,板子元器件种类规格不多的时候,导出的东西也不多,因…

数字孪生-智能制造

1、数字企业内循环:打造端到端的数字化应用体验 2、GARTNER分层架构 3、企业数字化架构 4、数字企业的两大核心特征 6、产品数字主线赋能企业转型 7、数字主线关键技术:基于统一架构构建产品全量数字模型 8、闭环数字化解决方案 9、基于数字主线的设计-仿真-试验协同 10、产品数…

C# WPF PrintDialog 打印(3)

前面https://www.cnblogs.com/yinyu5/p/18634080使用PrintDocument方法打印了Canvas,这里打印下面的DataGrid列表内容:这里DataGrid的数据源是DataTable,后台代码:1 private void PrintDocument_DataTable_Method(string Title, DataTable dataTable)2 {3 …

【JAVA代码审计】记一次某java类的cms最最最详细的代码审计

前言 刚好遇到一个授权的渗透是通过该cms实现getshell,所以顺便审计一下java类的cms,这个管理系统是一个内容管理系统,下载地址 https://gitee.com/oufu/ofcms/tree/V1.1.3/tomcat下载地址 https://dlcdn.apache.org/tomcat/tomcat-8/v8.5.78/bin/apache-tomcat-8.5.78-wind…

12.26日每日总结

昨天在调试51单片机的串口时,发现芯片手册上有一句话,在使用定时器1产生串口的波特率时,定时器1就不能使能了。不是不能用,是直接不让使能了,使能后会出错,导致发送的数据不稳定。 今天继续研究了触摸滑条,发现滑条输出的值为从小到大,如下图所示的样子,这就导致从最上…

Minio使用教程

Minio MinIO 是一个高性能的对象存储服务器,用于构建云存储解决方案。它使用Golang编写,专为私有云、公有云和混合云环境设计。它是兼容Amazon S3 API的,并可以作为一个独立的存储后端或与其他流行的开源解决方案(如Kubernetes)集成。 MinIO 允许你存储非结构化数据(如图…