UNet网络学习记录

如下图所示,整个网络结构包括两部分,编码结构和解码结构,编码结构是对特征进行提取,解码结构是对特征进行还原;如下右图所示,这个步骤包括数据集的加载,网络的搭建,训练网络(调用网络)
在这里插入图片描述
网络的结构解析:
输入图片,进行两次卷积,进行一次下采样,重复4次
最下面1层开始,经过两次卷积,再上采样
从图中可以看出,上采样后,通道数发生了改变,这里使用了1x1的卷积
在每次的上采样之后,都与对应的下采样部分进行cat,进行拼接
在整个网络的搭建过程中,是通过模块化编程来进行实现,
比如分为:模块1定义卷积,其中包括两次卷积;模块2定义下采样;模块3定义上采样
data.py的代码如下所示:主要用于对训练集测试集的图像进行处理等

import osfrom torch.utils.data import Dataset  #导入所需要的包(utils下的dataset)
from utils1 import *
from torchvision import transforms #这里的归一化在pytorch中封装成了包
#所有的图片都需要处理,归一化处理
transform=transforms.Compose([transforms.ToTensor
])   #这里仅仅使用了totensorclass MyDataset(Dataset): #定义一个自己数据集的类,继承自Datasetdef __init__(self,path): #初始化,并传入地址self.path=path #把实际地址传入到变量上,下一步是获取文件所有的文件名self.name=os.listdir(os.path.join(path,'seg'))#这部分是使用os下的path指令拿到path路径并进行拼接(path/seg) #w外部的这个os是拿到每一个文件夹下的图片,因为文件放在多个文件夹def __len__(self): #返回文件名的数量,那也就是数据集的数量return len(self.name)def __getitem__(self, index): #这部分内容是数据集的制作,传入的是数据集的下标索引segment_name=self.name[index] #xx.png 这里是获取名字的下标segment_path=os.path.join(self.path,'seg',segment_name) #这里是拿到每个分割后图片的路径,数据集路径+文件夹路径+每张图片的名字image_path=os.path.join(self.path,'JPG',segment_name.replace('png','jpg')) #拿到原图的路径,并把原图png转为jpg这种格式,保持格式一致#输出的图片大小一般是不一致的,需要对图片进行缩放,这部分代码封装成了模块utilssegment_image=keep_image_size_open(segment_path)  #这里是把分割图片统一大小image=keep_image_size_open(image_path) #把真实图片统一大小return transforms(image),transform(segment_image) #把分割图片和原图都进行归一化,转换为0-1的数,这里返回的是对应的一组图片if __name__ == '__main__': #上面是定义函数,这里是主函数,程序的入口,如果直接运行这个代码,就会执行,是这些代码的入口data=MyDataset('给个地址')print(data[0][0].shape)  #这里是打印第0张原图的形状,设置的是3x256X256#注意:在计算机中data返回的数据类型是元组(不可以增删改)#Data[0][0].shape  打印出来的结果是[3,256,256]

注意:在计算机中data返回的数据类型是元组(不可以增删改)
Data[0][0].shape 打印出来的结果是[3,256,256].第一个0表示第0张,第二个0表示原图,第二个0更改为1就是分割的图片形状;第一个[]表示索引,也就是图片的序号,第二个[]表示pair中的第一个,即返回值的第一个
**utils.py部分的代码如下图所示:**这部分代码一般是定义用于图像处理的工具

#注意,utils表示的是一个工具类
#在这里表示的是对图片进行处理
from PIL import Imagedef keep_image_size_open(path,size=(256,256)): #这部分统一图像大小,需要的参数是图像路径,缩放图片大小为256x256img=Image.open(path) #使用image打开图片并且传入到img变量temp=max(img.size) #读取图片的最长边mask=Image.new('RGB',(temp,temp),(0,0,0)) #做一个掩码,就是读取图片的最长边,做一个黑色的正方形mask.paste(img,(0,0)) #把数据的图片粘贴上去,从(0,0)开始mask=mask.resize(size) #把掩码图片缩放为想要的尺寸(256x256)return mask #返回mask

net.py代码如下图所示:这部分代码是网络的搭建过程

