深度学习——常见注意力机制

1.SENet

SENet属于通道注意力机制。2017年提出,是imageNet最后的冠军

SENet采用的方法是对于特征层赋予权值。

重点在于如何赋权

1.将输入信息的所有通道平均池化。
2.平均池化后进行两次全连接,第一次全连接链接的神经元较少,第二次全连接神经元数和通道数一致
3.将Sigmoid的值固定为0-1之间
4.将权值和特征层相乘。

在这里插入图片描述

import torch
import torch.nn as nn
import mathclass se_block(nn.Module):def __init__(self, channel, ratio=16):super(se_block, self).__init__()self.avg_pool = nn.AdaptiveAvgPool2d(1)self.fc = nn.Sequential(nn.Linear(channel, channel // ratio, bias=False),nn.ReLU(inplace=True),nn.Linear(channel // ratio, channel, bias=False),nn.Sigmoid())def forward(self, x):b, c, _, _ = x.size()y = self.avg_pool(x).view(b, c)y = self.fc(y).view(b, c, 1, 1)return x * y

2.ECANet

细心的人会发现,全连接其实是一个非常耗费算力的东西,对于边缘设备的压力非常大,所以ECANet觉得SENet并不需要那么多的全连接,我们直接在GAP后做一维卷积,而后取sigmoid为0-1来获取权值即可。

ECANet认为SE的全通道信息捕获是多此一举,而卷积就有很好的跨通道信息获取能力。
在这里插入图片描述

class eca_block(nn.Module):def __init__(self, channel, b=1, gamma=2):super(eca_block, self).__init__()kernel_size = int(abs((math.log(channel, 2) + b) / gamma))kernel_size = kernel_size if kernel_size % 2 else kernel_size + 1self.avg_pool = nn.AdaptiveAvgPool2d(1)self.conv = nn.Conv1d(1, 1, kernel_size=kernel_size, padding=(kernel_size - 1) // 2, bias=False) self.sigmoid = nn.Sigmoid()def forward(self, x):y = self.avg_pool(x)y = self.conv(y.squeeze(-1).transpose(-1, -2)).transpose(-1, -2).unsqueeze(-1)y = self.sigmoid(y)return x * y.expand_as(x)

4.GCNet

GCNet是我们项目的模型中使用的一种注意力机制

GCNet主要借鉴了SENet和NLNet的优点,主要基于NLNet,把NLNet的计算量削减了数倍

先看他是怎么用NLNet的

NLNet原公式
在这里插入图片描述
改进后的NLNet公式
在这里插入图片描述

改进的区别就是去掉了Wz系数。

Wz系数的削减主要是对图像中的观察得出的创意。
在这里插入图片描述
作者说,attention map在不同位置上计算的结果几乎一致,那么我们只需要计算一次然后共享attention map应该也可以获得很好的效果,并且计算量可以下降到1/(W*H)。

Simple NL Block和NL Block的结构对比如图所示,并且经过文章的实验表明,简化后的性能与原本的性能相当。
在这里插入图片描述

接着,作者基于S-NLNet和SENet的有点提出了GCNet

(1) 相比于SNL,SNL中的transform的1x1卷积在res5中是2048x1x1x2048,其计算量较大,所以借鉴SE的方法,加入压缩因子,为了更好的优化,还加入了layernorm。
(2)相比于SE,一方面是提取的全局信息更加充分(其实在后续的实验中说服力不是很强,单独avg pooling+add,只掉了0.3个点,但是更加简洁),另一方面则是加号和乘号的区别,而且在实验结果上,加号比乘号有显著的优势。

import torch
import torch.nn as nn
import torchvisionclass GlobalContextBlock(nn.Module):def __init__(self,inplanes,ratio,pooling_type='att',fusion_types=('channel_add', )):super(GlobalContextBlock, self).__init__()assert pooling_type in ['avg', 'att']assert isinstance(fusion_types, (list, tuple))valid_fusion_types = ['channel_add', 'channel_mul']assert all([f in valid_fusion_types for f in fusion_types])assert len(fusion_types) > 0, 'at least one fusion should be used'self.inplanes = inplanesself.ratio = ratioself.planes = int(inplanes * ratio)self.pooling_type = pooling_typeself.fusion_types = fusion_typesif pooling_type == 'att':self.conv_mask = nn.Conv2d(inplanes, 1, kernel_size=1)self.softmax = nn.Softmax(dim=2)else:self.avg_pool = nn.AdaptiveAvgPool2d(1)if 'channel_add' in fusion_types:self.channel_add_conv = nn.Sequential(nn.Conv2d(self.inplanes, self.planes, kernel_size=1),nn.LayerNorm([self.planes, 1, 1]),nn.ReLU(inplace=True),  # yapf: disablenn.Conv2d(self.planes, self.inplanes, kernel_size=1))else:self.channel_add_conv = Noneif 'channel_mul' in fusion_types:self.channel_mul_conv = nn.Sequential(nn.Conv2d(self.inplanes, self.planes, kernel_size=1),nn.LayerNorm([self.planes, 1, 1]),nn.ReLU(inplace=True),  # yapf: disablenn.Conv2d(self.planes, self.inplanes, kernel_size=1))else:self.channel_mul_conv = Nonedef spatial_pool(self, x):batch, channel, height, width = x.size()if self.pooling_type == 'att':input_x = x# [N, C, H * W]input_x = input_x.view(batch, channel, height * width)# [N, 1, C, H * W]input_x = input_x.unsqueeze(1)# [N, 1, H, W]context_mask = self.conv_mask(x)# [N, 1, H * W]context_mask = context_mask.view(batch, 1, height * width)# [N, 1, H * W]context_mask = self.softmax(context_mask)# [N, 1, H * W, 1]context_mask = context_mask.unsqueeze(-1)# [N, 1, C, 1]context = torch.matmul(input_x, context_mask)# [N, C, 1, 1]context = context.view(batch, channel, 1, 1)else:# [N, C, 1, 1]context = self.avg_pool(x)return contextdef forward(self, x):# [N, C, 1, 1]context = self.spatial_pool(x)out = xif self.channel_mul_conv is not None:# [N, C, 1, 1]channel_mul_term = torch.sigmoid(self.channel_mul_conv(context))out = out * channel_mul_termif self.channel_add_conv is not None:# [N, C, 1, 1]channel_add_term = self.channel_add_conv(context)out = out + channel_add_termreturn outif __name__=='__main__':model = GlobalContextBlock(inplanes=16, ratio=0.25)print(model)input = torch.randn(1, 16, 64, 64)out = model(input)print(out.shape)

4.CA注意力机制

CA机制也是和之前的GCNet一样对两个已有注意力(SENet和CBAM)进行了改进。

CA提出

1.SENet作为通道注意力机制,侧重通道之前的依赖关系,忽略了空间特征的作用。
2.CBAM可以一定程度弥补,但是CBAM对于长程依赖有待改进。

经过融合改进后,CA机制有以下优点

1、不仅考虑了通道信息,还考虑了方向相关的位置信息。
2、足够的灵活和轻量,能够简单的插入到轻量级网络的核心模块中。

CA机制的算法流程图如下

在这里插入图片描述
1.CA机制为了避免将空间特征全都压缩到通道中,放弃了全局平均池化,转为分别对x和y方向进行
别生成尺寸为C ∗ H ∗ 1 和C ∗ 1 ∗ W 的attention map
在这里插入图片描述
2.将生成的两个attention map进行池化,然后concat,然后进行F1操作(利用1*1卷积核进行降维,如SE注意力中操作)和激活操作,生成特征图f

在这里插入图片描述
这图怎么这么大?
3.沿着空间维度,再将f进行split操作,分别得到h和w的特征图后再用1 × 1卷积进行升维度操作,结合sigmoid激活函数得到最后的注意力向量gh和gw

代码

class CoordAtt(nn.Module):def __init__(self, inp, oup, groups=32):super(CoordAtt, self).__init__()self.pool_h = nn.AdaptiveAvgPool2d((None, 1))self.pool_w = nn.AdaptiveAvgPool2d((1, None))mip = max(8, inp // groups)self.conv1 = nn.Conv2d(inp, mip, kernel_size=1, stride=1, padding=0)self.bn1 = nn.BatchNorm2d(mip)self.conv2 = nn.Conv2d(mip, oup, kernel_size=1, stride=1, padding=0)self.conv3 = nn.Conv2d(mip, oup, kernel_size=1, stride=1, padding=0)self.relu = h_swish()def forward(self, x):identity = xn,c,h,w = x.size()x_h = self.pool_h(x)x_w = self.pool_w(x).permute(0, 1, 3, 2)y = torch.cat([x_h, x_w], dim=2)y = self.conv1(y)y = self.bn1(y)y = self.relu(y) x_h, x_w = torch.split(y, [h, w], dim=2)x_w = x_w.permute(0, 1, 3, 2)x_h = self.conv2(x_h).sigmoid()x_w = self.conv3(x_w).sigmoid()x_h = x_h.expand(-1, -1, h, w)x_w = x_w.expand(-1, -1, h, w)y = identity * x_w * x_hreturn y

明日:ODConv,数据结构复习,套磁老师

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/54638.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

LeetCode[207]课程表

难度:Medium 题目: 你这个学期必须选修 numCourses 门课程,记为 0 到 numCourses - 1 。 在选修某些课程之前需要一些先修课程。 先修课程按数组 prerequisites 给出,其中 prerequisites[i] [ai, bi] ,表示如果要学习…

Linux文本三剑客---grep、sed、awk

目录标题 1、grep1.1 命令格式1.2命令功能1.3命令参数1.4grep实战演练 2、sed2.1 认识sed2.2命令格式2.3常用选项options2.4地址定界2.5 编辑命令command2.6用法演示2.6.1常用选项options演示2.6.2地址界定演示2.6.3编辑命令command演示 3、awk3.1认识awk3.2常用命令选项3.3awk…

函数的学习

函数学习 最后附上全部java源码,可自行下载学习 文章目录 函数入门函数重载函数可变个数参数foreach输出传参 基本数据类型传参_引用数据类型文件夹展示所有里面的文件使用递归算法展示文件夹下所有文件1加到100的递归调用下载链接 函数入门 函数重载 public class…

【雕爷学编程】Arduino动手做(184)---快餐盒盖,极低成本搭建机器人实验平台

吃完快餐粥,除了粥的味道不错之外,我对个快餐盒的圆盖子产生了兴趣,能否做个极低成本的简易机器人呢?也许只需要二十元左右 知识点:轮子(wheel) 中国词语。是用不同材料制成的圆形滚动物体。简…

ceil(),floor(),round()函数C++详解

ceil&#xff08;&#xff09; ceil()函数是这样的&#xff1a; double ceil(double x) ceil函数可以把x上取整。 例子&#xff1a; #include <bits/stdc.h> using namespace std; int main() {double a, b;cin >> a >> b;printf("ceil(%.2f) %.2…

基于51单片机的点亮LED灯

目录 前言 一、整体目录结构 二、代码展示 三、main.c代码解析 四、下载到单片机中 总结 前言 首先我们先来了解一下LED发光二极管&#xff0c;二极管有两个极&#xff08;正极和负极&#xff09;&#xff0c;要想发光二极管导通点亮&#xff0c;必须要让正极电压&#xff1e;…

机器学习、人工智能、深度学习三者的区别

目录 1、三者的关系 2、能做些什么 3、阶段性目标 1、三者的关系 机器学习、人工智能&#xff08;AI&#xff09;和深度学习之间有密切的关系&#xff0c;它们可以被看作是一种从不同层面理解和实现智能的方法。 人工智能&#xff08;AI&#xff09;&#xff1a;人工智能是一…

K8S系列文章 之 编写自动化部署K8S脚本

介绍 通过ansible脚本shell实现自动化部署k8s基础集群(v1.25.0) 部署结构 1. 通过二进制部署包镜像安装k8s集群、目录etcd节点只支持1-3个节点、最多三个etcd节点 2. 因k8s版本相对较新、需要升级内核来支持后台程序、当前版本只支持Cento7&#xff0c;内核版本(5.19.4-1.el7…

管理类联考——写作——论说文——实战篇——行文篇——通用性强,解释多种现象的经典理论——析原因

前言 本节内容涉及“经济人假设”“自利性偏差”“机会成本”“沉没成本”“信息不对称”“科斯定理”“路径依赖”等理论。这些理论一般用在“现象分析式结构”中“析原因”的部分。 有时候也可以反过来使用&#xff0c;用于“提建议”的部分。 例如&#xff1a; 合作中&…

UML-构件图

目录 1.概述 2.构件的类型 3.构件和类 4.构件图 1.概述 构件图主要用于描述各种软件之间的依赖关系&#xff0c;例如&#xff0c;可执行文件和源文件之间的依赖关系&#xff0c;所设计的系统中的构件的表示法及这些构件之间的关系构成了构件图 构件图从软件架构的角度来描述…

【力扣每日一题】2023.8.5 合并两个有序链表

目录 题目&#xff1a; 示例&#xff1a; 分析&#xff1a; 代码&#xff1a; 题目&#xff1a; 示例&#xff1a; 分析&#xff1a; 题目给我们两个有序的链表&#xff0c;要我们保持升序的状态合并它们。 我们可以马上想要把两个链表都遍历一遍&#xff0c;把所有节点的…

Rocky(centos) jar 注册成服务,能开机自启动

概述 涉及&#xff1a;1&#xff09;sh 无法直接运行java命令&#xff0c;可以软连&#xff0c;此处是直接路径 2&#xff09;sh脚本报一堆空格换行错误&#xff1a;需将转成unix标准格式&#xff1b; #切换到上传的脚本路径 dos2unix 脚本文件名.sh 2&#xff09;SELINUX …