【深度学习注意力机制系列】—— SKNet注意力机制(附pytorch实现)

SKNet(Selective Kernel Network)是一种用于图像分类和目标检测任务的深度神经网络架构,其核心创新是引入了选择性的多尺度卷积核(Selective Kernel)以及一种新颖的注意力机制,从而在不增加网络复杂性的情况下提升了特征提取的能力。SKNet的设计旨在解决多尺度信息融合的问题,使网络能够适应不同尺度的特征。

1. 核心思想

SKNet的核心思想是**通过选择性地应用不同尺度的卷积核,从而在不同层级上捕捉多尺度特征。**为了实现这一点,SKNet引入了一个选择模块,用于自适应地决定在每个通道上使用哪些尺度的卷积核。这种选择性的多尺度卷积核有助于提升特征表示的能力,使网络更具适应性和泛化能力。

2. 结构

SKNet的结构如下:

在这里插入图片描述

实现机制:

  • split:对特征图进行多分支分离卷积,各分支使用不同的卷积核(感受野不同)进行特征提取。(并未对原始特征图进行拆解分离,只是使用不同的卷积核对原始特征图进行卷积操作)。假设分支为n,则特征图维度变换为 (c, h, w) -> (n, c, h, w),原文中n=2。

  • Fuse:将多个分支的特征图提取结果相加。特征图维度变换为 (n, c, h, w) -> (c, h, w)。再通过全局平均池,特征图维度变换为 (c, h, w) -> (c, 1, 1),然后利用全连接层进行降维(限制了最低维度,通过全连接层生成d×1的向量(图中的z),公式如图中所示(δ表示ReLU激活函数,B表示Batch Noramlization,W是一个d×C的维的)。d的取值是由公式d = max(C/r,L)确定,r是一个缩小的比率(与SENet中相似),L表示d的最小值,原文实验中L的值为32。),再利用两个(或多个,和分支数目相同,原论文中为两个)全连接层进行升维,得到两个(多个)维度同降维前相同的特征图(向量)。在对两个特征向量进行softmax处理。假设分支为n,则特征图维度为 n个(c, 1, 1) ,原文中n=2,即a->(c, 1, 1), b->(c, 1, 1)。

  • select:利用softmax处理后的多个特征向量分别乘以第一步中的多分支提取的特征图结果。特征维度变化为n个(c, 1 ,1) * n 个(c, h ,w) = (n, c, h, w)。最后将n个特征图进行相加。

3. 优势

SKNet的设计在以下几个方面具有优势:

  • 多尺度信息融合

通过选择性地应用不同尺度的卷积核,SKNet能够有效地融合多尺度的特征信息。这有助于网络捕捉不同层次的视觉特征,提高了特征的表征能力。

  • 自适应性

选择模块使网络能够自适应地选择卷积核的尺度,从而适应不同任务和图像的特点。这种自适应性能够使网络在各种场景下都能表现出色。

  • 减少计算成本

尽管引入了多尺度卷积核,但由于选择模块的存在,SKNet只会选择一部分卷积核进行计算,从而减少了计算成本,保持了网络的高效性。

4.代码实现

