7、DETR:基于Transformer的端到端目标检测

目录

一、论文题目

二、背景与动机

三、创新与卖点

四、具体实现细节

模型架构

简易代码

五、结论与展望

六、一些资料


一、论文题目

End-to-End Object Detection with Transformersicon-default.png?t=N7T8https://arxiv.org/abs/2005.12872

二、背景与动机

        在计算机视觉领域,目标检测一直是一个核心问题。传统的目标检测方法,如Faster R-CNN和SSD,依赖于一系列复杂的预处理步骤,包括锚框生成、非极大值抑制(NMS)等。这些步骤通常需要精心的设计和调整,同时也引入了额外的计算复杂度。

        随着Transformer模型在自然语言处理(NLP)领域的成功,其背后的注意力机制被发现对于处理序列数据非常有效。注意力机制的核心优势在于能够捕捉数据之间的长距离依赖性,这在处理图像时也非常有用。因此,研究者们开始探索Transformer在计算机视觉任务中的应用。

        DETR(Detection Transformer)是一个创新的目标检测模型,它将Transformer模型整合到目标检测流程中,并提出了一种端到端的检测方法,摒弃了传统的预处理步骤。

三、创新与卖点

DETR的主要卖点在于其简洁而高效的设计。它通过以下几点实现了对目标检测的重新思考:

  1. 端到端学习:DETR彻底改变了目标检测的传统流程,实现了真正的端到端训练,将图像特征提取、目标定位和分类任务全部整合在Transformer中,提升了模型的整体优化效果。

  2. 无锚框设计:不同于以往的目标检测器需要预先定义一系列大小不一的锚框进行匹配,DETR直接预测出有限数量(例如100个)的物体候选框及其对应的类别概率,大大简化了检测过程。

  3. 集合预测与解码器-编码器结构:DETR采用Transformer中的编码器-解码器结构,其中编码器负责捕获全局上下文信息,解码器则生成一组潜在的目标框。独特的“集合预测”机制允许模型以并行的方式预测所有目标,无需繁琐的排序或筛选操作。

  4. 注意力机制:DETR利用Transformer中的自注意力机制,使得模型能够更好地理解图像中各个部分之间的关系,进一步提升目标检测的精度和鲁棒性。

四、具体实现细节

模型架构

DETR的架构分为三个部分:CNN Backbone、Transformer和FFN(Feed Forward Network)。

  1. CNN Backbone: 用于提取图像的特征。这些特征随后被展平并传递给Transformer模型。

  2. Transformer: 包括编码器和解码器。编码器使用自注意力处理图像特征,解码器接收位置编码(learnable positional encodings)和来自编码器的特征表示,并通过自注意力机制和交叉注意力机制,生成一组固定长度的向量序列。每个向量代表一个潜在的目标框及其对应类别的预测。

  3. FFN和输出层: FFN对每个解码器输出进行处理,输出层则生成最终的边界框和类别标签。

        模型训练时,使用的损失函数是匈牙利损失和边界框回归损失的组合。匈牙利损失确保了预测和真实标签之间的有效匹配,而边界框回归损失则优化了框的精确位置。

一下内容引自这里:DETR目标检测新范式带来的思考 - 知乎

Transformer

CNN提取的特征拉直(flatten)后加入位置编码(positional encoding)得到序列特征,作为Transformer encoder的输入。Transformer中的attention机制具有全局感受野,能够实现全局上下文的关系建模,其中encoder和decoder均由多个encoder、decoder层堆叠而成。每个encoder层中包含self-attention机制,每个decoder中包含self-attention和cross-attention。

object queries

transformer解码器中的序列是object queries。每个query对应图像中的一个物体实例( 包含背景实例 ϕ),它通过cross-attention从编码器输出的序列中对特定物体实例的特征做聚合,又通过self-attention建模该物体实例域其他物体实例之间的关系。最终,FFN基于特征聚合后的object queries做分类的检测框的回归。

        值得一提的是,object queries是可学习的embedding,与当前输入图像的内容无关(不由当前图像内容计算得到)。论文中对不同object query在COCO数据集上输出检测框的位置做了统计(如上图所示),可以看不同object query是具有一定位置倾向性的。对object queries的理解可以有多个角度。首先,它随机初始化,并随着网络的训练而更新,因此隐式建模了整个训练集上的统计信息。其次,在目标检测中每个object query可以看作是一种可学习的动态anchor,可以发现,不同于Faster RCNN, RetinaNet等方法在特征的每个像素上构建稠密的anchor不同,detr只用少量稀疏的anchor(object queries)做预测,这也启发了后续的一系列工作。