import torch
from torch import nn
from  torch.nn import functional as Fclass Conv_Block(nn.Module):  #定义卷积模块,继承自己moduledef __int__(self,in_channel,out_channel): #初始化,需要传入输入和输出super(Conv_Block,self).__init__() #初始化父类方法self.layer=nn.Sequential( #定义层结构,使用Sequ实现连续性,即经过的每一步卷积过程nn.Conv2d(in_channel,out_channel,3,1,1,padding_mode='reflect',bias=False),nn.BatchNorm2d(out_channel),nn.Dropout2d(0.3),nn.LeakyReLU(),nn.Conv2d(out_channel,out_channel,3,1,1,padding_mode='reflect',bias=False),nn.BatchNorm2d(out_channel),nn.Dropout2d(0.3),nn.LeakyReLU(),)def forward(self,x): #定义前向传播,传入xreturn self.layer(x)class DownSample(nn.Module): #定义下采样模块def __int__(self,channel):super(DownSample,self).__init__()self.layer=nn.Sequential(nn.Conv2d(channel,channel,3,2,1,padding_mode='reflect',bias=False),nn.BatchNorm2d(channel),nn.LeakyReLU())def forward(self,x):return self.layer(x)class UpSample(nn.Module): #定义上采样模块def __int__(self,channel):super(UpSample,self).__int__()self.layer=nn.Conv2d(channel.channel//2,1,1)def forward(self,x,feature_map): #外加拼接部分up=F.interpolate(x,scale_factor=2,mode='nearest')out=self.layer(up)return torch.cat((out,feature_map),dim=1)#上述的过程都是定义相关的模块,下面开始搭建网络,就是把模块给组合起来
class UNet(nn.Module):def __int__(self):super(UNet,self).__init__()self.c1=Conv_Block(3,64)self.d1=DownSample(64)self.c2=Conv_Block(64,128)self.d2=DownSample(128)self.c3=Conv_Block(128,256)self.d3=DownSample(256)self.c4=Conv_Block(256,512)self.d4=DownSample(512)self.c5=Conv_Block(512,1024)self.u1=UpSample(1024)self.c6=Conv_Block(1024,512)self.u2 = UpSample(512)self.c7 = Conv_Block(512, 256)self.u3 = UpSample(256)self.c8 = Conv_Block(256, 128)self.u4 = UpSample(128)self.c9 = Conv_Block(128, 64)self.out=nn.Conv2d(64,3,3,1,1)self.Th=nn.Sigmoid()def forward(self,x):R1=self.c1(x)R2=self.c2(self.d1(R1))R3 = self.c3(self.d1(R2))R4 = self.c4(self.d1(R3))R5 = self.c5(self.d1(R4))o1=self.c6(self.u1(R5,R4))o2 = self.c7(self.u2(o1, R3))o3 = self.c8(self.u3(o2, R2))o4 = self.c9(self.u4(o3, R1))return self.Th(self.out(o4))if __name__ == '__main__':x=torch.randn(2,3,256,256)net=UNet()print(net(x).shape)

**train.py的代码如下图所示:**这部分代码主要是用来调用数据集,网络结构,定义训练优化器,损失函数,得到的结果以及可视化等等

from torch import nn,optim
import torch
from torch.utils.data import DataLoader
from data import *
from net import *
import os
from torchvision.utils import save_image
from utils1 import *device=torch.device('cuda'if torch.cuda.is_available()else'cup') #指定数据集
weight_path='params/unet.pth' #设置保存权重路径
data_path=r'数据集地址' #设置数据集路径
save_path='train_image'if __name__ == '__main__':data_loader=DataLoader(MyDataset(data_path),batch_size=4,shuffle=True) #调用原来写的data处理方式模块net=UNet().to(device) #把网络放在设备上if os.path.exists(weight_path):net.load_state_dict(torch.load(weight_path))#如果权重存在打印加载成功print('sucessful load weight!')else:print('not sucessful load weight')opt=optim.Adam(net.parameters()) #定义优化器,选择ADAM,把网络的参数放进去loss_fun=nn.BCELoss() #定义损失计算方式epoch=1 #定义训练的轮数while True: #一直训练for i, (image, segment_image) in enumerate(data_loader):image,segment_image=image.to(device),segment_image.to(device) #把原图和分割图都放在设备上面out_image=net(image) #经过网络输出的图片为out_imagetrain_loss=loss_fun(out_image,segment_image) #传入输出图和标签图,计算损失opt.zero_grad()#更新梯度train_loss.backward()#把损失反向传播,指导网络往损失小的方向下降opt.step()#使用优化器if i%5==0:#打印训练过程损失的变化print(f'训练的轮数:train_loss====>>{train_loss.item}')if i%50==0:#保存训练过程的参数torch.save(net.state_dict(),weight_path)#就是刚刚设置的保存路径#为了看训练过程中图的变化情况,对比原图+标签图+输出图,_image=image[0] #取第一张图作为,_image_segment_image=segment_image[0]_out_image=out_image[0]img=torch.stack([_image,_segment_image,_out_image],dim=0) #把得到的三张图进行拼接save_image(img,f'{save_path}/i.png’)epoch=epoch+1 #循环轮数

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/610941.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

