【深度学习注意力机制系列】—— CBAM注意力机制(附pytorch实现)

CBAM(Convolutional Block Attention Module)是一种用于增强卷积神经网络(CNN)性能的注意力机制模块。它由Sanghyun Woo等人在2018年的论文[1807.06521] CBAM: Convolutional Block Attention Module (arxiv.org)中提出。CBAM的主要目标是通过在CNN中引入通道注意力和空间注意力来提高模型的感知能力,从而在不增加网络复杂性的情况下改善性能。

1、概述

CBAM旨在克服传统卷积神经网络在处理不同尺度、形状和方向信息时的局限性。为此,CBAM引入了两种注意力机制:通道注意力和空间注意力。通道注意力有助于增强不同通道的特征表示,而空间注意力有助于提取空间中不同位置的关键信息。

2、模型结构

CBAM由两个关键部分组成:通道注意力模块(C-channel)空间注意力模块(S-channel)。这两个模块可以分别嵌入到CNN中的不同层,以增强特征表示。

2.1 通道注意力模块

在这里插入图片描述

通道注意力模块的目标是增强每个通道的特征表达。以下是实现通道注意力模块的步骤:

  1. 全局最大池化和全局平均池化: 对于输入特征图,首先对每个通道执行全局最大池化和全局平均池化操作,计算每个通道上的最大特征值和平均特征值。这会生成两个包含通道数的向量,分别表示每个通道的全局最大特征和平均特征。

  2. 全连接层: 将全局最大池化和平均池化后的特征向量输入到一个共享全连接层中。这个全连接层用于学习每个通道的注意力权重。通过学习,网络可以自适应地决定哪些通道对于当前任务更加重要。将全局最大特征向量和平均特征向相交,得到最终注意力权重向量。

  3. Sigmoid激活: 为了确保注意力权重位于0到1之间,应用Sigmoid激活函数来产生通道注意力权重。这些权重将应用于原始特征图的每个通道。

  4. 注意力加权: 使用得到的注意力权重,将它们与原始特征图的每个通道相乘,得到注意力加权后的通道特征图。这将强调对当前任务有帮助的通道,并抑制无关的通道。

代码实现

class ChannelAttention(nn.Module):"""CBAM混合注意力机制的通道注意力"""def __init__(self, in_channels, ratio=16):super(ChannelAttention, self).__init__()self.avg_pool = nn.AdaptiveAvgPool2d(1)self.max_pool = nn.AdaptiveMaxPool2d(1)self.fc = nn.Sequential(# 全连接层# nn.Linear(in_planes, in_planes // ratio, bias=False),# nn.ReLU(),# nn.Linear(in_planes // ratio, in_planes, bias=False)# 利用1x1卷积代替全连接,避免输入必须尺度固定的问题,并减小计算量nn.Conv2d(in_channels, in_channels // ratio, 1, bias=False),nn.ReLU(inplace=True),nn.Conv2d(in_channels // ratio, in_channels, 1, bias=False))self.sigmoid = nn.Sigmoid()def forward(self, x):avg_out = self.fc(self.avg_pool(x))max_out = self.fc(self.max_pool(x))out = avg_out + max_outout = self.sigmoid(out)return out * x

2.2 空间注意力模块

在这里插入图片描述

空间注意力模块的目标是强调图像中不同位置的重要性。以下是实现空间注意力模块的步骤:

  1. 全局最大池化和全局平均池化: 对于输入特征图,分别执行全局最大池化和全局平均池化操作,生成不同上下文尺度的特征。
  2. 连接和卷积: 将全局最大池化和全局平均池化后的特征沿着通道维度进行连接(拼接),得到一个具有不同尺度上下文信息的特征图。然后,通过卷积层处理这个特征图,以生成空间注意力权重。
  3. Sigmoid激活: 类似于通道注意力模块,对生成的空间注意力权重应用Sigmoid激活函数,将权重限制在0到1之间。
  4. 注意力加权: 将得到的空间注意力权重应用于原始特征图,对每个空间位置的特征进行加权。这样可以突出重要的图像区域,并减少不重要的区域的影响。

代码实现

class SpatialAttention(nn.Module):"""CBAM混合注意力机制的空间注意力"""def __init__(self, kernel_size=7):super(SpatialAttention, self).__init__()assert kernel_size in (3, 7), 'kernel size must be 3 or 7'padding = 3 if kernel_size == 7 else 1self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)self.sigmoid = nn.Sigmoid()def forward(self, x):avg_out = torch.mean(x, dim=1, keepdim=True)max_out, _ = torch.max(x, dim=1, keepdim=True)out = torch.cat([avg_out, max_out], dim=1)out = self.sigmoid(self.conv1(out))return out * x

2.3 混合注意力模块

在这里插入图片描述

CBAM就是将通道注意力模块和空间注意力模块的输出特征逐元素相乘,得到最终的注意力增强特征。这个增强的特征将用作后续网络层的输入,以在保留关键信息的同时,抑制噪声和无关信息。原文实验证明先进行通道维度的整合,再进行空间维度的整合,模型效果更好(有效玄学炼丹的感觉)。

