PyTorch训练简单的生成对抗网络GAN

文章目录

    • 原理
    • 代码
    • 结果
    • 参考

原理

同时训练两个网络:辨别器Discriminator 和 生成器Generator
Generator是 造假者,用来生成假数据。
Discriminator 是警察,尽可能的分辨出来哪些是造假的,哪些是真实的数据。

目的:使得判别模型尽量犯错,无法判断数据是来自真实数据还是生成出来的数据。

GAN的梯度下降训练过程:

在这里插入图片描述
上图来源:https://arxiv.org/abs/1406.2661

Train 辨别器: m a x max max l o g ( D ( x ) ) + l o g ( 1 − D ( G ( z ) ) ) log(D(x)) + log(1 - D(G(z))) log(D(x))+log(1D(G(z)))

Train 生成器: m i n min min l o g ( 1 − D ( G ( z ) ) ) log(1-D(G(z))) log(1D(G(z)))

我们可以使用BCEloss来计算上述两个损失函数

BCEloss的表达式: m i n − [ y l n x + ( 1 − y ) l n ( 1 − x ) ] min -[ylnx + (1-y)ln(1-x)] min[ylnx+(1y)ln(1x)]
具体过程参加代码中注释

代码

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from torch.utils.tensorboard import SummaryWriter  # to print to tensorboardclass Discriminator(nn.Module):def __init__(self, img_dim):super(Discriminator, self).__init__()self.disc = nn.Sequential(nn.Linear(img_dim, 128),nn.LeakyReLU(0.1),nn.Linear(128, 1),nn.Sigmoid(),)def forward(self, x):return self.disc(x)class Generator(nn.Module):def __init__(self, z_dim, img_dim): # z_dim 噪声的维度super(Generator, self).__init__()self.gen = nn.Sequential(nn.Linear(z_dim, 256),nn.LeakyReLU(0.1),nn.Linear(256, img_dim), # 28x28 -> 784nn.Tanh(),)def forward(self, x):return self.gen(x)# Hyperparameters
device = 'cuda' if torch.cuda.is_available() else 'cpu'
lr = 3e-4 # 3e-4是Adam最好的学习率
z_dim = 64 # 噪声维度
img_dim = 784 # 28x28x1
batch_size = 32
num_epochs = 50disc = Discriminator(img_dim).to(device)
gen = Generator(z_dim, img_dim).to(device)fixed_noise = torch.randn((batch_size, z_dim)).to(device)
transforms = transforms.Compose( # MNIST标准化系数:(0.1307,), (0.3081,)[transforms.ToTensor(), transforms.Normalize((0.1307, ), (0.3081,))] # 不同数据集就有不同的标准化系数
)dataset = datasets.MNIST(root='dataset/', transform=transforms, download=True)
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)opt_disc = optim.Adam(disc.parameters(), lr=lr)
opt_gen = optim.Adam(gen.parameters(), lr=lr)
# BCE 损失
criterion = nn.BCELoss()# 打开tensorboard:在该目录下,使用 tensorboard --logdir=runs
writer_fake = SummaryWriter(f"runs/GAN_MNIST/fake")
writer_real = SummaryWriter(f"runs/GAN_MNIST/real")
step = 0for epoch in range(num_epochs):for batch_idx, (real, _) in enumerate(loader):real = real.view(-1, 784).to(device) # view相当于reshapebatch_size = real.shape[0]### Train Discriminator: max log(D(real)) + log(1 - D(G(z)))noise = torch.randn(batch_size, z_dim).to(device)fake = gen(noise) # G(z)disc_real = disc(real).view(-1) # flatten# BCEloss的表达式:min -[ylnx + (1-y)ln(1-x)]# max log(D(real)) 相当于 min -log(D(real))# ones_like:1填充得到y=1, 即可忽略  min -[ylnx + (1-y)ln(1-x)]中的后一项# 得到 min -lnx,这里的x就是我们的real图片lossD_real = criterion(disc_real, torch.ones_like(disc_real))disc_fake = disc(fake).view(-1)# max log(1 - D(G(z))) 相当于 min -log(1 - D(G(z)))# zeros_like用0填充,得到y=0,即可忽略  min -[ylnx + (1-y)ln(1-x)]中的前一项# 得到 min -ln(1-x),这里的x就是我们的fake噪声lossD_fake = criterion(disc_fake, torch.zeros_like(disc_fake))lossD = (lossD_real + lossD_fake) / 2disc.zero_grad()lossD.backward(retain_graph=True)opt_disc.step()### Train Generator: min log(1-D(G(z))) <--> max log(D(G(z))) <--> min - log(D(G(z)))# 依然可使用BCEloss来做output = disc(fake).view(-1)lossG = criterion(output, torch.ones_like(output))gen.zero_grad()lossG.backward()opt_gen.step()if batch_idx == 0:print(f"Epoch [{epoch}/{num_epochs}] \ "f"Loss D: {lossD:.4f}, Loss G: {lossG:.4f}")with torch.no_grad():fake = gen(fixed_noise).reshape(-1, 1, 28, 28)data = real.reshape(-1, 1, 28, 28)img_grid_fake = torchvision.utils.make_grid(fake, normalize=True)img_grid_real = torchvision.utils.make_grid(data, normalize=True)writer_fake.add_image("Mnist Fake Images", img_grid_fake, global_step=step)writer_real.add_image("Mnist Real Images", img_grid_real, global_step=step)step += 1

