Pytorch从零开始实战21

Pytorch从零开始实战——Pix2Pix理论与实战

本系列来源于365天深度学习训练营

原作者K同学

文章目录

  • Pytorch从零开始实战——Pix2Pix理论与实战
    • 内容介绍
    • 数据集加载
    • 模型实现
    • 开始训练
    • 总结

内容介绍

Pix2Pix是一种用于用于图像翻译的通用框架,即图像到图像的转换。它在生成对抗网络的框架下进行训练。Pix2Pix的目标是将输入图像转换为输出图像,例如将黑白线稿转换为彩色图像,或者将地图转换为卫星图像等。Pix2Pix模型的训练通常需要大量的配对数据,即包含输入图像与相应输出图像的数据集。

图像内容:指的是图像的固有内容,它是区分不同图像的依据。

图像域:指在特定上下文中所涵盖的一组图像的集合,这些图像通常具有某种相似性或共同特征。图像域可以用来表示一类具有共同属性或内容的图像。在图像处理和计算机视觉领域,图像域常常被用于描述参与某项任务或问题的图像集合。

图像翻译:是将一个物体的图像表征转换为该物体的另一个表征,例如根据皮包的轮廓图得到皮包的彩色图。也就是找到一个函数,能让域A的图像映射到域B,从而实现图像的跨域转换。

Pix2Pix的三个核心技术:

基于CGAN的损失函数:CGAN是Conditional Generative Adversarial Network的缩写,它将条件信息(如输入图像)作为生成器和判别器的输入,以帮助生成器生成更加逼真的输出。Pix2Pix使用了基于条件GAN的损失函数来指导生成器生成与目标图像更加接近的输出。

基于U-Net的生成器:U-Net是一种用于图像分割的卷积神经网络结构,它由编码器和解码器组成,通过跳跃连接将低级特征与高级特征相结合,有助于保留更多的图像细节。Pix2Pix中的生成器采用了U-Net结构,以实现图像到图像的转换。
下图为Pix2Pix的生成器的网络结构。
在这里插入图片描述

基于PatchGAN的判别器:PatchGAN是一种判别器的设计,对图像中的局部区域进行分类。Pix2Pix提出了将输入图像分成个图像块,这有助于提高模型对细节的感知,并使得判别器更能够捕捉到图像的局部结构信息。Pix2Pix使用了基于PatchGAN的判别器来评估生成的图像的真实度。

数据集加载

定义了一个自定义的数据集类ImageDataset,用于加载图像数据集。

import glob
import random
import os
import numpy as np
from torch.utils.data import Dataset
from PIL import Image
import torchvision.transforms as transformsclass ImageDataset(Dataset):def __init__(self, root, transforms_=None, mode="train"):self.transform = transforms.Compose(transforms_)self.files = sorted(glob.glob(os.path.join(root, mode) + "/*.*"))if mode == "train":self.files.extend(sorted(glob.glob(os.path.join(root, "test") + "/*.*")))def __getitem__(self, index):img = Image.open(self.files[index % len(self.files)])w, h = img.sizeimg_A = img.crop((0, 0, w / 2, h))img_B = img.crop((w / 2, 0, w, h))if np.random.random() < 0.5:img_A = Image.fromarray(np.array(img_A)[:, ::-1, :], "RGB")img_B = Image.fromarray(np.array(img_B)[:, ::-1, :], "RGB")img_A = self.transform(img_A)img_B = self.transform(img_B)return {"A": img_A, "B": img_B}def __len__(self):return len(self.files)

模型实现

下面代码定义了一个U-Net生成器和一个PatchGAN鉴别器。

import torch.nn as nn
import torch.nn.functional as F
import torch
def weights_init_normal(m):classname = m.__class__.__name__if classname.find("Conv") != -1:torch.nn.init.normal_(m.weight.data, 0.0, 0.02)elif classname.find("BatchNorm2d") != -1:torch.nn.init.normal_(m.weight.data, 1.0, 0.02)torch.nn.init.constant_(m.bias.data, 0.0)

UNetDown 类和 UNetUp 类:这些类定义了 U-Net 结构中的下采样和上采样部分的层次。UNetDown 类用于定义 U-Net 的下采样部分,它包括卷积层、归一化层、激活函数层和丢弃层。UNetUp 类定义了 U-Net 的上采样部分,它包括转置卷积层、归一化层、激活函数层和丢弃层。

