深度学习系列53:mmdetection上手

1. 安装

使用openmim安装:

pip install -U openmim
mim install "mmengine>=0.7.0"
mim install "mmcv>=2.0.0rc4"

2. 测试案例

下载代码和模型:

git clone https://github.com/open-mmlab/mmdetection.git
mkdir ./checkpoints
mim download mmdet --config rtmdet_tiny_8xb32-300e_coco --dest ./checkpoints

运行代码,核心是定义inferencer和使用inferencer进行推理两行:

from mmdet.apis import DetInferencer# Choose to use a config
model_name = 'rtmdet_tiny_8xb32-300e_coco'
# Setup a checkpoint file to load
checkpoint = './checkpoints/rtmdet_tiny_8xb32-300e_coco_20220902_112414-78e30dcc.pth'# Set the device to be used for evaluation
device = 'cpu'# Initialize the DetInferencer
inferencer = DetInferencer(model_name, checkpoint, device)# Use the detector to do inference
img = 'demo.jpg'
result = inferencer(img, out_dir='./output')# Show the structure of result dict
from rich.pretty import pprint
pprint(result, max_length=4)# Show the output image
from PIL import Image
Image.open('./output/vis/demo.jpg')

3. 自定义数据进行训练

3.1 准备数据

建议使用coco格式,参见https://cocodataset.org/#format-data。文件从头至尾按照顺序分为以下段落:

{
“info”: info,
“licenses”: [license],
“images”: [image],
“annotations”: [annotation],
“categories”: [category]
}
下面是从instances_val2017.json文件中摘出的一个annotation的实例,这里的segmentation就是polygon格式:

{
“segmentation”: [[510.66,423.01,511.72,420.03,510.45…]],
“area”: 702.1057499999998,
“iscrowd”: 0,
“image_id”: 289343,
“bbox”: [473.07,395.93,38.65,28.67],
“category_id”: 18,
“id”: 1768
},
从instances_val2017.json文件中摘出的2个category实例如下所示:

{
“supercategory”: “person”,
“id”: 1,
“name”: “person”
},
{
“supercategory”: “vehicle”,
“id”: 2,
“name”: “bicycle”
},

我们来看测试案例的例子,包含三个大字段,其中categories非常简单,只有一个balloon(我们需要训练的目标)
在这里插入图片描述
images则是如下的清单:
在这里插入图片描述
annotations如下:
在这里插入图片描述

3.2 配置config文件

config文件中需要定义数据,模型,训练参数,优化器等各种参数。测试案例如下:

config_balloon = """
# Inherit and overwrite part of the config based on this config
_base_ = './rtmdet_tiny_8xb32-300e_coco.py'data_root = 'data/balloon/' # dataset roottrain_batch_size_per_gpu = 4
train_num_workers = 2max_epochs = 20
stage2_num_epochs = 1
base_lr = 0.00008metainfo = {'classes': ('balloon', ),'palette': [(220, 20, 60),]
}train_dataloader = dict(batch_size=train_batch_size_per_gpu,num_workers=train_num_workers,dataset=dict(data_root=data_root,metainfo=metainfo,data_prefix=dict(img='train/'),ann_file='train.json'))val_dataloader = dict(dataset=dict(data_root=data_root,metainfo=metainfo,data_prefix=dict(img='val/'),ann_file='val.json'))test_dataloader = val_dataloaderval_evaluator = dict(ann_file=data_root + 'val.json')test_evaluator = val_evaluatormodel = dict(bbox_head=dict(num_classes=1))# learning rate
param_scheduler = [dict(type='LinearLR',start_factor=1.0e-5,by_epoch=False,begin=0,end=10),dict(# use cosine lr from 10 to 20 epochtype='CosineAnnealingLR',eta_min=base_lr * 0.05,begin=max_epochs // 2,end=max_epochs,T_max=max_epochs // 2,by_epoch=True,convert_to_iter_based=True),
]train_pipeline_stage2 = [dict(type='LoadImageFromFile', backend_args=None),dict(type='LoadAnnotations', with_bbox=True),dict(type='RandomResize',scale=(640, 640),ratio_range=(0.1, 2.0),keep_ratio=True),dict(type='RandomCrop', crop_size=(640, 640)),dict(type='YOLOXHSVRandomAug'),dict(type='RandomFlip', prob=0.5),dict(type='Pad', size=(640, 640), pad_val=dict(img=(114, 114, 114))),dict(type='PackDetInputs')
]# optimizer
optim_wrapper = dict(_delete_=True,type='OptimWrapper',optimizer=dict(type='AdamW', lr=base_lr, weight_decay=0.05),paramwise_cfg=dict(norm_decay_mult=0, bias_decay_mult=0, bypass_duplicate=True))default_hooks = dict(checkpoint=dict(interval=5,max_keep_ckpts=2,  # only keep latest 2 checkpointssave_best='auto'),logger=dict(type='LoggerHook', interval=5))custom_hooks = [dict(type='PipelineSwitchHook',switch_epoch=max_epochs - stage2_num_epochs,switch_pipeline=train_pipeline_stage2)
]# load COCO pre-trained weight
load_from = './checkpoints/rtmdet_tiny_8xb32-300e_coco_20220902_112414-78e30dcc.pth'train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=max_epochs, val_interval=1)
visualizer = dict(vis_backends=[dict(type='LocalVisBackend'),dict(type='TensorboardVisBackend')])
"""with open('../configs/rtmdet/rtmdet_tiny_1xb4-20e_balloon.py', 'w') as f:f.write(config_balloon)

3.3 开始训练

使用Mac M2芯片需要修改3个地方。首先是需要设置

export PYTORCH_ENABLE_MPS_FALLBACK=1

