先前从大脑MRI诊断阿尔茨海默病的工作表明,卷积神经网络(CNNs)可以利用图像信息进行分类。然而,很少有研究关注这些模型如何利用表格信息,如患者统计数据或实验测量数据。这里介绍了动态仿射特征图变换(DAFT,Dy-namic Affine Feature Map Transform),这是一种用于表格神经网络的通用模块,它根据患者的表格临床信息动态地重新缩放和移动卷积层的特征图。实验表明,DAFT在结合3D图像进行预测方面非常有效,其平均balanced accuracy为0.622,平均c-index为0.748,优于其他模型。
来自:Combining 3D Image and Tabular Data via the Dynamic Affine Feature Map Transform
工程地址:https://github.com/ai-med/DAFT
目录
- 背景概述
- 相关工作
- 方法
- 实验
- 数据预处理
- 评估
背景概述
近年来,CNN已经成为从MRI中分类阿尔茨海默病的标准。CNN擅长从MRI中提取神经解剖学的高级信息。然而,大脑的MRI只能提供部分关于导致认知能力下降的潜在变化的观点。因此,临床医生和研究人员还需要依靠表格数据,如患者人口统计、家族史或脑脊液的实验测量来进行诊断。与图像信息相比,表格数据通常是低维的,单个变量捕获了丰富的临床知识。
由于图像和表格数据是相互补充的,因此我们希望在单个模型中合并两个信息源,以便一个信息源可以通知另一个信息源。有效的集成是具有挑战性的,因为图像和表格数据之间的维数不匹配。大多数现有的深度学习方法通过将潜在图像表示与网络最后一层的表格数据连接起来集成图像和表格数据。在这样的网络中,图像和表格部分只有最小的交互,并且在一个部分通知另一个部分的方式上受到限制。
作者提出通过动态仿射特征映射变换来增加CNN融合患者3D脑部MRI和表格数据信息的能力。DAFT是一个通用模块,可以集成到任何CNN架构中,在从3D图像和表格生物标志物中学到的高级概念之间建立双向信息交换。DAFT使用一个辅助神经网络,在图像和表格信息的条件下动态地激发或抑制卷积层的每个特征映射。在关于AD预测的实验中,DAFT比单独使用图像或表格数据具有更好的预测性能,并且在很大程度上优于之前在单个神经网络中结合图像和表格数据的方法。
相关工作
结合图像和表格数据的方法是首先在图像数据上训练CNN,并在第二个通常是线性的模型中结合表格数据预测(或者获得潜在表示)。通过这种方式,已经有人将脑MRI中提取的感兴趣区域与常规临床标志物相结合,以预测AD。
当使用单个网络将临床信息与最后一个FC层之前的潜在图像表示连接起来时,这种情况得到了缓解,比如:
- 使用组织病理学图像、基因组数据和人口统计学进行了生存预测;
- 使用海马体形状和临床标志物进行了痴呆的预测。
这种方法的缺点是表格数据只能线性地对最终预测做出贡献。如果串联之后是多层感知器MLP而不是单个FC层。
与上述密切相关的是,后来有人在连接之前对表格数据使用MLP,并在连接之后对组合表示使用MLP。
后续又有人使用一个辅助网络,该网络获取表格数据,并为其CNN的每个其他卷积层的每个特征图输出标量权重。因此,患者的表格数据可以在多个层面上放大或抑制图像的潜在表征的贡献。这种方法的缺点是,辅助网络中的权重数量随着CNN的深度呈二次增长,这很快就变得不切实际。
方法
DAFT希望CNN利用高维3D图像信息,并在其预测中无缝地考虑互补的低维表格信息。根据患者的图像和临床表格信息,通过动态缩放和移动3D卷积层的特征图,实现两个信息源的紧密集成。由于表格信息通常包含描述患者整体状态的统计数据,因此需要在表格数据和图像数据之间进行一定程度的信息交换。因此,DAFT提出在最后一个ResNet残差块中对卷积层的输出进行仿射变换,这样可以用高级概念而不是原始概念(如边缘)来描述图像。图1总结了网络。
- 图1:作者提出了在最后的残差块中使用动态仿射特征图变换的网络结构。DAFT结合了卷积层的 C C C个特征图张量 F i ∈ R C × D × H × W F_{i}\in R^{C\times D\times H\times W} Fi∈RC×D×H×W, x i ∈ R P x_{i}\in R^{P} xi∈RP为表格数据的向量, α i \alpha_{i} αi为缩放因子, β i \beta_{i} βi为平移因子。DAFT架构极其简单,本质是用表格数据获得通道注意力,然后修正特征图。
实验
作者使用来自阿尔茨海默病神经影像学提出的T1 brain MRI(The Alzheimer’s disease neuroimaging initiative (ADNI): MRI methods)在两个任务上评估DAFT:
- 诊断患者为认知正常(CN)、轻度认知受损(MCI)或痴呆;
- 预测MCI患者痴呆发病时间。
将诊断任务作为一个分类问题,将痴呆时间任务作为一个生存分析问题。比如仅在一小部分患者中观察到痴呆发作,而其余患者在观察期间保持稳定。
数据预处理
首先使用FreeSurfer进行分段扫描,并在左侧海马体周围提取一个大小为 6 4 3 64^3 643的感兴趣区域,因为已知该区域受到AD的强烈影响。接下来,预处理管道对图像进行归一化。表格数据包括9个变量:年龄、性别、受教育程度、ApoE4、脑脊液生物标志物Aβ42、P-tau181和T-tau,以及来自18f -氟脱氧葡萄糖和florbetapir PET扫描的两个汇总测量。
为了解释缺失值,通过附加二进制变量来表示所有特征的缺失,除了年龄、性别和教育程度,这些特征总是存在的。这允许网络使用不完整的数据并从缺失的模式中学习。表格数据共包含 P = 15 P=15 P=15个特征。
为了避免年龄和性别的混杂影响导致数据泄露,将数据分成5个不重叠的折叠,以便在折叠之间平衡诊断、年龄和性别。使用一个折叠作为测试集,并将其余折叠组合起来,使其80%组成训练集,20%组成验证集。对于诊断任务,作者扩展了训练集,但没有验证或测试。对于痴呆时间预测任务,作者只包括基线预测时患有轻度认知障碍的患者,这样所有患者都保持轻度认知障碍或进展为痴呆。
评估
考虑了两个基线:
- 一个仅使用基于图1架构的图像信息的ResNet,不使用DAFT;
- 仅使用表格信息的线性模型;
对于诊断,使用平衡精度bACC,对痴呆时间的分析,使用一致性指数c-index。