C# 深度学习:对抗生成网络(GAN)训练头像生成模型

news/2025/2/11 8:36:16/文章来源:https://www.cnblogs.com/whuanle/p/18708861

通过生成对抗网络(GAN)训练和生成头像

目录
  • 通过生成对抗网络(GAN)训练和生成头像
      • 说明
      • 简介
      • 什么是 GAN
      • 什么是 DCGAN
      • 参数说明
      • 数据集处理
      • 权重初始化
      • 生成器
      • 判别器
      • 损失函数和优化器
      • 训练

说明

https://torch.whuanle.cn
电子书仓库:https://github.com/whuanle/cs_pytorch
Maomi.Torch 项目仓库:https://github.com/whuanle/Maomi.Torch

本文根据 Pytorch 官方文档的示例移植而来,部分文字内容和图片来自 Pytorch 文档,文章后面不再单独列出引用说明。

官方文档地址:

https://pytorch.org/tutorials/beginner/dcgan_faces_tutorial.html

社区中文翻译版本:https://pytorch.ac.cn/tutorials/beginner/dcgan_faces_tutorial.html


Pytorch 示例项目仓库:

https://github.com/pytorch/examples

对应 Python 版本示例:https://github.com/pytorch/tutorials/blob/main/beginner_source/dcgan_faces_tutorial.py


本文项目参考 dcgan 项目:https://github.com/whuanle/Maomi.Torch/tree/main/examples/dcgan

简介

本教程将通过一个示例介绍生成对抗网络(DCGAN),在教程中,我们将训练一个生成对抗网络 (GAN) 模型来生成新的名人头像。这里的大部分代码来自 pytorch/examples 中的 DCGAN 实现,然后笔者通过 C# 移植了代码实现,本文档将对该实现进行详尽的解释,并阐明该模型的工作原理和原因,阅读本文不需要 GAN 的基础知识,原理部分比较难理解,不用将精力放在这上面,主要是根据代码思路走一遍即可。


生成式对抗网络,简单来说就像笔者喜欢摄影,但是摄影水平跟专业摄影师有差距,然后不断苦练技术,每拍一张照片就让朋友判断是笔者拍的还是专业摄影师拍的,如果朋友一眼就发现是我拍的,说明水平还不行。然后一直练,一直拍,直到朋友区分不出照片是笔者拍的,还是专业摄影师拍的,这就是生成式对抗网络。

设计生成式对抗网络,需要设计生成网络和判断网络,生成网络读取训练图片并训练转换生成输出结果,然后由判断器识别,检查生成的图片和训练图片的差异,如果判断器可以区分出生成的图片和训练图片的差异,说明还需要继续训练,直到判断器区分不出来。

什么是 GAN

GANs 是一种教深度学习模型捕捉训练数据分布的框架,这样我们可以从相同的分布生成新的数据。GANs 由 Ian Goodfellow 于 2014 年发明,并首次在论文 Generative Adversarial Nets 中描述。它们由两个不同的模型组成,一个是生成器,另一个是判别器。生成器的任务是生成看起来像训练图像的“假”图像。判别器的任务是查看图像,并输出它是否是真实训练图像或来自生成器的假图像。在训练期间,生成器不断尝试通过生成越来越好的假图像来欺骗判别器,而判别器则努力成为一名更好的侦探,正确分类真实图像和假图像。这场博弈的平衡点是生成器生成完美的假图像,看起来似乎直接来自训练数据,而判别器总是以 50% 的置信度猜测生成器的输出是真实的还是假的。

现在,让我们定义一些将在整个教程中使用的符号,从判别器开始。设 \(x\) 为表示图像的数据。\(D(x)\) 是判别器网络,输出 \(x\) 来自训练数据而不是生成器的(标量)概率。这里,由于我们处理的是图像,\(D(x)\) 的输入是 CHW 尺寸为 3x64x64 的图像。直观上,当 \(x\) 来自训练数据时,\(D(x)\) 应该是高的,而当 \(x\) 来自生成器时,\(D(x)\) 应该是低的。\(D(x)\) 也可以视为传统的二分类器。

