使用Pytorch从零开始构建DCGAN

在本文中,我们将深入研究生成建模的世界,并使用流行的 PyTorch 框架探索 DCGAN(生成对抗网络 (GAN) 的一种变体)的实现。具体来说,我们将使用 CelebA 数据集(名人面部图像的集合)来生成逼真的合成面部。在深入了解实现细节之前,我们首先了解 GAN 是什么以及 DCGAN 与它们有何不同,并详细探讨其架构。
在这里插入图片描述
那么,什么是 GAN?

简而言之,生成对抗网络(GAN)是一种令人着迷的机器学习模型,其中两个玩家(生成器和鉴别器)参与竞争性游戏。生成器的目的是创建真实的合成样本,例如图像或文本,而鉴别器的工作是区分真实样本和假样本。通过这场猫鼠游戏,GAN 学会生成高度令人信服且真实的输出,最终突破人工智能在创建新的多样化数据方面的界限。

在这里插入图片描述
理解DCGAN:

DCGAN(深度卷积生成对抗网络)是一种令人兴奋的机器学习模型,可以创建极其逼真和详细的图像。想象一下,一个系统可以通过分析数千个示例来学习生成全新的图片,例如面孔或风景。DCGAN 通过巧妙地结合两个网络来实现这一目标——一个网络生成图像,另一个网络试图区分真假图像。通过竞争过程,DCGAN 成为生成令人信服且难以与真实图像区分开的图像的大师,展示了人工智能的巨大创造潜力。

DCGAN的架构:

DCGAN(深度卷积生成对抗网络)的架构​​由两个基本组件组成:生成器和判别器。

生成器将随机噪声作为输入,并逐渐将其转换为类似于训练数据的合成样本。它通过一系列层来实现这一点,包括转置卷积、批量归一化和激活函数。这些层使生成器能够学习复杂的模式和结构,从而生成捕获真实数据的复杂细节的高维样本。

另一方面,鉴别器充当二元分类器,区分真实样本和生成样本。它接收输入样本并将其传递给卷积层、批量归一化和激活函数。鉴别器的作用是评估样本的真实性并向生成器提供反馈。通过对抗性训练过程,生成器和鉴别器不断竞争并提高性能,最终生成越来越真实的样本。

让我们开始实现我们的第一个 DCGAN:

假设您的环境中安装了 PyTorch 和 CUDA,我们首先导入必要的库。

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from tqdm import tqdm
import torchvision.datasets as datasets
from torchvision.datasets import ImageFolder
from torchvision.utils import make_grid
import torchvision.utils as vutils
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from torch.utils.data import Subset
import numpy as np

导入这些库后,您就可以使用 CelebA 数据集在 PyTorch 中继续实施 DCGAN。

为了设置训练 DCGAN 模型所需的配置,我们可以定义设备、学习率、批量大小、图像大小、输入图像中的通道数、潜在空间维度、历元数以及特征数判别器和生成器。以下是配置:

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
LEARNING_RATE = 2e-4  # could also use two lrs, one for gen and one for disc
BATCH_SIZE = 128
IMAGE_SIZE = 64
CHANNELS_IMG = 3
Z_DIM = 100
NUM_EPOCHS = 5
FEATURES_DISC = 64
FEATURES_GEN = 64

“device”变量确保代码在可用的 GPU 上运行,否则返回到 CPU。“LEARNING_RATE”决定了模型在优化过程中学习的速率。“BATCH_SIZE”表示每次迭代中处理的样本数量。“IMAGE_SIZE”表示输入图像的所需尺寸。“CHANNELS_IMG”指定输入图像中颜色通道的数量(例如,RGB 为 3)。“Z_DIM”表示潜在空间的维度,它是生成器的输入。“NUM_EPOCHS”决定了训练期间遍历整个数据集的次数。“FEATURES_DISC”和“FEATURES_GEN”分别表示鉴别器和生成器网络中的特征数量。