GeneratorUNet 类:这是整个 U-Net 生成器的定义。它利用了之前定义的 UNetDown 和 UNetUp 类来构建一个完整的 U-Net 网络。在 init方法中,它初始化了 U-Net 的各个层次,并在 forward 方法中定义了数据在网络中的传播方式。

class UNetDown(nn.Module):def __init__(self, in_size, out_size, normalize=True, dropout=0.0):super(UNetDown, self).__init__()layers = [nn.Conv2d(in_size, out_size, 4, 2, 1, bias=False)]if normalize:layers.append(nn.InstanceNorm2d(out_size))layers.append(nn.LeakyReLU(0.2))if dropout:layers.append(nn.Dropout(dropout))self.model = nn.Sequential(*layers)def forward(self, x):return self.model(x)class UNetUp(nn.Module):def __init__(self, in_size, out_size, dropout=0.0):super(UNetUp, self).__init__()layers = [nn.ConvTranspose2d(in_size, out_size, 4, 2, 1, bias=False),nn.InstanceNorm2d(out_size),nn.ReLU(inplace=True),]if dropout:layers.append(nn.Dropout(dropout))self.model = nn.Sequential(*layers)def forward(self, x, skip_input):x = self.model(x)x = torch.cat((x, skip_input), 1)return xclass GeneratorUNet(nn.Module):def __init__(self, in_channels=3, out_channels=3):super(GeneratorUNet, self).__init__()self.down1 = UNetDown(in_channels, 64, normalize=False)self.down2 = UNetDown(64, 128)self.down3 = UNetDown(128, 256)self.down4 = UNetDown(256, 512, dropout=0.5)self.down5 = UNetDown(512, 512, dropout=0.5)self.down6 = UNetDown(512, 512, dropout=0.5)self.down7 = UNetDown(512, 512, dropout=0.5)self.down8 = UNetDown(512, 512, normalize=False, dropout=0.5)self.up1 = UNetUp(512, 512, dropout=0.5)self.up2 = UNetUp(1024, 512, dropout=0.5)self.up3 = UNetUp(1024, 512, dropout=0.5)self.up4 = UNetUp(1024, 512, dropout=0.5)self.up5 = UNetUp(1024, 256)self.up6 = UNetUp(512, 128)self.up7 = UNetUp(256, 64)self.final = nn.Sequential(nn.Upsample(scale_factor=2),nn.ZeroPad2d((1, 0, 1, 0)),nn.Conv2d(128, out_channels, 4, padding=1),nn.Tanh(),)def forward(self, x):# U-Net generator with skip connections from encoder to decoderd1 = self.down1(x)d2 = self.down2(d1)d3 = self.down3(d2)d4 = self.down4(d3)d5 = self.down5(d4)d6 = self.down6(d5)d7 = self.down7(d6)d8 = self.down8(d7)u1 = self.up1(d8, d7)u2 = self.up2(u1, d6)u3 = self.up3(u2, d5)u4 = self.up4(u3, d4)u5 = self.up5(u4, d3)u6 = self.up6(u5, d2)u7 = self.up7(u6, d1)return self.final(u7)

discriminator_block 函数:这个函数定义了 PatchGAN 判别器中的一个“块”,包括一个卷积层、一个归一化层和一个 LeakyReLU 激活函数。

Discriminator 类:这是整个 PatchGAN 判别器的定义。它由一系列卷积层和 LeakyReLU 激活函数层组成,用于从图像对中提取特征,并输出一个判别值,表示输入图像对是真实对还是生成对。

class Discriminator(nn.Module):def __init__(self, in_channels=3):super(Discriminator, self).__init__()def discriminator_block(in_filters, out_filters, normalization=True):"""Returns downsampling layers of each discriminator block"""layers = [nn.Conv2d(in_filters, out_filters, 4, stride=2, padding=1)]if normalization:layers.append(nn.InstanceNorm2d(out_filters))layers.append(nn.LeakyReLU(0.2, inplace=True))return layersself.model = nn.Sequential(*discriminator_block(in_channels * 2, 64, normalization=False),*discriminator_block(64, 128),*discriminator_block(128, 256),*discriminator_block(256, 512),nn.ZeroPad2d((1, 0, 1, 0)),nn.Conv2d(512, 1, 4, padding=1, bias=False))def forward(self, img_A, img_B):# Concatenate image and condition image by channels to produce inputimg_input = torch.cat((img_A, img_B), 1)return self.model(img_input)

