深度学习(UNet)

news/2024/12/24 11:18:51/文章来源:https://www.cnblogs.com/tiandsp/p/17922362.html

和FCN类似,UNet是另一个做语义分割的网络,网络从输入到输出中间呈一个U型而得名。

相比于FCN,UNet增加了更多的中间连接,能够更好处理不同尺度上的特征。

网络结构如下:

下面代码是用UNet对VOC数据集做的语义分割。

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset,DataLoader
from torchvision import transforms
import os
from PIL import Image
import numpy as nptransform = transforms.Compose([transforms.Resize((256, 256)),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])device = torch.device("cuda" if torch.cuda.is_available() else "cpu")colormap = [[0,0,0],[128,0,0],[0,128,0], [128,128,0], [0,0,128],[128,0,128],[0,128,128],[128,128,128],[64,0,0],[192,0,0],[64,128,0],[192,128,0],[64,0,128],[192,0,128],[64,128,128],[192,128,128],[0,64,0],[128,64,0],[0,192,0],[128,192,0],[0,64,128]]class VOCData(Dataset):def __init__(self, root):super(VOCData, self).__init__()self.lab_path = root + 'VOC2012/SegmentationClass/'self.img_path = root + 'VOC2012/JPEGImages/'self.lab_names = self.get_file_names(self.lab_path)self.img_names=[]for file in self.lab_names:self.img_names.append(file.replace('.png', '.jpg'))self.cm2lbl = np.zeros(256**3) for i,cm in enumerate(colormap): self.cm2lbl[cm[0]*256*256+cm[1]*256+cm[2]] = iself.image = []self.label = []for i in range(len(self.lab_names)):image = Image.open(self.img_path+self.img_names[i]).convert('RGB')image = transform(image)label = Image.open(self.lab_path+self.lab_names[i]).convert('RGB').resize((256,256))label = torch.from_numpy(self.image2label(label))self.image.append(image)self.label.append(label)def __len__(self):return len(self.image)def __getitem__(self, idx):return self.image[idx], self.label[idx]def get_file_names(self,directory):file_names = []for file_name in os.listdir(directory):if os.path.isfile(os.path.join(directory, file_name)):file_names.append(file_name)return file_namesdef image2label(self,im):data = np.array(im, dtype='int32')idx = data[:, :, 0] * 256 * 256 + data[:, :, 1] * 256 + data[:, :, 2]return np.array(self.cm2lbl[idx], dtype='int64')class convblock(nn.Module):def __init__(self, in_channels, out_channels):super(convblock, self).__init__()self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)self.bn1 = nn.BatchNorm2d(out_channels)self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)self.bn2 = nn.BatchNorm2d(out_channels)self.relu = nn.ReLU(inplace=True)def forward(self, x):x = self.conv1(x)x = self.bn1(x)x = self.relu(x)x = self.conv2(x)x = self.bn2(x)x = self.relu(x)        return xclass Unet(nn.Module):def __init__(self, num_classes):super(Unet, self).__init__()self.conv_block1 = convblock(3,64)self.conv_block2 = convblock(64,128)self.conv_block3 = convblock(128,256)self.conv_block4 = convblock(256,512)self.conv_block5 = convblock(512,1024)self.upsample1 = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2, padding=0)self.upsample2 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2, padding=0)self.upsample3 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2, padding=0)self.upsample4 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2, padding=0)self.conv_block6 = convblock(1024,512)self.conv_block7 = convblock(512,256)self.conv_block8 = convblock(256,128)self.conv_block9 = convblock(128,64)self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)self.conv_out = convblock(64,num_classes)def forward(self, x):x1 = self.conv_block1(x)  x = self.maxpool(x1) x2 = self.conv_block2(x) x = self.maxpool(x2)x3 = self.conv_block3(x)x = self.maxpool(x3)x4 = self.conv_block4(x)x = self.maxpool(x4)x = self.conv_block5(x)x = self.upsample1(x)x = torch.cat([x4,x],dim=1)x = self.conv_block6(x)x = self.upsample2(x)x = torch.cat([x3,x],dim=1)x = self.conv_block7(x)x = self.upsample3(x)x = torch.cat([x2,x],dim=1)x = self.conv_block8(x)x = self.upsample4(x)x = torch.cat([x1,x],dim=1)x = self.conv_block9(x)x = self.conv_out(x)return xdef train():train_dataset = VOCData(root='./VOCdevkit/')train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)net = Unet(21)optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)criterion = nn.CrossEntropyLoss()net.to(device)net.train()num_epochs = 100for epoch in range(num_epochs):loss_sum = 0img_sum = 0for inputs, labels in train_loader:inputs =  inputs.to(device)labels =  labels.to(device)outputs = net(inputs)loss = criterion(outputs, labels)   optimizer.zero_grad()loss.backward()optimizer.step()loss_sum += loss.item()img_sum += inputs.shape[0]print('epochs:',epoch,loss_sum / img_sum )torch.save(net.state_dict(), 'unet.pth')def val():net = Unet(21)net.load_state_dict(torch.load('unet.pth'))net.to(device)net.eval()image = Image.open('./VOCdevkit/VOC2012/JPEGImages/2012_001064.jpg').convert('RGB')image = transform(image).unsqueeze(0).to(device)out = net(image).squeeze(0)ToPIL= transforms.ToPILImage()maxind = torch.argmax(out,dim=0)outimg = torch.zeros([3,256,256])for y in range(256):for x in range(256):outimg[:,x,y] = torch.from_numpy(np.array(colormap[maxind[x,y]]))re = ToPIL(outimg)re.show()if __name__ == "__main__":train()#val()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/806603.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【防忘笔记】测试过程与技术

