RTDETR论文快速理解和代码快速实现(训练与预测)

文章目录

  • 前言
  • 一、摘要
  • 二、论文目的
  • 三、论文贡献
  • 四、模型结构
    • 1、模型整体结构
    • 2、backbone结构
    • 3、neck结构
    • 4、混合编码器(neck)
  • 五、RTDERT模型训练(data-->train)
    • 1、环境安装
    • 2、训练
      • 1、数据准备
      • 2、数据yaml文件
      • 3、训练代码
      • 4、训练运行结果
    • 3、推理
      • 1、推理代码
      • 2、推理运行结果
  • 总结


前言

最近,我们想比较基于DETR的transformer模型与基于CNN的yolo模型效果,而百度RT-DETR模型声称“在实时目标检测领域打败YOLO”。从数据的角度来看,RT-DETR似乎确实在某些方面超越了YOLO。我选择RT-DETR模型与YOLO模型比较。本篇文章将介绍RT-DETR模型原理–>环境安装–>数据准备–>训练实现–>预测实现。


一、摘要

近期,端到端基于transformer检测器DETRs已有显著性能。然而,DETR的计算成本限制其实际应用,也阻止其无后处理的优势(如:NMS)。在这篇论文,我们首次分析NMS对目标检测的速度与精确率影响,并构建了端到端的speed基准。为了解决这些问题,我们提出RT-DETR模型,据我们所知,这是第一个实时端到端检测模型。特别的,我们设计一个高效混合编码器加工多尺度特征与特征交互和融合,并提出IOU感知查询,通过像解码器提供更高初始目标来提示性能。除此之外,我们提出的检测模型,可使用解码层without retraining灵活调整推理速度,这样可适应多样的实时场景。我们模型RT-DETR-L在coco2017实现了53%的AP和114FPS on T4 gpu,而RT-DETR-X实现54.8%AP和74FPS,超过同规模模型的yolo。此外,我们的 RT-DETR-R50 实现了53.1%的AP和108FPS的速度,准确性优于 DINO-Deformable-DETR-R50 约 2.2% AP,帧率约为其21倍。
在这里插入图片描述

二、论文目的

实时目标检测是一个重要的研究领域,而DETR的高计算成本问题尚未得到有效解决,这限制了DETR的实际应用,并导致无法充分利用其优势(后处理)。换句话说,RTDETR解决问题是

为了实现上述目标,我们重新思考了DETR,并对其关键组件进行了详细分析和实验,以减少不必要的计算冗余。具体而言,我们发现虽然引入多尺度特征有助于加快训练收敛和提高性能[43],但它也导致输入编码器的序列长度显著增加。因此,由于高计算成本,Transformer编码器成为模型的计算瓶颈。为了实现实时目标检测,我们设计了一个高效的混合编码器来替代原始的Transformer编码器。通过解耦多尺度特征的内尺度交互和跨尺度融合,编码器能够高效处理具有不同尺度的特征。此外,先前的研究[35, 20]表明,解码器的对象查询初始化方案对于检测性能至关重要。为了进一步提高性能,我们提出了基于IoU的查询选择方法,通过在训练过程中提供IoU约束,为解码器提供更高质量的初始对象查询。此外,我们提出的检测器支持通过使用不同的解码器层对推理速度进行灵活调节,无需重新训练,这得益于DETR架构中解码器的设计,并有助于实时检测器的实际应用。

三、论文贡献

本论文的主要贡献总结如下:

1、我们提出了第一个实时端到端目标检测器,不仅在准确性和速度方面优于当前最先进的实时检测器,而且不需要后处理,因此推理速度不会延迟并保持稳定;

2、我们详细分析了NMS对实时检测器的影响,并从后处理的角度得出了关于基于CNN的实时检测器的结论;

3、我们提出的IoU-aware查询选择在模型中展现出卓越的性能改进,为改进目标查询的初始化方案提供了新的思路;

4、我们的工作为端到端检测器的实时实现提供了可行的解决方案,所提出的检测器可以通过使用不同的解码器层进行灵活调整模型大小和推理速度,无需重新训练。

