YOLOv8改进算法之添加CA注意力机制

1. CA注意力机制

CA(Coordinate Attention)注意力机制是一种用于加强深度学习模型对输入数据的空间结构理解的注意力机制。CA 注意力机制的核心思想是引入坐标信息,以便模型可以更好地理解不同位置之间的关系。如下图:

1. 输入特征: CA 注意力机制的输入通常是一个特征图,它通常是卷积神经网络(CNN)中的某一层的输出,具有以下形状:[C, H, W],其中:

  • C 是通道数,表示特征图中的不同特征通道。
  • H 是高度,表示特征图的垂直维度。
  • W 是宽度,表示特征图的水平维度。

2. 全局平均池化: CA 注意力机制首先对输入特征图进行两次全局平均池化,一次在宽度方向上,一次在高度方向上。这两次操作分别得到两个特征映射:

  • 在宽度方向上的平均池化得到特征映射 [C, H, 1]
  • 在高度方向上的平均池化得到特征映射 [C, 1, W]

这两个特征映射分别捕捉了在宽度和高度方向上的全局特征。

3. 合并宽高特征: 将上述两个特征映射合并,通常通过简单的堆叠操作,得到一个新的特征层,形状为 [C, 1, H + W],其中 H + W 表示在宽度和高度两个方向上的维度合并在一起。

4. 卷积+标准化+激活函数: 对合并后的特征层进行卷积操作,通常是 1x1 卷积,以捕捉宽度和高度维度之间的关系。然后,通常会应用标准化(如批量标准化)和激活函数(如ReLU)来进一步处理特征,得到一个更加丰富的表示。

5. 再次分开: 分别从上述特征层中分离出宽度和高度方向的特征:

  • 一个分支得到特征层 [C, 1, H]
  • 另一个分支得到特征层 [C, 1, W]

6. 转置: 对分开的两个特征层进行转置操作,以恢复宽度和高度的维度,得到两个特征层分别为 [C, H, 1][C, 1, W]

7. 通道调整和 Sigmoid: 对两个分开的特征层分别应用 1x1 卷积,以调整通道数,使其适应注意力计算。然后,应用 Sigmoid 激活函数,得到在宽度和高度维度上的注意力分数。这些分数用于指示不同位置的重要性。

8. 应用注意力: 将原始输入特征图与宽度和高度方向上的注意力分数相乘,得到 CA 注意力机制的输出。

2. YOLOv8添加CA注意力机制

加入注意力机制,在ultralytics包中的nn包的modules里添加CA注意力模块,我这里选择在conv.py文件中添加CA注意力机制。

CA注意力机制代码如下:

