贝叶斯增量式跨域适应:少样本 + 跨模态学习 + 知识保留和推断【fundus + OCT】,做视网膜病变

贝叶斯深度学习:增量式少样本学习跨域适应 + 贝叶斯多目标函数 + 跨模态学习 fundus + OCT,做视网膜病变

    • 核心思想
      • 设计网络:寻找分类模型、损失函数
      • 实验结果
      • 混淆矩阵与注意力图评估
    • 总结

 


核心思想

论文:https://arxiv.org/pdf/2110.09319.pdf

代码:https://github.com/taimurhassan/continual_learning/blob/17d286a40fd5f17c9cdc8a3eac2129857a5397ac/README.md

 

怎样在只有少量训练样本的情况下,逐步准确识别出视网膜病变(如糖尿病性视网膜病变、黄斑变性),同时还要保证已经学过的知识不会被忘记 ?

论文提出了一种新型的增量式跨域适应方法

这个问题可以分为以下几个子问题和对应的解决方案:

  1. 首先,子问题就是在只有少量训练数据的情况下怎么识别视网膜病变。

    解法就是增量式跨域适应

    这个方法让深度学习模型能够利用少量的训练样本逐步学习光学相干断层扫描(OCT)和眼底图像中的异常视网膜病理。

    这在医学影像领域特别重要,因为这个领域很难获得大量的标注数据。

  2. 接着,第二个子问题是如何在学习新知识的同时,还能保留以前学过的知识。

    这个问题的解决方案是使用一种特殊的贝叶斯多目标函数

    这个方法不仅能让分类网络在增量训练过程中保持其先前的知识,还能帮助网络理解之前学过的病理和新加的疾病类别之间的结构和语义关系,这样在实际应用时就能更有效地识别它们。

    这一点对医学诊断来说非常关键。

  3. 最后一个子问题是如何处理不同成像技术(比如OCT和眼底摄影)之间的差异。

    解决这个问题的办法是跨模态学习

    这样可以让模型处理并识别不同成像技术中的疾病特征,对于提高视网膜病变筛查的准确性非常重要。

这篇论文的解决方案就是结合增量式跨域适应、贝叶斯多目标函数和跨模态学习,这样不仅能提升模型在少量数据和多模态数据下的学习和适应能力,还能保持对以前学习知识的记忆,从而大大提高了视网膜病变筛查的准确性和效率。

 


设计网络:寻找分类模型、损失函数

  1. 首先,研究团队面临的一个挑战是选出最适合用于跨域视网膜病变筛查的分类网络。

    为此,他们比较了好几种模型,包括MobileNet、ResNet-50、ResNet-101和VGG-16。

    最后,他们选定了ResNet-101,因为它的残差特征融合机制能更好地保留图像的细节特征。

  2. 接着,为了证明他们提出的框架的有效性,研究人员将这个框架与其他现有的视网膜诊断系统进行了比较。

    他们发现,尽管采用了增量式学习的方法,但他们的框架在多个数据集上都超越了其他方法,这说明了他们的方法不仅新颖,而且非常有效。

·

  1. 最后,他们还探索了如何在不同成像技术(比如OCT和眼底图像)之间实现有效的视网膜病变筛查。

    在一系列的实验中,他们发现使用 Lcl 损失函数 的 ResNet-101 能够有效地同时学习这两种不同模态的视网膜分类任务,并且在防止灾难性遗忘方面表现得非常好。

 
假设我们有一个用于视网膜病变筛查的深度学习模型,我们想要让它能够识别和区分多种不同的视网膜病变,比如糖尿病性视网膜病变、黄斑变性等,而且我们想要这个模型能够处理两种不同的图像数据:OCT和眼底图像。

跨域视网膜病变筛查的增量式训练框架,整个流程分为三个主要部分:

