目录
•1、网络整体框架
•2 、Patch Merging
•3 、W-MSA
MSA模块计算量
W-MSA模块计算量
•4、 SW-MSA
•5 、Relative Position Bias
•1、网络整体框架
•2 、Patch Merging
这里看着挺复杂,其实就相当于先对特征图进行LayerNorm,然后再进行一个卷积核大小为2×2,步距为2的深度可分离卷积。
•3 、W-MSA
MSA模块计算量
W-MSA模块计算量
•4、 SW-MSA
采用W-MSA模块时,只会在每个窗口内进行自注意力计算,所以窗口与窗口之间是无法进行信息传递的。为了解决这个问题,作者引入了SW-MSA模块,即进行偏移的W-MSA。可以理解成窗口从左上角分别向右侧和下方各偏移了M/2
可以发现通过将窗口进行偏移后,由原来的4个窗口变成9个窗口了。后面又要对每个窗口内部进行MSA,为了避免进行太多的窗口多头自注意力
为了防止不同窗口之间的信息乱窜,在实际计算中使用的是masked MSA即带蒙板mask的MSA,这样就能够通过设置蒙板来隔绝不同区域的信息。
•5 、Relative Position Bias
这里描述的是相对位置索引,也就是相对位置关系,并不是相对位置偏置参数。可以根据相对位置索引去获取对应的参数。关键是怎么根据位置索引获取相对位置偏置参数?
为了方便把二维索引转成一维索引。但如果将行标和列表直接简单相加会出现问题。比如相对位置索引中有(0 , -1)和(-1 , 0) 在二维的相对位置索引中明显是代表不同的位置,但如果简单相加都等于-1那就出问题了。
这样每个位置就得到了自己唯一的相对位置索引
我们可以创建一个可训练的相对位置偏置列表,在列表之找到对应的相对位置偏置。