第G2周:人脸图像生成(DCGAN)

🍨 本文为[🔗365天深度学习训练营学习记录博客\n🍦 参考文章:365天深度学习训练营\n🍖 原作者:[K同学啊 | 接辅导、项目定制]\n🚀 文章来源:[K同学的学习圈子](https://www.yuque.com/mingtian-fkmxf/zxwb45)

一、设置超参数、导入数据 

import os
import random
import argparse
import numpy as np
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.optim as optim
import torch.utils.data
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision import datasets as dset
import torchvision.utils as vutils
from torchvision.utils import save_image
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from IPython.display import HTMLmanualSeed = 999  # 随机种子
print("Random Seed: ", manualSeed)
random.seed(manualSeed)
torch.manual_seed(manualSeed)
torch.use_deterministic_algorithms(True) # Needed for reproducible results# 超参数配置
dataroot   = "D:/GAN-Data"  # 数据路径
batch_size = 128                   # 训练过程中的批次大小
n_epochs   = 5                     # 训练的总轮数
img_size   = 64                    # 图像的尺寸(宽度和高度)
nz         = 100                   # z潜在向量的大小(生成器输入的尺寸)
ngf        = 64                    # 生成器中的特征图大小
ndf        = 64                    # 判别器中的特征图大小
beta1      = 0.5                   # Adam优化器的Beta1超参数
beta2      = 0.2                   # Adam优化器的Beta1超参数
lr         = 0.0002                # 学习率# 创建数据集
dataset = dset.ImageFolder(root=dataroot,transform=transforms.Compose([transforms.Resize(img_size),        # 调整图像大小transforms.CenterCrop(img_size),    # 中心裁剪图像transforms.ToTensor(),                # 将图像转换为张量transforms.Normalize((0.5, 0.5, 0.5), # 标准化图像张量(0.5, 0.5, 0.5)),]))
# 创建数据加载器
dataloader = torch.utils.data.DataLoader(dataset,batch_size=batch_size,  # 批量大小shuffle=True)           # 是否打乱数据集
# 选择要在哪个设备上运行代码
device = torch.device("cuda:0" if (torch.cuda.is_available()) else "cpu")
print("使用的设备是:",device)
# 绘制一些训练图像
real_batch = next(iter(dataloader))
plt.figure(figsize=(8,8))
plt.axis("off")
plt.title("Training Images")
plt.imshow(np.transpose(vutils.make_grid(real_batch[0].to(device)[:24],padding=2,normalize=True).cpu(),(1,2,0)))

f1c970f54ef84bafbb319b386e9b4a85.png

二、定义模型、可视化 

# 自定义权重初始化函数,作用于netG和netD
def weights_init(m):# 获取当前层的类名classname = m.__class__.__name__# 如果类名中包含'Conv',即当前层是卷积层if classname.find('Conv') != -1:# 使用正态分布初始化权重数据,均值为0,标准差为0.02nn.init.normal_(m.weight.data, 0.0, 0.02)# 如果类名中包含'BatchNorm',即当前层是批归一化层elif classname.find('BatchNorm') != -1:# 使用正态分布初始化权重数据,均值为1,标准差为0.02nn.init.normal_(m.weight.data, 1.0, 0.02)# 使用常数初始化偏置项数据,值为0nn.init.constant_(m.bias.data, 0)'''
定义生成器 Generator
'''class Generator(nn.Module):def __init__(self):super(Generator, self).__init__()self.main = nn.Sequential(# 输入为Z,经过一个转置卷积层nn.ConvTranspose2d(nz, ngf * 8, 4, 1, 0, bias=False),nn.BatchNorm2d(ngf * 8),  # 批归一化层,用于加速收敛和稳定训练过程nn.ReLU(True),  # ReLU激活函数# 输出尺寸:(ngf*8) x 4 x 4nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),nn.BatchNorm2d(ngf * 4),nn.ReLU(True),# 输出尺寸:(ngf*4) x 8 x 8nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),nn.BatchNorm2d(ngf * 2),nn.ReLU(True),# 输出尺寸:(ngf*2) x 16 x 16nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),nn.BatchNorm2d(ngf),nn.ReLU(True),# 输出尺寸:(ngf) x 32 x 32nn.ConvTranspose2d(ngf, 3, 4, 2, 1, bias=False),nn.Tanh()  # Tanh激活函数# 输出尺寸:3 x 64 x 64)def forward(self, x):return self.main(x)# 创建生成器
netG = Generator().to(device)
# 使用 "weights_init" 函数对所有权重进行随机初始化,
# 平均值(mean)设置为0,标准差(stdev)设置为0.02。
netG.apply(weights_init)
# 打印生成器模型
print(netG)'''
定义判别器 Discriminator
'''class Discriminator(nn.Module):def __init__(self):super(Discriminator, self).__init__()# 定义判别器的主要结构,使用Sequential容器将多个层按顺序组合在一起self.main = nn.Sequential(# 输入大小为3 x 64 x 64nn.Conv2d(3, ndf, 4, 2, 1, bias=False),nn.LeakyReLU(0.2, inplace=True),# 输出大小为(ndf) x 32 x 32nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),nn.BatchNorm2d(ndf * 2),nn.LeakyReLU(0.2, inplace=True),# 输出大小为(ndf*2) x 16 x 16nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),nn.BatchNorm2d(ndf * 4),nn.LeakyReLU(0.2, inplace=True),# 输出大小为(ndf*4) x 8 x 8nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),nn.BatchNorm2d(ndf * 8),nn.LeakyReLU(0.2, inplace=True),# 输出大小为(ndf*8) x 4 x 4nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),nn.Sigmoid())def forward(self, x):# 将输入通过判别器的主要结构进行前向传播return self.main(x)# 创建判别器对象
netD = Discriminator().to(device)
# 应用 "weights_init" 函数来随机初始化所有权重
# 使用 mean=0, stdev=0.2 的方式进行初始化
netD.apply(weights_init)
# 打印判别器模型
print(netD)# 初始化“BCELoss”损失函数
criterion = nn.BCELoss()
# 创建用于可视化生成器进程的潜在向量批次
fixed_noise = torch.randn(64, nz, 1, 1, device=device)
real_label = 1.
fake_label = 0.
# 为生成器(G)和判别器(D)设置Adam优化器
optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(beta1, beta2))
optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(beta1, beta2))img_list = []  # 用于存储生成的图像列表
G_losses = []  # 用于存储生成器的损失列表
D_losses = []  # 用于存储判别器的损失列表
iters = 0  # 迭代次数

e548707830e646fd8e48defa581a6bd2.png

三、训练模型 

print("Starting Training Loop...")  # 输出训练开始的提示信息
# 进行多个epoch的训练
for epoch in range(n_epochs):# 对于dataloader中的每个batchfor i, data in enumerate(dataloader, 0):############################# (1) 更新判别器网络:最大化 log(D(x)) + log(1 - D(G(z)))############################## 使用真实图像样本训练netD.zero_grad()  # 清除判别器网络的梯度# 准备真实图像的数据real_cpu = data[0].to(device)b_size = real_cpu.size(0)label = torch.full((b_size,), real_label, dtype=torch.float, device=device)  # 创建一个全是真实标签的张量# 将真实图像样本输入判别器,进行前向传播output = netD(real_cpu).view(-1)# 计算真实图像样本的损失errD_real = criterion(output, label)# 通过反向传播计算判别器的梯度errD_real.backward()D_x = output.mean().item()  # 计算判别器对真实图像样本的输出的平均值## 使用生成图像样本训练# 生成一批潜在向量noise = torch.randn(b_size, nz, 1, 1, device=device)# 使用生成器生成一批假图像样本fake = netG(noise)label.fill_(fake_label)  # 创建一个全是假标签的张量# 将所有生成的图像样本输入判别器,进行前向传播output = netD(fake.detach()).view(-1)# 计算判别器对生成图像样本的损失errD_fake = criterion(output, label)# 通过反向传播计算判别器的梯度errD_fake.backward()D_G_z1 = output.mean().item()  # 计算判别器对生成图像样本的输出的平均值# 计算判别器的总损失,包括真实图像样本和生成图像样本的损失之和errD = errD_real + errD_fake# 更新判别器的参数optimizerD.step()############################# (2) 更新生成器网络:最大化 log(D(G(z)))############################netG.zero_grad()  # 清除生成器网络的梯度label.fill_(real_label)  # 对于生成器成本而言,将假标签视为真实标签# 由于刚刚更新了判别器,再次将所有生成的图像样本输入判别器,进行前向传播output = netD(fake).view(-1)# 根据判别器的输出计算生成器的损失errG = criterion(output, label)# 通过反向传播计算生成器的梯度errG.backward()D_G_z2 = output.mean().item()  # 计算判别器对生成器输出的平均值# 更新生成器的参数optimizerG.step()# 输出训练统计信息if i % 400 == 0:print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f'% (epoch, n_epochs, i, len(dataloader), errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))# 保存损失值以便后续绘图G_losses.append(errG.item())D_losses.append(errD.item())# 通过保存生成器在固定噪声上的输出来检查生成器的性能if (iters % 500 == 0) or ((epoch == n_epochs - 1) and (i == len(dataloader) - 1)):with torch.no_grad():fake = netG(fixed_noise).detach().cpu()img_list.append(vutils.make_grid(fake, padding=2, normalize=True))iters += 1# 可视化
plt.figure(figsize=(10,5))
plt.title("Generator and Discriminator Loss During Training")
plt.plot(G_losses,label="G")
plt.plot(D_losses,label="D")
plt.xlabel("iterations")
plt.ylabel("Loss")
plt.legend()
plt.show()# 创建一个大小为8x8的图形对象
fig = plt.figure(figsize=(8, 8))
# 不显示坐标轴
plt.axis("off")
# 将图像列表img_list中的图像转置并创建一个包含每个图像的单个列表ims
ims = [[plt.imshow(np.transpose(i, (1, 2, 0)), animated=True)] for i in img_list]
# 使用图形对象、图像列表ims以及其他参数创建一个动画对象ani
ani = animation.ArtistAnimation(fig, ims, interval=1000, repeat_delay=1000, blit=True)
# 将动画以HTML形式呈现
HTML(ani.to_jshtml())# 从数据加载器中获取一批真实图像
real_batch = next(iter(dataloader))
# 绘制真实图像
plt.figure(figsize=(15,15))
plt.subplot(1,2,1)
plt.axis("off")
plt.title("Real Images")
plt.imshow(np.transpose(vutils.make_grid(real_batch[0].to(device)[:64], padding=5, normalize=True).cpu(),(1,2,0)))
# 绘制上一个时期生成的假图像
plt.subplot(1,2,2)
plt.axis("off")
plt.title("Fake Images")
plt.imshow(np.transpose(img_list[-1],(1,2,0)))
plt.show()

 训练结果:

[Epoch 0/50][Batch 0/36][D loss: 1.446340][G loss: 5.497820][D : 0.496465][G : 0.006384]
[Epoch 1/50][Batch 0/36][D loss: 0.198702][G loss: 32.366798][D : 0.000000][G : 0.000000]
[Epoch 2/50][Batch 0/36][D loss: 0.007939][G loss: 39.840797][D : 0.000000][G : 0.000000]
[Epoch 3/50][Batch 0/36][D loss: 0.008718][G loss: 39.420380][D : 0.000000][G : 0.000000]
[Epoch 4/50][Batch 0/36][D loss: 0.000432][G loss: 39.375351][D : 0.000000][G : 0.000000]
[Epoch 5/50][Batch 0/36][D loss: 0.000377][G loss: 39.141502][D : 0.000000][G : 0.000000]
[Epoch 6/50][Batch 0/36][D loss: 0.000066][G loss: 38.554665][D : 0.000000][G : 0.000000]
[Epoch 7/50][Batch 0/36][D loss: 0.000161][G loss: 37.076347][D : 0.000000][G : 0.000000]
[Epoch 8/50][Batch 0/36][D loss: 0.236551][G loss: 5.515038][D : 0.126019][G : 0.009809]
[Epoch 9/50][Batch 0/36][D loss: 0.774763][G loss: 4.037993][D : 0.041982][G : 0.032798]
[Epoch 10/50][Batch 0/36][D loss: 1.355027][G loss: 7.484296][D : 0.627779][G : 0.001169]
[Epoch 11/50][Batch 0/36][D loss: 1.026440][G loss: 3.390290][D : 0.480961][G : 0.066138]
[Epoch 12/50][Batch 0/36][D loss: 0.698196][G loss: 2.289851][D : 0.117281][G : 0.149754]
[Epoch 13/50][Batch 0/36][D loss: 0.407120][G loss: 3.295501][D : 0.169919][G : 0.056703]
[Epoch 14/50][Batch 0/36][D loss: 0.858621][G loss: 4.627818][D : 0.297173][G : 0.028583]
[Epoch 15/50][Batch 0/36][D loss: 1.068889][G loss: 4.085044][D : 0.314014][G : 0.029605]
[Epoch 16/50][Batch 0/36][D loss: 0.761256][G loss: 1.878336][D : 0.122635][G : 0.189217]
[Epoch 17/50][Batch 0/36][D loss: 0.946410][G loss: 5.986092][D : 0.486197][G : 0.005545]
[Epoch 18/50][Batch 0/36][D loss: 0.607918][G loss: 8.022884][D : 0.355339][G : 0.000997]
[Epoch 19/50][Batch 0/36][D loss: 0.387959][G loss: 5.217168][D : 0.148431][G : 0.012128]
[Epoch 20/50][Batch 0/36][D loss: 0.502083][G loss: 3.828265][D : 0.124887][G : 0.032919]
[Epoch 21/50][Batch 0/36][D loss: 0.341051][G loss: 5.217510][D : 0.129790][G : 0.010647]
[Epoch 22/50][Batch 0/36][D loss: 0.305131][G loss: 3.878963][D : 0.118515][G : 0.034831]
[Epoch 23/50][Batch 0/36][D loss: 0.326738][G loss: 3.092067][D : 0.048084][G : 0.073658]
[Epoch 24/50][Batch 0/36][D loss: 1.001996][G loss: 7.870810][D : 0.531534][G : 0.001848]
[Epoch 25/50][Batch 0/36][D loss: 0.646764][G loss: 5.994369][D : 0.328600][G : 0.005999]
[Epoch 26/50][Batch 0/36][D loss: 1.305306][G loss: 3.512106][D : 0.027197][G : 0.060318]
[Epoch 27/50][Batch 0/36][D loss: 0.230971][G loss: 6.018190][D : 0.160877][G : 0.005384]
[Epoch 28/50][Batch 0/36][D loss: 0.479868][G loss: 2.851458][D : 0.012263][G : 0.132684]
[Epoch 29/50][Batch 0/36][D loss: 1.190969][G loss: 6.840727][D : 0.560059][G : 0.003298]
[Epoch 30/50][Batch 0/36][D loss: 1.005036][G loss: 6.322803][D : 0.486148][G : 0.005413]
[Epoch 31/50][Batch 0/36][D loss: 0.407194][G loss: 5.357150][D : 0.025872][G : 0.012775]
[Epoch 32/50][Batch 0/36][D loss: 0.715868][G loss: 4.764071][D : 0.410440][G : 0.018443]
[Epoch 33/50][Batch 0/36][D loss: 0.525104][G loss: 4.291232][D : 0.187566][G : 0.026254]
[Epoch 34/50][Batch 0/36][D loss: 0.363458][G loss: 4.643357][D : 0.184744][G : 0.021312]
[Epoch 35/50][Batch 0/36][D loss: 0.550998][G loss: 3.245662][D : 0.190560][G : 0.078518]
[Epoch 36/50][Batch 0/36][D loss: 0.686132][G loss: 5.602957][D : 0.362706][G : 0.007369]
[Epoch 37/50][Batch 0/36][D loss: 0.556991][G loss: 3.656791][D : 0.147552][G : 0.046845]
[Epoch 38/50][Batch 0/36][D loss: 0.459933][G loss: 4.163424][D : 0.245844][G : 0.033957]
[Epoch 39/50][Batch 0/36][D loss: 0.232279][G loss: 4.535916][D : 0.114630][G : 0.016447]
[Epoch 40/50][Batch 0/36][D loss: 0.479002][G loss: 5.497972][D : 0.263936][G : 0.012047]
[Epoch 41/50][Batch 0/36][D loss: 0.720815][G loss: 3.263973][D : 0.259178][G : 0.061856]
[Epoch 42/50][Batch 0/36][D loss: 0.703234][G loss: 6.425527][D : 0.400735][G : 0.003400]
[Epoch 43/50][Batch 0/36][D loss: 0.741217][G loss: 2.052215][D : 0.048953][G : 0.209300]
[Epoch 44/50][Batch 0/36][D loss: 0.658782][G loss: 3.800625][D : 0.272119][G : 0.040041]
[Epoch 45/50][Batch 0/36][D loss: 0.402264][G loss: 5.260798][D : 0.185568][G : 0.009509]
[Epoch 46/50][Batch 0/36][D loss: 0.753039][G loss: 4.797507][D : 0.406285][G : 0.022727]
[Epoch 47/50][Batch 0/36][D loss: 0.301918][G loss: 4.467443][D : 0.173592][G : 0.022788]
[Epoch 48/50][Batch 0/36][D loss: 0.638086][G loss: 1.768839][D : 0.072733][G : 0.227529]
[Epoch 49/50][Batch 0/36][D loss: 0.576230][G loss: 2.268032][D : 0.082981][G : 0.151779]

c46f116c90bd47d88d0a6ae2b2857d37.png 314562a5064a467aa98fd81bf10d6287.png

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/307566.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

SourceTree的安装和使用

提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档 文章目录 前言一、安装:二、使用步骤1.获取地址2.放入sourceTree 3.点击推送 前言 提示:这里可以添加本文要记录的大概内容: 简单讲解一…

【网络技术】【Kali Linux】Wireshark嗅探(三)用户数据报(UDP)协议

一、实验目的 本次实验使用wireshark流量分析工具进行网络嗅探,旨在了解UDP协议的报文格式。 二、网络环境设置 本次实验使用Kali Linux虚拟机完成,主机操作系统为Windows 11,虚拟化平台选择Oracle VM VirtualBox,组网模式选择…

【Python排序算法系列】—— 选择排序

​ 🌈个人主页: Aileen_0v0 🔥热门专栏: 华为鸿蒙系统学习|计算机网络|数据结构与算法 💫个人格言:"没有罗马,那就自己创造罗马~" 目录 选择排序 过程演示: 选择排序实现代码: 分析选择排序&#xff1a…

虚拟机和电脑如何传送文件

一.桥接 (实现电脑和虚拟机在同一网段) 虚拟机上网盘设置 二.属性---文件共享设置 1打开属性,点击共享 2.添加共享人为全部人,并修改权限为读写模式 3.点击高级共享,选定此文件夹 4.点击网络和共享中心,划…

函数式编程的妙用

前言 我们平常项目中维护的比较多的就是实体类中的数量问题,我们最常见的做法就是通过get方法读取旧数据,然后进行新数据的set 。这套方法相对来说是比较统一固定的,如果有多处地方使用,我们可以想着通过Function和BiConsumer的函…

java设计模式学习之【模板方法模式】

文章目录 引言模板方法模式简介定义与用途实现方式 使用场景优势与劣势在Spring框架中的应用游戏设计示例代码地址 引言 设想你正在准备一顿晚餐,无论你想做意大利面、披萨还是沙拉,制作过程中都有一些共同的步骤:准备原料、加工食物、摆盘。…

毅速:3D打印技术传统模具行业影响深远

随着3D打印技术的不断发展和完善,一系列的优势使其在模具制造领域的应用越来越广泛,这一技术在模具行业的应用将为整个行业带来变革。 首先,3D打印技术将大幅提高模具制造的精度和效率。传统的模具制造过程中,由于加工设备的限制和…

Typora使用PicGo+Gitee上传图片报错403 Forbidden

Typora使用PicGoGitee上传图片报错403 Forbidden Typora使用PicGoGitee上传图片,上传失败了,错误信息如下 打开PicGo的日志文件查看,可以看到错误详情如下 换了一个插件github-plus重新配置,解决了这个问题 再打开日志查看&…

gem5学习(7):内存系统中创建 SimObjects--Creating SimObjects in the memory system

目录 一、gem5 master and slave ports 二、Packets 三、Port interface 1、主设备发送请求时从设备忙 2、从设备发送响应时主设备忙 四、Simple memory object example 1、Declare the SimObject 2、Define the SimpleMemobj class 3、Define the SimpleMemobj class…

python 异步Web框架sanic

我们继续学习Python异步编程,这里将介绍异步Web框架sanic,为什么不是tornado?从框架的易用性来说,Flask要远远比tornado简单,可惜flask不支持异步,而sanic就是类似Flask语法的异步框架。 github&#xff1…

1200PLC连接分布式IO组态编程应用

SMART PLC作为S7-1200PLC的智能IO从站设备组态和编程应用详细介绍请参考下面链接文章: https://rxxw-control.blog.csdn.net/article/details/130257474https://rxxw-control.blog.csdn.net/article/details/130257474这篇博客我们介绍S7-1200PLC和分布式IO连接组…

【LangChain】与文档聊天:将OpenAI与LangChain集成的终极指南

欢迎来到人工智能的迷人世界,在那里,人与机器之间的通信越来越模糊。在这篇博客文章中,我们将探索人工智能驱动交互的一个令人兴奋的新前沿:与您的文本文档聊天!借助OpenAI模型和创新的LangChain框架的强大组合&#x…