这些配置对于训练 DCGAN 模型至关重要,可以根据具体要求和可用资源进行调整。

要加载 CelebA 数据集并准备进行训练,我们可以定义数据转换、创建数据集对象并设置数据加载器:

transform = transforms.Compose([transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])dataset = datasets.ImageFolder('<path_to_celeba_dataset_in_your_directoty>', transform=transform)
dataloader = DataLoader(dataset=dataset, batch_size=BATCH_SIZE, shuffle=True)

要可视化 CelebA 数据集中的一批训练图像:

real_batch = next(iter(dataloader))
plt.figure(figsize=(7,7))
plt.axis("off")
plt.title("Training Images")
plt.imshow(np.transpose(vutils.make_grid(real_batch[0].to(device)[:49], padding=2, normalize=True).cpu(),(1,2,0)))

在这里插入图片描述
创建生成器网络:
在这里插入图片描述
在 DCGAN 架构中,噪声数据最初表示为形状为 100x1x1 的张量,经过一系列转置卷积运算,将其转换为大小为 3x64x64 的输出图像。

以下是重塑过程的逐步分解:

  1. 输入噪声:100x1x1
  2. 第一个转置卷积层:
    — 输出大小:1024x4x4
    — 内核大小:4x4,步长:1,填充:0
  3. 第二个转置卷积层:
    — 输出大小:512x8x8
    — 内核大小:4x4,步长:2,填充:1
  4. 第三转置卷积层:
    — 输出大小:256x16x16
    — 内核大小:4x4,步长:2,填充:1 5.
  5. 第四转置卷积层:
    — 输出大小:128x32x32
    — 内核大小:4x4,步长:2,填充:1
  6. 最终转置卷积层:
    — 输出大小:3x64x64
    — 内核大小:4x4,步幅:2,填充:1

通过将噪声数据传递到这些转置卷积层,生成器逐渐将低维噪声放大为与所需大小 3x64x64 匹配的高维图像。重塑过程涉及增加空间维度,同时减少通道数量,从而产生具有代表 RGB 颜色通道的三个通道和 64x64 像素尺寸的最终输出图像。

class Generator(nn.Module):def __init__(self, z_dim, channels_img, features_g):super(Generator, self).__init__()self.net = nn.Sequential(self._block(z_dim, features_g * 16, 4, 1, 0),self._block(features_g * 16, features_g * 8, 4, 2, 1),self._block(features_g * 8, features_g * 4, 4, 2, 1),self._block(features_g * 4, features_g * 2, 4, 2, 1),nn.ConvTranspose2d(features_g * 2, channels_img, kernel_size=4, stride=2, padding=1),nn.Tanh())def _block(self, in_channels, out_channels, kernel_size, stride, padding):return nn.Sequential(nn.ConvTranspose2d(in_channels,out_channels,kernel_size,stride,padding,bias=False),nn.BatchNorm2d(out_channels),nn.ReLU())def forward(self, x):return self.net(x)

Generator 类代表 DCGAN 架构中的生成器网络。它将潜在空间的维度 (z_dim)、输出图像中的通道数 (channels_img) 和特征数 (features_g) 作为输入。生成器被定义为顺序模块。

该_block方法在生成器中定义了一个块,其中包含转置卷积层、批量归一化和 ReLU 激活函数。该块重复多次,逐渐增加输出图像的空间维度。

在该forward方法中,输入噪声 (x) 通过生成器的连续层,从而生成生成的图像。

生成器的架构旨在将低维潜在空间转换为高维图像,并逐渐将其放大到所需的输出大小。然后输出图像通过双曲正切激活函数 ( nn.Tanh()) 以确保其像素值在 [-1, 1] 范围内。

通过以这种方式定义生成器,当提供随机噪声作为输入时,它学会生成类似于训练数据的合成图像。

