CBAM注意力机制(结构图加逐行代码注释讲解)

学CBAM前建议先学会SEnet(因为本篇涉及SEnet的重合部分会略加带过)->传送门

⒈结构图

下面这个是自绘的,有些许草率。。。

因为CBAM机制是由通道和空间两部分组成的,所以有这两个模块(左边是通道注意力机制,右边是空间注意力机制)

下面这两个是官方论文里的:

⒉机制流程讲解

SEnet只关注了通道注意力机制而忽略了空间上的一些简单特征,相比之下,CBAM将通道注意力机制和空间注意力机制进行一个结合,对输入进来的特征层,分别进行通道注意力机制的处理和空间注意力机制的处理,而是是先通道后空间,也就是第一张结构图表达的意思。

①首先是通道机制:

对于输入特征层,分别作全局最大池化和全局平均池化,输出结果分别送入一个共享全连接层(官方源码在这里和SEnet的全连接层一模一样),为什么叫共享全连接层?因为最大池化和平均池化的两条路线用的是这同一个全连接层。然后对两个结果(maxout和avgout)做加法,最后进行归一化操作,获得通道上的权重矩阵。

②然后是空间机制:

对于输入特征层,在每一个特征点的通道上取最大值和平均值,(这里和通道机制的最大池化和平均池化完全不同,通道机制里是在H、W两个维度求最大或平均,空间机制是在C一个维度上求最大和平均。)然后对两个结果(maxout和avgout)做拼接,也就是maxout的1*H*W与avgout的1*H*W进行拼接,得到2*H*W的张量,因此紧接着下一步就要进行一个7*7的卷积(conv)将通道压缩回1,最后还是进行归一化操作,获得空间上的权重矩阵。

③整体上:

对于输入特征层,输入特征层先乘上通道机制的输出权重(channel_out),然后再乘上空间上的输出权重(spatial_out)

⒊源码(pytorch框架实现)及逐行解释