四、模型结构

1、模型整体结构

RT-DETR模型由主干网络(backbone)、混合编码器(hybrid encoder)和带有辅助预测头的Transformer解码器组成。模型的整体架构概述如下图所示。具体来说,我们利用主干网络最后三个阶段的输出特征{S3,S4,S5}作为编码器的输入。混合编码器通过内部尺度交互和跨尺度融合,将多尺度特征转换为图像特征序列。随后,采用IoU感知的查询选择机制,从编码器的输出序列中选择固定数量的图像特征作为解码器的初始对象查询。最后,带有辅助预测头的解码器迭代优化对象查询,生成边界框和置信度分数。
在这里插入图片描述
RT-DETR模型架构图显示了主干网络的最后三个阶段{S3,S4,S5}作为编码器的输入。高效的混合编码器通过内部尺度特征交互(AIFI)和跨尺度特征融合模块(CCFM)将多尺度特征转化为图像特征序列。采用IoU感知的查询选择方法,选择固定数量的图像特征作为解码器的初始对象查询。最后,解码器通过辅助预测头迭代优化对象查询,生成边界框和置信度分数
本文最重要是设计AIFI与CCFM结构

2、backbone结构

与YOLO相似,RT-DETR最终会输出三种不同尺寸的特征图,它们相对于输入图像的分辨率下采样倍数分别是 8 倍、16 倍和 32 倍。这与主流的YOLO算法相似。除此之外,在主干结构的其他方面,RT-DETR并没有特别的地方。

3、neck结构

对于颈部网络部分,RT-DETR 采用了一层 Transformer 的 Encoder ,文中这个颈部网络叫做 Efficient Hybrid Encoder,其包括两部分:Attention-based Intra-scale Feature Interaction (AIFI) 和 CNN-based Cross-scale Feature-fusion Module (CCFM),这个AIFI模块有一点值得注意,这个模块只对S5特征图进行处理

对于AIFI模块(如下左图),它首先将二维的 S5 特征拉成向量,然后交给AIFI模块处理,其数学过程就是多头自注意力与 FFN,随后,再将输出Reshape回二维,记作 F5,以便去完成后续的所谓的“跨尺度特征融合”。

对于CCFM模块(如下右图),以YOLO的角度看这个结构的话,这个CCFM模块就是一个FPN/PAN结构。关于CCFM模块中的Fusion文中也给了详细的结构图,是由 2 个1×1 卷积和 N 个 RepBlock 构成的,这里之所以写成 N ,我觉得是因为 RT-DETR 可以进行缩放处理,通过调整 CCFM中RepBlock 的数量和 Encoder 的编码维度分别控制 Hybrid Encoder 的深度和宽度,同时对 backbone 进行相应的调整即可实现检测器的缩放。
在这里插入图片描述

4、混合编码器(neck)

在3已经介绍neck最终结构,而设计neck结构时,作者为了实时性与减少冗余,设计了一些列结构,其原因是注意力机制的改进减少了计算开销,却输入序列的大幅增加仍导致编码器成为计算瓶颈,不太好实时场景中使用。作者分析了多尺度变换器编码器中存在的计算冗余,设计了一系列变种来证明同时进行内部尺度和跨尺度特征交互在计算上效率低下。

在这里插入图片描述
A → B:变体B插入了一个单尺度的Transformer编码器,它使用了一个Transformer块的层。每个尺度的特征共享编码器,进行内部尺度的特征交互,然后将输出的多尺度特征进行连接。
B → C:变体C在B的基础上引入了基于尺度的特征融合,将连接的多尺度特征输入编码器进行特征交互。
C → D:变体D将多尺度特征的内部尺度交互和跨尺度融合解耦。首先,使用单尺度的Transformer编码器进行内部尺度交互,然后利用类似于PANet [21]的结构进行跨尺度融合。
D → E:变体E在D的基础上进一步优化多尺度特征的内部尺度交互和跨尺度融合,采用了我们设计的高效混合编码器。

