免费阅读篇 | 芒果YOLOv8改进111:注意力机制CBAM:轻量级卷积块注意力模块,无缝集成到任何CNN架构中,开销可以忽略不计

💡🚀🚀🚀本博客 改进源代码改进 适用于 YOLOv8 按步骤操作运行改进后的代码即可

该专栏完整目录链接: 芒果YOLOv8深度改进教程

该篇博客为免费阅读内容,YOLOv8+CBAM改进内容🚀🚀🚀

文章目录

      • 1. CBAM 论文
      • 2. YOLOv8 核心代码改进部分
      • 2.1 核心新增代码
        • 2.2 修改部分
      • 2.3 YOLOv8-CBAM 网络配置文件
      • 2.4 运行代码
      • 改进说明


1. CBAM 论文

在这里插入图片描述

我们提出了卷积块注意力模块(CBAM),这是一种用于前馈卷积神经网络的简单而有效的注意力模块。 给定中间特征图,我们的模块沿着两个独立的维度(通道和空间)顺序推断注意力图,然后将注意力图乘以输入特征图以进行自适应特征细化。 由于 CBAM 是一个轻量级通用模块,因此它可以无缝集成到任何 CNN 架构中,且开销可以忽略不计,并且可以与基础 CNN 一起进行端到端训练。 我们通过在 ImageNet-1K、MS~COCO 检测和 VOC~2007 检测数据集上进行大量实验来验证我们的 CBAM。 我们的实验表明各种模型的分类和检测性能得到了一致的改进,证明了 CBAM 的广泛适用性。 代码和模型将公开。

在这里插入图片描述

具体细节可以去看原论文:https://arxiv.org/pdf/1807.06521.pdf


2. YOLOv8 核心代码改进部分

2.1 核心新增代码

首先在ultralytics/nn/modules文件夹下,创建一个 cbam.py文件,新增以下代码

import numpy as np
import torch
from torch import nn
from torch.nn import initclass ChannelAttentionModule(nn.Module):def __init__(self, c1, reduction=16):super(ChannelAttentionModule, self).__init__()mid_channel = c1 // reductionself.avg_pool = nn.AdaptiveAvgPool2d(1)self.max_pool = nn.AdaptiveMaxPool2d(1)self.shared_MLP = nn.Sequential(nn.Linear(in_features=c1, out_features=mid_channel),nn.LeakyReLU(0.1, inplace=True),nn.Linear(in_features=mid_channel, out_features=c1))self.act = nn.Sigmoid()#self.act=nn.SiLU()def forward(self, x):avgout = self.shared_MLP(self.avg_pool(x).view(x.size(0),-1)).unsqueeze(2).unsqueeze(3)maxout = self.shared_MLP(self.max_pool(x).view(x.size(0),-1)).unsqueeze(2).unsqueeze(3)return self.act(avgout + maxout)class SpatialAttentionModule(nn.Module):def __init__(self):super(SpatialAttentionModule, self).__init__()self.conv2d = nn.Conv2d(in_channels=2, out_channels=1, kernel_size=7, stride=1, padding=3)self.act = nn.Sigmoid()def forward(self, x):avgout = torch.mean(x, dim=1, keepdim=True)maxout, _ = torch.max(x, dim=1, keepdim=True)out = torch.cat([avgout, maxout], dim=1)out = self.act(self.conv2d(out))return outclass CBAM(nn.Module):def __init__(self, c1,c2):super(CBAM, self).__init__()self.channel_attention = ChannelAttentionModule(c1)self.spatial_attention = SpatialAttentionModule()def forward(self, x):out = self.channel_attention(x) * xout = self.spatial_attention(out) * outreturn out   
2.2 修改部分

在ultralytics/nn/modules/init.py中导入 定义在 cbam.py 里面的模块

from .cbam import CBAM'CBAM' 加到 __all__ = [...] 里面

第一步:
ultralytics/nn/tasks.py文件中,新增

from ultralytics.nn.modules import CBAM

然后在 在tasks.py中配置
找到

        elif m is nn.BatchNorm2d:args = [ch[f]]

在这句上面加一个

        elif m is CBAM:c1, c2 = ch[f], args[0]if c2 != nc:  # if c2 not equal to number of classes (i.e. for Classify() output)c2 = make_divisible(min(c2, max_channels) * width, 8)args = [c1, c2, *args[1:]]

2.3 YOLOv8-CBAM 网络配置文件

新增YOLOv8-CBAM.yaml