创建鉴别器网络:
在这里插入图片描述
在 DCGAN 架构中,判别器采用大小为 3x64x64 的输入图像,并通过一系列卷积层对其进行处理,从而产生 1x1x1 的输出。以下是重塑过程的逐步分解:

  1. 输入图像:3x64x64
  2. 第一个卷积层:
    — 输出大小:64x32x32
    — 内核大小:4x4,步长:2,填充:1
  3. 第二个卷积层:
    — 输出大小:128x16x16
    — 内核大小:4x4,步长:2 ,填充:1
  4. 第三个卷积层:
    — 输出大小:256x8x8
    — 内核大小:4x4,步长:2,填充:1 5.
  5. 第四个卷积层:
    — 输出大小:512x4x4
    — 内核大小:4x4,步长:2,填充:1
  6. 最终卷积层:
    — 输出大小:1x1x1
    — 内核大小:4x4,步长:2,填充:0

通过将输入图像传递给这些卷积层,鉴别器逐渐减小空间维度,同时增加通道数量。这种转换允许鉴别器评估图像并对其真实性做出决定。1x1x1 的输出大小表示单个值,该值表示输入图像为真或假的概率。

class Discriminator(nn.Module):def __init__(self, channels_img, features_d):super(Discriminator, self).__init__()self.disc = nn.Sequential(nn.Conv2d(channels_img, features_d, kernel_size=4, stride=2, padding=1),nn.LeakyReLU(0.2),self._block(features_d, features_d * 2, 4, 2, 1),self._block(features_d * 2, features_d * 4, 4, 2, 1),self._block(features_d * 4, features_d * 8, 4, 2, 1),nn.Conv2d(features_d * 8, 1, kernel_size=4, stride=2, padding=0),nn.Sigmoid())def _block(self, in_channels, out_channels, kernel_size, stride, padding):return nn.Sequential(nn.Conv2d(in_channels,out_channels,kernel_size,stride,padding,bias=False),nn.BatchNorm2d(out_channels),nn.LeakyReLU(0.2))def forward(self, x):return self.disc(x)

Discriminator 类代表 DCGAN 架构中的鉴别器网络。它将输入图像中的通道数 (channels_img) 和特征数 (features_d) 作为输入。鉴别器被定义为一个顺序模块。

该_block方法在判别器内定义了一个块,由卷积层、批量归一化和 LeakyReLU 激活组成。该块重复多次,逐渐增加特征数量并减少输入图像的空间维度。

在该forward方法中,输入图像 (x) 通过鉴别器的连续层,由于最后的 sigmoid 激活,输出表示输入为真 (1) 或假 (0) 的概率。

鉴别器的架构使其能够区分真假图像,使其成为 DCGAN 对抗训练过程的重要组成部分。

创建一个函数来初始化权重:

def initialize_weights(model):classname = model.__class__.__name__if classname.find('Conv') != -1:nn.init.normal_(model.weight.data, 0.0, 0.02)elif classname.find('BatchNorm') != -1:nn.init.normal_(model.weight.data, 1.0, 0.02)nn.init.constant_(model.bias.data, 0)

在此函数中,我们接收 amodel作为输入,它可以引用生成器或鉴别器网络。我们迭代模型的各层并根据层类型初始化权重。这些权重是由 DCGAN 论文建议的。

对于卷积层 ( Conv),我们使用nn.init.normal_平均值 0 和标准差 0.02 来初始化权重。

对于批量归一化层 ( BatchNorm),我们使用 来初始化权重,平均值为 1,标准差为 0.02 nn.init.normal_,并使用 来将偏差设置为 0 nn.init.constant_。

通过调用此函数并将生成器和鉴别器网络作为参数传递,您可以确保网络的权重得到适当初始化以训练 DCGAN 模型。

gen = Generator(Z_DIM, CHANNELS_IMG, FEATURES_GEN).to(device)
disc = Discriminator(CHANNELS_IMG, FEATURES_DISC).to(device)initialize_weights(gen)
initialize_weights(disc)

