YOLOV8注意力改进方法:Deformable Attention Transformer(附改进代码)

原论文地址:原论文下载地址

论文摘要:Transformer最近在各种视觉任务上表现出了优异的性能。对于巨大的甚至是全局性的感受野赋予Transformer模型比CNN模型更高的表现力。然而,仅仅扩大感受野也会引起一些担忧。一方面,在ViT中使用密集的注意力会导致过度的内存和计算成本,并且特征可能会受到感兴趣区域之外的不相关部分的影响。另一方面,在PVT或Swin Transformer中采用的稀疏注意力可能会限制建立long range关系模型的能力。为了缓解这些问题,我们提出了一种新的可变形的自注意模块,其中以数据相关的方式选择自注意中的键-值对的位置。这种灵活的方案使自注意模块能够专注于相关区域并捕获其特征。在此基础上,我们提出了Deformable Attention Transformer,这是一种用于图像分类和密集预测任务的具有可变形注意力的通用主干模型。大量实验表明,我们的模型在综合基准上取得了持续改进的结果。

具体的内容学习可以参考这篇博客介绍的比较详细:

博客地址

可变形注意机制的图示。 (a) 呈现可变形注意力的信息流。在左侧部分,一组参考点均匀地放置在特征图上,其偏移量是通过偏移网络从查询中学习到的。然后根据变形点从采样的特征中投影出变形的键和值,如右图所示。相对位置偏差也由变形点计算,增强了输出变形特征的多头注意力。我们只展示了 4 个参考点以进行清晰的展示,实际实施中还有更多参考点。 (b) 揭示了偏移生

Deformable CNN and attention
可变形卷积是一种处理以输入数据为条件的灵活变换空间位置的机制。最近,它已应用于ViT。Deformable DETR通过为CNN backbone顶部的每个query选择少量的key,改进了DETR的收敛性。由于缺少key限制了表示能力,其可变形注意力不适合用于作为特征提取的视觉主干。此外,Deformable DETR中的注意力来自简单学习的线性投影,并且query token之间不共享key。DPT和PS ViT构建可变形模块以优化视觉token。具体而言,DPT提出了一种可变形patch embedding方法来细化各个stage的patch,而PS ViT在ViT主干之前引入了一个空间采样模块来改进视觉token。它们都没有将变形注意力纳入视觉中枢。相比之下,我们的可变形注意力采用了一种强大而简单的设计来学习视觉token之间共享的一组全局key,可以作为各种视觉任务的通用主干。我们的方法也可以被视为一种空间自适应机制,这在各种工作中都被证明是有效的。

与 DCN 类似,将设计好的可变形注意力模块加入到 ViT 模型的最后两个阶段,得到 Deformable Attention Transformer,如下图所示:

2.DAT加入到yolov8的步骤:

2.1 加入ultralytics/nn/attention/deformable_attention_2d.py

