【霹雳吧啦】手把手带你入门语义分割の番外12:U2-Net 源码讲解(PyTorch)—— 网络的搭建

目录

前言

Preparation

一、U2-Net 网络结构图

二、U2-Net 网络源代码

1、model.py

(1)ConvBNReLU 类

(2)DownConvBNReLU 类

(3)UpConvBNReLU 类

(4)RSU 类 & RSU4F 类

(5)U2Net 类

(6)model.py 源代码


前言

文章性质:学习笔记 📖

视频教程:U2-Net 源码解析(Pytorch)- 2 网络的搭建

主要内容:根据 视频教程 中提供的 U2-Net 源代码(PyTorch),对 model.py 文件进行具体讲解。

Preparation

源代码:https://github.com/WZMIAOMIAO/deep-learning-for-image-processing/tree/master/pytorch_segmentation/u2net

├── src: 搭建网络相关代码
├── train_utils: 训练以及验证相关代码
├── my_dataset.py: 自定义数据集读取相关代码
├── predict.py: 简易的预测代码
├── train.py: 单GPU或CPU训练代码
├── train_multi_GPU.py: 多GPU并行训练代码
├── validation.py: 单独验证模型相关代码
├── transforms.py: 数据预处理相关代码
└── requirements.txt: 项目依赖

一、U2-Net 网络结构图

原论文提供的 U2-Net 网络结构图如下所示: 

【说明】在 Encoder 阶段,每通过一个 block 后都下采样 2 倍(maxpool),在 Decoder 阶段,每通过一个 block 后都上采样 2 倍(bilinear)。U2-Net 网络的核心 block 是 ReSidual U-block,分为具备上下采样的 block 和不具备上下采样的 block(Encoder5、Encoder6、Decoder5)。

二、U2-Net 网络源代码

1、model.py

(1)ConvBNReLU 类

这个 ConvBNReLU 类继承自 nn.Module 父类,传入了输入通道数 in_ch、输出通道数 out_ch、卷积核大小 kernel_size 和膨胀因子 dilation 等参数,当 dilation 设置为 1 时,代表当前这个卷积层是普通卷积,当 dilation 大于 1 时,代表当前这个卷积层是膨胀卷积。

【代码解析】对 ConvBNReLU 类代码进行具体讲解(结合上图):

  1.  根据传入的 kernel_size 和 dilation 计算对应的 padding
  2.  分别使用 Conv2d、BatchNorm2d、ReLU 实例化卷积层、BN 层、ReLU 层
  3.  定义前向传播函数,使得输入张量 x 依次通过卷积层、BN 层、ReLU 层

(2)DownConvBNReLU 类

这个 DownConvBNReLU 类继承自 ConvBNReLU 类。

【说明】Encoder 部分:使用 F.max_pool2d 方法进行下采样,再依次通过 Conv2d 层、BN 层、ReLU 层。 

(3)UpConvBNReLU 类

这个 UpConvBNReLU 类继承自 ConvBNReLU 类。

【说明】Decoder 部分:使用 F.interpolate 中的双线性插值方法进行上采样,这里注意,经过双线性插值法后,x2 的宽高将与 x1 的宽高相同,再用 torch.cat 将 Encoder 中的 x2 与 Decoder 中的 x1 连接,最后依次通过 Conv2d 层、BN 层、ReLU 层。

(4)RSU 类 & RSU4F 类

在之前的文章中已经具体介绍过 U2-Net 的两种 block 结构:Encoder1~Encoder4 与 Decoder1~Decoder4 采用的是同一种结构的 block ,只不过深度不同。Encoder5、Encoder6、Decoder5 采用的是另一种结构的 block 。 我们先来简单回忆下第一种 block 结构:

  • U2-Net 网络结构中的 Encoder1 和 Decoder1 采用的是 RSU-7 结构
  • U2-Net 网络结构中的 Encoder2 和 Decoder2 采用的是 RSU-6 结构
  • U2-Net 网络结构中的 Encoder3 和 Decoder3 采用的是 RSU-5 结构
  • U2-Net 网络结构中的 Encoder4 和 Decoder4 采用的是 RSU-4 结构