gen现在,我们通过传递潜在空间的维度 ( Z_DIM)、输出图像中的通道数 ( CHANNELS_IMG) 以及生成器中的特征数量 ( FEATURES_GEN)来创建生成器网络 ( ) 的实例。disc类似地,我们通过指定输入通道的数量 ( CHANNELS_IMG) 和鉴别器中的特征数量 ( )创建鉴别器网络 ( ) 的实例FEATURES_DISC。

创建网络实例后,我们调用initialize_weights生成器和鉴别器网络上的函数,根据 DCGAN 论文中建议的权重初始化技术来初始化它们的权重。

通过执行此代码,您将准备好生成器和鉴别器网络,并正确初始化它们的权重,用于训练 DCGAN 模型。

opt_gen = optim.Adam(gen.parameters(), lr=LEARNING_RATE, betas=(0.5, 0.999))
opt_disc = optim.Adam(disc.parameters(), lr=LEARNING_RATE, betas=(0.5, 0.999))
criterion = nn.BCELoss()fixed_noise = torch.randn(32, Z_DIM, 1, 1).to(device)

我们为生成器和鉴别器网络定义了优化器。我们使用模块中的 Adam 优化器torch.optim。生成器优化器 ( opt_gen) 使用生成器参数、学习率LEARNING_RATE以及 Adam 优化器的 beta 参数(0.5 和 0.999)进行初始化。

类似地,判别器优化器 ( opt_disc) 使用判别器的参数、相同的学习率和 beta 参数进行初始化。

接下来,我们定义对抗训练过程的损失标准。在这里,我们使用二元交叉熵损失 (Binary Cross Entropy Loss nn.BCELoss()),它在 GAN 中常用来将判别器的预测与真实标签进行比较。

最后,我们创建一个固定噪声张量 ( fixed_noise),用于在训练过程中生成样本图像。该torch.randn函数根据正态分布生成随机数。

通过设置优化器、损失准则和固定噪声张量,您就可以开始训练 DCGAN 模型了。

def show_tensor_images(image_tensor, num_images=32, size=(1, 64, 64)):image_tensor = (image_tensor + 1) / 2image_unflat = image_tensor.detach().cpu()image_grid = make_grid(image_unflat[:num_images], nrow=4)plt.imshow(image_grid.permute(1, 2, 0).squeeze())plt.show()

“show_tensor_images”函数是一个实用函数,它获取图像张量,对它们进行标准化,创建图像网格,并使用 matplotlib 显示它们,以便在训练过程中轻松可视化。

训练模型:

gen.train()
disc.train()for epoch in range(NUM_EPOCHS):for batch_idx, (real, _ ) in enumerate(dataloader):real = real.to(device)### create noise tensornoise = torch.randn((BATCH_SIZE, Z_DIM, 1, 1)).to(device)fake = gen(noise)### Train Discriminator: max log(D(x)) + log(1 - D(G(z)))disc_real = disc(real).reshape(-1)loss_disc_real = criterion(disc_real, torch.ones_like(disc_real))disc_fake = disc(fake.detach()).reshape(-1)loss_disc_fake = criterion(disc_fake, torch.zeros_like(disc_fake))loss_disc = (loss_disc_real + loss_disc_fake) / 2disc.zero_grad()loss_disc.backward()opt_disc.step()### Train Generator: min log(1 - D(G(z))) <-> max log(D(G(z))output = disc(fake).reshape(-1)loss_gen = criterion(output, torch.ones_like(output))gen.zero_grad()loss_gen.backward()opt_gen.step()### Print losses occasionally and fake images occasionallyif batch_idx % 50 == 0:print(f"Epoch [{epoch}/{NUM_EPOCHS}] Batch {batch_idx}/{len(dataloader)} \Loss D: {loss_disc:.4f}, loss G: {loss_gen:.4f}")with torch.no_grad():fake = gen(fixed_noise)img_grid_real = torchvision.utils.make_grid(real[:32], normalize=True)img_grid_fake = torchvision.utils.make_grid(fake[:32], normalize=True)show_tensor_images(img_grid_fake)

