论文学习——U-Net: Convolutional Networks for Biomedical Image Segmentation

UNet的特点

  • 采用端到端的结构,通过FCN(最后一层仍然是通过卷积完成),最后输出图像。
  • 通过编码(下采样)-解码(上采样)形成一个“U”型结构。每次下采样时,先进行两次卷积(通道数不变),然后通过一次池化层(也可以通过卷积)处理(长宽减半,通道数加倍);在每次上采样时,同样先进行两次卷积操作,再通过反卷积函数进行上采样(长宽加倍,通道不变),然后与编码过程中对应层进行拼接(通道加倍)。到最后一层时,通过1x1的卷积核修改通道数,最后输出目标图像。编码操作逐层提取图像特征,解码操作则逐层恢复图像信息。
  • 通过跳跃连接,将编码器结构中的底层信息与解码器结构中的高层信息融合,从而提高了分割精度。

网络结构如图所示:
在这里插入图片描述
代码实现(基于pytorch):
相关包的引入:

from math import sqrt
import torch
from torch import nn
import torch.nn.functional as F

定义卷积块:
定义了两个卷积操作,分别使用大小为3x3的卷积核进行卷积,步长为1,并且对卷积后的输出进行批量归一化(批量归一化的作用),激活函数采用ReLU。使用卷积模块时,需要指明输入通道数(in_channel)和输出通道数(out_channel)。

class Conv_Block(nn.Module):def __init__(self, in_channel, out_channel):super(Conv_Block, self).__init__()self.conv1 = nn.Sequential(nn.Conv2d(in_channel, out_channel, 3, 1, 1),nn.BatchNorm2d(out_channel),nn.ReLU(inplace=True))self.conv2 = nn.Sequential(nn.Conv2d(out_channel, out_channel, 3, 1, 1),nn.BatchNorm2d(out_channel),nn.ReLU(inplace=True))def forward(self, input):outputs = self.conv1(input)outputs = self.conv2(outputs)return outputs

编码操作(下采样):将卷积模块的输出进行池化处理。

class UnetDown(nn.Module):def __init__(self, in_channel, out_channel):super(UnetDown, self).__init__()self.conv = Conv_Block(in_channel, out_channel)self.down = nn.MaxPool2d(2, 2, ceil_mode=True)def forward(self, inputs):outputs = self.conv(inputs)outputs = self.down(outputs)return outputs

解码操作(上采样):这里的上采样操作提出了两种——ConvTranspose2d和UpsamplingBilinear2d,两者的区别见这里。另外,由于要进行拼接操作,所以在拼接前对上采样的输出进行填充,避免拼接出错。
(ps:代码里面的解码操作是先进行上采样,然后拼接数据,最后进行卷积的,但是在UnetModel中的最后一个编码操作后,单独进行了一次卷积操作,最后的网络结构还是没有变的。)

class UnetUp(nn.Module):def __init__(self, in_channel, out_channel, is_deconv=True):super(UnetUp, self).__init__()self.conv = Conv_Block(in_channel, out_channel)if is_deconv:self.up = nn.ConvTranspose2d(in_channel, out_channel, kernel_size=2, stride=2)else:self.up = nn.UpsamplingBilinear2d(scale_factor=2)def forward(self, inputs1, inputs2):outputs2 = self.up(inputs2)offset1 = (outputs2.size()[2] - inputs1.size()[2])offset2 = (outputs2.size()[3] - inputs1.size()[1])# pad传入四个元素时,指的是左填充,右填充,上填充,下填充;前两个元素作用在第一四维,后两个元素作用在第三维padding = [offset2 // 2, (offset2 + 1) // 2, offset1 // 2, (offset1 + 1) // 2]# Skip and concatenateoutputs1 = F.pad(inputs1, padding)return self.conv(torch.cat([outputs1, outputs2], 1))

最后定义整个UNet模块:将代码和网络结构的图结合起来看就很容易理解了。

class UnetModel(nn.Module):def __init__(self, n_classes, in_channels, is_deconv):super(UnetModel, self).__init__()self.is_deconv = is_deconvself.in_channels = in_channelsself.n_classes = n_classesfilters = [64, 128, 256, 512, 1024]self.down1 = UnetDown(self.in_channels, filters[0])self.down2 = UnetDown(filters[0], filters[1])self.down3 = UnetDown(filters[1], filters[2])self.down4 = UnetDown(filters[2], filters[3])self.center = Conv_Block(filters[3], filters[4])self.up4 = UnetUp(filters[4], filters[3], self.is_deconv)self.up3 = UnetUp(filters[3], filters[2], self.is_deconv)self.up2 = UnetUp(filters[2], filters[1], self.is_deconv)self.up1 = UnetUp(filters[1], filters[0], self.is_deconv)self.final = nn.Conv2d(filters[0], self.n_classes, 1)def forward(self, inputs, label_dsp_dim):down1 = self.down1(inputs)down2 = self.down2(down1)down3 = self.down3(down2)down4 = self.down4(down3)center = self.center(down4)up4 = self.up1(down4, center)up3 = self.up2(down3, up4)up2 = self.up3(down2, up3)up1 = self.up4(down1, up2)up1 = up1[:, :, 1:1 + label_dsp_dim[0], 1:1 + label_dsp_dim[1]].contiguous()return self.final(up1)# Initialization of parametersdef _initialize_weights(self):for m in self.modules():if isinstance(m, nn.Conv2d):n = m.kernel_size[0] * m.kernel_size[1] * m.out_channelsm.weight.data.normal_(0, sqrt(2. / n))if m.bias is not None:m.bias.data.zero_()elif isinstance(m, nn.BatchNorm2d):m.weight.data.fill_(1)m.bias.data.zero_()elif isinstance(m, nn.ConvTranspose2d):n = m.kernel_size[0] * m.kernel_size[1] * m.out_channelsm.weight.data.normal_(0, sqrt(2. / n))if m.bias is not None:m.bias.data.zero_()