结果

训练50轮的的损失

Epoch [0/50] \ Loss D: 0.7366, Loss G: 0.7051
Epoch [1/50] \ Loss D: 0.2483, Loss G: 1.6877
Epoch [2/50] \ Loss D: 0.1049, Loss G: 2.4980
Epoch [3/50] \ Loss D: 0.1159, Loss G: 3.4923
Epoch [4/50] \ Loss D: 0.0400, Loss G: 3.8776
Epoch [5/50] \ Loss D: 0.0450, Loss G: 4.1703
...
Epoch [43/50] \ Loss D: 0.0022, Loss G: 7.7446
Epoch [44/50] \ Loss D: 0.0007, Loss G: 9.1281
Epoch [45/50] \ Loss D: 0.0138, Loss G: 6.2177
Epoch [46/50] \ Loss D: 0.0008, Loss G: 9.1188
Epoch [47/50] \ Loss D: 0.0025, Loss G: 8.9419
Epoch [48/50] \ Loss D: 0.0010, Loss G: 8.3315
Epoch [49/50] \ Loss D: 0.0007, Loss G: 7.8302

使用

tensorboard --logdir=runs

打开tensorboard:

在这里插入图片描述
可以看到效果并不好,这是由于我们只是采用了简单的线性网络来做辨别器和生成器。后面的博文我们会使用更复杂的网络来训练GAN。

参考

[1] Building our first simple GAN
[2] https://arxiv.org/abs/1406.2661

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/81284.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

机器学习笔记之优化算法(十六)梯度下降法在强凸函数上的收敛性证明

机器学习笔记之优化算法——梯度下降法在强凸函数上的收敛性证明 引言回顾&#xff1a;凸函数与强凸函数梯度下降法&#xff1a;凸函数上的收敛性分析 关于白老爹定理的一些新的认识梯度下降法在强凸函数上的收敛性收敛性定理介绍结论分析证明过程 引言 本节将介绍&#xff1a…

基于旗鱼算法优化的BP神经网络(预测应用) - 附代码

基于旗鱼算法优化的BP神经网络&#xff08;预测应用&#xff09; - 附代码 文章目录 基于旗鱼算法优化的BP神经网络&#xff08;预测应用&#xff09; - 附代码1.数据介绍2.旗鱼优化BP神经网络2.1 BP神经网络参数设置2.2 旗鱼算法应用 4.测试结果&#xff1a;5.Matlab代码 摘要…