(A) 训练阶段 I

  • 这个阶段分类网络如何在第一个目标域(OCT)上逐步接受训练,以识别 k+1 种视网膜病变。

  • 网络通过批次1到批次k的数据,使用 Lcl损失函数 进行训练,其中Lce代表分类交叉熵损失,用于训练模型初始学习;Lcl代表持续学习损失,用于后续的增量学习。

    我们从一些OCT扫描的数据集开始,这些数据已经被标记为含有或不含有特定类型的视网膜病变。

    我们首先使用一批数据来训练我们的模型(例如,使用了300个ci-DME样本和200个正常样本)。

    我们使用分类交叉熵损失(Lce)来进行这一阶段的训练。

    随着时间的推移,我们增量地添加新批次的数据,每个新批次都可能包含之前没见过的病变类型(例如,新添加了Dry AMD样本)。

    在这些新的迭代中,我们使用持续学习损失(Lcl)来保持模型在学习新知识的同时,不会忘记旧的知识。

(B) 训练阶段 II

  • 在这个阶段,分类网络在第二个目标域(眼底图像)上逐步适应,以识别 m+2 种视网膜病变。

  • 这意味着网络不仅能识别OCT图像中的病变,也能识别眼底图像中的病变。

    现在我们开始将模型适应第二种类型的图像数据:眼底图像。

    这个过程同样是增量式的。我们先用少量的眼底图像(比如,40个CSME样本和20个正常样本)来进行初步训练。

    接着,我们继续增量地添加新的眼底图像样本,每一批都包含新类型的病变,同时使用 Lcl损失函数 来继续训练模型。

© 跨域视网膜病变筛查(推理阶段)

  • 最终,当k=6,m=11时,提出的框架通过增量式训练能够同时识别 OCT 和眼底图像中的 13 种视网膜病变,不受扫描器规格的影响。

  • 在这个阶段,训练好的模型被用来对输入的测试扫描进行分类。

    经过上述两个阶段的训练后,我们的模型现在已经准备好进行实际的病变筛查了。

    当有新的OCT或眼底图像输入时,我们的模型可以识别出这是13种已知类型中的哪一种视网膜病变,无论这些图像来自哪种扫描器。

    如果我们的模型接受到一个新的OCT扫描,并且它判断这个扫描显示了患有ci-DME,那么它就会将这个案例归类为ci-DME。

图中的颜色代码对应了不同的病变类别,显示了模型是如何被训练来识别这些病变的。

这张图展示的是,一个深度学习模型如何分两个阶段通过增量式训练来识别多种视网膜病变,并在最后应用于实际的病变筛查任务中。

 


实验结果

任务:快速准确地从成百上千的患者眼底照片中筛查出不同类型的视网膜病变,且病变类型繁多,且每种扫描设备产生的图像都有微妙的差异。

模型在实际的临床环境进行视网膜病变筛查:

  1. 首先,研究人员面对的问题是怎样在真实的临床环境里实时筛查视网膜病变。

    为了解决这个问题,他们进行了一系列的临床实验

    这些实验特别重要,因为它们证明了这个框架不只是理论上可行,实际应用起来也很有效果。

    虽然这些实验的细节因为篇幅限制放在了补充材料里,但这部分是验证框架实用性的关键。

  2. 接着,他们还研究了 Lcl 损失函数里一个特别的部分——互相蒸馏损失(Lmd)对分类性能的影响。

    他们发现,包含了 Lmd 的 Lcl 损失函数在学习新的分类任务的同时,还能很好地保持对已知任务的高性能。

    原因在于 Lmd 通过贝叶斯推理增强了网络学习新类别的能力。

    比如一个患者的眼底图像显示了一些典型的DME症状,但同时又有一些不那么明显的AMD迹象,Lmd 能帮助模型在学习识别新的 AMD 病例的同时,不会忘记已经学到的关于 DME 的知识。

  3. 最后一个问题是如何处理类别不平衡的情况。

    在实际应用中,会遇到大量的正常眼底图像和一小部分显示出病变的图像。

    如果没有特别的处理,模型可能会倾向于将所有图像都分类为正常,从而忽视那些较少见的病变类别。

    他们发现,尽管 Lcl 损失函数 在防止遗忘已学知识方面做得很好,但在类别不平衡的情况下表现出了一些脆弱性。

    为了克服这个问题,研究人员提出了一些方法,比如引入最大边界约束、焦点损失函数和高斯亲和优化,来提高对不平衡类别的预测能力

 


