文章目录
- 前言
- Abstract
- Introduction
- Methods
- Problem Definition
- Network Overview
- Mask Incorporated Feature Extraction
- Cross Masked Attention Transformer
- Self-Attention Module
- Cross Masked Attention Module
- Prototypical Segmentation Module
- Iterative Refinement Framework
- 总结
前言
本文来自港科陈浩老师组发表在 MICCAI23 上的一篇有关 few-shot 在医学图像上的应用。方法简洁高效,可供参考。插一句题外话,医学图像做 few-shot 主要基于腹部器官的 3 个数据集,期待后续可以见到一些在更多数据集上更通用更有效的方法。
原论文链接:Few Shot Medical Image Segmentation with Cross Attention Transformer
Abstract
本文提出了一种基于交叉掩码注意力 Transformer
的少样本医学图像分割新框架 CAT-Net
:
- 通过挖掘
support
和query
图像之间的相关性,并限制模型仅关注有用的前景信息,来提高support
和query
特征的表达能力 - 同时,本文还进一步设计了一个迭代细化训练框架来优化查询
query
图像分割
Introduction
大多数 few-shot
分割方法都在学习如何学习(旨在学习元学习器),根据 support
图像及其相应的分割标签的知识预测 query
图像的分割,而这里的核心是:如何有效地将知识从 support
图像传递到 query
图像。现有的少样本分割方法主要集中在以下两个方面:
- 如何学习一个元学习器
- 如何更好地将知识从
support
图像传递到query
图像
尽管基于原型的方法效果已经不错,但它们通常忽略了训练过程中 support
和 query
特征之间的交互。
本文提出了一种名为 CAT-Net
的新型网络结构,其基于交叉注意力 Transformer:
- 可以更好地捕捉
support
图像和query
图像之间的相关性,促进support
和query
特征之间的相互作用,同时减少无用像素信息,提高特征表达能力和分割性能 - 此外,本文还提出了一个迭代训练框架,将先前的
support
分割结果反馈到注意力 Transformer 中,以有效增强并细化特征和分割结果。
Methods
Problem Definition
少样本分割(Few-shot segmentation,FSS
)的目的是通过只有少量标注的样本来分割新类别。在FSS
中,数据集被分为训练集 Dtrain
和测试集 Dtest
,其中训练集包含基类别 Ctrain
,测试集包含新类别 Ctest
,且 Ctrain
和 Ctest
没有交集。为了获得用于 FSS
的分割模型,采用了通常使用的 episode
训练方法。每个训练 / 测试 e p i s o d e ( S i , Q i ) \mathrm{episode(S_i,Q_i)} episode(Si,Qi) 实例化一个 N-way, K-shot
分割学习任务。具体而言:support 集 S i \mathrm{S_i} Si 包含 N 个类别的 K 个样本,而 query 集 Q i \mathrm{Q_i} Qi 包含同一类别的一个样本。FSS
模型通过 episode
训练以预测 query
图像的新类别。在模型推理测试时,模型直接在 Dtest
上进行评估,无需重新训练。
Network Overview
CAT-Net
主要由三部分组成:
- 带有mask的特征提取
MIFE
子网络,用于提取初始query
和support
特征以及query mask
- 交叉 mask 注意力
Transformer
模块CMAT
,其中query
和support
特征相互促进,从而提高query
预测的准确性 - 迭代细化框架,顺序应用
CMAT
模块以持续促进分割性能,整个框架以端到端的方式进行训练
Mask Incorporated Feature Extraction
MIFE
子网络接收查询和支持图像作为输入,生成它们各自的特征,同时集成支持掩膜。然后,使用一个简单的分类器来预测查询图像的分割结果:
- 具体地,首先使用一个特征提取器网络(即 ResNet-50)将查询和支持图像对 I q I^q Iq 和 I s I^s Is 映射到特征空间中,分别产生查询图像的多层特征图 F q F^q Fq 和支持图像的特征图 F s F^s Fs
- 接下来,将支持掩膜与 F s F^s Fs 进行池化,然后将其扩展并与 F q F^q Fq 和 F s F^s Fs 进行连接
- 此外,还将一个先验掩膜进一步与查询特征进行连接,通过像素级相似度图来增强查询和支持特征之间的相关性
- 最后,使用一个简单的分类器来处理查询特征,得到查询掩膜
Cross Masked Attention Transformer
CMAT
模块包括三个主要组成部分:自注意力模块、交叉掩码注意力模块,和原型分割模块
- 自注意力模块用于提取查询
query
特征和支持support
特征中的全局信息 - 交叉掩码注意力模块用于在传递前景信息的同时消除冗余的背景信息
- 原型分割模块用于生成查询图像的最终预测结果
Self-Attention Module
自注意力模块首先将查询特征 F 0 q F_0^q F0q 和支持特征 F 0 s F_0^s F0s 展平为 1D 序列,然后输入到两个相同的自注意力模块中。每个自注意力模块由一个多头注意力层(MHA)和一个多层感知器层(MLP)组成。给定一个输入序列 S S S,MHA
层首先使用不同的权重将序列投影为三个序列 Q Q Q, K K K 和 V V V。然后计算注意力矩阵 A A A,公式为:
其中, d d d 是输入序列的维度。注意力矩阵通过 softmax
函数归一化,并乘以值序列 V V V 以获得输出序列 O O O。MLP
层是一个简单的 1 × 1 1 \times 1 1×1 卷积层,将输出序列 O O O 映射到与输入序列 S S S 相同的维度。最终,将输出序列 O O O 添加到输入序列 S S S 中,并使用层归一化(LN
)对其进行规范化,以获得最终的输出序列 X X X。自注意力对齐编码器的输出特征序列分别表示为 X q X^q Xq 和 X s X^s Xs,分别对应于查询和支持特征
Cross Masked Attention Module
用于将查询特征和支持特征按照它们的前景信息结合起来
具体来说,给定查询特征 X q X^q Xq 和来自自注意力模块的支持特征 X s X^s Xs,首先使用不同的权重将输入序列投影到三个序列 K K K, Q Q Q 和 V V V 中,从而得到 K q K^q Kq、 Q q Q^q Qq、 V q V^q Vq 和 K s K^s Ks、 Q s Q^s Qs、 V s V^s Vs。以查询特征为例,交叉注意力矩阵通过下面的公式计算得到:
其中, d d d 表示查询特征的维度。这里使用的是点积注意力的形式,通过 K q K^q Kq 和 Q s Q^s Qs 的点积计算查询和支持之间的相关性。通过 d \sqrt{d} d 来缩放点积,防止在较高维度时点积的大小对注意力分布的影响过大
Prototypical Segmentation Module
通过 MAP
建立每个类别的原型 p c p_c pc,用于表示该类别的特征分布:
- K K K 是支持集中图像的数量
- m ( k , x , y , c ) s m_{(k,x,y,c)}^s m(k,x,y,c)s 是一个二进制掩模,表示位置 ( x , y ) (x,y) (x,y) 在支持特征 k k k 中是否属于类别 c c c
- F 1 s F_1^s F1s 是支持特征
对于每个类别 c c c,该原型是在所有支持图像中该类别对应位置的特征平均值,这样可以得到每个类别的原型 p c p_c pc
接着使用非参数度量学习方法进行分割。原型网络计算查询特征向量与原型 P = P c ∣ c ∈ C P=P_c|c \in C P=Pc∣c∈C 之间的距离。对所有类别应用 softmax
函数,生成查询分割结果:
- c o s ( ⋅ ) cos(·) cos(⋅) 表示余弦距离
- α \alpha α 是一个缩放因子,有助于在训练中反向传播梯度,本文中设置为 20
Iterative Refinement Framework
该模块的设计目的是优化查询和支持特征以及查询分割掩模。因此可通过迭代优化的思路进行精细化分割,第 i i i 次迭代后的结果由以下公式给出:
每个步骤的细分可表示如下:
其中 C M A ( ⋅ ) CMA(·) CMA(⋅) 表示自注意力和交叉掩码注意力模块, P r o t o ( ⋅ ) Proto(·) Proto(⋅) 代表原型分割模块,该公式表示通过多次迭代应用 CMA 和 Proto 模块,来获得增强的特征和优化的分割结果
总结
本文提出了一种用于 few-shot 医学图像分割的交叉注意力 Transformer 网络 CAT-Net。通过交叉掩码注意力模块实现了查询和支持特征之间的交互,增强了特征表达能力。此外,所提出的 CMAT 模块可以通过迭代优化的方式以持续提高分割性能,实验结果表明了每个模块的有效性以及模型相对于 SOTA 方法的卓越性能。其中论文中的各个组件属于即插即用模块,可很好的嵌入到 few-shot 任务中,以提高少样本分割的性能。