RT-DETR认为S5特征相对于较浅的S3和S4特征来说,具有更深、更高级和更丰富的语义特征。这些语义特征对于Transformer模型更加重要,因为它们对于区分不同物体的特征非常有用,而浅层特征由于缺乏良好的语义特征并不是很丰富。

五、RTDERT模型训练(data–>train)

我将在此部分介绍环境安装、数据准备格式、训练相关配置与代码、预测相关内容与代码,我也将数据、官网提供权重放在这里,有需要自行下载。

1、环境安装

使用命令安装,如下:

conda create -n yolov8 python=3.8
conda activate yolov8
git clone https://github.com/ultralytics/ultralytics.git
cd ultralytics
pip install -r requirement.txt
pip install ultralytics

使用上面命令安装可能会报错Could not load library libcudnn_cnn_train.so.8 ,解决方法点击这里,建议先安装较低点的torch版本。

2、训练

我们使用yolov8集成的RTDETR模型,训练与预测文件大致如下图。
在这里插入图片描述

1、数据准备

实际为yolo数据格式,可按照yolov5或v8格式准备即可。

2、数据yaml文件

其数据yaml文件与yolo差不多,但少了nc且将names变成字典的映射,coco8.yaml内容如下:

# Ultralytics YOLO 🚀, AGPL-3.0 license
# COCO8 dataset (first 8 images from COCO train2017) by Ultralytics
# Example usage: yolo train data=coco8.yaml
# parent
# ├── ultralytics
# └── datasets
#     └── coco8  ← downloads here (1 MB)# Train/val/test sets as 1) dir: path/to/imgs, 2) file: path/to/imgs.txt, or 3) list: [path/to/imgs1, path/to/imgs2, ..]
path: C:/Users/Administrator/Desktop/rtdetr/coco128  # dataset root dir
train: images/train  # train images (relative to 'path') 4 images
val: images/train  # val images (relative to 'path') 4 images
test:  # test images (optional)# Classes
names:0: person1: bicycle2: car3: motorcycle4: airplane5: bus6: train7: truck8: boat9: traffic light10: fire hydrant11: stop sign12: parking meter13: bench14: bird15: cat16: dog17: horse18: sheep19: cow20: elephant21: bear22: zebra23: giraffe24: backpack25: umbrella26: handbag27: tie28: suitcase29: frisbee30: skis31: snowboard32: sports ball33: kite34: baseball bat35: baseball glove36: skateboard37: surfboard38: tennis racket39: bottle40: wine glass41: cup42: fork43: knife44: spoon45: bowl46: banana47: apple48: sandwich49: orange50: broccoli51: carrot52: hot dog53: pizza54: donut55: cake56: chair57: couch58: potted plant59: bed60: dining table61: toilet62: tv63: laptop64: mouse65: remote66: keyboard67: cell phone68: microwave69: oven70: toaster71: sink72: refrigerator73: book74: clock75: vase76: scissors77: teddy bear78: hair drier79: toothbrush# Download script/URL (optional)
download: https://ultralytics.com/assets/coco8.zip

3、训练代码

我们使用命令训练,如下代码:

yolo train model=rtdetr-l.pt data=coco8.yaml epochs=100 imgsz=640 batch=2 amp=False

4、训练运行结果

配置好以上内容即可训练,执行过程如下显示
在这里插入图片描述

3、推理

1、推理代码

这里不在过多介绍推理代码,朋友们可自行查阅。