代码实现

class CBAM(nn.Module):"""CBAM混合注意力机制"""def __init__(self, in_channels, ratio=16, kernel_size=3):super(CBAM_Block, self).__init__()self.channelattention = ChannelAttention(in_channels, ratio=ratio)self.spatialattention = SpatialAttention(kernel_size=kernel_size)def forward(self, x):x = self.channelattention(x)x = self.spatialattention(x)return x

总结

总之,CBAM模块通过自适应地学习通道和空间注意力权重,以提高卷积神经网络的特征表达能力。通过将通道注意力和空间注意力结合起来,CBAM模块能够在不同维度上捕获特征之间的相关性,从而提升图像识别任务的性能。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/61054.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

接口自动化测试框架及接口测试自动化主要知识点

接口自动化测试框架: 接口测试框架:使用最流行的Requests进行接口测试接口请求构造:常见的GET/POST/PUT/HEAD等HTTP请求构造 接口测试断言:状态码、返回内容等断言JSON/XML请求:发送json\xml请求JSON/XML响应断言&…

数据结构:各种结构函数参数辨析

(一)顺序表 1)结构 typedef int SLDateType;typedef struct SeqList {SLDateType* data;int size;int capacity; }SeqList;SeqList ps { 0 }; 2)函数参数 // 对数据的管理:增删查改 void SeqListInit(SeqList* ps); void Seq…

Linux系统目录结构介绍

Linux系统目录结构介绍 一、目录结构 Linux系统的目录结构是一颗倒状树: “/”表示最顶层的目录,叫做根目录。 (1)pwd可以显示当前所在的目录。 (2)cd可以切换当前的目录,例如,…

【Linux】TCP协议——传输层

目录 TCP协议 谈谈可靠性 TCP协议格式 序号与确认序号 窗口大小 六个标志位 确认应答机制(ACK) 超时重传机制 连接管理机制 三次握手 四次挥手 流量控制 滑动窗口 拥塞控制 延迟应答 捎带应答 面向字节流 粘包问题 TCP异常情况 TC…

【文献阅读笔记】深度异常检测模型

文章目录 导读相关关键词及其英文描述记录深度异常检测模型Supervised deep anomaly detection 有监督深度异常检测Semi-Supervised deep anomaly detection 半监督深度异常检测Hybrid deep anomaly detection 混合深度异常检测One-class neural network for anomaly detection…

心跳跟随的心形灯(STM32(HAL)+WS2812+MAX30102)

文章目录 前言介绍系统框架原项目地址本项目开发开源地址硬件PCB软件功能 详细内容硬件外壳制作WS2812级联及控制MAX30102血氧传感器0.96OLEDFreeRTOS 效果视频总结 前言 在好几年前,我好像就看到了焊武帝 jiripraus在纪念结婚五周年时,制作的一个心跳跟…

Android进阶之SeekBar动态显示进度

SeekBar 在开发中并不陌生,默认的SeekBar是不显示进度的,当然用吐司或者文案在旁边实时显示也是可以的,那能不能移动的时候才显示,默认不显示呢,当然网上花哨的三方工具类太多了,但是我只是单纯的想在SeekBar的基础上去添加一个可以跟随移动显示的气泡而…

Python爬虫在电商数据挖掘中的应用

作为一名长期扎根在爬虫行业的专业的技术员,我今天要和大家分享一些有关Python爬虫在电商数据挖掘中的应用与案例分析。在如今数字化的时代,电商数据蕴含着丰富的信息,通过使用爬虫技术,我们可以轻松获取电商网站上的产品信息、用…

在WebStorm中通过live-server插件搭建Ajax运行环境

1.下载node.js 官网: https://nodejs.cn/download/ 2.配置Node.js的HTTPS 使用淘宝的镜像: npm config set registry https://registry.npm.taobao.org 也可以使用cnpm npm install -g cnpm --registryhttps://registry.npm.taobao.org 配置之后可以验证是否成…

http相关知识点

文章目录 长链接http周边会话保持方案1方案2 基本工具postmanFiddlerFiddler的原理 长链接 一张网页实际上可能会有多种元素组成,这也就说明了网页需要多次的http请求。可由于http是基于TCP的,而TCP创建链接是有代价的,因此频繁的创建链接会…

中国信息安全测评中心CISP家族认证一览

随着国家对网络安全的重视,中国信息安全测评中心根据国家政策、未来趋势、重点内容陆续增添了很多CISP细分认证。 今日份详细介绍,部分CISP及其子品牌相关认证内容,一定要收藏哟! 校园版CISP NISP国家信息安全水平考试&#xff…

【C# 基础精讲】循环语句:for、while、do-while

循环语句是C#编程中用于重复执行一段代码块的关键结构。C#支持for、while和do-while三种常见的循环语句,它们允许根据条件来控制代码块的重复执行。在本文中,我们将详细介绍这三种循环语句的语法和使用方法。 for循环 for循环是一种常见的循环结构&…