第G5周:Pix2Pix理论与实战

  • 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍖 原作者:K同学啊 | 接辅导、项目定制

一、背景知识

1.背景知识

1.1 图像翻译

图像翻译指的是将图像从源域转换到目标域的过程,同时保持图像内容的一致性。具体解释如下:

  1. 图像内容(Content):这是图像的固有属性,指的是图像展示的对象、场景或任何其他可视化信息。图像内容是区分不同图像的主要依据。
  2. 图像域(Domain):在图像翻译的背景下,一个域可以被认为是一组具有共同特征的图像。例如,所有带有蓝色天空的照片可以属于同一个域。在图像翻译中,通常涉及至少两个域:源域和目标域。域内的图像可以认为其内容被赋予了某些相同的风格、纹理或其他视觉特性。
  3. 图像翻译(Image-to-Image Translation, I2I):这是一个过程,目的是将图像从一个域(源域)转换到另一个域(目标域),同时尽可能保留原始图像的内容。这涉及到一系列复杂的算法和模型,如生成对抗网络(GANs),它们能够捕捉并学习不同域之间的映射关系。这个过程在计算机视觉和图像处理领域有着广泛的应用,包括图像风格转换、草图着色、照片卡通化等。

1.2 U-Net

U-Net是一种专为图像分割任务设计的深度学习网络结构,具有以下特点:

  • 编码器-解码器(Encoder-Decoder)结构:U-Net由一个收缩路径(编码器)和一个对称的扩展路径(解码器)组成。编码器部分主要负责通过卷积层提取特征,而解码器部分则用于上采样特征图,逐步恢复到原始图像的尺寸。
  • 跳跃连接(Skip Connections):在编码和解码阶段之间存在跳跃连接,即从编码器到解码器的深层特征图会与解码器相应层次的输出进行拼接。这种设计可以帮助保持图像的细节信息,并有助于更好地进行精确的分割。
  • 多层次特征融合:U-Net结构允许在不同层级的特征之间进行融合,这样可以让网络同时学习到浅层次的细节特征和深层次的语义特征,从而增强模型对不同尺度结构的识别能力。

2. pix2pix解析

Pix2Pix是一种基于条件生成对抗网络(cGAN)的图像翻译模型。它能够将输入的图像转换为对应的输出图像,通常用于解决图像到图像的转换问题。

  • 原理结构
  1. 编码器-解码器结构(U-Net):Pix2Pix的生成器G采用的是U-Net结构,这种结构通过跳跃连接(skip connections)使得网络能够更好地学习输入图像和输出图像之间的对应关系,保留图像的细节信息。
  2. 条件判别器(PatchGAN):判别器D使用的是PatchGAN结构,它的作用是在给定输入图像的条件下,判断输出图像是真实图像还是生成图像。PatchGAN不是对整个图像进行判别,而是对图像的局部区域(patches)进行判别,这样可以提高判别的效率和准确性。
  3. 条件对抗性网络(cGAN):Pix2Pix是基于cGAN的扩展,它在传统的GAN基础上增加了条件变量,使得生成器在生成图像时能够参考额外的信息,如输入图像或者其他条件信息。
  • 优势
  1. 端到端训练:Pix2Pix可以实现端到端的训练,不需要复杂的图像预处理或后处理步骤。
  2. 通用性强:虽然Pix2Pix是为特定的图像转换任务设计的,但由于其基于cGAN的结构,它具有很好的通用性,可以应用于多种图像到图像的转换任务。
  3. 高质量的输出:Pix2Pix能够生成高分辨率、高质量的图像,这得益于其精细的网络结构和训练过程。
  • 劣势
  1. 计算资源要求高:由于Pix2Pix使用了深度学习模型,尤其是GAN,它的训练过程需要大量的计算资源和时间。
  2. 模型调优难度大:GAN类模型通常较难训练,需要精心设计的网络架构和合适的超参数设置。
  3. 可能的模式崩溃(Mode Collapse):在某些情况下,GAN可能会生成非常相似的输出,而忽略了输入数据中的多样性。

综上所述,Pix2Pix以其强大的图像转换能力和较高的通用性在图像处理领域得到了广泛的应用,但同时也面临着计算资源要求高和模型调优难度大等挑战。