import cv2
import torch
import numpy as np
from ultralytics.nn.autobackend import AutoBackenddef preprocess(image):image = cv2.resize(image, (640, 640))image = (image[..., ::-1] / 255.0).astype(np.float32) # BGR to RGB, 0 - 255 to 0.0 - 1.0image = image.transpose(2, 0, 1)[None]  # BHWC to BCHW (n, 3, h, w)image = torch.from_numpy(image)return imagedef postprocess(pred, oh, ow, conf_thres=0.25):# 输入是模型推理的结果,即300个预测框# 1,300,84 [cx,cy,w,h,class*80]boxes = []for item in pred[0]:cx, cy, w, h = item[:4]label = item[4:].argmax()confidence = item[4 + label]if confidence < conf_thres:continueleft    = cx - w * 0.5top     = cy - h * 0.5right   = cx + w * 0.5bottom  = cy + h * 0.5boxes.append([left, top, right, bottom, confidence, label])boxes = np.array(boxes)lr = boxes[:,[0, 2]]tb = boxes[:,[1, 3]]boxes[:,[0,2]] = ow * lrboxes[:,[1,3]] = oh * tbreturn boxesdef hsv2bgr(h, s, v):h_i = int(h * 6)f = h * 6 - h_ip = v * (1 - s)q = v * (1 - f * s)t = v * (1 - (1 - f) * s)r, g, b = 0, 0, 0if h_i == 0:r, g, b = v, t, pelif h_i == 1:r, g, b = q, v, pelif h_i == 2:r, g, b = p, v, telif h_i == 3:r, g, b = p, q, velif h_i == 4:r, g, b = t, p, velif h_i == 5:r, g, b = v, p, qreturn int(b * 255), int(g * 255), int(r * 255)def random_color(id):h_plane = (((id << 2) ^ 0x937151) % 100) / 100.0s_plane = (((id << 3) ^ 0x315793) % 100) / 100.0return hsv2bgr(h_plane, s_plane, 1)if __name__ == "__main__":img = cv2.imread("1.jpg")oh, ow = img.shape[:2]img_pre = preprocess(img)# postprocess# ultralytics/models/rtdetr/predict.pymodel  = AutoBackend(weights="rtdetr-l.pt")names  = model.namesresult = model(img_pre)[0]  # 1,300,84boxes  = postprocess(result, oh, ow)for obj in boxes:left, top, right, bottom = int(obj[0]), int(obj[1]), int(obj[2]), int(obj[3])confidence = obj[4]label = int(obj[5])color = random_color(label)cv2.rectangle(img, (left, top), (right, bottom), color=color ,thickness=2, lineType=cv2.LINE_AA)caption = f"{names[label]} {confidence:.2f}"w, h = cv2.getTextSize(caption, 0, 1, 2)[0]cv2.rectangle(img, (left - 3, top - 33), (left + w + 10, top), color, -1)cv2.putText(img, caption, (left, top - 5), 0, 1, (0, 0, 0), 2, 16)cv2.imwrite("infer.jpg", img)print("save done")  

注:若下载了文件可直接 python detect.py执行,可得结果

2、推理运行结果

在这里插入图片描述


总结

文章主要是更换backbone(个人觉得不是文章重点),而使用S5在结合作者多个neck模块实验,该neck结构主打消除计算实现实时。
代码可使用百度官网代码,也可使用yolov8自带代码(高效实现)。
后期,我将仿yolov5一键训练与预测,直接使用xml文件格式训练有预测RTDETR文章。

参考博客链接:
https://blog.csdn.net/qq_40672115/article/details/134356250
https://blog.csdn.net/weixin_43694096/article/details/131353118

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/289077.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

PyQt5连接mysql失败解决

一&#xff1a;背景 最近研究一个项目&#xff0c;里面用的Pyqt5编写的桌面应用&#xff0c;跑了下源码发现连接数据库那块出来问题&#xff0c;最终调试发现里面用的QtSql去连接mysql提示驱动找不到。 具体报错信息如下&#xff1a; Could not parse stylesheet of object …

CAS为什么还存在线程安全问题(从所谓的ABA问题再学CAS)

概述 之前学习 CAS&#xff0c;从 Java 代码层面知道其原理&#xff0c;借助一条 CPU 原子指令&#xff0c;通过不断地自旋去比较&#xff08;compare&#xff09;和&#xff08;and&#xff09;赋值&#xff08;set&#xff09;。当时对线程安全的认知停留在将多条 Java 语句…

MySQL实战45讲课后问题