import torch
import torch.nn.functional as F
from torch import nn, einsumfrom einops import rearrange, repeat# helper functionsdef exists(val):return val is not Nonedef default(val, d):return val if exists(val) else ddef divisible_by(numer, denom):return (numer % denom) == 0# tensor helpersdef create_grid_like(t, dim = 0):h, w, device = *t.shape[-2:], t.devicegrid = torch.stack(torch.meshgrid(torch.arange(h, device = device),torch.arange(w, device = device),indexing = 'ij'), dim = dim)grid.requires_grad = Falsegrid = grid.type_as(t)return griddef normalize_grid(grid, dim = 1, out_dim = -1):# normalizes a grid to range from -1 to 1h, w = grid.shape[-2:]grid_h, grid_w = grid.unbind(dim = dim)grid_h = 2.0 * grid_h / max(h - 1, 1) - 1.0grid_w = 2.0 * grid_w / max(w - 1, 1) - 1.0return torch.stack((grid_h, grid_w), dim = out_dim)class Scale(nn.Module):def __init__(self, scale):super().__init__()self.scale = scaledef forward(self, x):return x * self.scale# continuous positional bias from SwinV2class CPB(nn.Module):""" https://arxiv.org/abs/2111.09883v1 """def __init__(self, dim, heads, offset_groups, depth):super().__init__()self.heads = headsself.offset_groups = offset_groupsself.mlp = nn.ModuleList([])self.mlp.append(nn.Sequential(nn.Linear(2, dim),nn.ReLU()))for _ in range(depth - 1):self.mlp.append(nn.Sequential(nn.Linear(dim, dim),nn.ReLU()))self.mlp.append(nn.Linear(dim, heads // offset_groups))def forward(self, grid_q, grid_kv):device, dtype = grid_q.device, grid_kv.dtypegrid_q = rearrange(grid_q, 'h w c -> 1 (h w) c')grid_kv = rearrange(grid_kv, 'b h w c -> b (h w) c')pos = rearrange(grid_q, 'b i c -> b i 1 c') - rearrange(grid_kv, 'b j c -> b 1 j c')bias = torch.sign(pos) * torch.log(pos.abs() + 1)  # log of distance is sign(rel_pos) * log(abs(rel_pos) + 1)for layer in self.mlp:bias = layer(bias)bias = rearrange(bias, '(b g) i j o -> b (g o) i j', g = self.offset_groups)return bias# main classclass DeformableAttention2D(nn.Module):def __init__(self,dim,dim_head = 64,heads = 8,dropout = 0.,downsample_factor = 4,offset_scale = None,offset_groups = None,offset_kernel_size = 6,group_queries = True,group_key_values = True):super().__init__()offset_scale = default(offset_scale, downsample_factor)assert offset_kernel_size >= downsample_factor, 'offset kernel size must be greater than or equal to the downsample factor'assert divisible_by(offset_kernel_size - downsample_factor, 2)offset_groups = default(offset_groups, heads)assert divisible_by(heads, offset_groups)inner_dim = dim_head * headsself.scale = dim_head ** -0.5self.heads = headsself.offset_groups = offset_groupsoffset_dims = inner_dim // offset_groupsself.downsample_factor = downsample_factorself.to_offsets = nn.Sequential(nn.Conv2d(offset_dims, offset_dims, offset_kernel_size, groups = offset_dims, stride = downsample_factor, padding = (offset_kernel_size - downsample_factor) // 2),nn.GELU(),nn.Conv2d(offset_dims, 2, 1, bias = False),nn.Tanh(),Scale(offset_scale))self.rel_pos_bias = CPB(dim // 4, offset_groups = offset_groups, heads = heads, depth = 2)self.dropout = nn.Dropout(dropout)self.to_q = nn.Conv2d(dim, inner_dim, 1, groups = offset_groups if group_queries else 1, bias = False)self.to_k = nn.Conv2d(dim, inner_dim, 1, groups = offset_groups if group_key_values else 1, bias = False)self.to_v = nn.Conv2d(dim, inner_dim, 1, groups = offset_groups if group_key_values else 1, bias = False)self.to_out = nn.Conv2d(inner_dim, dim, 1)def forward(self, x, return_vgrid = False):"""b - batchh - headsx - heighty - widthd - dimensiong - offset groups"""heads, b, h, w, downsample_factor, device = self.heads, x.shape[0], *x.shape[-2:], self.downsample_factor, x.device# queriesq = self.to_q(x)# calculate offsets - offset MLP shared across all groupsgroup = lambda t: rearrange(t, 'b (g d) ... -> (b g) d ...', g = self.offset_groups)grouped_queries = group(q)offsets = self.to_offsets(grouped_queries)# calculate grid + offsetsgrid =create_grid_like(offsets)vgrid = grid + offsetsvgrid_scaled = normalize_grid(vgrid)kv_feats = F.grid_sample(group(x),vgrid_scaled,mode = 'bilinear', padding_mode = 'zeros', align_corners = False)kv_feats = rearrange(kv_feats, '(b g) d ... -> b (g d) ...', b = b)# derive key / valuesk, v = self.to_k(kv_feats), self.to_v(kv_feats)# scale queriesq = q * self.scale# split out headsq, k, v = map(lambda t: rearrange(t, 'b (h d) ... -> b h (...) d', h = heads), (q, k, v))# query / key similaritysim = einsum('b h i d, b h j d -> b h i j', q, k)# relative positional biasgrid = create_grid_like(x)grid_scaled = normalize_grid(grid, dim = 0)rel_pos_bias = self.rel_pos_bias(grid_scaled, vgrid_scaled)sim = sim + rel_pos_bias# numerical stabilitysim = sim - sim.amax(dim = -1, keepdim = True).detach()# attentionattn = sim.softmax(dim = -1)attn = self.dropout(attn)# aggregate and combine headsout = einsum('b h i j, b h j d -> b h i d', attn, v)out = rearrange(out, 'b h (x y) d -> b (h d) x y', x = h, y = w)out = self.to_out(out)if return_vgrid:return out, vgridreturn out

 2.2 注册ultralytics/nn/tasks.py