开始训练

主要功能包括:

1.解析命令行参数:使用argparse.ArgumentParser()解析命令行参数,包括训练所需的超参数、数据集名称等。

2.定义生成器和判别器模型:使用之前定义的GeneratorUNet和Discriminator类创建模型。

3.配置优化器:使用Adam优化器来优化生成器和判别器的参数。

4.加载数据集:使用PyTorch的DataLoader加载训练和验证数据集。

5.定义损失函数:定义了GAN损失和像素级别的L1损失。

6.训练过程:使用双重循环进行训练,其中外层循环遍历每个epoch,内层循环遍历每个batch。在每个batch内,首先训练生成器,然后训练判别器。训练过程中,会打印损失信息以及训练的进度。

7.保存模型和生成示例图像:在每个epoch结束时,会保存生成器和判别器的模型参数,并周期性地生成一些示例图像以供可视化和评估模型效果。

import argparse
import time
import datetime
import sysimport torchvision.transforms as transforms
from torchvision.utils import save_image
from torch.utils.data import DataLoader
from torch.autograd import Variablefrom models import *
from datasets import *import torch.nn as nn
import torch.nn.functional as F
import torchparser = argparse.ArgumentParser()
parser.add_argument("--epoch", type=int, default=0, help="epoch to start training from")
parser.add_argument("--n_epochs", type=int, default=100, help="number of epochs of training")
parser.add_argument("--dataset_name", type=str, default="data_facades", help="name of the dataset")
parser.add_argument("--batch_size", type=int, default=1, help="size of the batches")
parser.add_argument("--lr", type=float, default=0.0002, help="adam: learning rate")
parser.add_argument("--b1", type=float, default=0.5, help="adam: decay of first order momentum of gradient")
parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of first order momentum of gradient")
parser.add_argument("--decay_epoch", type=int, default=100, help="epoch from which to start lr decay")
parser.add_argument("--n_cpu", type=int, default=8, help="number of cpu threads to use during batch generation")
parser.add_argument("--img_height", type=int, default=256, help="size of image height")
parser.add_argument("--img_width", type=int, default=256, help="size of image width")
parser.add_argument("--channels", type=int, default=3, help="number of image channels")
parser.add_argument("--sample_interval", type=int, default=500, help="interval between sampling of images from generators"
)
parser.add_argument("--checkpoint_interval", type=int, default=-1, help="interval between model checkpoints")
opt = parser.parse_args()
print(opt)os.makedirs("images/%s" % opt.dataset_name, exist_ok=True)
os.makedirs("saved_models/%s" % opt.dataset_name, exist_ok=True)cuda = True if torch.cuda.is_available() else False# Loss functions
criterion_GAN = torch.nn.MSELoss()
criterion_pixelwise = torch.nn.L1Loss()# Loss weight of L1 pixel-wise loss between translated image and real image
lambda_pixel = 100# Calculate output of image discriminator (PatchGAN)
patch = (1, opt.img_height // 2 ** 4, opt.img_width // 2 ** 4)# Initialize generator and discriminator
generator = GeneratorUNet()
discriminator = Discriminator()if cuda:generator = generator.cuda()discriminator = discriminator.cuda()criterion_GAN.cuda()criterion_pixelwise.cuda()if opt.epoch != 0:# Load pretrained modelsgenerator.load_state_dict(torch.load("saved_models/%s/generator_%d.pth" % (opt.dataset_name, opt.epoch)))discriminator.load_state_dict(torch.load("saved_models/%s/discriminator_%d.pth" % (opt.dataset_name, opt.epoch)))
else:# Initialize weightsgenerator.apply(weights_init_normal)discriminator.apply(weights_init_normal)# Optimizers
optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))# Configure dataloaders
transforms_ = [transforms.Resize((opt.img_height, opt.img_width), Image.BICUBIC),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]dataloader = DataLoader(ImageDataset("./%s" % opt.dataset_name, transforms_=transforms_),batch_size=opt.batch_size,shuffle=True,num_workers=opt.n_cpu,
)val_dataloader = DataLoader(ImageDataset("./%s" % opt.dataset_name, transforms_=transforms_, mode="val"),batch_size=10,shuffle=True,num_workers=1,
)# Tensor type
Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensordef sample_images(batches_done):"""Saves a generated sample from the validation set"""imgs = next(iter(val_dataloader))real_A = Variable(imgs["B"].type(Tensor))real_B = Variable(imgs["A"].type(Tensor))fake_B = generator(real_A)img_sample = torch.cat((real_A.data, fake_B.data, real_B.data), -2)save_image(img_sample, "images/%s/%s.png" % (opt.dataset_name, batches_done), nrow=5, normalize=True)# ----------
#  Training
# ----------if __name__ == '__main__':prev_time = time.time()for epoch in range(opt.epoch, opt.n_epochs):for i, batch in enumerate(dataloader):# Model inputsreal_A = Variable(batch["B"].type(Tensor))real_B = Variable(batch["A"].type(Tensor))# Adversarial ground truthsvalid = Variable(Tensor(np.ones((real_A.size(0), *patch))), requires_grad=False)fake = Variable(Tensor(np.zeros((real_A.size(0), *patch))), requires_grad=False)# ------------------#  Train Generators# ------------------optimizer_G.zero_grad()# GAN lossfake_B = generator(real_A)pred_fake = discriminator(fake_B, real_A)loss_GAN = criterion_GAN(pred_fake, valid)# Pixel-wise lossloss_pixel = criterion_pixelwise(fake_B, real_B)# Total lossloss_G = loss_GAN + lambda_pixel * loss_pixelloss_G.backward()optimizer_G.step()# ---------------------#  Train Discriminator# ---------------------optimizer_D.zero_grad()# Real losspred_real = discriminator(real_B, real_A)loss_real = criterion_GAN(pred_real, valid)# Fake losspred_fake = discriminator(fake_B.detach(), real_A)loss_fake = criterion_GAN(pred_fake, fake)# Total lossloss_D = 0.5 * (loss_real + loss_fake)loss_D.backward()optimizer_D.step()# --------------#  Log Progress# --------------# Determine approximate time leftbatches_done = epoch * len(dataloader) + ibatches_left = opt.n_epochs * len(dataloader) - batches_donetime_left = datetime.timedelta(seconds=batches_left * (time.time() - prev_time))prev_time = time.time()# Print logsys.stdout.write("\r[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f, pixel: %f, adv: %f] ETA: %s"% (epoch,opt.n_epochs,i,len(dataloader),loss_D.item(),loss_G.item(),loss_pixel.item(),loss_GAN.item(),time_left,))# If at sample interval save imageif batches_done % opt.sample_interval == 0:sample_images(batches_done)if opt.checkpoint_interval != -1 and epoch % opt.checkpoint_interval == 0:# Save model checkpointstorch.save(generator.state_dict(), "saved_models/%s/generator_%d.pth" % (opt.dataset_name, epoch))torch.save(discriminator.state_dict(), "saved_models/%s/discriminator_%d.pth" % (opt.dataset_name, epoch))