混淆矩阵与注意力图评估

如何评估模型在各类别上的具体表现?

  • 解决方案:使用混淆矩阵和注意力图进行细致的性能分析。
  • 原因:混淆矩阵能够详细展示模型在每个类别上的准确性和误分类情况,注意力图可以显示模型在做出判断时关注的区域,这有助于理解模型的决策过程。


注意力图通过颜色的强度来表示模型关注的程度,颜色越亮表示模型越关注那一部分区域。

这些影像可能包括 OCT图像和眼底图像,其中 OCT 图像通常用于检测视网膜厚度和结构的变化,眼底图像用于观察视网膜表面的特征。

注意力集中在某些区域,是由于这些区域在视网膜病变的上下文中具有重要的诊断特征。

 

这个混淆矩阵提供了使用 Lcl损失函数 训练的 增量式ResNet-101模型,在跨域视网膜病变筛查任务中的分类效果的详细视图。

从矩阵可以看出:

  • 模型在某些类别上具有很高的准确性,例如,ci-DME 类别中有8730个样本被正确分类。
  • 然而,模型在区分某些相似的类别时存在混淆,如将 ci-DME 误分类为 nci-DME 或 Dry AMD OCT。
  • 大部分的预测都集中在混淆矩阵的对角线上,表明多数类别都被正确分类。

 


总结

问题与解法组成:

  1. 子问题1:如何让模型能够在识别新的视网膜病变类型时不忘记已学知识?

    • 解决方案:采用增量学习策略,使用Lcl损失函数进行训练。
    • 原因:增量学习能够通过连续训练来逐步扩展模型的识别能力,而 Lcl损失函数 专门设计来平衡新旧知识,防止过去学习的知识被新信息覆盖,即解决了“灾难性遗忘”问题。
  2. 子问题2:如何处理从 OCT 图像到眼底图像的跨域适应?

    • 解决方案:使用同一模型进行跨域训练,逐步引入新域的数据。
    • 原因:跨域适应需要模型在不同类型的数据上都具有良好的性能,逐步引入新域的数据允许模型调整其参数以适应新环境的特征。
  3. 子问题3:如何提高模型对不平衡数据的分类性能?

    • 解决方案:引入额外的技术,如最大边界约束、焦点损失函数和高斯亲和优化。
    • 原因:类别不平衡会导致模型偏向于多数类,这些技术可以帮助模型更关注少数类,改善对不平衡数据的处理。
  4. 子问题4:知识保留,使其能够在每一步学习中同时保留以前学到的知识,并且有效地学习新的数据和类别。神经网络在学习新任务时往往会遗忘之前学习的任务的情况,当网络参数为了适应新的数据而调整时,之前针对旧数据设置的权重可能会被覆盖,导致网络在旧任务上的性能下降。

    • 解决方案:贝叶斯多目标函数
    • 原因:通过量化不确定性并通过后验概率整合新信息,使模型能够在添加新类别数据时,维持对旧类别的识别能力。
    • 原因:贝叶斯方法还帮助模型理解新旧类别之间的结构和语义关系,这对于跨模态学习尤为重要。在眼底图像和OCT图像之间可能存在一些共通的病变特征,贝叶斯目标函数使模型能够理解这些共通性,即使这些特征在不同的模态中以不同的形式呈现。
    • 原因:贝叶斯方法适合少样本学习情境,因为它可以利用先验知识来指导模型的学习过程,这样即使在只有有限数据的情况下,模型也能做出合理的推断。

贝叶斯增量式跨域适应 = 增量学习与Lcl损失函数 + 跨域适应性训练 + 处理类别不平衡 + 贝叶斯实现知识保留和推断

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/435002.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

paddlepaddle在执行loss.item()的时候,报错an illegal memory access was encountered.