将目标检测问题看做Set Prediction问题

DETR中将目标检测问题看做Set Prediction问题,即将图像中所有感兴趣的物体看作是一个集和,要实现的目标是预测出这一集和。也就是说在DETR的视角下,目标检测不再是单独预测多个感兴趣的物体,而是从全局上将检测出所有目标所构成的整体作为目标。对应的,DETR站在全局的视角,用二分图匹配算法(匈牙利算法)计算prediction与ground truth之间的最佳匹配,从而实现label assignment。以上过程中需要定义什么是最佳匹配,也就是对所有可能的匹配做排序,DETR将一种匹配下模型的总定位和分类损失作为评判标准,损失越低,匹配越佳。注意,该匹配过程是不回传梯度的。DETR这种从全局的视角来实现label assignment的范式也启发了后续的一系列工作。

简易代码

import torch
import torch.nn as nn
from torchvision.models import resnet50
from torch.nn.functional import cross_entropy, softmaxclass DETR(nn.Module):def __init__(self, num_classes, num_queries, hidden_dim=256, nheads=8, num_encoder_layers=6, num_decoder_layers=6):super(DETR, self).__init__()# CNN Backbone: 使用 ResNet50,移除最后的分类层和池化层。self.backbone = nn.Sequential(*list(resnet50(pretrained=True).children())[:-2])# 将 CNN 特征映射到 Transformer 的维度self.conv = nn.Conv2d(2048, hidden_dim, 1)# Transformer: 使用 PyTorch 的 nn.Transformerself.transformer = nn.Transformer(d_model=hidden_dim, nhead=nheads, num_encoder_layers=num_encoder_layers, num_decoder_layers=num_decoder_layers)# 类别预测:每个查询预测所有类别的概率self.class_pred = nn.Linear(hidden_dim, num_classes)# 边界框预测:每个查询预测一个边界框self.bbox_pred = nn.Linear(hidden_dim, 4)# 对象查询:固定数目的查询向量self.query_embed = nn.Embedding(num_queries, hidden_dim)def forward(self, images):# 获取 CNN 特征features = self.backbone(images)# 调整特征维度匹配 Transformer 的输入要求h, _, _, _ = features.shapefeatures = self.conv(features)features = features.flatten(2).permute(2, 0, 1)# 对象查询向量queries = self.query_embed.weight.unsqueeze(1).repeat(1, h, 1)# 通过 Transformer 进行特征和查询的交互transformer_out = self.transformer(features, queries)# 预测类别和边界框class_logits = self.class_pred(transformer_out)bbox_logits = self.bbox_pred(transformer_out).sigmoid()return class_logits, bbox_logits# 实例化模型
num_classes = 91  # COCO 数据集的类别数目
num_queries = 100  # 根据任务需求设定的查询数目
model = DETR(num_classes, num_queries)# 输入图像张量 (batch_size, channels, height, width)
images = torch.rand(2, 3, 800, 800)  # 示例图像张量# 预测
class_logits, bbox_logits = model(images)# 计算损失(需要实际标签才能完成)
# class_loss = cross_entropy(class_logits, labels)
# bbox_loss = ... # 边界框损失计算
# total_loss = class_loss + bbox_loss

        请注意,这个示例代码没有包含完整的 DETR 模型的所有细节,例如边界框损失的计算和匈牙利匹配算法。此外,为了训练 DETR 模型,还需要定义适当的数据加载器、优化器、学习率调度器以及训练循环。这个示例仅用于说明 DETR 架构的基本组件和数据流。 

五、结论与展望

        DETR提供了一种全新的视角来解决目标检测问题。它通过利用Transformer强大的编码能力和端到端的优势,显著简化了检测流程,同时在准确率上与传统方法保持竞争力。尽管在速度上可能不如一些专门为实时应用设计的检测模型,DETR的架构为未来的研究和应用提供了一个有趣的新方向。

六、一些资料

官方源码icon-default.png?t=N7T8https://github.com/facebookresearch/detrDETR目标检测新范式带来的思考 - 知乎公众号:将门创投2020年,Transformer在计算机视觉领域大放异彩。Detection Transformer (DETR) [1]就是Transformer在目标检测领域的成功应用。利用Transformer中attention机制能够有效建模图像中的长程关系(long…icon-default.png?t=N7T8https://zhuanlan.zhihu.com/p/398940573

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/411997.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

使用docker搭建LNMP架构