在tasks.py文件的上面导入部分粘贴下面的代码

from ultralytics.nn.attention.deformable_attention_2d import DeformableAttention2D

修改def parse_model(d, ch, verbose=True): # model_dict, input_channels(3)

只需要加入 DeformableAttention2D,加入以下代码:

 if m in (Classify, Conv, ConvTranspose, GhostConv, Bottleneck, GhostBottleneck, SPP, SPPF, DWConv, Focus,BottleneckCSP, C1, C2, C2f, C3, C3TR, C3Ghost, nn.ConvTranspose2d, DWConvTranspose2d, C3x, DeformableAttention2D):

2.3 yolov8_DeformableAttention2D.yaml

# Ultralytics YOLO 🚀, GPL-3.0 license
# YOLOv8 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect# Parameters
nc: 1  # number of classes
scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n'# [depth, width, max_channels]n: [0.33, 0.25, 1024]  # YOLOv8n summary: 225 layers,  3157200 parameters,  3157184 gradients,   8.9 GFLOPss: [0.33, 0.50, 1024]  # YOLOv8s summary: 225 layers, 11166560 parameters, 11166544 gradients,  28.8 GFLOPsm: [0.67, 0.75, 768]   # YOLOv8m summary: 295 layers, 25902640 parameters, 25902624 gradients,  79.3 GFLOPsl: [1.00, 1.00, 512]   # YOLOv8l summary: 365 layers, 43691520 parameters, 43691504 gradients, 165.7 GFLOPsx: [1.00, 1.25, 512]   # YOLOv8x summary: 365 layers, 68229648 parameters, 68229632 gradients, 258.5 GFLOPs# YOLOv8.0n backbone
backbone:# [from, repeats, module, args]- [-1, 1, Conv, [64, 3, 2]]  # 0-P1/2- [-1, 1, Conv, [128, 3, 2]]  # 1-P2/4- [-1, 3, C2f, [128, True]]- [-1, 1, Conv, [256, 3, 2]]  # 3-P3/8- [-1, 6, C2f, [256, True]]- [-1, 1, Conv, [512, 3, 2]]  # 5-P4/16- [-1, 6, C2f, [512, True]]- [-1, 1, Conv, [1024, 3, 2]]  # 7-P5/32- [-1, 3, C2f, [1024, True]]- [-1, 1, SPPF, [1024, 5]]  # 9- [-1, 1, DeformableAttention2D, [1024]]  # 22# YOLOv8.0n head
head:- [-1, 1, nn.Upsample, [None, 2, 'nearest']]- [[-1, 6], 1, Concat, [1]]  # cat backbone P4- [-1, 3, C2f, [512]]  # 12- [-1, 1, nn.Upsample, [None, 2, 'nearest']]- [[-1, 4], 1, Concat, [1]]  # cat backbone P3- [-1, 3, C2f, [256]]  # 15 (P3/8-small)- [-1, 1, Conv, [256, 3, 2]]- [[-1, 13], 1, Concat, [1]]  # cat head P4- [-1, 3, C2f, [512]]  # 18 (P4/16-medium)- [-1, 1, Conv, [512, 3, 2]]- [[-1, 10], 1, Concat, [1]]  # cat head P5- [-1, 3, C2f, [1024]]  # 21 (P5/32-large)- [[16, 19, 22], 1, Detect, [nc]]  # Detect(P3, P4, P5)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/601733.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