华为昇思的解释还可以,这里给出链接,也有动手实践的部分
Pix2Pix实现图像转换

二、代码运行

这里有一个小点。当你想在jupyter里面运行.py文件时,使用下面的语句就可以直接运行了,它会把.py文件里面的代码直接导入当前cell里面

%load 文件的名字.py

还有一个问题是,如果导入直接运行会遇到以下报错
在这里插入图片描述

这点网上的解释大多都是jupyter不支持argparse这个库
对于这个情况,要不然转战pycharm,不过要是嫌弃现配环境太麻烦的话,就按照我找到的这个方法去试试,我反正一试就好了

在代码中找到这句

opt = parser.parse_args()

改成

opt = parser.parse_args([])

怎么说呢,这个原理我也不是很懂,不过运行即正义嘛。
原解释文章链接我给到这里,非常感谢这位知乎大佬
【报错】use %tb to see the full traceback.作者:叫我刘五环
然后根据我们卑微的算力慢慢等就可以了

2.1 models.py

  1. 定义一个函数weights_init_normal,用于初始化网络层的权重。这个函数接收一个参数m,表示网络层。
  2. 获取网络层的名称,存储在变量classname中。
  3. 如果网络层名称中包含"Conv",则对卷积层的权重进行正态分布初始化,均值为0.0,标准差为0.02。
  4. 如果网络层名称中包含"BatchNorm2d",则对批量归一化层的权重进行正态分布初始化,均值为1.0,标准差为0.02,并将偏置项初始化为0.0。
  5. 定义一个名为UNetDown的类,表示U-Net的下采样部分。这个类继承自nn.Module
  6. 定义__init__方法,接收输入通道数in_size、输出通道数out_size、是否进行归一化normalize和Dropout概率dropout作为参数。
  7. 调用父类的__init__方法。
  8. 创建一个列表layers,用于存储下采样部分的网络层。
  9. 添加一个卷积层,输入通道数为in_size,输出通道数为out_size,卷积核大小为4,步长为2,填充为1,不使用偏置。
  10. 如果需要进行归一化,添加一个实例归一化层,输入通道数为out_size
  11. 添加一个LeakyReLU激活函数,负斜率为0.2。
  12. 如果设置了Dropout概率,添加一个Dropout层,丢弃概率为dropout
  13. layers列表中的网络层组合成一个序列模型,赋值给self.model
  14. 定义forward方法,接收输入张量x作为参数。
  15. 返回经过下采样部分处理后的张量。
  16. 定义一个名为UNetUp的类,表示U-Net的上采样部分。这个类继承自nn.Module
  17. 定义__init__方法,接收输入通道数in_size、输出通道数out_size和Dropout概率dropout作为参数。
  18. 调用父类的__init__方法。
  19. 创建一个列表layers,用于存储上采样部分的网络层。
  20. 添加一个反卷积层,输入通道数为in_size,输出通道数为out_size,卷积核大小为4,步长为2,填充为1,不使用偏置。
  21. 添加一个实例归一化层,输入通道数为out_size
  22. 添加一个ReLU激活函数。
  23. 如果设置了Dropout概率,添加一个Dropout层,丢弃概率为dropout
  24. layers列表中的网络层组合成一个序列模型,赋值给self.model
  25. 定义forward方法,接收输入张量x和跳跃连接的输入张量skip_input作为参数。
  26. 对输入张量x进行上采样部分的处理。
  27. 将处理后的张量与跳跃连接的输入张量进行拼接。
  28. 返回拼接后的张量。
import torch.nn as nn
import torch.nn.functional as F
import torchdef weights_init_normal(m):classname = m.__class__.__name__if classname.find("Conv") != -1:torch.nn.init.normal_(m.weight.data, 0.0, 0.02)elif classname.find("BatchNorm2d") != -1:torch.nn.init.normal_(m.weight.data, 1.0, 0.02)torch.nn.init.constant_(m.bias.data, 0.0)##############################
#           U-NET
##############################class UNetDown(nn.Module):def __init__(self, in_size, out_size, normalize=True, dropout=0.0):super(UNetDown, self).__init__()layers = [nn.Conv2d(in_size, out_size, 4, 2, 1, bias=False)]if normalize:layers.append(nn.InstanceNorm2d(out_size))layers.append(nn.LeakyReLU(0.2))if dropout:layers.append(nn.Dropout(dropout))self.model = nn.Sequential(*layers)def forward(self, x):return self.model(x)class UNetUp(nn.Module):def __init__(self, in_size, out_size, dropout=0.0):super(UNetUp, self).__init__()layers = [nn.ConvTranspose2d(in_size, out_size, 4, 2, 1, bias=False),nn.InstanceNorm2d(out_size),nn.ReLU(inplace=True),]if dropout:layers.append(nn.Dropout(dropout))self.model = nn.Sequential(*layers)def forward(self, x, skip_input):x = self.model(x)x = torch.cat((x, skip_input), 1)return x

