VGAN实现视网膜图像血管分割(基于pytorch)

背景介绍

VGAN(Retinal Vessel Segmentation in Fundoscopic Images with Generative Adversarial Networks)出自2018年的一篇论文,尝试使用生成性对抗网络实现视网膜血管分割的任务,原论文地址:https://arxiv.org/abs/1706.09318
在github上有相应的源码仓库,不过由于版本的原因也会出现一些bug,本篇博客在复现项目的过程中也对源码进行了相应的修改,源码地址: https://github.com/guyuchao/Vessel-wgan-pytorch?tab=readme-ov-file

另一方面,刚好这个项目作为我2023年的最后一个项目,就斗胆当作是2023年编程之旅的回顾,博主是在茫茫知识海洋漂泊的一叶小舟,还有许多的知识尚未学习,希望可以和大家互相交流学习!

2024年,冲鸭!!!!!


前言

生成对抗网络(Generative Adversarial Networks)在提出的时候是为了实现模型创造性的能力,如今在AI图像生成领域已经有非常广阔的应用,例如知名的Midjourney网站,就是通过用户输入的prompt提示,利用GAN的框架生成对应用户想要生成的图片;我自己对于这个模型的名声也是早有耳闻,刚好前一段时间看到了《Retinal Vessel Segmentation in Fundoscopic Images with Generative Adversarial Networks》这篇文章,内容里探究了把GAN模型应用到视网膜血管分割的领域,刚好可以与我本学期的生物医学创新实践联系在一起。

生成对抗网络介绍

生成对抗网络(Generative Adversarial Network)简称GAN,是深度学习领域的一种重要模型,由Ian Goodfellow在2014年提出。

GAN模型包括两部分:一个是生成器(Generator),另一个是判别器(Discriminator)。这两部分模型相互博弈,共同训练,赋予网络生成特定分布的数据的能力。

1. 生成器(Generator):该部分的目标是生成尽可能真实的数据。例如,如果我们想让网络生成一张风景图片,生成器的目标就是生成一张看上去就像是某个摄影师拍摄的风景照片。

2. 判别器(Discriminator):该部分的目标是尽可能好地区分出真实的和生成的数据。在风景图片的例子中,判别器需要区分出哪些图片是真实的风景照片,哪些是生成器生成的假照片。

两者相互博弈的过程中,判别器会不断提高对真假数据的判断能力,生成器也会不断提高生成数据的逼真度,理想状态下,生成器生成的数据将和真实数据无法区分,判别器对生成器的生成结果的判断是50%,即做出了随机猜测。这样,就完成了GAN的训练过程。

视网膜分割的GAN模型(VGAN)

从上图中可以看出,模型中的GAN的generator是一个U-net形状的网络模型,每一层上采样层都与对称的下采样的输出进行连接,能够很好的处理图像的边缘及其他的细节特征;discriminator是一个多层的下采样的网络模型,最后是输出是实现一个二分类的效果,即{(0,1)}^N,接近0表示判断机器生成的(generator),接近1表示判断为真实的血管分割标签,每层的generator和discriminator都是由基本的block卷积神经网络组成,block的代码构建为:

class block(nn.Module):def __init__(self,in_filters,n_filters):super(block,self).__init__()self.deconv1 = nn.Sequential(nn.Conv2d(in_filters, n_filters, 3, stride=1, padding=1),nn.BatchNorm2d(n_filters),nn.ReLU())def forward(self, x):x=self.deconv1(x)return x

 是用nn.Sequential()连接的包含卷积,标准化和池化的经典卷积层,卷积核为3\times3,步长为1,边缘补充为1,处理的视网膜图片是三通道彩色图盘,第一层的intput_channel为3;下面用pytorch的tensorboard工具对搭建的generator和discriminator进行网络结构的可视化

这里需要注意的是,在原论文中,Discriminator的输入并不是Generator直接生成的图片或者原数据集中的label,而是需要在C通道上与进行分割的原视网膜图片进行合并再进行输入。

损失函数

在训练Generator和训练Discriminator时,使用不同的损失函数,我们最后是使用Generator进行mask的生成,也就是使用Generator输入需要进行视网膜血管分割的图片,输出分割的结果,所以我们更加注重Generator损失函数的设计。

GAN整体的损失函数可以定义为