训练结果
在这里插入图片描述

总结

Pix2Pix模型在许多领域都有广泛的应用,可以用于许多任务,如将黑白图像转换为彩色图像、将语义标签转换为真实图像、图像超分辨率等。 由于Pix2Pix使用了生成对抗网络,生成的图像可以在一定程度上理解为真实图像与生成图像之间的差异,使得生成结果更具可解释性。

Pix2Pix的训练通常需要大量的配对数据,即包含输入图像与相应输出图像的数据集,以获得更好的性能和生成效果。

并且,Pix2Pix模型的训练较为复杂,需要精细调节超参数、选择合适的损失函数以及处理训练不稳定等问题。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/539646.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

PostgreSQL开发与实战(6.3)体系结构3

作者&#xff1a;太阳 四、物理结构 4.1 软件安装目录 bin //二进制可执行文件 include //头文件目录 lib //动态库文件 share //文档以及配置模版文件4.2 数据目录 4.2.1 参数文件 pg_hba.conf //认证配置文件 p…

目标检测——YOLOv3算法解读

论文&#xff1a;YOLOv3&#xff1a;An Incremental Improvement 作者&#xff1a;Joseph Redmon, Ali Farhadi 链接&#xff1a;https://arxiv.org/abs/1804.02767 代码&#xff1a;http://pjreddie.com/yolo/ YOLO系列其他文章&#xff1a; YOLOv1通俗易懂版解读SSD算法解读…