这段代码定义了两个类:GeneratorUNetDiscriminator,它们分别表示 U-Net 生成器和判别器。

  1. GeneratorUNet 类继承自 nn.Module,用于生成图像。它包含多个下采样(down)和上采样(up)层,以及一个最终的输出层。在 __init__
    方法中,它初始化了这些层,并在 forward 方法中实现了前向传播过程。

  2. Discriminator 类也继承自 nn.Module,用于判别图像。它包含多个判别器块,每个块包含卷积层、实例归一化层和 LeakyReLU 激活函数。在 __init__ 方法中,它初始化了这些层,并在 forward 方法中实现了前向传播过程。

  3. discriminator_block 函数是一个辅助函数,用于创建判别器块中的层。它接受输入通道数、输出通道数和一个布尔值,表示是否使用实例归一化。根据这些参数,它返回一个包含卷积层、实例归一化层和
    LeakyReLU 激活函数的列表。

  4. GeneratorUNet 类的 forward 方法首先通过下采样层处理输入图像,然后通过上采样层将特征图恢复到原始尺寸,并最终通过输出层生成生成图像。

  5. Discriminator 类的 forward 方法将两个输入图像(例如真实图像和生成图像)按通道拼接起来,然后将拼接后的图像传递给判别器模型,最后返回一个标量作为判别结果。

class GeneratorUNet(nn.Module):def __init__(self, in_channels=3, out_channels=3):super(GeneratorUNet, self).__init__()self.down1 = UNetDown(in_channels, 64, normalize=False)self.down2 = UNetDown(64, 128)self.down3 = UNetDown(128, 256)self.down4 = UNetDown(256, 512, dropout=0.5)self.down5 = UNetDown(512, 512, dropout=0.5)self.down6 = UNetDown(512, 512, dropout=0.5)self.down7 = UNetDown(512, 512, dropout=0.5)self.down8 = UNetDown(512, 512, normalize=False, dropout=0.5)self.up1 = UNetUp(512, 512, dropout=0.5)self.up2 = UNetUp(1024, 512, dropout=0.5)self.up3 = UNetUp(1024, 512, dropout=0.5)self.up4 = UNetUp(1024, 512, dropout=0.5)self.up5 = UNetUp(1024, 256)self.up6 = UNetUp(512, 128)self.up7 = UNetUp(256, 64)self.final = nn.Sequential(nn.Upsample(scale_factor=2),nn.ZeroPad2d((1, 0, 1, 0)),nn.Conv2d(128, out_channels, 4, padding=1),nn.Tanh(),)def forward(self, x):# U-Net generator with skip connections from encoder to decoderd1 = self.down1(x)d2 = self.down2(d1)d3 = self.down3(d2)d4 = self.down4(d3)d5 = self.down5(d4)d6 = self.down6(d5)d7 = self.down7(d6)d8 = self.down8(d7)u1 = self.up1(d8, d7)u2 = self.up2(u1, d6)u3 = self.up3(u2, d5)u4 = self.up4(u3, d4)u5 = self.up5(u4, d3)u6 = self.up6(u5, d2)u7 = self.up7(u6, d1)return self.final(u7)##############################
#        Discriminator
##############################class Discriminator(nn.Module):def __init__(self, in_channels=3):super(Discriminator, self).__init__()def discriminator_block(in_filters, out_filters, normalization=True):"""Returns downsampling layers of each discriminator block"""layers = [nn.Conv2d(in_filters, out_filters, 4, stride=2, padding=1)]if normalization:layers.append(nn.InstanceNorm2d(out_filters))layers.append(nn.LeakyReLU(0.2, inplace=True))return layersself.model = nn.Sequential(*discriminator_block(in_channels * 2, 64, normalization=False),*discriminator_block(64, 128),*discriminator_block(128, 256),*discriminator_block(256, 512),nn.ZeroPad2d((1, 0, 1, 0)),nn.Conv2d(512, 1, 4, padding=1, bias=False))def forward(self, img_A, img_B):# Concatenate image and condition image by channels to produce inputimg_input = torch.cat((img_A, img_B), 1)return self.model(img_input)