对于D(Discriminator),在代码中不设计具体的损失函数,从任务设计中可以得出,当输入到D的是真实的标签图像时,我们期望D输出越接近1越好;当输入到D的是Generator生成的图像时,我们期望D输出越接近0越好,基于这种关系,我们直接把D的输出作为损失函数,同时,为了避免GAN模型中常见的断层问题,引入了经典的WGAN方法(Gradient Penalty),即获得一个1-Lipschitz函数,保证GAN模型的训练曲线是足够平滑从而生成稳定的图片,在梯度计算中引入作为梯度惩罚,则D的损失函数可以表示为:

对于G,为适应分割的任务不需要使用隐含空间的向量而是直接获取img的输入,具体地,在代码中使用二分类交叉熵损失函数获取与对应标签的loss值,同时加上来自D的反馈,具体的损失函数为:

数据集预处理

本次项目使用的数据集是经典的视网膜血管分割数据集,含有20张训练集和20张测试集,每张视网膜眼底的图片是584\times565\times3像素的三通道彩色tif格式图片,对应的标签是584\times565像素单通道灰度tif图片

对img的预处理包括随机改变图片的亮度、对比度和色相,图片像素标准化、转换成tensor数据的经典图片训练格式[B,C,H,W];

对label的预处理包括图片像素标准化、转换成tensor数据的经典图片训练格式[B,C,H,W];

同时对img和label的预处理包括随机裁剪图片高和宽为512\times512像素大小,随机水平翻转和垂直翻转;

下面是对训练数据中的进行预处理后的结果可视化

训练过程

一开始打算把本项目放到colab上跑或者服务器上跑,在此之前抱着试一试的态度先用本地的显卡1080ti加4G显存跑了50个epoch,结果竟然能跑得动!于是就先跑了300个epoch,显卡没崩,结果保存在./pth中。

对于GAN的结果,除了使用传统的评估方法外,也会对训练过程的结果进行输出可视化看看结果有没有生成奇奇怪怪的图像从而停止训练重新调整,不过对于本次项目,Generator的输入不包含隐含空间z,而且加入了WGAN进行约束,所以生成的图像基本上是比较完整的;作为视觉上直观地对GAN的效果进行评估,我们每跑50个epoch对应输出测试集生成的分割图像,保存在对应的文件夹名称里面,跑完全部epoch后,对应把D和G训练好的checkpoint也保存在./pth路径下。

300个epoch跑了一个多小时,显卡的散热吹风机感觉可以起飞了,不过结果其实可圈可点,为了不浪费训练好的资料,后面又写了一个re_train脚本,读取前一次训练好的模型权重再进行训练,又花了一个小时跑了300个epoch,结果保存在./pth2目录下,结构与./pth中文件相同,所以综合起来一共训练了600个epoch。

结果分析

 先对比每50个epoch生成的图像,选取测试集中的第一张图片

对于测试集的预处理也包括随机裁剪像素512\times512像素大小与随机水平翻转与垂直翻转,所以生成的图像包含有不同的方向和裁剪风格。

从第50个epoch到第600个epoch结果,可以明显的看出图片质量提升的效果,说明Generator的学习是非常有效果的,没有出现GAN中经常出现的图片断层效果。

对比一下对应的眼底原图和血管分割标签图:

从视觉上来看,GAN生成图像与原图像的分割标签是比较接近的。

接下来我们使用训练好的Generator模型进行量化的对比:

绘制对应的PR曲线与ROC曲线:

从两个曲线的效果可以看出,训练出来的模型在测试集上也具有比较好的效果,本次项目使用的VGAN在处理视网膜血管分割的任务上体现出比较好的性能。

参考文献

  1. Son, J., Park, S.J., & Jung, K. (2017). Retinal Vessel Segmentation in Fundoscopic Images with Generative Adversarial Networks. ArXiv, abs/1706.09318.
  2. Arjovsky, M., Chintala, S., & Bottou, L. (2017). Wasserstein GAN. ArXiv, abs/1701.07875.

彩蛋

我们都知道GAN以创作能力而闻名,那我们试一下用上面训练好的模型接受随机初始化满足正态分布的z隐含空间的数据会输出怎么样的图像

嗯...看来想要GAN生成像样的图片,还是需要再训练机制里面下手

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/345332.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【K8s学习】