其次是mmcv中的nms需要转到cpu上计算,打开mmcv/ops/nms.py,将class NMSop(torch.autograd.Function)中的inds = ext_module.nms(bboxes, scores…)改为inds = ext_module.nms(bboxes.cpu(), scores.cpu()…)
运行后会出现一个assert报错,找到源代码,把那一行assert删掉即可。
运行完成后,可以查看tensorboard:

%load_ext tensorboard# see curves in tensorboard
%tensorboard --logdir ./work_dirs

然后查看测试结果

from mmdet.apis import DetInferencer
import glob# Choose to use a config
config = '../configs/rtmdet/rtmdet_tiny_1xb4-20e_balloon.py'
# Setup a checkpoint file to load
checkpoint = glob.glob('./work_dirs/rtmdet_tiny_1xb4-20e_balloon/best_coco*.pth')[0]# Set the device to be used for evaluation
device = 'cpu'# Initialize the DetInferencer
inferencer = DetInferencer(config, checkpoint, device)# Use the detector to do inference
img = './data/balloon/val/4838031651_3e7b5ea5c7_b.jpg'
result = inferencer(img, out_dir='./output')
# Show the output image
Image.open('./output/vis/4838031651_3e7b5ea5c7_b.jpg')

在这里插入图片描述

4. 其他

MMYOLO:传统的目标检测库
MMRotate:旋转检测库
MMDetection3D:三维检测库
下面几期一一介绍。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/193194.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

网络运维与网络安全 学习笔记2023.11.18

网络运维与网络安全 学习笔记 第十九天 今日目标 冲突域和交换机工作原理、广播域和VLAN原理 VLAN配置、TRUNK原理与配置、HYBRID原理与配置 冲突域和交换机工作原理 冲突域概述 定义 网络设备发送的数据,产生冲突的区域(范围) 对象 “数…

Nacos注册表解读

基本介绍 在 Nacos 中,注册表是其中一个重要的组件,用于管理服务的注册和发现。 注册表是一个存储服务实例信息的数据库,它记录了所有已注册的服务实例的相关信息,包括服务名称、IP 地址、端口号等。 通过注册表,服…

定时获取公网ip并发送邮件提醒

前一段时间路由器刷的老毛子固件“穿透服务”中定时更新阿里DDNS失败了,用了很久第一次遇到。所以需要做个备用的措施用来实时获取公网ip信息 1、基于python实现 开启邮箱的SMTP功能拿到授权码(不是登录密码) #!/usr/bin/python # -*- coding: UTF-8 -*- import …

【RocketMq系列-01】RocketMq安装和基本概念

RocketMq系列整体栏目 内容链接地址【一】RocketMq安装和基本概念https://zhenghuisheng.blog.csdn.net/article/details/134486709 RocketMq安装和基本概念 一,RocketMq安装和基本概念1,RocketMq基本安装(本地安装)2,Rocketmq的核心概念2.1&…

6.7二叉树的最小深度(LC111)

审题要清楚: 最小深度是从根节点到最近叶子节点的最短路径上的节点数量。注意是叶子节点(左右孩子都为空的节点才是叶子节点!)。 算法: 既可以求最小高度,也可以直接求深度。 最小高度: 后序…

JVM面试必备

目录 JVM三大问题 一、JVM内存区域划分 ​编辑 二、JVM类加载机制 双亲委派模型(常考) 类加载的格式,类卸载 三、垃圾回收(GC) 具体垃圾回收GC步骤 1.判定对象是否为垃圾 方案1:引用计数 方案2:可达性分析 2.释放对象的…

供应链|顶刊MSOM论文解读:服务竞争下的库存共享

问题背景 在汽车、玩具等行业中,零售商之间的库存共享变得十分常见。库存共享可以解决由需求不确定导致的库存错配问题。如果零售商之间同意共享库存,那么当需求较少、自身库存过剩时,可以将过剩库存卖给其他零售商;反之&#xf…

图像分类(一) 全面解读复现AlexNet

解读 论文原文:http://papers.nips.cc/paper/4824-imagenet-classification-with-deep-convolutional-neural-networks.pdf Abstract-摘要 翻译 我们训练了一个庞大的深层卷积神经网络,将ImageNet LSVRC-2010比赛中的120万张高分辨率图像分为1000个不…

WinForms C# 导入和导出 CSV 文件 Spread.NET

使用 WinForms C# 和 VB.NET 导入和导出 CSV 文件 2023 年 11 月 17 日 使用 Spread.NET 直接在 .NET WinForms 应用程序中处理 CSV 文件。 Spread.NET可帮助您创建电子表格、网格、仪表板和表单。它包括一个强大的计算引擎,具有 450 多个函数以及导入和导出 Micros…

React+后端实现导出Excle表格的功能

最近在做一个基于Reactantd前端框架的Excel导出功能,我主要在后端做了处理,这个功能完成后,便总结成一篇技术分享文章,感兴趣的小伙伴可以参考该分享来做导出excle表格功能,以下步骤同样适用于vue框架,或者…

从0开始学习JavaScript--JavaScript 字符串与文本内容使用

JavaScript中的字符串和文本内容处理是前端开发中的核心技能之一。本文将深入研究字符串的创建、操作,以及文本内容的获取、修改等操作,并通过丰富的示例代码,帮助读者更全面地了解和应用这些概念。 JavaScript 字符串基础 字符串是JavaScr…

在VS Code中使用VIM

文章目录 安装和基本使用设置 安装和基本使用 VIM是VS Code的强大对手,其简化版本VI是Linux内置的文本编辑器,堪称VS Code问世之前最流行的编辑器,也是VS Code问世之后,我仍在使用的编辑器。 对VIM无法割舍的原因有二&#xff0…