2. pix2pix.py

  1. import argparse:导入argparse模块,用于处理命令行参数。
  2. import time:导入time模块,用于处理时间相关的操作。
  3. import datetime:导入datetime模块,用于处理日期和时间相关的操作。
  4. import sys:导入sys模块,用于处理与Python解释器和它的环境有关的函数。
  5. import torchvision.transforms as transforms:导入torchvision.transforms模块,并将其重命名为transforms,用于图像预处理。
  6. from torchvision.utils import save_image:从torchvision.utils模块中导入save_image函数,用于保存生成的图像。
  7. from torch.utils.data import DataLoader:从torch.utils.data模块中导入DataLoader类,用于加载数据集。
  8. from torch.autograd import Variable:从torch.autograd模块中导入Variable类,用于自动求导。
  9. from models import *:从models模块中导入所有内容,通常用于导入自定义的模型结构。
  10. from datasets import *:从datasets模块中导入所有内容,通常用于导入自定义的数据集。
  11. import torch.nn as nn:导入torch.nn模块,并将其重命名为nn,用于定义神经网络模型。
  12. import torch.nn.functional as F:导入torch.nn.functional模块,并将其重命名为F,用于定义激活函数等。
  13. import torch:导入torch模块,用于实现张量计算和神经网络。
import argparse
import time
import datetime
import sysimport torchvision.transforms as transforms
from torchvision.utils import save_image
from torch.utils.data import DataLoader
from torch.autograd import Variablefrom models import *
from datasets import *import torch.nn as nn
import torch.nn.functional as F
import torch

这段代码是使用argparse模块来解析命令行参数的。
这段代码是使用argparse模块来解析命令行参数的。下面是每行代码的解释:

  1. parser = argparse.ArgumentParser():创建一个ArgumentParser对象,用于解析命令行参数。
  2. parser.add_argument("--epoch", type=int, default=0, help="epoch to start training from"):添加一个名为"–epoch"的命令行参数,类型为整数,默认值为0,帮助信息为"epoch
    to start training from"。
    以下都是类似的
    parser.parse_args([])会返回一个命名空间,其中包含所有定义的命令行选项和它们的值。然后,print(opt)将打印这个命名空间的内容。
parser = argparse.ArgumentParser()
parser.add_argument("--epoch", type=int, default=0, help="epoch to start training from")
parser.add_argument("--n_epochs", type=int, default=100, help="number of epochs of training")
parser.add_argument("--dataset_name", type=str, default="data_facades", help="name of the dataset")
parser.add_argument("--batch_size", type=int, default=1, help="size of the batches")
parser.add_argument("--lr", type=float, default=0.0002, help="adam: learning rate")
parser.add_argument("--b1", type=float, default=0.5, help="adam: decay of first order momentum of gradient")
parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of first order momentum of gradient")
parser.add_argument("--decay_epoch", type=int, default=100, help="epoch from which to start lr decay")
parser.add_argument("--n_cpu", type=int, default=8, help="number of cpu threads to use during batch generation")
parser.add_argument("--img_height", type=int, default=256, help="size of image height")
parser.add_argument("--img_width", type=int, default=256, help="size of image width")
parser.add_argument("--channels", type=int, default=3, help="number of image channels")
parser.add_argument("--sample_interval", type=int, default=500, help="interval between sampling of images from generators"
)
parser.add_argument("--checkpoint_interval", type=int, default=-1, help="interval between model checkpoints")
opt = parser.parse_args([])
print(opt)