# Ultralytics YOLO 🚀, AGPL-3.0 license
# YOLOv8 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect# Parameters
nc: 80  # number of classes
scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n'# [depth, width, max_channels]n: [0.33, 0.25, 1024]  # YOLOv8n summary: 225 layers,  3157200 parameters,  3157184 gradients,   8.9 GFLOPss: [0.33, 0.50, 1024]  # YOLOv8s summary: 225 layers, 11166560 parameters, 11166544 gradients,  28.8 GFLOPsm: [0.67, 0.75, 768]   # YOLOv8m summary: 295 layers, 25902640 parameters, 25902624 gradients,  79.3 GFLOPsl: [1.00, 1.00, 512]   # YOLOv8l summary: 365 layers, 43691520 parameters, 43691504 gradients, 165.7 GFLOPsx: [1.00, 1.25, 512]   # YOLOv8x summary: 365 layers, 68229648 parameters, 68229632 gradients, 258.5 GFLOPs# YOLOv8.0n backbone
backbone:# [from, repeats, module, args]- [-1, 1, Conv, [64, 3, 2]]  # 0-P1/2- [-1, 1, Conv, [128, 3, 2]]  # 1-P2/4- [-1, 3, C2f, [128, True]]- [-1, 1, Conv, [256, 3, 2]]  # 3-P3/8- [-1, 6, C2f, [256, True]]- [-1, 1, Conv, [512, 3, 2]]  # 5-P4/16- [-1, 6, C2f, [512, True]]- [-1, 1, Conv, [1024, 3, 2]]  # 7-P5/32- [-1, 3, C2f, [1024, True]]- [-1, 3, CBAM, [1024]]- [-1, 1, SPPF, [1024, 5]]  # 9# YOLOv8.0n head
head:- [-1, 1, nn.Upsample, [None, 2, 'nearest']]- [[-1, 6], 1, Concat, [1]]  # cat backbone P4- [-1, 3, C2f, [512]]  # 12- [-1, 1, nn.Upsample, [None, 2, 'nearest']]- [[-1, 4], 1, Concat, [1]]  # cat backbone P3- [-1, 3, C2f, [256]]  # 15 (P3/8-small)- [-1, 1, Conv, [256, 3, 2]]- [[-1, 13], 1, Concat, [1]]  # cat head P4- [-1, 3, C2f, [512]]  # 18 (P4/16-medium)- [-1, 1, Conv, [512, 3, 2]]- [[-1, 10], 1, Concat, [1]]  # cat head P5- [-1, 3, C2f, [1024]]  # 21 (P5/32-large)- [[16, 19, 22], 1, Detect, [nc]]  # Detect(P3, P4, P5)

2.4 运行代码

直接替换YOLOv8-CBAM.yaml 进行训练即可

到这里就完成了这篇的改进。

改进说明

这里改进是放在了主干后面,如果想放在改进其他地方,也是可以的。直接新增,然后调整通道,配齐即可,如果有不懂的,可以添加博主联系方式,如下


🥇🥇🥇
添加博主联系方式:

友好的读者可以添加博主QQ: 2434798737, 有空可以回答一些答疑和问题

🚀🚀🚀


参考

https://github.com/ultralytics/ultralytics

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/542774.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

JAVA基础—JVM内存结构基础需知

1.JVM内存结构 JVM内存结构分为5个区域:方法区,虚拟机栈,本地方法栈、堆、程序计数器。 1.方法区(Method Area):用于存储类的结构信息、常量、静态变量、即使编译器编译后的代码等数据。方法区也是所有线…

邮件协议(SMTP、POP3、IMAP4)

电子邮件系统 1、概述 (1)网络电子邮件系统,好处在于价格低廉,速度非常快 (2)形式多样化(文字、图像、声音……) 2、电子邮件系统组成部分 用户代理 MUA(Mail User …

z1-5输入编码器实验

一. 实验内容 1、制作LED计数电路,输入是编号为1~5的5个开关,输出是5个发光二极管(LED) 点几号开关,就有几个LED发光。 2、制作一个5位输入3位输出的编码器, 输入的第5位为1,输出就是数字5对应…

如何通过做自己喜欢的事来赚钱?

今天想要跟大家分享一本我今年反复读过最多次的一本书《The Almanack of Naval Ravikant 纳瓦尔宝典》。我之前也有介绍过Naval Ravikant,他是硅谷创业界的一位传奇人物,创办了知名的天使投资平台AngelList。早期他还投资超过了200家科技公司,其中很多都成为了今天的科技巨头…

数据结构——动态顺序表

数据结构的动态顺序表有以下几个操作:创建,销毁,初始化,增删查改和打印以及内存空间不够时的扩容 本文的宏定义: #define SeqTypeData int 1.动态顺序表的创建 typedef struct SeqListInit{//动态顺序表的创建SeqT…

Acwing.1262 鱼塘钓鱼(多路归并)

题目 有 N个鱼塘排成一排,每个鱼塘中有一定数量的鱼,例如:N5时,如下表: 即:在第 1个鱼塘中钓鱼第 1分钟内可钓到 10条鱼,第 2分钟内只能钓到 8条鱼,……,第 5分钟以后再…

中国首个基于区块链的分布式算力网络上线

随着美国人工智能公司OpenAI近期发布的Sora视频模型,全球对高性能算力的需求突破了历史新高。Sora的创新在于它能够以超长生成时间、多角度镜头捕捉,理解物理世界的能力,这不仅是技术的一大突破,更是对算力需求的一大挑战。在这样…

前端项目,个人笔记(一)【Vue-cli - 定制化主题 + 路由设计】

目录 1、项目准备 1.1、项目初始化 1.2、elementPlus按需引入 注:使用cnpm安装elementplus及两个插件,会报错:vueelement-plus报错TypeError: Cannot read properties of null (reading isCE ) ,修改: 测试&#…

计算机服务器中了360后缀勒索病毒怎么办,勒索病毒解密工具与流程

对于众多的企业来说,利用网络开展各项工作业务是必不可少的环节,网络为企业的生产运营提供了有利条件,但网络是一把双刃剑,在为人们提供便利的同时,也为企业的数据安全带来严重威胁。近期,云天数据恢复中心…

MySQL锁整理

MySQL锁信息来源 MySQL锁太多,内容太杂。写篇文章记录一下

PHP 线性搜索算法

线性搜索被定义为一种顺序搜索算法,从一端开始,遍历列表中的每个元素,直到找到所需的元素,否则搜索将继续,直到数据集的末尾。 线性搜索算法 线性搜索算法如何工作? 在线性搜索算法中: …