遇到的问题是“an illegal memory access was encountered”,这是一个常见的内存错误。 首先,要确保你的代码中没有其他明显的内存错误,例如: 确保你没有试图访问超出数组边界的元素。确保你没有试图释放已经释放的内存。确保你没…

尚无忧球馆助教系统源码,助教小程序源码,助教源码,陪练系统源码

特色功能: 不同助教服务类型选择 助教申请,接单,陪练师入住,赚取外快 线下场馆入住 设置自己服务 城市代理 分销商入住 优惠券 技术栈:前端uniapp后端thinkphp 独立全开源

C#学习(十一)——Array和Collection

一、集合 集合重要且常用 孤立的数据是没有意义的,集合可以作为大量数据的处理,可进行数据的搜索、迭代、添加、删除。 C#中,所有集合都必须实现ICollection接口(数组Array除外) 集合说明Array数组,固定长…

网络安全02--负载均衡下的webshell连接

目录 一、环境准备 1.1ubentu虚拟机一台,docker环境,蚁剑 1.2环境压缩包(文件已上传资源): 二、开始复原 2.1上传ubentu: 2.2解压缩 2.3版本20没有docker-compose手动下载,包已上传资源 …

第2章-神经网络的数学基础——python深度学习

第2章 神经网络的数学基础 2.1 初识神经网络 我们来看一个具体的神经网络示例,使用 Python 的 Keras 库 来学习手写数字分类。 我们这里要解决的问题是, 将手写数字的灰度图像(28 像素28 像素)划分到 10 个类别 中(0…

32GPIO输入LED闪烁蜂鸣器

一.GPIO简介 所有的GPIO都挂载到APB2上,每个GPIO有16个引脚 内核可以通过APB2对寄存器进行读写,寄存器都是32位的,但每个引脚端口只有16位 驱动器用于增加信号的驱动能力 二.具体…

数据结构-线性表

文章目录 数据结构—线性表1.线性表的定义和基本操作线性表的定义线性表的特点线性表的基本操作 2.线性表的顺序存储和链式存储表示顺序存储链式存储单链表循环链表双向链表 数据结构—线性表 1.线性表的定义和基本操作 线性表的定义 定义:线性表是具有相同数据类…

SpringSecurity(16)——OAuth2客户端授权模式

工作流程 基本使用 <dependency><groupId>org.springframework.boot</groupId><artifactId>spring-boot-starter-security</artifactId><version>2.3.12.RELEASE</version> </dependency> <dependency><groupId>…

Android发展历程及安装

目录 发展历程 下载网址 安装过程 发展历程 安卓基于Linux内核&#xff0c;Linux内核相当于房屋的地基 开源不等于免费&#xff0c;不能商用 安卓一般每半年小更新&#xff0c;一年大更新 对应API相当于别名 现在安卓安全性越来越高&#xff0c;性能越来越快&#xff0c…

LeNet跟LeNet5详解

1 LeNet结构 主要是为了手写数字识别 具体结构讲解&#xff1a;从图中例子可得 1 先传入一个灰度图像尺寸为1x28x28&#xff0c;通道数为1&#xff0c;尺寸为28x28的灰度图像 2 第一层5x5卷积&#xff0c;经过公式 输入图像尺寸-卷积核尺寸2padding/步长1&#xff0c;&#…

JasperReports渲染报表文件时候,读取图片报错:Byte data not found at:xxx.png.\r\n\tat

目录【知识星球】 1.1、错误描述 1.2、解决方案 1.1、错误描述 最近在工作中&#xff0c;使用JasperReports报表比较多一些&#xff0c;有次线上环境里面运行报错&#xff0c;查看报错日志&#xff0c;如下所示&#xff1a; net.sf.jasperreports.engine.JRException: Byte…

Android源码设计模式解析与实战第2版笔记(五)

第七章 时势造英雄 – 策略模式 策略模式的定义 策略模式定义了一系列的算法&#xff0c;并将每一个算法封装起来&#xff0c;而且使它们还可以相互替换。策略模式让算法独立于使用它的客户而独立变化。 策略模式的使用场景 针对同一类型问题的多种处理方式&#xff0c;仅仅…