【说明】相邻 block 相差一次下采样和一次上采样,例如 RSU-6 相比于 RSU-7 少了一个下采样卷积和上采样卷积部分,RSU-7 是下采样 32 倍和上采样 32 倍,RSU-6 是下采样 16 倍和上采样 16 倍。

再来回忆第二种 block 结构,U2-Net 中的 Encoder5、Encoder6、Decoder5 采用的都是 RSU-4F 结构,注意 RSU-4F 与 RSU-4 的结构并不相同,在 RSU-4F 中未进行下采样和上采样,而是将 上下采样 全部替换成了 膨胀卷积 ,整个过程中特征图的宽高不变。

下面是 RSU 类和 RSU4F 类的代码截图,在 RSU 类的初始化 __init__ 方法中,传入的 height 参数是指 RSU 结构的深度。

(5)U2Net 类

这个 U2Net 类继承自 nn.Module 父类,在其初始化 __init__ 方法中,传入参数包括 cfg ,而在我们搭建 U2-Net 标准版以及轻量级的版本时,都会传入 cfg ,分别可以在 u2net_full 和 u2net_lite 函数中查看:

除此之外,在 U2-Net 中,默认是针对显著性目标检测任务去做的,只区分前景和背景,因此整个网络最终输出的预测概率图的通道数 out_ch 将设置为 1 ,也就是说预测的每一个像素的概率分数都是从 0 到 1 的,趋于 1 则说明代表二维前景的概率大,趋于 0 则说明代表背景的概率大。 

【说明】经过 concat 拼接后,将得到的特征图通过一个 1x1 的卷积层,融合来自不同尺度的信息,最终得到只有一个通道的特征图。

【说明】使用 decode_outputs.insert(0, x) 将处理后的结果 x 插入到 decode_outputs 列表的第一个位置,以 保持解码器输出的顺序

(6)model.py 源代码