我们可以将本次培训分为四个部分以便更好地理解。

  1. 噪声生成:随机噪声 ( noise) 使用形状为 的正态分布生成(BATCH_SIZE, Z_DIM, 1, 1)。
  2. 判别器训练:判别器网络是通过评估真实图像 ( ) 并根据判别器与地面真实标签(全部)的预测real计算损失 ( ) 来训练的。loss_disc_real假图像 ( fake) 是通过将噪声传递到生成器网络来生成的。评估鉴别器对假图像 ( ) 的预测,并根据与地面真实标签(全零)相比的预测来计算disc_fake损失 ( ) 。loss_disc_fake平均损失 ( ) 计算为和loss_disc的平均值。判别器参数的梯度设置为零,反向传播损失,并更新判别器的优化器。loss_disc_realloss_disc_fake
  3. 生成器训练:生成器网络是通过fake使用噪声生成假图像( )来训练的。获得判别器对假图像 ( ) 的预测,并根据与地面真实标签(全部)相比的预测来计算output损失 ( )。loss_gen生成器参数的梯度设置为零,反向传播损失,并更新生成器的优化器。
  4. 进度跟踪和图像可视化:每 50 个批次后,打印当前纪元、批次索引、鉴别器损失 ( loss_disc) 和生成器损失 ( )。loss_gen样本图像是通过将固定噪声传递到生成器网络来生成的。真实图像和生成图像都被转换为图像网格,然后显示图像网格以可视化生成器网络的学习进度。

每个时期生成的假图像:

Starting of the first epoch:Starting of the first epoch
After first epoch:
在这里插入图片描述
After second epoch:
在这里插入图片描述
After third epoch:
在这里插入图片描述
After fourth epoch:
在这里插入图片描述
End results ( After 5 epochs):
在这里插入图片描述

结论: 如果您能够获得更好的结果,

  1. 增加模型容量:在生成器和鉴别器网络中添加更多层或增加滤波器数量,以增强模型学习复杂模式并生成更高质量图像的能力。
  2. 利用更深的卷积层:实施ResNet 或 DenseNet 等更深层次的架构来捕获更复杂的特征和纹理,从而提高图像质量。
  3. 使用更大的数据集:在更大、更多样化的数据集上训练 DCGAN 可以帮助模型学习更广泛的图像变化,并提高其生成高分辨率图像的能力。
  4. 调整训练参数:试验学习率、批量大小和训练迭代等超参数,以优化训练过程并提高模型生成更高分辨率图像的能力。

博文译自Manohar Kedamsetti的博客。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/209364.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

振南技术干货集:制冷设备大型IoT监测项目研发纪实(2)

注解目录 1.制冷设备的监测迫在眉睫 1.1 冷食的利润贡献 1.2 冷设监测系统的困难 &#xff08;制冷设备对于便利店为何如何重要&#xff1f;了解一下你所不知道的便利店和新零售行业。关于电力线载波通信的论战。&#xff09; 2、电路设计 2.1 防护电路 2.1.1 强电防护 …

《DApp开发:开启全新数字时代篇章》

随着区块链技术的日益成熟&#xff0c;去中心化应用&#xff08;DApp&#xff09;逐渐成为数字世界的新焦点。在这个充满无限可能的全新领域&#xff0c;DApp开发为创新者们提供了开启数字时代新篇章的钥匙。 一、DApp&#xff1a;区块链创新成果 DApp是建立在区块链技术基础之…

2023年度openGauss标杆应用实践案例征集