1、第一章 如果表T中没有字段k&#xff0c;而你执行了这个语句 select *fromTwhere k1, 那肯定是会报“不存在这个列”的错误&#xff1a; “Unknown column ‘k’ in ‘where clause’”。你觉得这个错误是在我们上面提到的哪个阶段报出来的呢&#xff1f; 解答&#xff1a;…

STL stack练习

CSTL之stack栈容器 - 数据结构教程 - C语言网CSTL之stack栈容器1.再谈栈回顾一下之前所学的栈&#xff0c;栈是一种先进后出的数据结构&#xff0c;而实现方式需要创建多个结构体&#xff0c;通过链式的方式进行实现&#xff0c;这是标准的栈的思路&#xff0c;而在STL中栈可以…

中央空调冷却塔循环水全自动加药装置PH电导率设备工作原理动画

一&#xff1a;全自动加药装置【概述】 随着在给水、排水处理过程中&#xff0c;常常投加各类化学药剂作为阻垢、杀菌灭藻、混凝、絮凝用&#xff0c;以达到净化水的目的。这些药剂有固体颗粒、液体&#xff0c;在投加过程中必须溶解、稀释及按配比定量投加方能取得最佳效果&am…

数据安全扫描仪荣膺网络安全优秀创新成果大赛优胜奖 - 凸显多重优势

近日&#xff0c;由中国网络安全产业联盟&#xff08;CCIA&#xff09;主办、CCI数据安全工作委员会中国电子技术标准化研究院等单位承办的“2023年网络安全优秀创新成果大赛”获奖名单公布。天空卫士数据安全扫描仪&#xff08;DSS&#xff09;产品获得创新成果大赛优胜奖。 本…

jenkins入门

文章目录 前言一、 jenkins的安装二、新建简单任务总结 前言 本篇文章是 jenkins 的入门级别案例&#xff0c;包括安装、基础概念介绍、新建简单任务 一、 jenkins的安装 下载 jenkins https://www.jenkins.io/download/ 当前案例下载的是 2.426.2 LTS 版本 下载安装jdk11 …

Leetcode—75.颜色分类【中等】

2023每日刷题&#xff08;六十五&#xff09; Leetcode—75.颜色分类 实现代码 class Solution { public:void sortColors(vector<int>& nums) {int red 0, white 0, blue 0;for(auto num: nums) {if(num 0) {red;} else if(num 1) {white;} else {blue;}}for…

Modbus转Profinet网关如何选型应该注意哪些事项

针对当今市场上众多Modbus转Profinet网关&#xff0c;怎样才能选择适合现场使用的Modbus转Profinet网关呢&#xff01; 在进行网关选型之前&#xff0c;一定要清楚设备支持哪些协议&#xff0c;如Modbus或Profinet&#xff0c;ModbusRTU&#xff0c;ModbusTCP;再或者是CANOPEN&…

pbootcms模板(自适应手机端)响应式相册图片网站模板 图片壁纸类网站源码下载

附详细搭建教程&#xff0c;赠送运营资料&#xff1a;SEO教程&#xff0c;友链交换&#xff0c;站长提交收录&#xff0c;优化推广联盟&#xff0c;自媒体等各大优质运营平台 (自适应手机端)响应式相册图片网站模板 图片壁纸类网站源码下载 PbootCMS内核开发的网站模板&#…

博客3万访问量了……

博客有3万访问量了呢。自从第一次用了赠送的1500的流量券&#xff0c;粉丝了从零突破了&#xff0c;到现在有150个粉丝了。 之前预想的写博客的初衷&#xff0c;也是记录自己的学习过程&#xff0c;毕竟好记忆不如烂笔头&#xff0c;记录下来就是长长久久的&#xff0c;随时可以…

图片编辑文字用什么软件?带你了解这5个

图片编辑文字用什么软件&#xff1f;在当今数字化的时代&#xff0c;图片编辑已经成为我们日常生活中不可或缺的一部分。有时候&#xff0c;我们需要在图片上添加文字&#xff0c;以增强图片的视觉效果或传达特定的信息。那么&#xff0c;有哪些可以在图片上编辑文字的软件呢&a…