01 华为华三模拟器解决兼容问题

01 华为华三模拟器解决兼容问题 AI思路 要让华为和华三的模拟器兼容,您可以尝试以下方法: 更新模拟器版本:确保您使用的华为和华三模拟器都是最新版本。在华为官方网站或华三官方网站上下载最新的模拟器版本。 检查系统要求:确保…

Java调试之JDB命令行调试入门

0.前言 Java 调试器 (JDB) 是一个简单的 Java 类命令行调试器。 jdb 命令及其选项调用 JDB。 jdb 命令演示了 Java 平台调试器架构,并提供本地或远程 JVM 的检查和调试。 1.准备待调试的Java应用程序 public class JDB {public static int sum(int a,int b){int …

【简单讲解下PHP AES加解密示例】

🌈个人主页: 程序员不想敲代码啊 🏆CSDN优质创作者,CSDN实力新星,CSDN博客专家 👍点赞⭐评论⭐收藏 🤝希望本文对您有所裨益,如有不足之处,欢迎在评论区提出指正,让我们共…

[C++][算法基础]合并集合(并查集)

一共有 n 个数,编号是 1∼n,最开始每个数各自在一个集合中。 现在要进行 m 个操作,操作共有两种: M a b,将编号为 a 和 b 的两个数所在的集合合并,如果两个数已经在同一个集合中,则忽略这个操…

如何恢复被.locked勒索病毒加密的服务器和数据库?

.locked勒索病毒有什么特点? .locked勒索病毒的特点主要包括以下几个方面: 文件加密:.locked勒索病毒会对受感染设备上的所有文件进行加密,包括图片、文档、视频和其他各种类型的重要文件。一旦文件被加密,文件的扩展…

指针的深入理解(六)

指针的深入理解(六) 个人主页:大白的编程日记 感谢遇见,我们一起学习进步! 文章目录 指针的深入理解(六)前言一. sizeof和strlen1.1sizeof1.2strlen1.3sizeof和strlen对比 二.数组名和指针加减…

一文搞懂从爬楼梯到最小花费(力扣70,746)

文章目录 题目前知动态规划简介动态规划模版 爬楼梯一、思路二、解题方法三、Code 使用最小花费爬楼梯一、思路二、解题方法三、Code 总结 在计算机科学中,动态规划是一种强大的算法范例,用于解决多种优化问题。本文将介绍动态规划的核心思想&#xff0c…

吴恩达机器学习理论基础解读

吴恩达机器学习理论基础 机器学习最常见的形式监督学习,无监督学习 线性回归模型概述 应用场景一:根据房屋大小预测房价 应用场景二:分类算法(猫狗分类) 核心概念:将训练模型的数据称为数据集(学习数据…

MySQL学习笔记(二)

1、把查询结果中去除重复记录 2、连接查询 从一张表中单独查询,称为单表查询。emp表和dept表联合起来查询数据,从emp表中取员工名字,从dept表中取部门名字,这种跨表查询,多张表联合起来查询数据,被称为连…

信阳附大医院-市民心中的健康守护者

信阳附大医院,一所集医疗、预防、保健、科研、教学、康复于一体的现代化综合医院,坐落于信阳市工区路600号,是市卫生部门批准成立的医疗机构,更是市民心中的健康守护者. 医院环境优雅,设施先进,服务周到,汇聚了一支技术精湛、经验丰富的医疗团队.医师们以患者为中心,用心倾听,精…

皮具5G智能制造工厂数字孪生可视化平台,推进企业数字化转型

皮具5G智能制造工厂数字孪生可视化平台,推进企业数字化转型。随着信息技术的快速发展,数字化转型已成为企业提升竞争力、实现可持续发展的关键路径。皮具行业,作为一个传统的手工制造业,正面临着巨大的市场变革和技术挑战。如何在…

一款轻量、干净的 Vue 后台管理框架

开始之前 在开始介绍之前我想谈谈为什么要自己做一个后台管理,我知道很多人都用一些开源的后台管理项目,这些老前辈有很多亮点值得学习,但是存在的一些问题同样不可忽视,我认为很多开发者会被困扰(仅代表个人观点) 技术栈老旧不升…