k8s的简单执行流程: Kubernetes Master(API Server、Scheduler等组件)负责调度Pod到合适的Node上。 当Pod被调度到某个Node时,该Node上的kubelet代理会收到指令并开始执行Pod的生命周期管理任务,包括创建、监控和终止P…

React16源码: React中的schedule调度整体流程

schedule调度的整体流程 React Fiber Scheduler 是 react16 最核心的一部分,这块在 react-reconciler 这个包中这个包的核心是 fiber reconciler,也即是 fiber 结构fiber的结构帮助我们把react整个树的应用,更新的流程,能够拆成每…

c++学习:智能指针的底层作用原理+用法

目录 智能指针作用原理 作用 原理 模仿int*类型的智能指针 模仿所有类型的智能指针(模板) 共享智能指针类 思考;如果多个智能指针同时指向同一个堆空间,怎么只执行一次析构函数进行释放空间 (共享智能指针类&…

ubuntu 20.04下 Tesla P100加速卡使用

1.系统环境:系统ubuntu 20.04, python 3.8 2.查看cuDNN/CUDA与tensorflow的版本关系如下: Build from source | TensorFlow 从上图可以看出,python3.8 对应的tensorflow/cuDNN/CUDA版本。 3.安装tensorflow #pip3 install tensorflow 新版…

MES生产执行系统在生产车间的主要作用

MES生产执行系统提供从生产订单下达到产品完成全流程的优化管理。实现现场设备、执行系统及管理系统的集成,实时监控生产管理各项绩效指标。 如果说ERP是上层决策,生产车间是下层执行,那么MES就是连接管理软件和一线生产的中间桥梁。 MES也…

gitee创建远程仓库并克隆远程仓库到电脑

1、首先点加号新建一个仓库 2、输入仓库名,路径会自动填充,填写简单的仓库介绍,先选择私有,在仓库创建之后,可以改为开源 3、打开建好的仓库 4、复制仓库链接 5、打开一个文件夹(想要存储远程仓库的地址),在…

Mac M2芯片pycharm配置conda python环境

Mac M2芯片pycharm配置conda python环境 详细步骤如下 1、pycharm界面右上方的小齿轮⚙️,进入Setting…状态 2、进入setting界面后,选择左边栏的Project-->python Interpreter,然后选择右边的Add Interpreter 3、进入Add Interpreter后&#xff0c…

四种无监督聚类算法说明

目录 一、K-Means无监督学习(K-Means)的认识-CSDN博客​​​​​​ 二、Mini-Batch K-Means -- Centroid models 三、AffinityPropagation (Hierarchical) -- Connectivity models 四、Mean Shift -- Centroid models 无监督聚类是一种机器学习技术&…

11Spring IoC注解式开发(上)(元注解/声明Bean的注解/注解的使用/负责实例化Bean的注解)

注解的存在主要是为了简化XML的配置。Spring6倡导全注解开发。 注解开发的优点:提高开发效率 注解开发的缺点:在一定程度上违背了OCP原则,使用注解的开发的前提是需求比较固定,变动较小。 1 注解的注解称为元注解 自定义一个注解: package com.sunspl…

Unity中的异步编程【7】——在一个异步方法里播放了animation动画,取消任务时,如何停止动画播放

用一个异步方法来播放一个动画,正常情况是:动画播放结束时,异步方法宣告结束。那如果我提前取消这个异步任务,那在这个异步方法里面,我要怎么停止播放呢?! 一、播放animation动画的异步实现 1…

html+css+Jquery 实现 文本域 文字数量限制、根据输入字数自适应高度

先看效果&#xff1a;初始的效果&#xff0c;样式多少有点问题&#xff0c;不重要&#xff01;重要的是功能&#xff01; 输入后&#xff1a; 根据文字长度&#xff0c;决定文本域长度 限制文字数量 话不多说&#xff0c;直接上代码&#xff01; <!DOCTYPE html> <h…

特征工程-特征处理(一)

特征处理-&#xff08;离散型特征处理&#xff09; 完成特征理解和特征清洗之后&#xff0c;我们要进行特征工程中最为重要和复杂的一步了——特征处理 离散型特征处理 离散型特征通常为非连续值或以字符串形式存在的特征&#xff0c;离散型特征通常来讲是不能直接喂入模型中…