总结:UNet 是一种经典的图像分割网络,它通过编码器-解码器结构、跳跃连接和多尺度特征融合等设计,能够在图像分割任务中取得优秀的性能。基于UNet还衍生出了很多网络,例如 U-Net++, ResUNet, Dense U-Net等,接下来就学习它的衍生网络吧,学习大佬是怎么魔改网络的~另外,刚开始写深度学习的代码时,我不知道从何下手,通过学习大佬实现代码的过程,我发现结合两点就能轻松实现代码:1)写代码时结合网络结构的图片,2)百度相关操作的函数。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/14116.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

CygWin:windows上运行类linux命令

CygWin是一个在Windows平台上运行的类UNIX模拟环境,是Cygnus Solutions公司开发的自由软件。它提供了类似于Linux系统的终端环境和工具,使用户可以在Windows平台上运行Unix-like的程序,如Bash、awk、sed和grep等 。 下载setup.exe 安装Cygwin…

【Excel经验】日期时间处理方法

概览-公式汇总 公式功能公式公式说明提取时间中的日期TEXT(A2,“yyyy-mm-dd”)A2列数据格式样例:2023/7/5 6:20:10计算耗时得到单位:秒数VALUE(TEXT(B2-A2,“[ss]”))A2、B2列数据格式样例:2023/7/5 6:20:10计算耗时得到格式化显示年月日B2-…

ModaHub魔搭社区:向量数据库Milvus Lite 的优势和安装教程

想要体验开源向量数据库MIlvus,缺少专业的工程师团队作为支撑?Milvus 安装环境受限? 别担心,轻量版 Milvus 来啦! 有用户反馈刚开始接触 Milvus 或者想要在 Notebook 中进行快速实验时,安装或部署 Milvus 有些力不从心 。开发了 Milvus 的轻量级版本Milvus Lite ,方…

Vue3使用element-plus实现弹窗效果-demo

使用 <ShareDialog v-model"isShow" onChangeDialog"onChangeDialog" /> import ShareDialog from ./ShareDialog.vue; const isShow ref(false); const onShowDialog (show) > {isShow.value show; }; const onChangeDialog (val) > {co…

【ElasticSearch】JavaRestClient实现索引库和文档的增删改查

文章目录 一、RestClient1、什么是RestClient2、导入demo工程3、数据结构分析与索引库创建4、初始化JavaRestClient 二、RestClient操作索引库1、创建索引库2、删除索引库3、判断索引库是否存在 三、RestClient操作文档1、新增文档2、查询文档3、修改文档4、删除文档5、批量导入…

[SWPUCTF 2022 新生赛]js_sign

[SWPUCTF 2022 新生赛]js_sign 查看源码js文件 hint的意思是敲击码 解出flag&#xff0c;记得去掉nssctf&#xff0c;包上NSSCTF{} Tap Code - 许愿星 (wishingstarmoye.com)

Linux(Ubuntu)+Qt+C++与OpenCV窗体程序使用

程序示例精选 Linux(Ubuntu)QtC与OpenCV窗体程序使用 如需安装运行环境或远程调试&#xff0c;见文章底部个人QQ名片&#xff0c;由专业技术人员远程协助&#xff01; 前言 这篇博客针对<<Linux(Ubuntu)QtC与OpenCV窗体程序使用>>编写代码&#xff0c;代码整洁&am…

基于Python爬虫+KNN数字验证码识别系统——机器学习算法应用(含全部工程源码)+训练数据集

目录 前言总体设计系统整体结构图系统流程图 运行环境Python 环境 模块实现1. 数据爬取2. 去噪与分割3. 模型训练及保存4. 准确率验证 系统测试工程源代码下载其它资料下载 前言 本项目利用Python爬虫技术&#xff0c;通过网络爬取验证码图片&#xff0c;并通过一系列的处理步…

javassist implements interface 模拟mybatis 生成代理类

动态创建代理对象的工具类 package com.wsd.util;import org.apache.ibatis.javassist.ClassPool; import org.apache.ibatis.javassist.CtClass; import org.apache.ibatis.javassist.CtMethod; import org.apache.ibatis.session.SqlSession;import java.lang.reflect.Const…

【C++ OJ练习】4.字符串中的第一个唯一字符

1.题目链接 力扣 2.解题思路 利用计数排序的思想 映射进行计数 最后计数为1的那个字符就是唯一字符 从前往后遍历 可以得到 第一个唯一字符 3.代码 class Solution { public:int firstUniqChar(string s) {//使用映射的方式统计次数 计数排序思想int count[26] { 0 };fo…

rpm包安装mysql8.0

一、环境准备 1查看本机IP地址&#xff0c;使用Xshell工具登录 [rootmysql ~]# ip a 1: lo: <LOOPBACK,UP,LOWER_UP> mtu 65536 qdisc noqueue state UNKNOWN group default qlen 1000link/loopback 00:00:00:00:00:00 brd 00:00:00:00:00:00inet 127.0.0.1/8 scope ho…

基于java+swing+mysql商城购物系统

基于javaswingmysql商城购物系统 一、系统介绍二、功能展示1.项目骨架2.主界面3.用户登陆4.添加商品类别5、添加商品6、商品管理 四、其它1.其他系统实现五.获取源码 一、系统介绍 项目类型&#xff1a;Java SE项目 项目名称&#xff1a;商城购物系统 用户类型&#xff1a;双…