ArcGIS全系列实战视频教程——9个单一课程组合+系列直播回放

《ArcGIS全系列实战视频教程》是由9个单一课程组合合成。组成一条ArcGIS入门实战各项专题深入应用学习全链条&#xff0c;让你学有方向、学有目的&#xff0c;系统全面掌握ArcGIS。 ArcGIS全系列实战视频教程——9个单一课程组合https://edu.csdn.net/combo/detail/2569 《Ar…

vue 基于elementUI/antd-vue, h函数实现message中嵌套链接跳转到指定路由 (h函数点击事件的写法)

效果如图&#xff1a; 点击message 组件中的 工单管理&#xff0c; 跳转到工单管理页面。 以下是基于vue3 antd-vue 代码如下&#xff1a; import { message } from ant-design-vue; import { h, reactive, ref, watch } from vue; import { useRouter } from vue-router; c…

jenkins+maven+gitlab自动化构建打包、部署

Jenkins自动化部署实现原理 环境准备 1、jenkins已经安装好 docker安装jenkins 2、gitlab已经安装好 docker安装gitlab 一、Jenkins系统配置 1.Global Tool Configuration 任务构建所用到的编译环境等配置&#xff0c;配置参考&#xff1a; jdk配置&#xff08;jenkins自带…

数据资产在制造行业的应用:释放潜在价值,驱动产业升级

随着信息技术的飞速发展&#xff0c;数据已成为各行各业的重要资产。在制造行业中&#xff0c;数据资产的应用更是日益广泛&#xff0c;为企业的生产、管理、决策等各个环节带来了前所未有的变革。本文将深入探讨数据资产在制造行业的应用&#xff0c;以及如何通过数据驱动实现…

递归类C++

1、汉诺塔 面试题 08.06. 汉诺塔问题 - 力扣(LeetCode) Why?为什么这个汉诺塔问题可以用递归来解决? 如何来解决汉诺塔问题? 如果N == 1时,A[0] B C,直接将A上的盘子转移到C上面。 如果N == 2时,A[1,0] B C,先将上面所有的盘子放到B,然后将A上最大的那个放到C,再把…

深度学习_20_卷积中的填充与步幅

如果图片本身比较小&#xff0c;卷积之后输出也会很小&#xff0c;那么可以在图片与卷积核相乘之前先填充一下&#xff0c;让输出为预期大小 一般填充后输入&#xff0c;输出相同 当图片比较大的时候&#xff0c;如果利用卷积核去得到我们想要的大小的话&#xff0c;得用到多层…

爬虫3_爬取翻页URL不变的网站

之前实现了对大学排数据爬取&#xff1a;爬虫2_2019年549所中国大学排名. 近期复现代码&#xff0c;发现原网站升级&#xff0c;在翻页时&#xff0c;发现URL不改变&#xff0c;修改代码&#xff0c;使用网页自动化工具selenium实现对该类网站数据获取。 #-*- coding: UTF-8 -…

【数据结构和算法初阶(C语言)】队列实操(概念实现+oj题目栈和队列的双向实现,超级经典!!!)

1. 队列的概念及结构 队列&#xff1a;只允许在一端进行插入数据操作&#xff0c;在另一端进行删除数据操作的特殊线性表&#xff0c; 队列具有先进先出 FIFO(First In First Out) 入队列&#xff1a;进行插入操作的一端称为队尾 出队列&#xff1a;进行删除操作的一端称为…

BUGKU-WEB shell

题目描述 题目截图如下&#xff1a; 描述&#xff1a; $poc "a#s#s#e#r#t";$poc_1 explode("#", $poc);$poc_2 $poc_1[0].$poc_1[1].$poc_1[2].$poc_1[3].$poc_1[4].$poc_1[5];$poc_2($_GET[s])进入场景看看&#xff1a;是一个空白的界面 解题思路 …

从根到叶:深度理解哈希表

​​​​​​​ 一.哈希表的概念 关于查找元素时&#xff1a; 在顺序结构以及平衡树 中&#xff0c;元素关键码与其存储位置之间没有对应的关系&#xff0c;因此在 查找一个元素时&#xff0c;必须要经过关键 码的多次比较 。 顺序查找时间复杂度为 O(N) &#xff0c;平衡树中…