目录 环境准备 下载安装包 服务器环境 任务分析 nginx部分 建立工作目录 编写 Dockerfile 脚本 准备 nginx.conf 配置文件 生成镜像 创建自定义网络 启动镜像容器 验证nginx MySQL部分 建立工作目录 编写 Dockerfile 准备 my.cnf 配置文件 生成镜像 启动镜像…

如何利用SD-WAN升级企业网络,混合组网稳定性更高?

随着企业信息化的升级,传统网络架构已经无法满足企业复杂的、多样化的组网互联需求。 企业多样化的组网需求包括: 一是需要将各办公点互联起来进行数据传输、资源共享; 二是视频会议、ERP、OA、邮箱系统、云服务应用程序等访问需求&#xff…

F-Droid:开源Android应用的宝库

F-Droid:开源Android应用的宝库 引言 F-Droid是一个开源应用程序存储库,旨在为安卓用户提供自由、隐私和安全的应用程序。它最初于2010年由Ciaran Gultnieks创建,因为他认为Google Play Store上的应用程序不够透明和安全。F-Droid的目标是为…

elasticsearch[二]-DSL查询语法:全文检索、精准查询(term/range)、地理坐标查询(矩阵、范围)、复合查询(相关性算法)、布尔查询

ES-DSL查询语法(全文检索、精准查询、地理坐标查询) 1.DSL查询文档 elasticsearch 的查询依然是基于 JSON 风格的 DSL 来实现的。 1.1.DSL 查询分类 Elasticsearch 提供了基于 JSON 的 DSL(Domain Specific Language)来定义查…

码住!软件测试人员的基本有哪些?

在软件测试领域,许多人误以为软件测试只是简单的点点鼠标、看看屏幕就能完成。然而,软件测试的复杂性远不止于此。作为一名软件测试人员,你需要具备多项技能和素质来保证测试的有效性和质量。 打字技能可以事半功倍 打字是软件测试人员必备的…

[linux]使用libqrencode库生成二维码数据

一、需求 要将一段数据生成为二维码&#xff0c; 二、方案 使用linux标准库&#xff0c;通过libqrencode将需要写入的信息转为二维码图片数据。 三、实现 3.1编写c文件 #include <stdio.h> #include <stdlib.h> #include <qrencode.h> int main() {QRc…

Facebook广告优化

通过Facebook广告优化来提高产品销量&#xff0c;以下是一些步骤和技巧&#xff1a; 1、确定目标受众&#xff1a;在Facebook广告平台上&#xff0c;您可以根据性别、年龄、地理位置、兴趣爱好等多种因素来定义您的目标受众。通过细分目标受众&#xff0c;您可以更精准地将广告…

[足式机器人]Part2 Dr. CAN学习笔记- Kalman Filter卡尔曼滤波器Ch05

本文仅供学习使用 本文参考&#xff1a; B站&#xff1a;DR_CAN Dr. CAN学习笔记 - Kalman Filter卡尔曼滤波器 Ch05 1. Recursive Algirithm 递归算法2. Data Fusion 数据融合Covarince Matrix协方差矩阵State Space状态空间方程 Observation观测器3. Step by step : Deriatio…

jmeter-线程数设置为1,循环10次没问题,循环100次出现异常

一、多次尝试&#xff0c;发现出现异常的接口大致相同。 解决办法&#xff1a;在第一个出现异常的接口下添加超时时间&#xff0c;固定定时器&#xff1a;2000ms&#xff0c;再次运行就没问题了。 二、压力机自身存在的问题 1&#xff09;在网络编程中&#xff0c;特别是在短…

查找国外文献的技巧

文章目录 一、方法二、配置参考 一、方法 xrelay&#xff08;1年&#xff09; 其他手段&#xff1a; 手段1手段2 需要自己去看怎么配置 二、配置 google浏览器走代理的配置&#xff1a; 配置步骤&#xff1a; 方法1&#xff1a;https://steemit.com/cn/causenet/7-switc…

Axure全面指南:正确打开并高效使用的步骤!

AxureRP是目前流行的设计精美的用户界面和交互软件。AxureRP根据其应用领域提供了一组丰富的UI控制。作为Axure的国内替代品&#xff0c;即时设计可以在线协作&#xff0c;浏览器可以在无需下载客户端的情况下打开和使用。如果以前使用Axure&#xff0c;很容易切换到即时设计。…

x-www-form-urlencoded接收方式代码示例

数据回推方式是 “x-www-form-urlencoded”&#xff0c;可以选择使用 GET 或 POST 方法来接收数据回推。 使用 GET 方法接收数据回推时&#xff0c;您可以将数据作为查询参数附加在请求的 URL 中。例如&#xff1a; http://example.com/callback?param1value1&param2val…