测试人员应该想些什么 我自己是做后端的,对于模棱两可的需求和莫名其妙的测试case是深恶痛绝的,所以有时候我就会想测试人员应该会需要注意什么?以他们的角度,他们更在乎什么 最近有机会了解相关的知识,遂整理记录一下,以便之后在工作中更好的理解发生的各种事情 以客户为…

论文总结1--基于深度强化学习的四足机器人步态分析--2024.10.01

四足机器人的运动控制方法研究 1.传统运动控制 - 基于模型的控制方法目前,在四足机器人研究领域内应用最广泛的控制方法就是基于模型的控制方法,其中主要包括基于虚拟模型控制(Virtual Model Control,VMC)方法 、基于零力矩点(Zero Moment Point,ZMP) 的控制方法、弹簧…

Linux系统密码忘记

Linux系统密码忘记 1.故障背景误删除或修改/etc/passwd导致无法远程登录. 禁止root远程登录,没有添加普通用户,无法远程登录. root密码忘记,无法远程登录. linux无法启动.2.解决方法 root密码,恢复有备份的系统文件,都要重启系统,才能进入救援模式.解决方案 应用场景系统自带的…

应用中的错误处理概述

title: 应用中的错误处理概述 date: 2024/10/1 updated: 2024/10/1 author: cmdragon excerpt: 摘要:本文介绍了Nuxt中的错误处理机制,包括全局错误处理器和组件层级错误捕获,以及错误传递规则和生产环境下的处理方式 categories:前端开发tags:错误处理 Nuxt应用 全局处理…

TypeScrip在vue中的使用----defineEmits

向父元素发送消息 之前的语法: 在TS语法中,我们既要对defineEmits做类型约束,又要对emits做类型约束。 最主要是对defineEmits做一个泛型的约束。//在泛型对象中,有几个事件就写几个约束 type emitsType = {//()中有n个参数,第一个固定的是e,其他有具体参数决定。具体的写…

电影《749局》迅雷BT下载/百度云下载资源[MP4/2.12GB/5.35GB]超清版

电影《749局》:近未来的冒险与成长之旅电影《749局》是一部融合了科幻、冒险与奇幻元素的电影,由陆川编剧并执导,王俊凯、苗苗、郑恺、任敏、辛柏青领衔主演,李晨特邀主演,张钧甯、李梦、杨皓宇特别主演。该片于2024年10月1日在中国大陆上映,以其独特的科幻设定、宏大的视…

电影《749局》迅雷百度云下载资源4K分享[1.16GB/2.72GBMKV]高清加长版【1280P已完结】

电影《749局》的深度剖析与全面解读电影《749局》是一部集科幻、冒险、动作与奇幻元素于一体的力作,由陆川编剧并执导,王俊凯、苗苗、郑恺、任敏、辛柏青领衔主演,李晨特邀主演,张钧甯、李梦、杨皓宇特别主演。影片于2024年国庆档在中国大陆上映,以其独特的科幻设定、宏大…

南沙C++信奥赛陈老师解一本通题 1983:【19CSPJ普及组】公交换乘

​【题目描述】著名旅游城市 B 市为了鼓励大家采用公共交通方式出行,推出了一种地铁换乘公交车的优惠方案: 1、在搭乘一次地铁后可以获得一张优惠票,有效期为 4545 分钟,在有效期内可以消耗这张优惠票,免费搭乘一次票价不超过地铁票价的公交车。在有效期内指开始乘公交车的…

Flutter 实现骨架屏CE

什么是骨架屏 在客户端开发中,我们总是需要等待拿到服务端的响应后,再将内容呈现到页面上,那么在用户发起请求到客户端成功拿到响应的这段时间内,应该在屏幕上呈现点什么好呢? 答案是:骨架屏 那么什么是骨架屏呢,来问下 GPT:骨架屏(Skeleton Screen)是一种现代的用户…

[rCore学习笔记 028] Rust 中的动态内存分配

引言 想起我们之前在学习C的时候,总是提到malloc,总是提起,使用malloc现场申请的内存是属于堆,而直接定义的变量内存属于栈. 还记得当初学习STM32的时候CubeIDE要设置stack 和heap的大小. 但是我们要记得,这么好用的功能,实际上是操作系统在负重前行. 那么为了实现动态内存分配…

解决MacOS 13.0.1 苹果M1芯片 导入pyaudio报错的问题

【问题】 如果正常按照网上的教程,在terminal先使用brew安装portaudio(brew install portaudio),再使用pip在conda环境里安装pyaudio(pip install pyaudio),然后python直接导入pyaudio(import pyaudio)会报错如下:【分析】 可知报错来自于portaudio动态库。网上搜索解…

值班脱岗智能监测识别系统

值班脱岗智能监测识别系统通过AI视频智能分析技术,值班脱岗智能监测识别系统对办公工作岗位区域、岗亭、值班室、生产线岗位等进行7*24小时不间断实时监测,当超过后台规定时间没有人员在规定工作区域,无需人为干预系统立即抓拍告警提醒后台值班人员及时处理。值班脱岗智能监…