免费阅读篇 | 芒果YOLOv8改进110:注意力机制GAM:用于保留信息以增强渠道空间互动

💡🚀🚀🚀本博客 改进源代码改进 适用于 YOLOv8 按步骤操作运行改进后的代码即可

该专栏完整目录链接: 芒果YOLOv8深度改进教程

该篇博客为免费阅读内容,直接改进即可🚀🚀🚀

文章目录

      • 1. GAM论文
      • 2. YOLOv8 核心代码改进部分
      • 2.1 核心新增代码
        • 2.2 修改部分
      • 2.3 YOLOv8-gam 网络配置文件
      • 2.4 运行代码
      • 改进说明


1. GAM论文

在这里插入图片描述

研究了多种注意力机制来提高各种计算机视觉任务的性能。然而,现有的方法忽略了保留通道和空间方面的信息以增强跨维度交互的重要性。因此,我们提出了一种全局注意力机制,通过减少信息缩减和放大全局交互式表示来提高深度神经网络的性能。我们引入了带有多层感知器的 3D 排列,用于通道注意力以及卷积空间注意力子模块。对CIFAR-100和ImageNet-1K上图像分类任务的所提机制的评估表明,我们的方法在ResNet和轻量级MobileNet上都稳定地优于最近的几种注意力机制。

在这里插入图片描述

具体细节可以去看原论文:https://arxiv.org/pdf/2112.05561v1.pdf


2. YOLOv8 核心代码改进部分

2.1 核心新增代码

首先在ultralytics/nn/modules文件夹下,创建一个 gam.py文件,新增以下代码

import numpy as np
import torch
from torch import nn
from torch.nn import initclass GAMAttention(nn.Module):#https://paperswithcode.com/paper/global-attention-mechanism-retain-informationdef __init__(self, c1, c2, group=True,rate=4):super(GAMAttention, self).__init__()self.channel_attention = nn.Sequential(nn.Linear(c1, int(c1 / rate)),nn.ReLU(inplace=True),nn.Linear(int(c1 / rate), c1))self.spatial_attention = nn.Sequential(nn.Conv2d(c1, c1//rate, kernel_size=7, padding=3,groups=rate)if group else nn.Conv2d(c1, int(c1 / rate), kernel_size=7, padding=3), nn.BatchNorm2d(int(c1 /rate)),nn.ReLU(inplace=True),nn.Conv2d(c1//rate, c2, kernel_size=7, padding=3,groups=rate) if group else nn.Conv2d(int(c1 / rate), c2, kernel_size=7, padding=3), nn.BatchNorm2d(c2))def forward(self, x):b, c, h, w = x.shapex_permute = x.permute(0, 2, 3, 1).view(b, -1, c)x_att_permute = self.channel_attention(x_permute).view(b, h, w, c)x_channel_att = x_att_permute.permute(0, 3, 1, 2)x = x * x_channel_attx_spatial_att = self.spatial_attention(x).sigmoid()x_spatial_att=channel_shuffle(x_spatial_att,4) #last shuffle out = x * x_spatial_attreturn out  def channel_shuffle(x, groups=2):B, C, H, W = x.size()out = x.view(B, groups, C // groups, H, W).permute(0, 2, 1, 3, 4).contiguous()out=out.view(B, C, H, W) return out   
2.2 修改部分

在ultralytics/nn/modules/init.py中导入 定义在 gam.py 里面的模块

from .gam import GAMAttention'GAMAttention' 加到 __all__ = [...] 里面

第一步:
ultralytics/nn/tasks.py文件中,新增

from ultralytics.nn.modules import GAMAttention

然后在 在tasks.py中配置
找到

        elif m is nn.BatchNorm2d:args = [ch[f]]

在这句上面加一个

        elif m is GAMAttention:c1, c2 = ch[f], args[0]if c2 != nc:  # if c2 not equal to number of classes (i.e. for Classify() output)c2 = make_divisible(min(c2, max_channels) * width, 8)args = [c1, c2, *args[1:]]