from typing import Union, List
import torch
import torch.nn as nn
import torch.nn.functional as Fclass ConvBNReLU(nn.Module):def __init__(self, in_ch: int, out_ch: int, kernel_size: int = 3, dilation: int = 1):super().__init__()padding = kernel_size // 2 if dilation == 1 else dilationself.conv = nn.Conv2d(in_ch, out_ch, kernel_size, padding=padding, dilation=dilation, bias=False)self.bn = nn.BatchNorm2d(out_ch)self.relu = nn.ReLU(inplace=True)def forward(self, x: torch.Tensor) -> torch.Tensor:return self.relu(self.bn(self.conv(x)))class DownConvBNReLU(ConvBNReLU):def __init__(self, in_ch: int, out_ch: int, kernel_size: int = 3, dilation: int = 1, flag: bool = True):super().__init__(in_ch, out_ch, kernel_size, dilation)self.down_flag = flagdef forward(self, x: torch.Tensor) -> torch.Tensor:if self.down_flag:x = F.max_pool2d(x, kernel_size=2, stride=2, ceil_mode=True)return self.relu(self.bn(self.conv(x)))class UpConvBNReLU(ConvBNReLU):def __init__(self, in_ch: int, out_ch: int, kernel_size: int = 3, dilation: int = 1, flag: bool = True):super().__init__(in_ch, out_ch, kernel_size, dilation)self.up_flag = flagdef forward(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor:if self.up_flag:x1 = F.interpolate(x1, size=x2.shape[2:], mode='bilinear', align_corners=False)return self.relu(self.bn(self.conv(torch.cat([x1, x2], dim=1))))class RSU(nn.Module):def __init__(self, height: int, in_ch: int, mid_ch: int, out_ch: int):super().__init__()assert height >= 2self.conv_in = ConvBNReLU(in_ch, out_ch)encode_list = [DownConvBNReLU(out_ch, mid_ch, flag=False)]decode_list = [UpConvBNReLU(mid_ch * 2, mid_ch, flag=False)]for i in range(height - 2):encode_list.append(DownConvBNReLU(mid_ch, mid_ch))decode_list.append(UpConvBNReLU(mid_ch * 2, mid_ch if i < height - 3 else out_ch))encode_list.append(ConvBNReLU(mid_ch, mid_ch, dilation=2))self.encode_modules = nn.ModuleList(encode_list)self.decode_modules = nn.ModuleList(decode_list)def forward(self, x: torch.Tensor) -> torch.Tensor:x_in = self.conv_in(x)x = x_inencode_outputs = []for m in self.encode_modules:x = m(x)encode_outputs.append(x)x = encode_outputs.pop()for m in self.decode_modules:x2 = encode_outputs.pop()x = m(x, x2)return x + x_inclass RSU4F(nn.Module):def __init__(self, in_ch: int, mid_ch: int, out_ch: int):super().__init__()self.conv_in = ConvBNReLU(in_ch, out_ch)self.encode_modules = nn.ModuleList([ConvBNReLU(out_ch, mid_ch),ConvBNReLU(mid_ch, mid_ch, dilation=2),ConvBNReLU(mid_ch, mid_ch, dilation=4),ConvBNReLU(mid_ch, mid_ch, dilation=8)])self.decode_modules = nn.ModuleList([ConvBNReLU(mid_ch * 2, mid_ch, dilation=4),ConvBNReLU(mid_ch * 2, mid_ch, dilation=2),ConvBNReLU(mid_ch * 2, out_ch)])def forward(self, x: torch.Tensor) -> torch.Tensor:x_in = self.conv_in(x)x = x_inencode_outputs = []for m in self.encode_modules:x = m(x)encode_outputs.append(x)x = encode_outputs.pop()for m in self.decode_modules:x2 = encode_outputs.pop()x = m(torch.cat([x, x2], dim=1))return x + x_inclass U2Net(nn.Module):def __init__(self, cfg: dict, out_ch: int = 1):super().__init__()assert "encode" in cfgassert "decode" in cfgself.encode_num = len(cfg["encode"])encode_list = []side_list = []for c in cfg["encode"]:# c: [height, in_ch, mid_ch, out_ch, RSU4F, side]assert len(c) == 6encode_list.append(RSU(*c[:4]) if c[4] is False else RSU4F(*c[1:4]))if c[5] is True:side_list.append(nn.Conv2d(c[3], out_ch, kernel_size=3, padding=1))self.encode_modules = nn.ModuleList(encode_list)decode_list = []for c in cfg["decode"]:# c: [height, in_ch, mid_ch, out_ch, RSU4F, side]assert len(c) == 6decode_list.append(RSU(*c[:4]) if c[4] is False else RSU4F(*c[1:4]))if c[5] is True:side_list.append(nn.Conv2d(c[3], out_ch, kernel_size=3, padding=1))self.decode_modules = nn.ModuleList(decode_list)self.side_modules = nn.ModuleList(side_list)self.out_conv = nn.Conv2d(self.encode_num * out_ch, out_ch, kernel_size=1)def forward(self, x: torch.Tensor) -> Union[torch.Tensor, List[torch.Tensor]]:_, _, h, w = x.shape# collect encode outputsencode_outputs = []for i, m in enumerate(self.encode_modules):x = m(x)encode_outputs.append(x)if i != self.encode_num - 1:x = F.max_pool2d(x, kernel_size=2, stride=2, ceil_mode=True)# collect decode outputsx = encode_outputs.pop()decode_outputs = [x]for m in self.decode_modules:x2 = encode_outputs.pop()x = F.interpolate(x, size=x2.shape[2:], mode='bilinear', align_corners=False)x = m(torch.concat([x, x2], dim=1))decode_outputs.insert(0, x)# collect side outputsside_outputs = []for m in self.side_modules:x = decode_outputs.pop()x = F.interpolate(m(x), size=[h, w], mode='bilinear', align_corners=False)side_outputs.insert(0, x)x = self.out_conv(torch.concat(side_outputs, dim=1))if self.training:# do not use torch.sigmoid for amp safereturn [x] + side_outputselse:return torch.sigmoid(x)def u2net_full(out_ch: int = 1):cfg = {# height, in_ch, mid_ch, out_ch, RSU4F, side"encode": [[7, 3, 32, 64, False, False],      # En1[6, 64, 32, 128, False, False],    # En2[5, 128, 64, 256, False, False],   # En3[4, 256, 128, 512, False, False],  # En4[4, 512, 256, 512, True, False],   # En5[4, 512, 256, 512, True, True]],   # En6# height, in_ch, mid_ch, out_ch, RSU4F, side"decode": [[4, 1024, 256, 512, True, True],   # De5[4, 1024, 128, 256, False, True],  # De4[5, 512, 64, 128, False, True],    # De3[6, 256, 32, 64, False, True],     # De2[7, 128, 16, 64, False, True]]     # De1}return U2Net(cfg, out_ch)def u2net_lite(out_ch: int = 1):cfg = {# height, in_ch, mid_ch, out_ch, RSU4F, side"encode": [[7, 3, 16, 64, False, False],  # En1[6, 64, 16, 64, False, False],  # En2[5, 64, 16, 64, False, False],  # En3[4, 64, 16, 64, False, False],  # En4[4, 64, 16, 64, True, False],  # En5[4, 64, 16, 64, True, True]],  # En6# height, in_ch, mid_ch, out_ch, RSU4F, side"decode": [[4, 128, 16, 64, True, True],  # De5[4, 128, 16, 64, False, True],  # De4[5, 128, 16, 64, False, True],  # De3[6, 128, 16, 64, False, True],  # De2[7, 128, 16, 64, False, True]]  # De1}return U2Net(cfg, out_ch)def convert_onnx(m, save_path):m.eval()x = torch.rand(1, 3, 288, 288, requires_grad=True)# export the modeltorch.onnx.export(m,  # model being runx,  # model input (or a tuple for multiple inputs)save_path,  # where to save the model (can be a file or file-like object)export_params=True,opset_version=11)if __name__ == '__main__':# n_m = RSU(height=7, in_ch=3, mid_ch=12, out_ch=3)# convert_onnx(n_m, "RSU7.onnx")## n_m = RSU4F(in_ch=3, mid_ch=12, out_ch=3)# convert_onnx(n_m, "RSU4F.onnx")u2net = u2net_full()convert_onnx(u2net, "u2net_full.onnx")

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/328917.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