标杆应用实践案例征集 2023 openGauss 数据库作为企业IT系统的核心组成部分&#xff0c;是数字基础设施建设的关键&#xff0c;是实现数据安全稳定的保障。openGauss顺应开源发展趋势&#xff0c;强化核心技术突破&#xff0c;着力打造自主根社区&#xff0c;携手产业伙伴共同…

如何写好科研论文

写好科研论文需要遵循以下步骤&#xff1a; 确定研究主题和目标&#xff1a;在开始撰写论文之前&#xff0c;你需要明确你的研究主题和目标。这有助于你更好地组织论文的内容&#xff0c;并确保你的论文能够准确地传达你的研究成果。做好文献调研&#xff1a;在撰写论文之前&a…

关于数据摆渡 你关心的5个问题都在这儿!

数据摆渡&#xff0c;这个词语的概念源自于网络隔离和数据交换的场景和需求。不管是物理隔离、协议隔离、应用隔离还是逻辑隔离&#xff0c;最终目的都是为了保护内部核心数据的安全。而隔离之后&#xff0c;又必然会存在文件交换的需求。 传统的跨网数据摆渡方式经历了从人工U…

postman定义公共函数这样写,测试组长直呼牛逼!!!

postman定义公共函数 在postman中&#xff0c;如下面的代码&#xff1a; 1、返回元素是否与预期值一致 var assertEqual(name,actual,expected)>{tests[${name}&#xff1a;实际结果&#xff1a; ${actual} &#xff0c; 期望结果&#xff1a;${expected}]actualexpected…

Odoo16系统忘记Master密码的解决方法

1 打开项目配置文件../Odoo 16.0.20231119/server/odoo.conf 2 找到admin_passwd 开头的行&#xff0c;删除该行&#xff0c;或者在该行前面添加英文半角分号;注释掉本行 3 重启odoo服务&#xff0c;然后访问页面如&#xff1a;http://localhost:8069/web 4 选择数据库是&am…

金融众筹模式系统源码 适合创业孵化机构+天使投资机构+投资基金会等 附带完整的搭建教程

随着互联网技术的发展和金融市场的开放&#xff0c;金融众筹模式逐渐成为一种新型的融资方式。这种模式通过互联网平台聚集大量投资者&#xff0c;共同参与到一个项目中&#xff0c;为项目提供资金支持&#xff0c;最终获得投资回报。今天罗峰给大家分享一款金融众筹模式系统源…

iOS合并代码后解决冲突

合并主干和分支代码后有冲突&#xff0c;xcode无法运行&#xff0c;如下图&#xff1a;文件显示不了&#xff0c;项目名也显示不了 解决冲突&#xff1a; 1.选中左边目录栏的项目名。鼠标右键--> Show in Finder 2.选中项目文件 xxxx.xcodeproj。鼠标右键--> 显示包内容…

【TypeScrpt算法】算法的复杂度分析

算法的复杂度分析 什么是算法复杂度&#xff1f; 不同的算法&#xff0c;其实效率是不一样的 让我举一个案例来比较两种不同的算法在查找数组中给定元素的时间复杂度 [1,2,3,4,5,6,7,...9999,n] 顺序查找 这种方法从头到尾遍历整个数组&#xff0c;依次比较每个元素和给定元…

感恩节99句祝福语,感恩父母老师朋友亲人朋友们,永久快乐幸福

1、流星让夜空感动&#xff0c;生死让人生感动&#xff0c;爱情让生活感动&#xff0c;你让我感动&#xff0c;在感恩节真心祝福你比所有的人都开心快乐。 2、感恩节到了&#xff0c;想问候你一下&#xff0c;有太多的话语想要说&#xff0c;但是不知从何说起&#xff0c;还是用…

经典百搭女童加绒卫衣,看的见的时尚

经典版型套头卫衣 宽松百搭不挑人穿 单穿内搭都可以 胸口处有精美的小熊印花 面料是复合柔软奥利绒 暖和又不显臃肿哦&#xff01;&#xff01;