import torch
import torch.nn as nn
import torch.nn.functional as Fclass h_sigmoid(nn.Module):def __init__(self, inplace=True):super(h_sigmoid, self).__init__()self.relu = nn.ReLU6(inplace=inplace)def forward(self, x):return self.relu(x + 3) / 6class h_swish(nn.Module):def __init__(self, inplace=True):super(h_swish, self).__init__()self.sigmoid = h_sigmoid(inplace=inplace)def forward(self, x):return x * self.sigmoid(x)class CoordAtt(nn.Module):def __init__(self, inp, reduction=32):super(CoordAtt, self).__init__()self.pool_h = nn.AdaptiveAvgPool2d((None, 1))self.pool_w = nn.AdaptiveAvgPool2d((1, None))mip = max(8, inp // reduction)self.conv1 = nn.Conv2d(inp, mip, kernel_size=1, stride=1, padding=0)self.bn1 = nn.BatchNorm2d(mip)self.act = h_swish()self.conv_h = nn.Conv2d(mip, inp, kernel_size=1, stride=1, padding=0)self.conv_w = nn.Conv2d(mip, inp, kernel_size=1, stride=1, padding=0)def forward(self, x):identity = xn, c, h, w = x.size()x_h = self.pool_h(x)x_w = self.pool_w(x).permute(0, 1, 3, 2)y = torch.cat([x_h, x_w], dim=2)y = self.conv1(y)y = self.bn1(y)y = self.act(y)x_h, x_w = torch.split(y, [h, w], dim=2)x_w = x_w.permute(0, 1, 3, 2)a_h = self.conv_h(x_h).sigmoid()a_w = self.conv_w(x_w).sigmoid()out = identity * a_w * a_hreturn out

CA注意力机制的注册和引用如下:

 ultralytics/nn/modules/_init_.py文件中:

  ultralytics/nn/tasks.py文件夹中:

 在tasks.py中的parse_model中添加如下代码:

        elif m in {CoordAtt}:args=[ch[f],*args]

新建相应的yolov8s-CA.yaml文件,代码如下:

# Ultralytics YOLO 🚀, AGPL-3.0 license
# YOLOv8 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect# Parameters
nc: 80  # number of classes
scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n'# [depth, width, max_channels]n: [0.33, 0.25, 1024]  # YOLOv8n summary: 225 layers,  3157200 parameters,  3157184 gradients,   8.9 GFLOPss: [0.33, 0.50, 1024]  # YOLOv8s summary: 225 layers, 11166560 parameters, 11166544 gradients,  28.8 GFLOPsm: [0.67, 0.75, 768]   # YOLOv8m summary: 295 layers, 25902640 parameters, 25902624 gradients,  79.3 GFLOPsl: [1.00, 1.00, 512]   # YOLOv8l summary: 365 layers, 43691520 parameters, 43691504 gradients, 165.7 GFLOPsx: [1.00, 1.25, 512]   # YOLOv8x summary: 365 layers, 68229648 parameters, 68229632 gradients, 258.5 GFLOPs# YOLOv8.0n backbone
backbone:# [from, repeats, module, args]- [-1, 1, Conv, [64, 3, 2]]  # 0-P1/2- [-1, 1, Conv, [128, 3, 2]]  # 1-P2/4- [-1, 3, C2f, [128, True]]- [-1, 1, Conv, [256, 3, 2]]  # 3-P3/8- [-1, 6, C2f, [256, True]]- [-1,1,CoordAtt,[]]- [-1, 1, Conv, [512, 3, 2]]  # 5-P4/16- [-1, 6, C2f, [512, True]]- [-1,1,CoordAtt,[]]- [-1, 1, Conv, [1024, 3, 2]]  # 7-P5/32- [-1, 3, C2f, [1024, True]]- [-1,1,CoordAtt,[]]- [-1, 1, SPPF, [1024, 5]]  # 9# YOLOv8.0n head
head:- [-1, 1, nn.Upsample, [None, 2, 'nearest']]- [[-1, 8], 1, Concat, [1]]  # cat backbone P4- [-1, 3, C2f, [512]]  # 12- [-1, 1, nn.Upsample, [None, 2, 'nearest']]- [[-1, 5], 1, Concat, [1]]  # cat backbone P3- [-1, 3, C2f, [256]]  # 15 (P3/8-small)- [-1, 1, Conv, [256, 3, 2]]- [[-1, 15], 1, Concat, [1]]  # cat head P4- [-1, 3, C2f, [512]]  # 18 (P4/16-medium)- [-1, 1, Conv, [512, 3, 2]]- [[-1, 12], 1, Concat, [1]]  # cat head P5- [-1, 3, C2f, [1024]]  # 21 (P5/32-large)- [[18, 21, 24], 1, Detect, [nc]]  # Detect(P3, P4, P5)

在main.py文件中进行训练:

if __name__ == '__main__':# 使用yaml配置文件来创建模型,并导入预训练权重.model = YOLO('ultralytics/cfg/models/v8/yolov8s-CA.yaml')# model.load('yolov8n.pt')model.train(**{'cfg': 'ultralytics/cfg/default.yaml', 'data': 'dataset/data.yaml'})

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/128169.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

美容美甲小程序商城的作用是什么

美容院往往有很高需求,女性悦己经济崛起,加之爱美化程度提升,无论线下环境还是线上互联网信息冲击,美容服务、化妆产品等市场规格一直稳增不减。 通过【雨科】平台制作美容美甲商城,售卖相关服务/产品,模块…

计算机网络-计算机网络体系结构-概述,模型

目录 一、计算机网络概述 二、性能指标 速率 带宽 吞吐量 时延 往返时延RTT 利用率 三、计算机网络体系结构 分层结构 IOS模型 应用层-> 表示层-> 会话层-> 传输层-> 网络层-> 数据链路层-> 物理层-> TCP/IP模型 一、计算机网络概述 计…

轻松实现视频、音频、文案批量合并,享受批量剪辑的便捷

在日常生活中,我们经常会需要将多个视频、音频和文案进行合并剪辑,以制作出符合我们需求的短视频。然而,这个过程通常需要花费大量的时间和精力。幸运的是,现在有一款名为“固乔智剪软件”的工具可以帮助我们轻松完成这个任务。 首…

Nginx搭建Rtmp流媒体服务,并使用Ffmpeg推流

文章目录 1.rtmp流媒体服务框架图2.nginx配置3.配置nginx4.使用ffmpeg推流5.实时推摄像头流 本项目在开发板上使用nginx搭建流媒体服务,利用ffmpeg进行推流,在pc上使用vlc media进行拉流播放。 1.rtmp流媒体服务框架图 2.nginx配置 下载:wge…

GitHub爬虫项目详解

前言 闲来无事浏览GitHub的时候,看到一个仓库,里边列举了Java的优秀开源项目列表,包括说明、仓库地址等,还是很具有学习意义的。但是大家也知道,国内访问GitHub的时候,经常存在访问超时的问题,…

JVM技术文档--JVM诊断调优工具Arthas--阿里巴巴开源工具--一文搞懂Arthas--快速上手--国庆开卷!!

​ Arthas首页 简介 | arthas Arthas官网文档 Arthas首页、文档和下载 - 开源 Java 诊断工具 - OSCHINA - 中文开源技术交流社区 阿丹: 之前聊过了一些关于JMV中的分区等等,但是有同学还是在后台问我,还有私信问我,学了这些…

[SWPUCTF 2021 新生赛]sql - 联合注入

[SWPUCTF 2021 新生赛]sql 一、思路分析二、解题流程 一、思路分析 这题可以参考文章:[SWPUCTF 2021 新生赛]easy_sql - 联合注入||报错注入||sqlmap 这题相比于参考文章的题目多了waf过滤 二、解题流程 首先,仍然是网站标题提示参数是wllm 1、fuzz看…

10-Node.js模块化

01.模块化简介 目标 了解模块化概念和好处,以及 CommonJS 标准语法导出和导入 讲解 在 Node.js 中每个文件都被当做是一个独立的模块,模块内定义的变量和函数都是独立作用域的,因为 Node.js 在执行模块代码时,将使用如下所示的…

springboot和vue:七、mybatis/mybatisplus多表查询+分页查询

mybatisplus实际上只对单表查询做了增强(速度会更快),从传统的手写sql语句,自己做映射,变为封装好的QueryWrapper。 本篇文章的内容是有两张表,分别是用户表和订单表,在不直接在数据库做表连接的…

OLED透明屏交互技术:开创未来科技的新篇章

OLED透明屏交互技术作为一项前沿的科技创新,正在以其高透明度、触摸和手势交互等特点,引领着未来科技的发展。 不仅在智能手机、可穿戴设备和汽车行业有着广泛应用,还在广告和展示领域展现出巨大的潜力。 在这篇文章中,尼伽将深…

【Docker】简易版harbor部署

文章目录 依赖于docker-compose下载添加执行权限测试 安装harbor下载解压修改配置文件部署配置开机自启动登录验证 使用harbor登录打标签上传下载 常见问题 依赖于docker-compose 下载 curl -L “https://github.com/docker/compose/releases/download/2.22.0/docker-compose-…

C++并发与多线程(3) | 其他创建线程的方式

1. 用类(可调用对象) 必须要重载括号运算符,否则不是可调用对象。这种方式其实就是一个仿函数。 示例: #include <iostream> #include <thread> using namespace std;class TA { public:void operator() ()// 不能带参数 {cout << "子线程operato…