卷积注意力模块 CBAM | CBAM: Convolutional Block Attention Module

在这里插入图片描述

论文名称:《CBAM: Convolutional Block Attention Module》

论文地址:https://arxiv.org/pdf/1807.06521.pdf


我们提出了卷积块注意力模块(CBAM),这是一种简单但有效的前馈卷积神经网络注意力模块。给定一个中间特征图,我们的模块会按顺序沿两个独立的维度进行注意力推理,分别是通道和空间,然后将注意力图与输入特征图相乘,以实现自适应特征优化。由于 CBAM 是一个轻量级且通用的模块,它可以无缝地集成到任何 CNN 架构中,增加的负担微不足道,并且可以与基础 CNN 一起进行端到端的训练。

我们通过在 ImageNet-1KMS COCO 检测和 VOC 2007 检测数据集上的广泛实验验证了 CBAM。我们的实验显示,各种模型在分类和检测性能上都取得了稳定的改进,证明了 CBAM 的广泛适用性。代码和模型将公开提供。


问题背景

在计算机视觉领域,尤其是图像识别任务中,卷积神经网络(CNN)已经取得了显著的成功。然而,随着研究的深入,学者们开始关注如何通过增强网络的深度、宽度和模型复杂度来进一步提升性能。这篇文章介绍了一个全新的视角——注意力机制(Attention Mechanism),它能够更智能地从输入的特征图中提取有用信息,通过突出重要的特征并抑制不必要的信息来增强模型的表现力。


核心概念

文章的核心概念是引入了一个名为“卷积块注意力模块”(Convolutional Block Attention Module, CBAM),这是一种轻量级但非常有效的注意力机制,设计用来顺序推断通道和空间维度的注意力图,这些注意力图随后被用来乘以输入特征图,以进行自适应特征提炼。CBAM 的设计易于集成到任何 CNN 架构中,几乎不增加额外计算负担。


模块的操作步骤


在这里插入图片描述

CBAM 的概述。该模块包含两个顺序执行的子模块:通道和空间。在深度网络的每个卷积块中,中间特征图都会通过我们的模块(CBAM)进行自适应优化。


CBAM 模块包括两个主要的子模块:通道注意力模块和空间注意力模块。首先,通道注意力模块通过压缩特征图的空间维度来强调“重要的”通道;其次,空间注意力模块则聚焦于“哪里”是图像的重要部分,这是对通道注意力的补充。这两个模块的顺序应用确保了网络能够综合考虑哪些特征是重要的,从而更精确地调整特征图。


在这里插入图片描述

每个注意力子模块的示意图。如图所示,通道子模块同时利用最大池化输出和平均池化输出,并通过一个共享的网络进行处理;而空间子模块则利用在通道轴上进行池化的两个相似输出,并将它们传递给一个卷积层。


文章贡献

这篇文章的主要贡献包括:

  1. 提出了一个简单而有效的注意力模块 CBAM,能够广泛应用于各种 CNN 模型以增强其表征能力。
  2. 通过在 ImageNet-1KMS COCOVOC 2007 数据集上的广泛实验验证了 CBAM 的有效性。
  3. 展示了 CBAM 在多种模型上一致地提升分类和检测性能,证明了其广泛的适用性。

实验结果与应用

实验结果表明,CBAM 能在不同的网络架构中提供一致的性能提升。无论是在图像分类还是在对象检测的任务中,CBAM 增强的网络都比基线模型表现更好。例如,在使用 ResNet-50 基线的 ImageNet 分类任务中,CBAM 能显著降低误差,从而提高准确率。此外,CBAM 对计算和参数的额外要求极低,使其非常适合集成到现有的复杂网络中,甚至是轻量级网络中,如 MobileNet


对未来工作的启示

CBAM 的成功展示了注意力机制在深度学习中的潜力,特别是在自动强调重要特征并抑制次要特征方面。这种机制不仅可以提高模型的表现,还可以提高模型对输入数据中的噪声和不相关信息的鲁棒性。未来的工作可以探索将 CBAM 集成到更多类型的神经网络中,或者开发更先进的注意力机制,以解决更广泛的问题,如视频处理和自然语言处理。CBAM 的设计思想也可能激发研究人员思考如何通过注意力机制来优化模型的计算效率和性能。


代码