并发(5)

目录 26.说说final类型的类如何扩展&#xff1f; 27.final方法可以被重载吗&#xff1f; 28.父类的final方法能不能被子类重写&#xff1f; 29.说说基本类型的final域重排序规则&#xff1f; 30.说说final的原理&#xff1f; 26.说说final类型的类如何扩展&#xff1f; 比…

IDEA 每次新建工程都要重新配置 Maven的解决方案

文章目录 IDEA 每次新建工程都要重新配置 Maven 解决方案一、选择 File -> New Projects Setup -> Settingsfor New Projects…二、选择 Build,Execution,Deployment -> Build Tools -> Maven IDEA 每次新建工程都要重新配置 Maven 解决方案 DEA 每次新建工程都要…

STM32之模拟IIC总线控制SHT20温湿度芯片

一、IIC总线概述 1、IIC总线介绍 I2C (Inter&#xff0d;Integrated Circuit)总线产生于在80年代&#xff0c; 由PHILIPS公司开发的两线式串行总线&#xff0c;用于连接微控制器及其外围设备, 最初为音频和视频设备开发。I2C总线两线制包括&#xff1a;串行数据SDA&#xff08…

基于JAVA的服装店库存管理系统 开源项目

目录 一、摘要1.1 项目介绍1.2 项目录屏 二、功能模块2.1 数据中心模块2.2 角色管理模块2.3 服装档案模块2.4 服装入库模块2.5 服装出库模块 三、系统设计3.1 用例设计3.2 数据库设计3.2.1 角色表3.2.2 服装档案表3.2.3 服装入库表3.2.4 服装出库表 四、系统展示五、核心代码5.…

大数据分析案例-基于LinearRegression回归算法构建房屋价格预测模型

&#x1f935;‍♂️ 个人主页&#xff1a;艾派森的个人主页 ✍&#x1f3fb;作者简介&#xff1a;Python学习者 &#x1f40b; 希望大家多多支持&#xff0c;我们一起进步&#xff01;&#x1f604; 如果文章对你有帮助的话&#xff0c; 欢迎评论 &#x1f4ac;点赞&#x1f4…

JS事件循环

目录 概述1. 堆栈&#xff08;Call Stack&#xff09;2. 堆&#xff08;Heap&#xff09;3. 事件队列&#xff08;Event Queue&#xff09;4. 宿主环境&#xff08;Host Environment&#xff09; 事件循环&#xff08;Event Loop&#xff09;微任务和宏任务&#xff08;Microta…

王国维的人生三境界,这一生至少当一次傻瓜

一、人生三境界 古今之成大事业、大学问者&#xff0c;必经过三种之境界。“昨夜西风凋碧树&#xff0c;独上高楼&#xff0c;望尽天涯路。”此第一境也。“衣带渐宽终不悔&#xff0c;为伊消得人憔悴。”此第二境也。“众里寻他千百度&#xff0c;蓦然回首&#xff0c;那人却…

阿里云服务器地域怎么选择?2024地域选择方法

阿里云服务器地域和可用区怎么选择&#xff1f;地域是指云服务器所在物理数据中心的位置&#xff0c;地域选择就近选择&#xff0c;访客距离地域所在城市越近网络延迟越低&#xff0c;速度就越快&#xff1b;可用区是指同一个地域下&#xff0c;网络和电力相互独立的区域&#…

Windows11下载安装nacos(2.3.0)详解

一、环境要求 windows7以上 jdk8及以上版本&#xff0c;并且配置了JAVA_HOME环境变量 二、nacos下载解压 release版本地址:Releases alibaba/nacos GitHub 下载后解压即可&#xff0c;上面的tar.gz是linux版本 解压后如下 nacos自己内置有数据库derby&#xff0c;我用的是…

Java技术栈 —— Hadoop入门(一)

Java技术栈 —— Hadoop入门&#xff08;一&#xff09; 一、Hadoop第一印象二、安装Hadoop三、Hadoop解析3.1 Hadoop生态介绍3.1.1 MapReduce - 核心组件3.1.2 HDFS - 核心组件3.1.3 YARN - 核心组件3.1.4 其它组件3.1.4.1 HBase3.1.4.2 Hive3.1.4.3 Spark 一、Hadoop第一印象…

2024年【化工自动化控制仪表】报名考试及化工自动化控制仪表考试技巧

题库来源&#xff1a;安全生产模拟考试一点通公众号小程序 化工自动化控制仪表报名考试是安全生产模拟考试一点通生成的&#xff0c;化工自动化控制仪表证模拟考试题库是根据化工自动化控制仪表最新版教材汇编出化工自动化控制仪表仿真模拟考试。2024年【化工自动化控制仪表】…

npm、pnpm和yarn 的区别

包管理工具是JavaScript开发中不可或缺的一部分&#xff0c;它们可以帮助我们方便地安装、更新、删除和管理项目所依赖的各种库和模块。 目前&#xff0c;最流行的包管理工具有npm、yarn和pnpm&#xff0c;它们各有各的特点和优劣势。 本文将试着对这三个工具进行全面的对比。…