2.3 YOLOv8-gam 网络配置文件

新增YOLOv8-gam.yaml

# Ultralytics YOLO 🚀, AGPL-3.0 license
# YOLOv8 object detection model with P3-P5 outputs. For Usage examples see https://docs.ultralytics.com/tasks/detect# Parameters
nc: 80  # number of classes
scales: # model compound scaling constants, i.e. 'model=yolov8n.yaml' will call yolov8.yaml with scale 'n'# [depth, width, max_channels]n: [0.33, 0.25, 1024]  # YOLOv8n summary: 225 layers,  3157200 parameters,  3157184 gradients,   8.9 GFLOPss: [0.33, 0.50, 1024]  # YOLOv8s summary: 225 layers, 11166560 parameters, 11166544 gradients,  28.8 GFLOPsm: [0.67, 0.75, 768]   # YOLOv8m summary: 295 layers, 25902640 parameters, 25902624 gradients,  79.3 GFLOPsl: [1.00, 1.00, 512]   # YOLOv8l summary: 365 layers, 43691520 parameters, 43691504 gradients, 165.7 GFLOPsx: [1.00, 1.25, 512]   # YOLOv8x summary: 365 layers, 68229648 parameters, 68229632 gradients, 258.5 GFLOPs# YOLOv8.0n backbone
backbone:# [from, repeats, module, args]- [-1, 1, Conv, [64, 3, 2]]  # 0-P1/2- [-1, 1, Conv, [128, 3, 2]]  # 1-P2/4- [-1, 3, C2f, [128, True]]- [-1, 1, Conv, [256, 3, 2]]  # 3-P3/8- [-1, 6, C2f, [256, True]]- [-1, 1, Conv, [512, 3, 2]]  # 5-P4/16- [-1, 6, C2f, [512, True]]- [-1, 1, Conv, [1024, 3, 2]]  # 7-P5/32- [-1, 3, C2f, [1024, True]]- [-1, 3, GAMAttention, [1024]]- [-1, 1, SPPF, [1024, 5]]  # 9# YOLOv8.0n head
head:- [-1, 1, nn.Upsample, [None, 2, 'nearest']]- [[-1, 6], 1, Concat, [1]]  # cat backbone P4- [-1, 3, C2f, [512]]  # 12- [-1, 1, nn.Upsample, [None, 2, 'nearest']]- [[-1, 4], 1, Concat, [1]]  # cat backbone P3- [-1, 3, C2f, [256]]  # 15 (P3/8-small)- [-1, 1, Conv, [256, 3, 2]]- [[-1, 13], 1, Concat, [1]]  # cat head P4- [-1, 3, C2f, [512]]  # 18 (P4/16-medium)- [-1, 1, Conv, [512, 3, 2]]- [[-1, 10], 1, Concat, [1]]  # cat head P5- [-1, 3, C2f, [1024]]  # 21 (P5/32-large)- [[16, 19, 22], 1, Detect, [nc]]  # Detect(P3, P4, P5)

2.4 运行代码

直接替换YOLOv8-gam.yaml 进行训练即可

到这里就完成了这篇的改进。

改进说明

这里改进是放在了主干后面,如果想放在改进其他地方,也是可以的。直接新增,然后调整通道,配齐即可,如果有不懂的,可以添加博主联系方式,如下


🥇🥇🥇
添加博主联系方式:

友好的读者可以添加博主QQ: 2434798737, 有空可以回答一些答疑和问题

🚀🚀🚀


参考

https://github.com/ultralytics/ultralytics

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/549335.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

设计模式学习笔记 - 设计原则与思想总结:2.运用学过的设计原则和思想完善之前性能计数器项目

概述 在 《设计原则 - 10.实战:针对非业务的通用框架开发,如何做需求分析和设计及如何实现一个支持各种统计规则的性能计数器》中,我们讲解了如何对一个性能计数器框架进行分析、设计与实现,并且实践了一些设计原则和设计思想。当…