对于生成器的符号,设 \(z\) 为从标准正态分布中采样的潜在空间向量。\(G(z)\) 表示将潜在向量 \(z\) 映射到数据空间的生成器函数。\(G\) 的目标是估计训练数据来自的分布 (\(p_{data}\)),以便从该估计分布中生成假样本 (\(p_g\))。

因此,\(D(G(z))\) 是生成器输出 \(G\) 为真实图像的概率(标量)。如 Goodfellow 的论文 中所描述,\(D\)\(G\) 进行一个极小极大博弈,其中 \(D\) 尽量最大化它正确分类真实和假的概率 (\(logD(x)\)),而 \(G\) 尽量最小化 \(D\) 预测其输出为假的概率 (\(log(1-D(G(z)))\))。在这篇论文中,GAN 损失函数为

\[\underset{G}{\text{min}} \underset{D}{\text{max}}V(D,G) = \mathbb{E}_{x\sim p_{data}(x)}\big[logD(x)\big] + \mathbb{E}_{z\sim p_{z}(z)}\big[log(1-D(G(z))\big] \]

理论上,这个极小极大博弈的解是 \(p_g = p_{data}\),而判别器随机猜测输入是真实的还是假的。然而,GANs 的收敛理论仍在积极研究中,实际上模型并不总是能够训练到这一点。

什么是 DCGAN

DCGAN 是上述 GAN 的直接扩展,不同之处在于它在判别器和生成器中明确使用了卷积层和反卷积层。Radford 等人在论文《利用深度卷积生成对抗网络进行无监督表示学习》中首次描述了这种方法。判别器由步幅卷积层、批量归一化层以及LeakyReLU激活函数组成。输入是一个 3x64x64 的输入图像,输出是一个标量概率,表示输入是否来自真实的数据分布。生成器由反卷积层、批量归一化层和ReLU激活函数组成。输入是从标准正态分布中抽取的潜在向量 \(z\),输出是一个 3x64x64 的 RGB 图像。步幅的反卷积层允许将潜在向量转换为具有与图像相同形状的体积。在论文中,作者还提供了一些如何设置优化器、如何计算损失函数以及如何初始化模型权重的建议,这些将在后续章节中解释。


然后引入依赖并配置训练参数:

using dcgan;
using Maomi.Torch;
using System.Diagnostics;
using TorchSharp;
using TorchSharp.Modules;
using static TorchSharp.torch;// 使用 GPU 启动
Device defaultDevice = MM.GetOpTimalDevice();
torch.set_default_device(defaultDevice);// Set random seed for reproducibility
var manualSeed = 999;// manualSeed = random.randint(1, 10000) # use if you want new results
Console.WriteLine("Random Seed:" + manualSeed);
random.manual_seed(manualSeed);
torch.manual_seed(manualSeed);Options options = new Options()
{Dataroot = "E:\\datasets\\celeba",// 设置这个可以并发加载数据集,加快训练速度Workers = 10,BatchSize = 128,
};

稍后讲解如何下载图片数据集。


用于训练的人像图片数据集大概是 22万张,不可能一次性全部加载,所以需要设置 BatchSize 参数分批导入、分批训练,如果读者的 GPU 性能比较高,则可以设置大一些。

参数说明

前面提到了 Options 模型类定义训练模型的参数,下面给出每个参数的详细说明。

注意字段名称略有差异,并且移植版本并不是所有参数都用上。

  • dataroot - 数据集文件夹根目录的路径。我们将在下一节中详细讨论数据集。
  • workers - 用于使用 DataLoader 加载数据的工作线程数。
  • batch_size - 训练中使用的批大小。DCGAN 论文使用 128 的批大小。
  • image_size - 用于训练的图像的空间大小。此实现默认为 64x64。如果需要其他大小,则必须更改 D 和 G 的结构。有关更多详细信息,请参阅 此处。
  • nc - 输入图像中的颜色通道数。对于彩色图像,此值为 3。
  • nz - 潜在向量的长度。
  • ngf - 与通过生成器传递的特征图的深度有关。
  • ndf - 设置通过判别器传播的特征图的深度。
  • num_epochs - 要运行的训练 epoch 数。训练时间越长可能会带来更好的结果,但也会花费更长的时间。
  • lr - 训练的学习率。如 DCGAN 论文中所述,此数字应为 0.0002。
  • beta1 - Adam 优化器的 beta1 超参数。如论文中所述,此数字应为 0.5。
  • ngpu - 可用的 GPU 数量。如果此值为 0,则代码将在 CPU 模式下运行。如果此数字大于 0,则它将在那几个 GPU 上运行。

首先定义一个全局参数模型类,并设置默认值:

public class Options
{/// <summary>/// Root directory for dataset/// </summary>public string Dataroot { get; set; } = "data/celeba";/// <summary>/// Number of workers for dataloader/// </summary>public int Workers { get; set; } = 2;/// <summary>/// Batch size during training/// </summary>public int BatchSize { get; set; } = 128;/// <summary>/// Spatial size of training images. All images will be resized to this size using a transformer./// </summary>public int ImageSize { get; set; } = 64;/// <summary>/// Number of channels in the training images. For color images this is 3/// </summary>public int Nc { get; set; } = 3;/// <summary>/// Size of z latent vector (i.e. size of generator input)/// </summary>public int Nz { get; set; } = 100;/// <summary>/// Size of feature maps in generator/// </summary>public int Ngf { get; set; } = 64;/// <summary>/// Size of feature maps in discriminator/// </summary>public int Ndf { get; set; } = 64;/// <summary>/// Number of training epochs/// </summary>public int NumEpochs { get; set; } = 5;/// <summary>/// Learning rate for optimizers/// </summary>public double Lr { get; set; } = 0.0002;/// <summary>/// Beta1 hyperparameter for Adam optimizers/// </summary>public double Beta1 { get; set; } = 0.5;/// <summary>/// Number of GPUs available. Use 0 for CPU mode./// </summary>public int Ngpu { get; set; } = 1;
}

数据集处理

本教程中,我们将使用 Celeb-A Faces 数据集 来训练模型,可以从链接网站或在 Google Drive 下载。

数据集官方地址:https://mmlab.ie.cuhk.edu.hk/projects/CelebA.html

可以通过 Google 网盘或百度网盘下载:

https://drive.google.com/drive/folders/0B7EVK8r0v71pWEZsZE9oNnFzTm8?resourcekey=0-5BR16BdXnb8hVj6CNHKzLg&usp=sharing

https://pan.baidu.com/s/1CRxxhoQ97A5qbsKO7iaAJg

提取码:rp0s


注意,本文只需要用到图片,不需要用到标签,不用下载所有文件,只需要下载 CelebA/Img/img_align_celeba.zip 即可。下载后解压到一个空目录中,其目录结构示例:

/path/to/celeba-> img_align_celeba  -> 188242.jpg-> 173822.jpg-> 284702.jpg-> 537394.jpg...

然后在 Options.Dataroot 参数填写 /path/to/celeba 即可,导入数据集时会自动搜索该目录下的子目录,将子目录作为图像的分类名称,然后向子目录加载所有图像文件。


这是一个重要步骤,因为我们将使用 ImageFolder 数据集类,该类要求数据集根文件夹中有子目录。现在,我们可以创建数据集,创建数据加载器,设置运行设备,并最终可视化一些训练数据。


// 创建一个 samples 目录用于输出训练过程中产生的输出效果
if(Directory.Exists("samples"))
{Directory.Delete("samples", true);
}Directory.CreateDirectory("samples");// 加载图像并对图像做转换处理
var dataset = MM.Datasets.ImageFolder(options.Dataroot, torchvision.transforms.Compose(torchvision.transforms.Resize(options.ImageSize),torchvision.transforms.CenterCrop(options.ImageSize),torchvision.transforms.ConvertImageDtype(ScalarType.Float32),torchvision.transforms.Normalize(new double[] { 0.5, 0.5, 0.5 }, new double[] { 0.5, 0.5, 0.5 }))
);// 分批加载图像
var dataloader = torch.utils.data.DataLoader(dataset, batchSize: options.BatchSize, shuffle: true, num_worker: options.Workers, device: defaultDevice);var netG = new dcgan.Generator(options).to(defaultDevice);

在设置好输入参数并准备好数据集后,我们现在可以进入实现部分。我们将从权重初始化策略开始,然后详细讨论生成器、判别器、损失函数和训练循环。


权重初始化

根据 DCGAN 论文,作者指出所有模型权重应从均值为 0,标准差为 0.02 的正态分布中随机初始化。weights_init 函数以已初始化的模型为输入,重新初始化所有卷积层、转置卷积层和批量归一化层以满足此标准。此函数在模型初始化后立即应用于模型。


static void weights_init(nn.Module m)
{var classname = m.GetType().Name;if (classname.Contains("Conv")){if (m is Conv2d conv2d){nn.init.normal_(conv2d.weight, 0.0, 0.02);}}else if (classname.Contains("BatchNorm")){if (m is BatchNorm2d batchNorm2d){nn.init.normal_(batchNorm2d.weight, 1.0, 0.02);nn.init.zeros_(batchNorm2d.bias);}}
}

网络模型会有多层结构,模型训练时到不同的层时会自动调用 weights_init 函数初始化,作用对象不是模型本身,而是网络模型的层。

1739107119297

生成器

生成器 \(G\) 旨在将潜在空间向量 ( \(z\) ) 映射到数据空间。由于我们的数据是图像,将 \(z\) 转换为数据空间意味着最终要创建一个与训练图像具有相同大小的 RGB 图像 (即 3x64x64)。在实践中,这是通过一系列步幅为二维的卷积转置层来实现的,每一层都配有一个 2d 批量规范化层和一个 relu 激活函数。生成器的输出通过一个 tanh 函数返回到输入数据范围 \([-1,1]\) 。值得注意的是在 conv-transpose 层之后存在批量规范化函数,因为这是 DCGAN 论文的重要贡献之一。这些层有助于训练期间梯度的流动。下图显示了 DCGAN 论文中的生成器。

dcgan_generator


请注意,我们在输入部分设置的输入(nzngf,和 nc)如何影响代码中生成器的架构。nz 是 z 输入向量的长度,ngf 与在生成器中传播的特征图的大小有关,而 nc 是输出图像中的通道数(对于 RGB 图像设置为 3)。下面是生成器的代码。


定义图像生成的网络模型:

public class Generator : nn.Module<Tensor, Tensor>, IDisposable
{private readonly Options _options;public Generator(Options options) : base(nameof(Generator)){_options = options;main = nn.Sequential(// input is Z, going into a convolutionnn.ConvTranspose2d(options.Nz, options.Ngf * 8, 4, 1, 0, bias: false),nn.BatchNorm2d(options.Ngf * 8),nn.ReLU(true),// state size. (ngf*8) x 4 x 4nn.ConvTranspose2d(options.Ngf * 8, options.Ngf * 4, 4, 2, 1, bias: false),nn.BatchNorm2d(options.Ngf * 4),nn.ReLU(true),// state size. (ngf*4) x 8 x 8nn.ConvTranspose2d(options.Ngf * 4, options.Ngf * 2, 4, 2, 1, bias: false),nn.BatchNorm2d(options.Ngf * 2),nn.ReLU(true),// state size. (ngf*2) x 16 x 16nn.ConvTranspose2d(options.Ngf * 2, options.Ngf, 4, 2, 1, bias: false),nn.BatchNorm2d(options.Ngf),nn.ReLU(true),// state size. (ngf) x 32 x 32nn.ConvTranspose2d(options.Ngf, options.Nc, 4, 2, 1, bias: false),nn.Tanh()// state size. (nc) x 64 x 64);RegisterComponents();}public override Tensor forward(Tensor input){return main.call(input);}Sequential main;
}

初始化模型:

var netG = new dcgan.Generator(options).to(defaultDevice);
netG.apply(weights_init);
Console.WriteLine(netG);

判别器

如前所述,判别器 \(D\) 是一个二分类网络,它以图像为输入并输出一个标量概率,即输入图像是真实的(而非伪造的)的概率。这里,\(D\) 接受一个 3x64x64 的输入图像,通过一系列的 Conv2d、BatchNorm2d 和 LeakyReLU 层进行处理,并通过 Sigmoid 激活函数输出最终的概率。根据问题的需要,可以扩展这一架构以包含更多层数,但使用跨步卷积、BatchNorm 和 LeakyReLUs 是有意义的。DCGAN 论文提到,使用跨步卷积而非池化来进行下采样是一个好习惯,因为它使网络能够学习其自己的池化函数。此外,批量规范化和 leaky relu 函数促进了健康的梯度流动,这对 \(G\)\(D\) 的学习过程至关重要。


定义判别器网络模型:

public class Discriminator : nn.Module<Tensor, Tensor>, IDisposable
{private readonly Options _options;public Discriminator(Options options) : base(nameof(Discriminator)){_options = options;main = nn.Sequential(// input is (nc) x 64 x 64nn.Conv2d(options.Nc, options.Ndf, 4, 2, 1, bias: false),nn.LeakyReLU(0.2, inplace: true),// state size. (ndf) x 32 x 32nn.Conv2d(options.Ndf, options.Ndf * 2, 4, 2, 1, bias: false),nn.BatchNorm2d(options.Ndf * 2),nn.LeakyReLU(0.2, inplace: true),// state size. (ndf*2) x 16 x 16nn.Conv2d(options.Ndf * 2, options.Ndf * 4, 4, 2, 1, bias: false),nn.BatchNorm2d(options.Ndf * 4),nn.LeakyReLU(0.2, inplace: true),// state size. (ndf*4) x 8 x 8nn.Conv2d(options.Ndf * 4, options.Ndf * 8, 4, 2, 1, bias: false),nn.BatchNorm2d(options.Ndf * 8),nn.LeakyReLU(0.2, inplace: true),// state size. (ndf*8) x 4 x 4nn.Conv2d(options.Ndf * 8, 1, 4, 1, 0, bias: false),nn.Sigmoid());RegisterComponents();}public override Tensor forward(Tensor input){var output =  main.call(input);return output.view(-1, 1).squeeze(1);}Sequential main;
}

初始化模型:

var netD = new dcgan.Discriminator(options).to(defaultDevice);
netD.apply(weights_init);
Console.WriteLine(netD);

损失函数和优化器

设置好 \(D\)\(G\) 后,我们可以通过损失函数和优化器指定它们的学习方式。我们将使用二元交叉熵损失函数(BCELoss),它在 PyTorch 中定义如下:

\[\ell(x, y) = L = \{l_1,\dots,l_N\}^\top, \quad l_n = - \left[ y_n \cdot \log x_n + (1 - y_n) \cdot \log (1 - x_n) \right] \]


请注意,这个函数提供了目标函数中两个对数分量,即 \(log(D(x))\)\(log(1-D(G(z)))\) 的计算。我们可以通过 \(y\) 输入来指定使用 BCE 方程的哪一部分。这将在即将到来的训练循环中完成,但是了解我们可以通过改变 \(y\)(即 GT 标签)选择希望计算的分量非常重要。

接下来,我们将真实标签定义为 1,假的标签定义为 0。这些标签将在计算 \(D\)\(G\) 的损失时使用,这也是原始 GAN 论文中使用的约定。最后,我们设置两个独立的优化器,一个用于 \(D\),另一个用于 \(G\)。根据 DCGAN 论文的规定,两者都是 Adam 优化器,学习率为 0.0002,Beta1 = 0.5。为了追踪生成器的学习进展,我们将生成一个从高斯分布中抽取的固定批次的潜在向量(即 fixed_noise)。在训练循环中,我们将定期将这个 fixed_noise 输入 \(G\),并且在迭代过程中,我们将看到图像从噪声中形成。


var criterion = nn.BCELoss();
var fixed_noise = torch.randn(new long[] { options.BatchSize, options.Nz, 1, 1 }, device: defaultDevice);
var real_label = 1.0;
var fake_label = 0.0;
var optimizerD = torch.optim.Adam(netD.parameters(), lr: options.Lr, beta1: options.Beta1, beta2: 0.999);
var optimizerG = torch.optim.Adam(netG.parameters(), lr: options.Lr, beta1: options.Beta1, beta2: 0.999);

训练

最后,在我们定义了GAN框架的所有部分之后,我们可以开始训练它了。请注意,训练GANs在某种程度上是一门艺术,因为不正确的超参数设置会导致模式崩溃,并且很难解释出了什么问题。在这里,我们将紧密遵循 Goodfellow的论文 中的算法1,同时遵循一些在ganhacks中显示的最佳实践。具体来说,我们将“为真实和虚假的图像构建不同的小批量”,并调整G的目标函数以最大化 \(log(D(G(z)))\) 。训练分为两个主要部分:第一部分更新判别器,第二部分更新生成器。

第1部分 - 训练判别器

回顾一下,训练判别器的目标是最大化正确分类给定输入为真实或虚假的概率。根据Goodfellow的说法,我们希望“通过上升随机梯度来更新判别器”。实际上,我们希望最大化 \(log(D(x)) + log(1-D(G(z)))\) 。根据 ganhacks 的独立小批量建议,我们将分两步计算这一点。首先,我们将从训练集中构建一个真实样本的小批量,前向传递通过 \(D\),计算损失 ( \(log(D(x))\) ) ,然后反向传递计算梯度。其次,我们将使用当前的生成器构建一个虚假样本的小批量,将此批次前向传递通过 \(D\),计算损失 ( \(log(1-D(G(z)))\) ),并通过反向传递累积梯度。现在,随着从全真和全假批次累积的梯度,我们调用判别器优化器的一步。

第2部分 - 训练生成器

如原论文所述,我们希望通过最小化 \(log(1-D(G(z)))\) 来训练生成器,以便生成更好的虚假样本。如前所述,Goodfellow 显示这在学习过程中尤其是早期不会提供足够的梯度。作为解决方案,我们希望最大化 \(log(D(G(z)))\) 。在代码中,我们通过以下方法实现这一点:使用判别器对第1部分生成器的输出进行分类,使用真实标签作为GT计算G的损失,在反向传递中计算G的梯度,最后用优化器一步更新G的参数。使用真实标签作为损失函数的GT标签可能看起来违反直觉,但这允许我们使用 BCELoss\(log(x)\) 部分(而不是 \(log(1-x)\) 部分),这正是我们所需要的。

最后,我们将进行一些统计报告,并且在每个epoch结束时,我们将通过生成器推送我们的固定噪声批次,以便直观地跟踪G的训练进度。报告的训练统计数据包括:

  • Loss_D - 判别器损失,计算为全真和全假批次损失的总和 ( \(log(D(x)) + log(1 - D(G(z)))\) )。
  • Loss_G - 生成器损失,计算为 \(log(D(G(z)))\)
  • D(x) - 判别器对全真批次的平均输出(跨批次)。这应该从接近 1 开始,然后在G变好时理论上收敛到 0.5。想想这是为什么。
  • D(G(z)) - 判别器对全假批次的平均输出。第一个数字是 D 更新之前的,第二个数字是 D 更新之后的。这些数字应该从接近0开始,并在 G 变好时收敛到 0.5。想想这是为什么。

注意: 这一步可能需要一段时间,具体取决于你运行了多少个epochs以及是否从数据集中删除了一些数据。


var img_list = new List<Tensor>();
var G_losses = new List<double>();
var D_losses = new List<double>();Console.WriteLine("Starting Training Loop...");Stopwatch stopwatch = new();
stopwatch.Start();
int i = 0;
// For each epoch
for (int epoch = 0; epoch < options.NumEpochs; epoch++)
{foreach (var item in dataloader){var data = item[0];netD.zero_grad();// Format batchvar real_cpu = data.to(defaultDevice);var b_size = real_cpu.size(0);var label = torch.full(new long[] { b_size }, real_label, dtype: ScalarType.Float32, device: defaultDevice);// Forward pass real batch through Dvar output = netD.forward(real_cpu);// Calculate loss on all-real batchvar errD_real = criterion.call(output, label);// Calculate gradients for D in backward passerrD_real.backward();var D_x = output.mean().item<float>();// Train with all-fake batch// Generate batch of latent vectorsvar noise = torch.randn(new long[] { b_size, options.Nz, 1, 1 }, device: defaultDevice);// Generate fake image batch with Gvar fake = netG.call(noise);label.fill_(fake_label);// Classify all fake batch with Doutput = netD.call(fake.detach());// Calculate D's loss on the all-fake batchvar errD_fake = criterion.call(output, label);// Calculate the gradients for this batch, accumulated (summed) with previous gradientserrD_fake.backward();var D_G_z1 = output.mean().item<float>();// Compute error of D as sum over the fake and the real batchesvar errD = errD_real + errD_fake;// Update DoptimizerD.step();////////////////////////////// (2) Update G network: maximize log(D(G(z)))////////////////////////////netG.zero_grad();label.fill_(real_label);  // fake labels are real for generator cost// Since we just updated D, perform another forward pass of all-fake batch through Doutput = netD.call(fake);// Calculate G's loss based on this outputvar errG = criterion.call(output, label);// Calculate gradients for GerrG.backward();var D_G_z2 = output.mean().item<float>();// Update GoptimizerG.step();// ex: [0/25][4/3166] Loss_D: 0.5676 Loss_G: 7.5972 D(x): 0.9131 D(G(z)): 0.3024 / 0.0007Console.WriteLine($"[{epoch}/{options.NumEpochs}][{i%dataloader.Count}/{dataloader.Count}] Loss_D: {errD.item<float>():F4} Loss_G: {errG.item<float>():F4} D(x): {D_x:F4} D(G(z)): {D_G_z1:F4} / {D_G_z2:F4}");// 每处理 100 批,输出一次图片效果if (i % 100 == 0){real_cpu.SaveJpeg("samples/real_samples.jpg");fake = netG.call(fixed_noise);fake.detach().SaveJpeg("samples/fake_samples_epoch_{epoch:D3}.jpg");}i++;}netG.save("samples/netg_{epoch}.dat");netD.save("samples/netd_{epoch}.dat");
}

最后打印训练结果和输出:

Console.WriteLine("Training finished.");
stopwatch.Stop();
Console.WriteLine("Training Time: {stopwatch.Elapsed}");netG.save("samples/netg.dat");
netD.save("samples/netd.dat");

按照官方示例推荐进行 25 轮训练,由于笔者使用使用 4060TI 8G 机器训练,训练 25 轮大概时间:

Training finished.
Training Time: 00:49:45.6976041

每轮训练结果的图像:

image-20250210183009016


第一轮训练生成:

image-20250210205818778


第 25 轮生成的:

image-20250210214910109


虽然还是有些抽象,但生成结果比之前好一些了。

在 dcgan_out 项目中开业看到,使用 5 轮训练结果输出的模型,生成图像:

Device defaultDevice = MM.GetOpTimalDevice();
torch.set_default_device(defaultDevice);// Set random seed for reproducibility
var manualSeed = 999;// manualSeed = random.randint(1, 10000) # use if you want new results
Console.WriteLine("Random Seed:" + manualSeed);
random.manual_seed(manualSeed);
torch.manual_seed(manualSeed);Options options = new Options()
{Dataroot = "E:\\datasets\\celeba",Workers = 10,BatchSize = 128,
};var netG = new dcgan.Generator(options);
netG.to(defaultDevice);
netG.load("netg.dat");// 生成随机噪声
var fixed_noise = torch.randn(64, options.Nz, 1, 1, device: defaultDevice);// 生成图像
var fake_images = netG.call(fixed_noise);fake_images.SaveJpeg("fake_images.jpg");


虽然还是有些抽象,但确实还行。

fake_images

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/881954.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

开源的 DeepSeek-R1「GitHub 热点速览」

春节假期回来,一睁眼全是王炸级的开源模型 DeepSeek-R1!GitHub 地址→github.com/deepseek-ai/DeepSeek-R1DeepSeek-R1 开源还不到一个月,Star 数就飙升至冲破天际的 70k。虽然目前仅开源了模型权重,但同时发布的技术论文详细地介绍了 DeepSeek-R1 所采用的训练技术,如模型…

C#/.NET/.NET Core优秀项目和框架2025年1月简报

前言 公众号每月定期推广和分享的C#/.NET/.NET Core优秀项目和框架(每周至少会推荐两个优秀的项目和框架当然节假日除外),公众号推文中有项目和框架的详细介绍、功能特点、使用方式以及部分功能截图等(打不开或者打开GitHub很慢的同学可以优先查看公众号推文,文末一定会附…

Palo Alto Cortex XSOAR 6.13 for Linux - 安全编排、自动化和响应 (SOAR) 平台

Palo Alto Cortex XSOAR 6.13 for Linux - 安全编排、自动化和响应 (SOAR) 平台Palo Alto Cortex XSOAR 6.13 for Linux - 安全编排、自动化和响应 (SOAR) 平台 Security Orchestration, Automation and Response (SOAR) platform 请访问原文链接:https://sysin.org/blog/cort…

2025年01月总结及随笔之年前撞车

2025年01月总结及随笔之年前撞车1. 回头看 日更坚持了762天。读《数据保护:工作负载的可恢复性》更新完成 读《量子霸权》开更并更新完成 读《算法简史:从美索不达米亚到人工智能时代》开更并持续更新2023年至2025年01月底累计码字1936092字,累计日均码字2540字。 2025年01月…

[网络] 跨域问题及解决方案

同源策略及跨域问题 同源策略是一套浏览器安全机制,当一个源的文档和脚本,与另一个源的资源进行通信时,同源策略就会对这个通信做出不同程度的限制。 简单来说,同源策略对 同源资源 放行,对 异源资源 限制 因此限制造成的开发问题,称之为跨域(异源)问题 同源和异源 源(…

macos安装三方windows

mac需时intel芯 m芯片貌似不支持windows系统,甚至虚拟机的方式都支持不友好 准备好鼠标 因为在windows安装界面或第一次打开的系统时,此时时没有注入驱动的,macbook笔记本自带的键盘触摸板将都不可用! 格式必须是iso 如果你下载到的三方windows系统是esd 或者其他格式,则必…

[大模型/AI/GPT] Chatbox:大模型可视化终端应用

序 概述:Chatbox AIChatbox AI 是一款 AI 客户端应用和智能助手,支持众多先进的 AI 模型和 API,可在 Windows、MacOS、Android、iOS、Linux 和网页版上使用。User-friendly Desktop Client App for AI Models/LLMs (GPT, Claude, Gemini, Ollama...)https://github.com/Bin-…

【AI安全】大模型越狱探索

本文皆在探讨大模型越狱攻击手法,能实操落地非学术化的,所有案例用于技术分享交流,在后文中尽量会用最精简的语言来讲解 开篇点题:越狱追溯于早期 IOS,用户为了突破设备的封闭生态系统,自由操作自己的IOS,不被限制,而在大模型中,越狱同理,规避大模型的限制,执行那些…

【洛谷P1229】遍历问题

这道题好巧 遍历问题 题目描述 我们都很熟悉二叉树的前序、中序、后序遍历,在数据结构中常提出这样的问题:已知一棵二叉树的前序和中序遍历,求它的后序遍历,相应的,已知一棵二叉树的后序遍历和中序遍历序列你也能求出它的前序遍历。然而给定一棵二叉树的前序和后序遍历,你…

第二课 经济金融案例实战

目录导入数据并观察合并数据提取出标签并对标签进行处理合并训练集和测试集变量转化正确化变量属性对分类型特征进行独热编码填写数值型特征的缺失值标准化数值型特征建立模型分出训练集和测试集集成提交结果 导入数据并观察 合并数据这里可能有个问题。我们说不要让模型提前见…

STM32学习笔记【电赛历险记嵌入式学习心得】

关于STM32F103C8T6的学习笔记,除基础介绍外,包含标准库与HAL库,涉及蓝牙、电机、超声波、红外等模块,涉及GPIO、中断、定时器、IC输入捕获、ADC、DMA等基础模块,含有CubeMX学习,C语言预编译知识,含有推荐学习项目链接。前言 此篇随笔是博主在打电赛(全国大学生电子设计…

【K8S安全】浅析K8S各种未授权攻击方法

免责声明: 本篇文章仅用于技术交流,请勿利用文章内的相关技术从事非法测试,由于传播、利用本文所提供的信息而造成的任何直接或者间接的后果及损失,均由使用者本人负责,本文作者不为此承担任何责任,一旦造成后果请自行承担!如有侵权烦请告知,我们会立即删除并致歉。谢谢…