用于生成和保存图像以及训练一个图像生成模型。

  1. os.makedirs("images/%s" % opt.dataset_name, exist_ok=True): 创建一个名为 “images/数据集名称” 的目录,如果该目录已经存在,则不会引发错误。

  2. os.makedirs("saved_models/%s" % opt.dataset_name, exist_ok=True): 创建一个名为 “saved_models/数据集名称” 的目录,如果该目录已经存在,则不会引发错误。

  3. cuda = True if torch.cuda.is_available() else False: 检查是否有可用的CUDA设备,如果有,则将变量 cuda 设置为 True,否则设置为 False

  4. criterion_GAN = torch.nn.MSELoss(): 定义一个均方误差损失函数(Mean Squared Error Loss),用于计算生成器的损失。

  5. criterion_pixelwise = torch.nn.L1Loss(): 定义一个平均绝对误差损失函数(Mean Absolute Error Loss),用于计算像素级别的损失。

  6. lambda_pixel = 100: 设置像素级别损失的权重为100。

  7. patch = (1, opt.img_height // 2 ** 4, opt.img_width // 2 ** 4): 定义一个元组 patch,表示图像判别器的输出大小。

  8. generator = GeneratorUNet(): 初始化一个生成器对象,使用 U-Net 架构。

  9. discriminator = Discriminator(): 初始化一个判别器对象。

  10. if cuda:: 如果CUDA可用,则执行以下代码块。

  11. generator = generator.cuda(): 将生成器模型移动到GPU上进行计算。

  12. discriminator = discriminator.cuda(): 将判别器模型移动到GPU上进行计算。

  13. criterion_GAN.cuda(): 将GAN损失函数移动到GPU上进行计算。

  14. criterion_pixelwise.cuda(): 将像素级别损失函数移动到GPU上进行计算。

  15. if opt.epoch != 0:: 如果指定的训练轮数不为0,则执行以下代码块。

  16. generator.load_state_dict(torch.load("saved_models/%s/generator_%d.pth" % (opt.dataset_name, opt.epoch))): 加载指定轮数的预训练生成器模型参数。

  17. discriminator.load_state_dict(torch.load("saved_models/%s/discriminator_%d.pth" % (opt.dataset_name, opt.epoch))): 加载指定轮数的预训练判别器模型参数。

  18. else:: 如果指定的训练轮数为0,则执行以下代码块。

  19. generator.apply(weights_init_normal): 对生成器的权重进行初始化。

  20. discriminator.apply(weights_init_normal): 对判别器的权重进行初始化。

os.makedirs("images/%s" % opt.dataset_name, exist_ok=True)
os.makedirs("saved_models/%s" % opt.dataset_name, exist_ok=True)cuda = True if torch.cuda.is_available() else False# Loss functions
criterion_GAN = torch.nn.MSELoss()
criterion_pixelwise = torch.nn.L1Loss()# Loss weight of L1 pixel-wise loss between translated image and real image
lambda_pixel = 100# Calculate output of image discriminator (PatchGAN)
patch = (1, opt.img_height // 2 ** 4, opt.img_width // 2 ** 4)# Initialize generator and discriminator
generator = GeneratorUNet()
discriminator = Discriminator()if cuda:generator = generator.cuda()discriminator = discriminator.cuda()criterion_GAN.cuda()criterion_pixelwise.cuda()if opt.epoch != 0:# Load pretrained modelsgenerator.load_state_dict(torch.load("saved_models/%s/generator_%d.pth" % (opt.dataset_name, opt.epoch)))discriminator.load_state_dict(torch.load("saved_models/%s/discriminator_%d.pth" % (opt.dataset_name, opt.epoch)))
else:# Initialize weightsgenerator.apply(weights_init_normal)discriminator.apply(weights_init_normal)
  1. optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2)):
    创建一个Adam优化器对象,用于更新生成器的参数。学习率由opt.lr指定,beta值由opt.b1opt.b2指定。

  2. optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2)):
    创建一个Adam优化器对象,用于更新判别器的参数。学习率由opt.lr指定,beta值由opt.b1opt.b2指定。

  3. transforms_ = [ ... ]: 定义一个图像变换列表,包括调整图像大小、转换为张量以及归一化操作。

  4. dataloader = DataLoader( ... ): 创建一个数据加载器对象,用于从指定路径加载图像数据集,并应用上述定义的图像变换。

  5. val_dataloader = DataLoader( ... ): 创建一个验证集的数据加载器对象,用于加载验证集的图像数据。

  6. Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor: 根据是否有可用的CUDA设备,选择使用GPU上的浮点张量类型或CPU上的浮点张量类型。

  7. def sample_images(batches_done):: 定义一个函数,用于保存生成的样本图像。

  8. imgs = next(iter(val_dataloader)): 从验证集的数据加载器中获取一批图像数据。

  9. real_A = Variable(imgs["B"].type(Tensor)): 将真实图像A转换为张量,并将其封装为一个变量对象。

  10. real_B = Variable(imgs["A"].type(Tensor)): 将真实图像B转换为张量,并将其封装为一个变量对象。

  11. fake_B = generator(real_A): 使用生成器模型生成虚假图像B。

  12. img_sample = torch.cat((real_A.data, fake_B.data, real_B.data), -2): 将真实图像A、生成的虚假图像B和真实图像B在通道维度上进行拼接。

  13. save_image(img_sample, "images/%s/%s.png" % (opt.dataset_name, batches_done), nrow=5, normalize=True):
    将拼接后的图像保存为PNG格式的文件,文件名包含数据集名称和批次数。

optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))# Configure dataloaders
transforms_ = [transforms.Resize((opt.img_height, opt.img_width), Image.BICUBIC),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]dataloader = DataLoader(ImageDataset("./%s" % opt.dataset_name, transforms_=transforms_),batch_size=opt.batch_size,shuffle=True,num_workers=opt.n_cpu,
)val_dataloader = DataLoader(ImageDataset("./%s" % opt.dataset_name, transforms_=transforms_, mode="val"),batch_size=10,shuffle=True,num_workers=1,
)# Tensor type
Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensordef sample_images(batches_done):"""Saves a generated sample from the validation set"""imgs = next(iter(val_dataloader))real_A = Variable(imgs["B"].type(Tensor))real_B = Variable(imgs["A"].type(Tensor))fake_B = generator(real_A)img_sample = torch.cat((real_A.data, fake_B.data, real_B.data), -2)save_image(img_sample, "images/%s/%s.png" % (opt.dataset_name, batches_done), nrow=5, normalize=True)
# ----------
#  Training
# ----------if __name__ == '__main__':prev_time = time.time()for epoch in range(opt.epoch, opt.n_epochs):for i, batch in enumerate(dataloader):# Model inputsreal_A = Variable(batch["B"].type(Tensor))real_B = Variable(batch["A"].type(Tensor))# Adversarial ground truthsvalid = Variable(Tensor(np.ones((real_A.size(0), *patch))), requires_grad=False)fake = Variable(Tensor(np.zeros((real_A.size(0), *patch))), requires_grad=False)# ------------------#  Train Generators# ------------------optimizer_G.zero_grad()# GAN lossfake_B = generator(real_A)pred_fake = discriminator(fake_B, real_A)loss_GAN = criterion_GAN(pred_fake, valid)# Pixel-wise lossloss_pixel = criterion_pixelwise(fake_B, real_B)# Total lossloss_G = loss_GAN + lambda_pixel * loss_pixelloss_G.backward()optimizer_G.step()# ---------------------#  Train Discriminator# ---------------------optimizer_D.zero_grad()# Real losspred_real = discriminator(real_B, real_A)loss_real = criterion_GAN(pred_real, valid)# Fake losspred_fake = discriminator(fake_B.detach(), real_A)loss_fake = criterion_GAN(pred_fake, fake)# Total lossloss_D = 0.5 * (loss_real + loss_fake)loss_D.backward()optimizer_D.step()# --------------#  Log Progress# --------------# Determine approximate time leftbatches_done = epoch * len(dataloader) + ibatches_left = opt.n_epochs * len(dataloader) - batches_donetime_left = datetime.timedelta(seconds=batches_left * (time.time() - prev_time))prev_time = time.time()# Print logsys.stdout.write("\r[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f, pixel: %f, adv: %f] ETA: %s"% (epoch,opt.n_epochs,i,len(dataloader),loss_D.item(),loss_G.item(),loss_pixel.item(),loss_GAN.item(),time_left,))# If at sample interval save imageif batches_done % opt.sample_interval == 0:sample_images(batches_done)if opt.checkpoint_interval != -1 and epoch % opt.checkpoint_interval == 0:# Save model checkpointstorch.save(generator.state_dict(), "saved_models/%s/generator_%d.pth" % (opt.dataset_name, epoch))torch.save(discriminator.state_dict(), "saved_models/%s/discriminator_%d.pth" % (opt.dataset_name, epoch))

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/538424.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