Axure 中继器的Repeater属性的使用

dataCount 中继器当中存在多少条数据,总数。 visibleltemCount 中继器列表中可见项数量,也就是当前页面显示的数量。 pageCount 获取中继器分页的总数量,即能够获取分页后共有多少页。 pageIndex 获取中继器当前显示的页码

攻防世界新手模式例题(Web)

PHP2 首先我们查看页面,查看前端代码 发现均没有什么有效信息,由题目可知,此问题与php相关,于是我们可以看一下他的index.php文件 查看时用?index.phps 补充知识:phps文件就是php的源代码文件,通常用于…

Javaweb学习记录(二)web开发入门(请求响应)

第一个基于springboot的web请求程序 通过创建一个带有springboot的spring项目,项目会自动生成一个程序启动类,该类启动时会启动该整个项目,而我们需要写一个web请求类,要求在本地浏览器上发送请求后,浏览器显示Hello&…

Chrome历史版本下载地址:Google Chrome Older Versions Download (Windows, Linux Mac)

最近升级到最新版本Chrome后发现页面居然显示错乱,是在无语, 打算退回原来的版本, 又发现官方只提供最新的版本下载, 为了解决这个问题所有收集了Chrome历史版本的下载地址分享给大家. Google Chrome Windows version 32-bit VersionSizeDate104.0.5112.10279.68 MB2022-05-30…

day03vue学习

day03 一、今日目标 1.生命周期 生命周期介绍生命周期的四个阶段生命周期钩子声明周期案例 2.综合案例-小黑记账清单 列表渲染添加/删除饼图渲染 3.工程化开发入门 工程化开发和脚手架项目运行流程组件化组件注册 4.综合案例-小兔仙首页 拆分模块-局部注册结构样式完善…

【开源】SpringBoot框架开发学生综合素质评价系统

目录 一、摘要1.1 项目介绍1.2 项目录屏 二、功能模块2.1 学生功能2.2 教师功能2.3 教务处功能 三、系统展示四、核心代码4.1 查询我的学科竞赛4.2 保存单个问卷4.3 根据类型查询学生问卷4.4 填写语数外评价4.5 填写品德自评问卷分 五、免责说明 一、摘要 1.1 项目介绍 基于J…

DOcker搭建Rancher

简介 Rancher 是供采用容器的团队使用的完整软件堆栈。它解决了管理多个Kubernetes集群的运营和安全挑战,并为DevOps团队提供用于运行容器化工作负载的集成工具。 官网地址:https://www.rancher.cn/ 安装 拉取镜像 docker pull rancher/rancher:stab…

CSS元素定位(学习笔记)

一、 z-index 1.1 作用 规定元素的堆叠顺序,取值越大,层级越往上 1.2 属性值 属性值为数字,可以取负值,不推荐 默认值:auto(跟父元素同一层级) 1.3 注意 必须配合定位(static除外)使用,默认情况下,后面的元…

Tomcat Session ID---会话保持

简单拓补图 一、负载均衡、反向代理 7-1nginx代理服务器配置 [rootdlnginx ~]#yum install epel-release.noarch -y ###安装额外源[rootdlnginx ~]#yum install nginx -y[rootdlnginx ~]#systemctl start nginx.service[rootdlnginx ~]#systemctl status nginx.service [ro…

力扣-1351 统计有序矩阵中的负数

给你一个 m * n 的矩阵 grid,矩阵中的元素无论是按行还是按列,都以非严格递减顺序排列。 请你统计并返回 grid 中 负数 的数目。 示例 1: 输入:grid [[4,3,2,-1],[3,2,1,-1],[1,1,-1,-2],[-1,-1,-2,-3]] 输出:8 解释&…

Unity UGUI之Toggle基本了解

在Unity中,Toggle一般用于两种状态之间的切换,通常用于开关或复选框等功能。 它的基本属性如图: 其中, Interactable(可交互):指示Toggle是否可以与用户交互。设置为false时,禁用To…