消息队列前世今生 字节跳动 Kafka #创作活动

消息队列前世今生 1.1 案例一&#xff1a; 系统崩溃 首先大家跟着我想象一下下面的这个的场景&#xff0c; 看到新出的游戏机&#xff0c;太贵了买不起&#xff0c;这个时候你突然想到&#xff0c;今天抖音直播搞活动&#xff0c;打开抖音搜索&#xff0c;找到直播间以后&am…

Node基础--Node简介以及安装教程

1.Node简介 Node.js发布于2009年5月&#xff0c;由Ryan Dahl开发&#xff0c;是一个基于Chrome V8引擎的JavaScript运行环境&#xff0c;使用了一个事件驱动、非阻塞式I/O模型&#xff0c;让JavaScript 运行在服务端的开发平台&#xff0c;它让JavaScript成为与PHP、Python、Pe…

RabbitMq-3入门案例

rabbitmq入门 1.生产者&#xff08;服务提供方&#xff09; //依赖<dependencies> <!-- rabbitmq客户端依赖--><dependency><groupId>com.rabbitmq</groupId><artifactId>amqp-client</artifactId><version>5.8.0<…

Docker打包JDK20镜像

文章目录 Docker 打包 JDK 20镜像步骤1.下载 jdk20 压缩包2.编写 dockerfile3.打包4.验证5.创建并启动容器6.检查 Docker 打包 JDK 20镜像 步骤 1.下载 jdk20 压缩包 https://www.oracle.com/java/technologies/downloads/ 2.编写 dockerfile #1.指定基础镜像&#xff0c;并…

Linux: 使用scp命令复制文件夹报 not a regular file 错误解决

使用scp 命令复制文件夹报 not a regular file 错误解决 解决办法&#xff1a; 加入参数 -r&#xff1a; 递归复制整个目录 scp命令参数详解

【C++小项目】实现一个日期计算器

目录 Ⅰ. 引入 Ⅱ. 列轮廓 Ⅲ. 功能的实现 构造函数 Print 判断是否相等 | ! ➡️: ➡️!: 判断大小 > | > | < | < ➡️>&#xff1a; ➡️<&#xff1a; ➡️>&#xff1a; ➡️<&#xff1a; 加减天数 | | - | - ➡️&#xff1a;…

TypeScript——类型系统与类型推导

前言 TypeScript 是由 Microsoft 开发的一种开放源代码语言。 它是 JavaScript 的一个超集&#xff0c;这意味着你可以在 TypeScript 中使用 JS 已存在的所有语法&#xff0c;并且所有 JavaScript 脚本都可以当作 TypeScript 脚本&#xff0c;此外它还增加了一些自己的语法。T…

JVM——类加载与字节码技术—字节码指令

2.字节码指令 2.1 入门 jvm的解释器可以识别平台无关的字节码指令&#xff0c;解释为机器码执行。 2a b7 00 01 b1 this . init&#xff08;&#xff09; return 准备了System.out对象&#xff0c;准备了参数“hello world”,准备了对象的方法println(String)V&#xff…

1.文章复现《热电联产系统在区域综合能源系统中的定容选址研究》(附matlab程序)

0.代码链接 文章复现《热电联产系统在区域综合能源系统中的定容选址研究》&#xff08;matlab程序&#xff09;-Matlab文档类资源-CSDN文库 1.简述 本文采用遗传算法的方式进行了下述文章的复现并采用电-热节点的方式进行了潮流计算以降低电网的网络损耗 分析了电网的基本数…

C语言练习1(巩固提升)

C语言练习1 选择题 前言 “人生在勤&#xff0c;勤则不匮。”幸福不会从天降&#xff0c;美好生活靠劳动创造。全面建成小康社会的奋斗目标&#xff0c;为广大劳动群众指明了光明的未来&#xff1b;全面建成小康社会的历史任务&#xff0c;为广大劳动群众赋予了光荣的使命&…