C语言之文件操作(万字详解)

个人主页(找往期文章包括但不限于本期文章中不懂的知识点): 我要学编程(ಥ_ಥ)-CSDN博客 目录 前言 文件的打开和关闭 流和标准流 文件指针 文件的打开和关闭 文件的顺序读写 顺序读写函数介绍 fputc的使用 fgetc的使用 fput…

压缩json字符串

GZIPOutputStream 需要关闭,而 ByteArrayOutputStream 不需要关闭。具体原因如下: GZIPOutputStream:GZIPOutputStream是一种过滤流,它提供了将数据压缩为GZIP格式的功能。当使用此类的实例写入数据时,它会对数据进行压…

Linux的一些常用指令

一、文件中 r w x - 的含义 r(read)是只读权限, w(write)是写的权限, x(execute)是可执行权限, -是没有任何权限。 二、一些指令 # 解压压缩包 tar [-zxvf] 压缩包名…

从政府工作报告探计算机行业发展(在医疗健康领域)

从政府工作报告探计算机行业发展 政府工作报告作为政府工作的全面总结和未来规划,不仅反映了国家整体的发展态势,也为各行各业提供了发展的指引和参考。随着信息技术的快速发展,计算机行业已经成为推动经济社会发展的重要引擎之一。因此&…

C++训练营:引用传递

大家好: 衷心希望各位点赞。 您的问题请留在评论区,我会及时回答。 一、引用传递 简单来说,“引用”就是给已有的变量起一个别名。引用并没有自己单独的内存空间,作为引用,它和原变量共用一段内存空间。引用的定义格…

Linux系统Docker部署Plik系统结合内网穿透实现公网访问本地文件

文章目录 1. Docker部署Plik2. 本地访问Plik3. Linux安装Cpolar4. 配置Plik公网地址5. 远程访问Plik6. 固定Plik公网地址7. 固定地址访问Plik 本文介绍如何使用Linux docker方式快速安装Plik并且结合Cpolar内网穿透工具实现远程访问,实现随时随地在任意设备上传或者…

LeetCode刷题记录:(9)从中序与后序遍历序列构造二叉树

leetcode传送通道 /*** Definition for a binary tree node.* public class TreeNode {* int val;* TreeNode left;* TreeNode right;* TreeNode() {}* TreeNode(int val) { this.val val; }* TreeNode(int val, TreeNode left, TreeNode right) {* …

力扣串题:字符串中的第二大数字

此题的精妙之处在于char类型到int类型的转化&#xff0c;需要运算来解决 int secondHighest(char * s) {int max1-1;int max2-1;int szstrlen(s);int i 0 ;for(i0;i<sz;i){if(s[i]>0&&s[i]<9){if((s[i]-0)>max1){max2max1;max1s[i]-0;}else if((s[i]-0)&l…

全栈之路-新坑就绪-星野空间

感觉自己的技术栈一直没有形成一个很好的闭环 开新坑&#xff0c;准备把自己的技术栈链路打通&#xff0c; Don‘t think too much&#xff0c; just act&#xff01;[得意]

Springboot中Redis的配置使用

新建 向pom.xml中添加依赖&#xff0c;这个可以不用标注版本号 <dependency><groupId>org.springframework.boot</groupId><artifactId>spring-boot-starter-data-redis</artifactId></dependency> 配置yml文件&#xff08;文件名不可以错…

freemarker模板引擎结合node puppeteer库实现html生成图片

效果图&#xff1a; 先看效果图&#xff0c;以下是基于freemarker模板渲染数据&#xff0c;puppeteer加载html中的js及最后图片生成&#xff1a; 背景&#xff1a; 目前为止&#xff0c;后台java根据html模板或者一个网页路径生成图片&#xff0c;都不支持flex布局及最新的c…

GPT出现Too many requests in 1 hour. Try again later.

换节点 这个就不用多说了&#xff0c;你都可以上GPT帐号了&#xff0c;哈…… 清除cooki 然后退出账号&#xff0c;重新登录即可