模块出处
[link] [code] [NIPS 22] SegNeXt: Rethinking Convolutional Attention Design for Semantic Segmentation
模块名称
Multi-Scale Convolutional Attention (MSCA)
模块作用
多尺度特征提取,更大感受野
模块结构
模块代码
import torch
import torch.nn as nnclass MSCA(nn.Module):def __init__(self, dim):super(MSCA, self).__init__()self.conv0 = nn.Conv2d(dim, dim, 5, padding=2, groups=dim)self.conv0_1 = nn.Conv2d(dim, dim, (1, 7), padding=(0, 3), groups=dim)self.conv0_2 = nn.Conv2d(dim, dim, (7, 1), padding=(3, 0), groups=dim)self.conv1_1 = nn.Conv2d(dim, dim, (1, 11), padding=(0, 5), groups=dim)self.conv1_2 = nn.Conv2d(dim, dim, (11, 1), padding=(5, 0), groups=dim)self.conv2_1 = nn.Conv2d(dim, dim, (1, 21), padding=(0, 10), groups=dim)self.conv2_2 = nn.Conv2d(dim, dim, (21, 1), padding=(10, 0), groups=dim)self.conv3 = nn.Conv2d(dim, dim, 1)def forward(self, x):u = x.clone()attn = self.conv0(x)attn_0 = self.conv0_1(attn)attn_0 = self.conv0_2(attn_0)attn_1 = self.conv1_1(attn)attn_1 = self.conv1_2(attn_1)attn_2 = self.conv2_1(attn)attn_2 = self.conv2_2(attn_2)attn = attn + attn_0 + attn_1 + attn_2attn = self.conv3(attn)return attn * uif __name__ == '__main__':x = torch.randn([1, 512, 16, 16])msca = MSCA(512)out = msca(x)print(out.shape) # 1, 512, 16, 16
原文表述
如图2(a)所示,MSCA包含三个部分:一个深度卷积以汇总局部信息;一个多分支深度条带卷积以获取多尺度上下文;一个1×1卷积以建模不同通道之间的信息。该1×1卷积的输出将直接作为注意力以对MSCA的输出结果进行后处理加权。