计算机网络常见面试总结

文章目录 1. 计算机网络基础1.1 网络分层模型1. OSI 七层模型是什么?每一层的作用是什么?2.TCP/IP 四层模型是什么?每一层的作用是什么?3. 为什么网络要分层? 1.2 常见网络协议1. 应用层有哪些常见的协议?2…

力扣HOT100 - 160. 相交链表

解题思路: /*** Definition for singly-linked list.* public class ListNode {* int val;* ListNode next;* ListNode(int x) {* val x;* next null;* }* }*/ public class Solution {public ListNode getIntersectionNode(ListNode headA, ListNode headB) {if…

无人新零售引领的创新浪潮

无人新零售引领的创新浪潮 在数字化时代加速演进的背景下,无人新零售作为商业领域的一股新兴力量,正以其独特的高效性和便捷性重塑着传统的购物模式,开辟了一条充满创新潜力的发展道路。 依托人脸识别、物联网等尖端技术,无人新…

题目 2349: 信息学奥赛一本通T1437-扩散【二分答案】

信息学奥赛一本通T1437-扩散 - C语言网 (dotcpp.com) #include <iostream> #include <algorithm> using namespace std; #define int long long const int N2e2; int n; struct node{int x,y; }a[N]; int fa[N]; int dist[N][N];//记录两坐标间的曼哈顿距离 int fi…

玩机进阶教程------手机定制机 定制系统 解除系统安装软件限制的一些步骤解析

定制机 在于各工作室与商家合作定制rom中有一些定制机。限制用户私自安装第三方软件。或者限制解锁 。无法如正常机登陆账号等等。定制机一般用于固定行业或者一些部门。专机专用。例如很多巴枪扫描机型等等。或者一些小牌机型。对于没有官方包的机型首先要导出各个分区来制作…

代码随想录算法训练营Day52|LC300 最长递增子序列LC 674 最长连续递增序列LC 718 最长重复子数组

一句话总结&#xff1a;动规做多了就豁然开朗了。 原题链接&#xff1a;300 最长递增子序列 按照动规五部曲&#xff1a; 首先确定dp数组及下标的含义&#xff1a;dp[i]表示以nums[i]结尾的最长子序列的长度&#xff1b;确定状态转移方程&#xff1a;如果nums[i] > nums[j…

风储微网虚拟惯性控制系统simulink建模与仿真

目录 1.课题概述 2.系统仿真结果 3.核心程序与模型 4.系统原理简介 5.完整工程文件 1.课题概述 风储微网虚拟惯性控制系统simulink建模与仿真。风储微网虚拟惯性控制系统是一种模仿传统同步发电机惯性特性的控制策略&#xff0c;它通过集成风力发电系统、储能系统和其他分…

Java工具类:批量发送邮件(带附件)

​ 不好用请移至评论区揍我 原创代码&#xff0c;请勿转载&#xff0c;谢谢&#xff01; 一、介绍 用于给用户发送特定的邮件内容&#xff0c;支持附件、批量发送邮箱账号必须要开启 SMTP 服务&#xff08;具体见下文教程&#xff09;本文邮箱设置示例以”网易邮箱“为例&…

微服务学习(黑马)

学习黑马的微服务课程的笔记 导学 微服务架构 认识微服务 SpringCloud spring.io/projects/spring-cloud/ 服务拆分和远程调用 根据订单id查询订单功能 存在的问题 硬编码 eureka注册中心 搭建eureka 服务注册 在order-service中完成服务拉取 Ribbon负载均衡 Nacos注册中心…

ELK-Kibana 部署

目录 一、在 node1 节点上操作 1.1.安装 Kibana 1.2.设置 Kibana 的主配置文件 1.3.启动 Kibana 服务 1.4.验证 Kibana 1.5.将 Apache 服务器的日志&#xff08;访问的、错误的&#xff09;添加到 ES 并通过 Kibana 显示 1.6. 浏览器访问 二、部署FilebeatELK&…

使用 Axios 处理 AxiosError 的三种常见方法

在使用 Axios 时处理 AxiosError 有几种常见的方法: 使用 try-catch 语句捕获异常: try {const response await axios.get(/api/data);// 处理响应数据 } catch (error) {if (error.response) {// 请求成功但状态码不在 2xx 范围console.log(error.response.data);console.l…

学习JavaEE的日子 Day33 File类,IO流

Day33 1.File类 File是文件和目录路径名的抽象表示 File类的对象可以表示文件&#xff1a;C:\Users\Desktop\hhy.txt File类的对象可以表示目录路径名&#xff1a;C:\Users\Desktop File只关注文件本身的信息&#xff08;文件名、是否可读、是否可写…&#xff09;&#xff0c…