# https://openaccess.thecvf.com/content_ECCV_2018/papers/Sanghyun_Woo_Convolutional_Block_Attention_ECCV_2018_paper.pdf
import numpy as np
import torch
from torch import nn
from torch.nn import initclass ChannelAttention(nn.Module):def __init__(self, channel, reduction=16):super().__init__()self.maxpool = nn.AdaptiveMaxPool2d(1)self.avgpool = nn.AdaptiveAvgPool2d(1)self.se = nn.Sequential(nn.Conv2d(channel, channel // reduction, 1, bias=False),nn.ReLU(),nn.Conv2d(channel // reduction, channel, 1, bias=False),)self.sigmoid = nn.Sigmoid()def forward(self, x):max_result = self.maxpool(x)avg_result = self.avgpool(x)max_out = self.se(max_result)avg_out = self.se(avg_result)output = self.sigmoid(max_out + avg_out)return outputclass SpatialAttention(nn.Module):def __init__(self, kernel_size=7):super().__init__()self.conv = nn.Conv2d(2, 1, kernel_size=kernel_size, padding=kernel_size // 2)self.sigmoid = nn.Sigmoid()def forward(self, x):max_result, _ = torch.max(x, dim=1, keepdim=True)avg_result = torch.mean(x, dim=1, keepdim=True)result = torch.cat([max_result, avg_result], 1)output = self.conv(result)output = self.sigmoid(output)return outputclass CBAM(nn.Module):def __init__(self, channel=512, reduction=16, kernel_size=7):super().__init__()self.ca = ChannelAttention(channel=channel, reduction=reduction)self.sa = SpatialAttention(kernel_size=kernel_size)def forward(self, x):b, c, _, _ = x.size()residual = xout = x * self.ca(x)out = out * self.sa(out)return out + residualif __name__ == "__main__":input = torch.randn(64, 256, 8, 8)model = CBAM(channel=256, reduction=16, kernel_size=7)output = model(input)print(output.shape)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/661339.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

基于遗传算法的TSP算法(matlab实现)

一、理论基础 TSP(traveling salesman problem,旅行商问题)是典型的NP完全问题,即其最坏情况下的时间复杂度随着问题规模的增大按指数方式增长,到目前为止还未找到一个多项式时间的有效算法。TSP问题可描述为:已知n个城市相互之间的距离&…

【报错处理】ib_write_bw执行遇到Couldn‘t listen to port 18515原因与解决办法?

要点 要点: ib默认使用18515命令 相关命令: netstat -tuln | grep 18515 ib_write_bw --help |grep port# server ib_write_bw --ib-devmlx5_1 --port88990 # client ib_write_bw --ib-devmlx5_0 1.1.1.1 --port88990现象: 根因&#xff…

报错“Install Js dependencies failed”【鸿蒙开发Bug已解决】

文章目录 项目场景:问题描述原因分析:解决方案:此Bug解决方案总结Bug解决方案寄语项目场景: 最近也是遇到了这个问题,看到网上也有人在询问这个问题,本文总结了自己和其他人的解决经验,解决了【报错“Install Js dependencies failed”】的问题。 报错如下 问题描述 …

element的el-table 解决表格多页选择数据时,数据被清空

问题:切换页码时,勾选的数据会被清空 重点看我圈出来的,直接复制,注意,我这里 return row.productId;一般大家的是 return row.id,根据接口定的唯一变量 :row-key"getRowKeys"​​​​​​​:reserve-sele…

【GitHub】github学生认证,在vscode中使用copilot的教程

github学生认证并使用copilot教程 写在最前面一.注册github账号1.1、注册1.2、完善你的profile 二、Github 学生认证注意事项:不完善的说明 三、Copilot四、在 Visual Studio Code 中安装 GitHub Copilot 扩展4.1 安装 Copilot 插件4.2 配置 Copilot 插件&#xff0…

光伏储能是什么意思?有什么好处?

随着全球能源需求的持续增长和对环保要求的不断提高,新能源技术的发展已成为全球的热门话题。光伏储能作为其中的一项重要技术,正在逐渐受到人们的关注。那么,光伏储能是什么意思?它又有哪些好处呢? 一、光伏储能的定义…

YOLOv5入门(二)处理自己数据集(标签统计、数据集划分、数据增强)

上一节中我们讲到如何使用Labelimg工具标注自己的数据集,链接:YOLOv5利用Labelimg标注自己数据集,完成1658张数据集的预处理,接下来将进一步处理这批数据,通常是先划分再做数据增强。 目录 一、统计txt文件各标签类型…

Java中优雅实现泛型类型的强制转换

在Java中经常遇到将对象强制转换成泛型类的情况&#xff1a; Map<String, Object> data Map.of("name", "XiaoMing","age", 17,"scores", List.of(80, 90, 70) );List<Integer> scores (List<Integer>) data.get…

Python_GUI框架 PyQt 与 Pyside6的介绍

Python_GUI框架 PyQt 与 Pyside6的介绍 一、简介 在Python的GUI&#xff08;图形用户界面&#xff09;开发领域&#xff0c;PyQt和PySide6是两个非常重要的工具包。它们都基于Qt库&#xff0c;为Python开发者提供了丰富的GUI组件和强大的功能。当然Python也有一些其他的GUI工…

Linux 的静态库和动态库

本文目录 一、静态库1. 创建静态库2. 静态库的使用 二、动态库1. 为什么要引入动态库呢&#xff1f;2. 创建动态库3. 动态库的使用4. 查看可执行文件依赖的动态库 一、静态库 在编译程序的链接阶段&#xff0c;会将源码汇编生成的目标文件.o与引用到的库&#xff08;包括静态库…

Java学习第01天-Java及开发序言

目录 Java技术体系 Java安装 Hello World程序 JDK & JRE IDEA安装和使用 Java技术体系 技术体系说明Java SE(Java Standard Edition)&#xff1a;标准版 Java技术的核心和基础Java EE(Java Enterprise Edition)&#xff1a;企业版企业级应用开发的一套解决方案Java M…

设计模式 --6组合模式

文章目录 组合模式应用场景组合模式概念组合模式结构图透明方式和安全方式什么时候使用组合模式公司管理系统使用 组合模式来构架组合模式的好处 组合模式应用场景 整体和部分可以被一致性对待 比如人力资源部 财务部的管理功能可以复用于分公司的功能 可以引入一种 树状的结构…