import torch
from torch import nn
from torchsummary import summaryclass ChannelModule(nn.Module):def __init__(self, inputs, ratio=16):super(ChannelModule, self).__init__()_, c, _, _ = inputs.size()self.maxpool = nn.AdaptiveMaxPool2d(1)self.avgpool = nn.AdaptiveAvgPool2d(1)self.share_liner = nn.Sequential(nn.Linear(c, c // ratio),nn.ReLU(),nn.Linear(c // ratio, c))self.sigmoid = nn.Sigmoid()def forward(self, inputs):x = self.maxpool(inputs).view(inputs.size(0), -1)#ncmaxout = self.share_liner(x).unsqueeze(2).unsqueeze(3)#nchwy = self.avgpool(inputs).view(inputs.size(0), -1)avgout = self.share_liner(y).unsqueeze(2).unsqueeze(3)return self.sigmoid(maxout + avgout)class SpatialModule(nn.Module):def __init__(self):super(SpatialModule, self).__init__()self.maxpool = torch.maxself.avgpool = torch.meanself.concat = torch.catself.conv = nn.Conv2d(in_channels=2, out_channels=1, kernel_size=7, stride=1, padding=3)self.sigmoid = nn.Sigmoid()def forward(self, inputs):maxout, _ = self.maxpool(inputs, dim=1, keepdim=True)#n1hwavgout = self.avgpool(inputs, dim=1, keepdim=True)#n1hwouts = self.concat([maxout, avgout], dim=1)#n2hwouts = self.conv(outs)#n1hwreturn self.sigmoid(outs)class CBAM(nn.Module):def __init__(self, inputs):super(CBAM, self).__init__()self.channel_out = ChannelModule(inputs)self.spatial_out = SpatialModule()def forward(self, inputs):outs = self.channel_out(inputs) * inputsreturn self.spatial_out(outs) * outs

 解释:

①依赖包和SEnet解释的一样。

②整体上看,将通道机制和空间机制分别封装成类,再封装一个CBAM类来对这两个机制调用,其中用到的__init__构造方法(python称魔术方法)和foward函数(前向传播过程),这些模板和上面介绍SEnet时是一模一样的。

先来看通道机制:

class ChannelModule(nn.Module):#继承nn模块的Module类def __init__(self, inputs, ratio=16):#self必写,inputs接收输入特征张量,ratio是通道衰减因子super(ChannelModule, self).__init__()#调用父类构造_, c, _, _ = inputs.size()#获取通道数self.maxpool = nn.AdaptiveMaxPool2d(1)#nn模块的自适应二维最大池化self.avgpool = nn.AdaptiveAvgPool2d(1)#nn模块的自适应二维平均池化self.share_liner = nn.Sequential(nn.Linear(c, c // ratio),nn.ReLU(),nn.Linear(c // ratio, c))#这个共享全连接的3层和SEnet的一模一样,这里借助Sequential这个容器把这3个层整合在一起,方便forward函数去执行,直接调用share_liner(x)相当于直接执行了里面这3层self.sigmoid = nn.Sigmoid()#nn模块的Sigmoid函数def forward(self, inputs):x = self.maxpool(inputs).view(inputs.size(0), -1)#对于输入特征张量,做完最大池化后再重塑形状,view的第一个参数inputs.size(0)表示第一维度,显然就是n;-1表示会自适应的调整剩余的维度,在这里就将原来的(n,c,1,1)调整为了(n,c*1*1),后面才能送入全连接层(fc层)maxout = self.share_liner(x).unsqueeze(2).unsqueeze(3)#做完全连接后,再用unsqueeze解压缩,也就是还原指定维度,这里用了两次,分别还原2维度的h,和3维度的wy = self.avgpool(inputs).view(inputs.size(0), -1)avgout = self.share_liner(y).unsqueeze(2).unsqueeze(3)#y走的平均池化路线的代码和x是一样的解释return self.sigmoid(maxout + avgout)#最后相加两个结果并作归一化

再来看空间机制:(重复的模板就不再反复赘述了)

class SpatialModule(nn.Module):def __init__(self):super(SpatialModule, self).__init__()self.maxpool = torch.maxself.avgpool = torch.mean#和通道机制不一样!这里要进行的是在C这一个维度上求最大和平均,分别用的是torch库里的max方法和mean方法self.concat = torch.cat#torch的cat方法,用于拼接两个张量self.conv = nn.Conv2d(in_channels=2, out_channels=1, kernel_size=7, stride=1, padding=3)#nn模块的二维卷积,其中的参数分别是:输入通道(2),输出通道(1),卷积核大小(7*7),步长(1),灰度填充(3)self.sigmoid = nn.Sigmoid()def forward(self, inputs):maxout, _ = self.maxpool(inputs, dim=1, keepdim=True)#maxout接收特征点的最大值很好理解,为什么还要一个占位符?因为torch.max不仅返回张量最大值,还会返回索引,索引用不着所以直接忽略,dim=1表示在维度1(也就是nchw的c)上求最大值,keepdim=True表示要保持原来张量的形状avgout = self.avgpool(inputs, dim=1, keepdim=True)#torch.mean则只返回张量的平均值,至于参数的解释和上面是一样的outs = self.concat([maxout, avgout], dim=1)#torch.cat方法,传入一个列表,将列表中的张量在指定维度,这里是维度1(也就是nchw的c)拼接,即n*1*h*w拼接n*1*h*w得到n*2*h*wouts = self.conv(outs)#卷积压缩上面的n*2*h*w,又得到n*1*h*wreturn self.sigmoid(outs)

 最后看整体:

class CBAM(nn.Module):def __init__(self, inputs):super(CBAM, self).__init__()self.channel_out = ChannelModule(inputs)#获得通道权重self.spatial_out = SpatialModule()#获得空间权重def forward(self, inputs):outs = self.channel_out(inputs) * inputs #先乘上通道权重return self.spatial_out(outs) * outs #在乘完通道权重的基础上再乘上空间权重

⒋测试结果

大问题没有,但还是少了一些关键层,尤其是空间机制那里的拼接maxout和avgout,通道变为2再用卷积压缩回1的过程都没体现。。。只能说summary确实不太好使,或者说我没用对?网络层简写导致的?(最不可能是这个原因,因为我拿官方的源码测试也是summary出这些结果)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/193725.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

论文阅读:JINA EMBEDDINGS: A Novel Set of High-Performance Sentence Embedding Models

Abstract JINA EMBEDINGS构成了一组高性能的句子嵌入模型,擅长将文本输入转换为数字表示,捕捉文本的语义。这些模型在密集检索和语义文本相似性等应用中表现出色。文章详细介绍了JINA EMBEDINGS的开发,从创建高质量的成对(pairwi…

Java(二)(String的常见方法,ArrayList的常见方法)

String 创建string对象 package Helloworld;public class dome1 {public static void main(String[] args) {// 1.直接双引号得到字符串对象,封装字符串对象String name "lihao";System.out.println(name);// 2. new String 创建字符串对象,并调用构造器初始化字符…

搭建网关服务器实现DHCP自动分配、HTTP服务和免密登录

目录 一. 实验要求 二. 实验准备 三. 实验过程 1. 网关服务器新建网卡并改为仅主机模式 2. 修改新建网卡IP配置文件并重启服务 3. 搭建网关服务器的dhcp服务 4. 修改server2网卡配置文件重启服务并效验 5. 设置主机1的网络连接为仅主机模式 6. 给server2和网关服务器之…

【汇编】处理字符问题

文章目录 前言一、处理字符问题1.1 汇编语言如何处理字符1.2 asciiascii码是什么?ascii码表是什么? 1.3 汇编语言字符示例代码 二、大小写转换2.1 问题:对datasg中的字符串2.2 逻辑与和逻辑或2.3 程序:解决大小写转换的问题一个新…

devops底层是怎么实现的

DevOps的3大核心基础架构 简而言之,实现DevOps工具链,基本需要3个核心基础架构: SCM配置管理系统 Automation自动化系统 Cloud云(或者说可伸缩的、自服务的、虚拟化系统) SCM配置管理系统 SCM中所放置的内容又可以再…

ScalableMap

问题引入 传统方案在处理线性地图元素时忽略了其结构性约束,建图距离太近 方法 简介 结构引导BEV特征提取 一种新的层次稀疏地图表示方法 设计渐进解码机制和基于此表示的监督策略 组件 结构引导BEV表征 通过车载摄像头捕捉的环绕视图图像,利用Res…

我终于体会到了:代码竟然不可以运行,为什么呢?代码竟然可以运行,为什么呢?

废话不多说,直接上图 初看只当是段子,再看已是段中人 事情经过: 我在写动态顺序表的尾插函数时,写出了如下代码,可以跑,但是这段代码有一个bug暂时先不提 //动态顺序表的尾插 void SLPushBack(SL* psl, …

庖丁解牛:NIO核心概念与机制详解 02 _ 缓冲区的细节实现

文章目录 PreOverview状态变量概述Position 访问方法 Pre 庖丁解牛:NIO核心概念与机制详解 01 接下来我们来看下缓冲区内部细节 Overview 接下来将介绍 NIO 中两个重要的缓冲区组件:状态变量和访问方法 (accessor) 状态变量是"内部统计机制&quo…

【Spring Boot 源码学习】Banner 信息打印流程

Spring Boot 源码学习系列 Banner 信息打印流程 引言往期内容主要内容1. printBanner 方法2. 关闭 Banner 信息打印3. SpringApplicationBannerPrinter 类3.1 LOG 模式打印3.1.1 getBanner 方法3.1.1.1 新建 Banners3.1.1.2 添加 ImageBanner3.1.1.3 添加 ResourceBanner3.1.1.…

【IPC】消息队列

1、IPC对象 除了最原始的进程间通信方式信号、无名管道和有名管道外,还有三种进程间通信方式,这 三种方式称之为IPC对象 IPC对象分类:消息队列、共享内存、信号量(信号灯集) IPC对象也是在内核空间开辟区域,每一种IPC对象创建好…

酷柚易汛ERP - 序列号盘点操作指南

1、应用场景 将系统中开启序列号的商品数量与与实际存放的数量进行对比。 2、主要操作 2.1 录入序列号 打开【盘点】-【序列号盘点】,新增序列号盘点单,点击【SN】按钮,在弹框中输入序列号。 支持扫描枪录入序列号支持复制粘贴序列号录入…

Linux中系统时间同步

在Windwos中,系统时间的设置很简单,界面操作,通俗易懂,而且设置后,重启,关机都没关系。系统时间会自动保存在BIOS时钟里面,启动计算机的时候,系统会自动在BIOS里面取硬件时间&#x…