class SKNet(nn.Module):def __init__(self, in_channels, out_channels, stride=1, M=2, r=16, L=32):""":param in_channels:  输入通道维度:param out_channels: 输出通道维度   原论文中 输入输出通道维度相同:param stride:  步长,默认为1:param M:  分支数:param r: 特征Z的长度,计算其维度d 时所需的比率(论文中 特征S->Z 是降维,故需要规定 降维的下界):param L:  论文中规定特征Z的下界,默认为32采用分组卷积: groups = 32,所以输入channel的数值必须是group的整数倍"""super(SKNet, self).__init__()d = max(in_channels // r, L)  self.M = Mself.out_channels = out_channelsself.conv = nn.ModuleList() for i in range(M):self.conv.append(nn.Sequential(nn.Conv2d(in_channels, out_channels, 3, stride, padding=1 + i, dilation=1 + i, groups=32, bias=False),nn.BatchNorm2d(out_channels),nn.ReLU(inplace=True)))self.global_pool = nn.AdaptiveAvgPool2d(output_size=1) self.fc1 = nn.Sequential(nn.Conv2d(out_channels, d, 1, bias=False),nn.BatchNorm2d(d),nn.ReLU(inplace=True))  # 降维self.fc2 = nn.Conv2d(d, out_channels * M, 1, 1, bias=False)  self.softmax = nn.Softmax(dim=1) def forward(self, input):batch_size = input.size(0)output = []for i, conv in enumerate(self.conv):output.append(conv(input))U = reduce(lambda x, y: x + y, output)  s = self.global_pool(U)  z = self.fc1(s)a_b = self.fc2(z) a_b = a_b.reshape(batch_size, self.M, self.out_channels, -1) a_b = self.softmax(a_b) a_b = list(a_b.chunk(self.M, dim=1))  a_b = list(map(lambda x: x.reshape(batch_size, self.out_channels, 1, 1),a_b))  V = list(map(lambda x, y: x * y, output,a_b))  V = reduce(lambda x, y: x + y,V)  return V

总结

SKNet是一种创新的深度神经网络架构,通过引入选择性的多尺度卷积核和注意力机制,提升了特征提取的能力。其核心结构包括选择模块和SK卷积层,能够有效地融合多尺度信息、自适应地调整卷积核的尺度,并减少计算成本。这使得SKNet在图像分类和目标检测等任务中取得了优越的性能。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/60164.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

轻量级锁实现1——结构体解析、初始化

瀚高数据库 目录 环境 文档用途 详细信息 环境 系统平台:Linux x86-64 Red Hat Enterprise Linux 7 版本:14 文档用途 从底层理解轻量级锁的实现,从保护共享内存的角度理解轻量级锁的使用场景,包括上锁、等待、释放,理…

MongoDB SQL

Microsoft Windows [版本 6.1.7601] 版权所有 (c) 2009 Microsoft Corporation。保留所有权利。C:\Users\Administrator>cd C:\MongoDB\Server\3.4\binC:\MongoDB\Server\3.4\bin> C:\MongoDB\Server\3.4\bin> C:\MongoDB\Server\3.4\bin>net start MongoDB 请求的…

vue中使用this.$refs获取不到子组件的方法,属性方法都为undefined的解决方法

问题描述 vue2中refs获取不到子组件中的方法?,而获取到的是undefined 原因及解决方案: 第一种、在循环中注册了很多个ref 因为注册了多个ref,获取是不能单单知识refs.xxx,需要使用数组和索引来获取具体一个组件refs[…

游戏类APP的如何设置广告场景最大化用户价值?提升变现收益?

最近几年,游戏类市场的不断增长激发了广告主预算的不断投入,越来越多的游戏类APP通过广告变现的方式增收,并且获得持续上涨的eCPM。 具有潜力的游戏品类参考 根据游戏品类的增长数据和广告收益规模,休闲/模拟/街机均为收益正向的…

linuxARM裸机学习笔记(5)----定时器按键消抖和高精度延时实验

定时器按键消抖 之前的延时消抖,是直接借助delay函数进行的,但是这样会浪费CPU的性能。我们采用延时函数的方式实现,可以实现快进快出。 定时器消抖,必须是在t3的时间点才可以,当在t1,t2的时间点每次进入中断函数都要…

模仿火星科技 基于cesium+ 贴地测量+可编辑

当您进入Cesium的编辑贴地测量世界,下面是一个详细的操作过程,帮助您顺利使用这些功能: 1. 创建提示窗: 启动Cesium应用,地图场景将打开,欢迎您进入编辑模式。在屏幕的一角,一个友好的提示窗将…

【云原生】Kubernetes节点亲和性分配 Pod

目录 1 给节点添加标签 2 根据选择节点标签指派 pod 到指定节点[nodeSelector] 3 根据节点名称指派 pod 到指定节点[nodeName] 4 根据 亲和性和反亲和性 指派 pod 到指定节点 5 节点亲和性权重 6 pod 间亲和性和反亲和性及权重 7 污点和容忍度 8 Pod 拓扑分布约束 官方…

【基础IO】动静态库 {动静态库的创建和使用;动态库的加载;默认优先使用动态链接;为什么要有库;动态链接的优缺点;静态链接的优缺点;一些有趣的库}

动静态库 一、静态库(.a) 1.1 如何创建静态库? 编写源文件与头文件。注意:库的源文件没有main函数! 将所有的源文件编译生成目标文件。(如果只提供.h和.o给用户,用户也可以成功编译运行。但这样的做法太过麻烦,且容易…

Maven出现报错 ; Unable to import maven project: See logs for details错误的多种解决方法

问题现象; IDEA版本&#xff1a; Maven 版本 &#xff1a; 3.3.9 0.检查 maven 的设置 &#xff1a;F:\softeware\maven\apache-maven-3.9.3\conf 检查setting.xml 配置 本地仓库<localRepository>F:\softeware\maven\local\repository</localRepository>镜像…

Improved Deep Metric Learning with Multi-class N-pair Loss Objective

Improved Deep Metric Learning with Multi-class N-pair Loss Objective 来源&#xff1a; NIPS’2016NEC Laboratories America 文章目录 Improved Deep Metric Learning with Multi-class N-pair Loss ObjectiveDistance Metric LearningDeep Metric Learning with Multip…

redis的持久化

第一章、redis的持久化 1.1&#xff09;持久化概述 ①持久化可以理解为将数据存储到一个不会丢失的地方&#xff0c;Redis 的数据存储在内存中&#xff0c;电脑关闭数据就会丢失&#xff0c;所以放在内存中的数据不是持久化的&#xff0c;而放在磁盘就算是一种持久化。 ②为…

python实现简单的爬虫功能

前言 Python是一种广泛应用于爬虫的高级编程语言&#xff0c;它提供了许多强大的库和框架&#xff0c;可以轻松地创建自己的爬虫程序。在本文中&#xff0c;我们将介绍如何使用Python实现简单的爬虫功能&#xff0c;并提供相关的代码实例。 如何实现简单的爬虫 1. 导入必要的…