Pytorch Advanced(一) Generative Adversarial Networks

生成对抗神经网络GAN,发挥神经网络的想象力,可以说是十分厉害了

参考

1、AI作家
2、将模糊图变清晰(去雨,去雾,去抖动,去马赛克等),这需要AI具有“想象力”,能脑补情节;
3、进行数据增强,根据已有数据生成更多新数据供以feed,可以减缓模型过拟合现象。

那到底是怎么实现的呢?


GAN中有两大组成部分G和D

G是generator,生成器: 负责凭空捏造数据出来

D是discriminator,判别器: 负责判断数据是不是真数据

示例图如下:

给一个随机噪声z,通过G生成一张假图,然后用D去分辨是真图还是假图。假设G生成了一张图,在D那里的得分很高,那么G就很成功的骗过了D,如果D很轻松的分辨出了假图,那么G的效果不好,那么就需要调整参数了。


G和D是两个单独的网络,那么他们的参数都是训练好的吗?并不是,两个网络的参数是需要在博弈的过程中分别优化的。

下面就是一个训练的过程:

GAN在一轮反向传播中分为两步,先训练D在训练G。

训练D时,上一轮G产生的图片,和真实图片一起作为x进行输入,假图为0,真图标签为1,通过x生成一个score,通过score和标签y计算损失,就可以进行反向传播了。

训练G时,G和D是一个整体,取名为D_on_G。输入随机噪声,G产生一个假图,D去分辨,score = 1就是需要我们需要优化的目标,意思就是我们要让生成的图片变成真的。这里的D是不需要参与梯度计算的,我们通过反向传播来优化G,让他生成更加真实的图片。这就好比:如果你参加考试,你别指望能改变老师的评分标准


GAN无监督学习,(cGAN是有监督的),以后会学习的。怎么理解无监督学习呢?这里给的真图是没有经过人工标注的,只知道这是真的,D是不知道这是什么的,只需要分辨真假。G也不知道生成了什么,只需要学真图去骗D。


具体如何实施呢?

import os
import torch
import torchvision
import torch.nn as nn 
from torchvision import transforms
from torchvision.utils import save_imagedevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
latent_size = 64
hidden_size = 256
image_size = 784
num_epochs = 200
batch_size = 100
sample_dir = 'samples'

注意这里有个归一化的过程,MNIST是单通道,但是如果mean=(0.5,0.5,0.5)会报错,因为是对3通道操作 。

if not os.path.exists(sample_dir):os.makedirs(sample_dir)transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize(mean=(0.5,),   # 3 for RGB channelsstd=(0.5,))])# MNIST dataset
mnist = torchvision.datasets.MNIST(root='./data/',train=True,transform=transform,download=True)
# Data loader
data_loader = torch.utils.data.DataLoader(dataset=mnist,batch_size=batch_size, shuffle=True)

定义生成器和判别器:

生成器:可以看到输入的维度为64,是一组噪声图像,通过生成器将特征扩大到了MNIST图像大小784。

判别器:输入维度为图像大小,最后输出特征个数为1,采用sigmoid激活(不用softmax的)

# Discriminator
D = nn.Sequential(nn.Linear(image_size, hidden_size),nn.LeakyReLU(0.2),nn.Linear(hidden_size, hidden_size),nn.LeakyReLU(0.2),nn.Linear(hidden_size, 1),nn.Sigmoid())# Generator 
G = nn.Sequential(nn.Linear(latent_size, hidden_size),nn.ReLU(),nn.Linear(hidden_size, hidden_size),nn.ReLU(),nn.Linear(hidden_size, image_size),nn.Tanh())
# Device setting
D = D.to(device)
G = G.to(device)# Binary cross entropy loss and optimizer
criterion = nn.BCELoss()
d_optimizer = torch.optim.Adam(D.parameters(), lr=0.0002)
g_optimizer = torch.optim.Adam(G.parameters(), lr=0.0002)def denorm(x):out = (x + 1) / 2return out.clamp(0, 1)def reset_grad():d_optimizer.zero_grad()g_optimizer.zero_grad()

 重点看训练部分,我们到底是如何来训练GAN的。

判别器部分:判别器的损失值分为两部分,(一)将mini_batch定义为正样本,告诉他我是正品,所以设置标签为1。优化判别器判断正品的能力;(二)生成一幅赝品,再给判别器判别,这时候赝品的标签为0,优化判断赝品的能力。所以总损失为这两部分之和,计算梯度,优化判别器参数。

G_on_D:输入一个噪声,让生成器生成一幅图像,然后让D去判别,计算和正品之间的距离,即损失。反向传播,优化G的参数。

# Start training
total_step = len(data_loader)
for epoch in range(num_epochs):for i, (images, _) in enumerate(data_loader):images = images.reshape(batch_size, -1).to(device)# Create the labels which are later used as input for the BCE lossreal_labels = torch.ones(batch_size, 1).to(device)fake_labels = torch.zeros(batch_size, 1).to(device)# ================================================================== ##                      Train the discriminator                       ## ================================================================== ## Compute BCE_Loss using real images where BCE_Loss(x, y): - y * log(D(x)) - (1-y) * log(1 - D(x))# Second term of the loss is always zero since real_labels == 1outputs = D(images)d_loss_real = criterion(outputs, real_labels)real_score = outputs# Compute BCELoss using fake images# First term of the loss is always zero since fake_labels == 0z = torch.randn(batch_size, latent_size).to(device)fake_images = G(z)outputs = D(fake_images)d_loss_fake = criterion(outputs, fake_labels)fake_score = outputs# Backprop and optimized_loss = d_loss_real + d_loss_fakereset_grad()d_loss.backward()d_optimizer.step()# ================================================================== ##                        Train the generator                         ## ================================================================== ## Compute loss with fake imagesz = torch.randn(batch_size, latent_size).to(device)fake_images = G(z)outputs = D(fake_images)# We train G to maximize log(D(G(z)) instead of minimizing log(1-D(G(z)))# For the reason, see the last paragraph of section 3. https://arxiv.org/pdf/1406.2661.pdfg_loss = criterion(outputs, real_labels)# Backprop and optimizereset_grad()g_loss.backward()g_optimizer.step()if (i+1) % 200 == 0:print('Epoch [{}/{}], Step [{}/{}], d_loss: {:.4f}, g_loss: {:.4f}, D(x): {:.2f}, D(G(z)): {:.2f}' .format(epoch, num_epochs, i+1, total_step, d_loss.item(), g_loss.item(), real_score.mean().item(), fake_score.mean().item()))# Save real imagesif (epoch+1) == 1:images = images.reshape(images.size(0), 1, 28, 28)save_image(denorm(images), os.path.join(sample_dir, 'real_images.png'))# Save sampled imagesfake_images = fake_images.reshape(fake_images.size(0), 1, 28, 28)save_image(denorm(fake_images), os.path.join(sample_dir, 'fake_images-{}.png'.format(epoch+1)))

训练完了怎么用?

只要用我们的生成器就可以随意生成了。

import matplotlib.pyplot as plt
z = torch.randn(1,latent_size).to(device)
output = G(z)
plt.imshow(output.cpu().data.numpy().reshape(28,28),cmap='gray') 
plt.show()

 下面就是随机生成的图像了!

  

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/106512.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

使用Pyarmor保护Python脚本不被反向工程

Python可读性强,使用广泛。虽然这种可读性有利于协作,但也增加了未授权访问和滥用的风险。如果未采取适当的保护,竞争对手或恶意攻击者可以复制您的算法和专有逻辑,这将对您软件的完整性和用户的信任产生负面影响。 实施可靠的安…

Java拓展--空间复杂度和时间复杂度

空间复杂度和时间复杂度 文章目录 空间复杂度和时间复杂度空间复杂度时间复杂度**评价排序算法****时间频度****什么是时间频度****忽略常数项****忽略低次项****忽略系数** **时间复杂度****什么是时间复杂度****计算时间复杂度的方法****常见的时间复杂度** **常见的时间复杂…

学习笔记-静态路由配置有来无回导致无法访问目标IP

配置拓扑图: 已经在R1、R2、R3相应端口配置了相应IP。 在R2上配置静态路由: [R2]ip route-static 10.0.3.0 24 10.0.23.3 [R2]ip route-static 10.0.13.0 24 10.0.23.3执行tracert 10.0.3.3,可以到达目标IP 执行tracert 10.0.13.3&#xff…

电子行业云MES解决方案

电子行业MES解决方案主要是针对目前电子生产制造企业面临的产品迭代升级中多品种小批量混线生产、存呆滞问题多;质量检查标准多、售后问题难追溯;生产进度难追踪、车间物料难管控、实际成本难计算等问题,提出的一种切实可行且能降低成本、提高效率的有效…

关于Java的类加载机制

1、概述 类会在运行期间第一次使用时,被类加载器动态加载至JVM。JVM不会一次性加载所有类。因为如果一次性加载,会占用很多的内存。 2、类的生命周期 类的生命周期包括以下 7 个阶段: 加载(Loading)验证(…

【uni-app】

准备工作(Hbuilder) 1.下载hbuilder,插件使用Vue3的uni-app项目 2.需要安装编译器 3.下载微信开发者工具 4.点击运行->微信开发者工具 5.打开微信开发者工具的服务端口 效果图 准备工作(VScode) 插件 uni-cr…

Codeforces Round 827 (Div. 4) D 1e5+双重for循环技巧

Codeforces Round 827 (Div. 4) D 做题链接:Codeforces Round 827 (Div. 4) 给定一个由 n个正整数 a1,a2,…,an(1≤ai≤1000)组成的数组。求ij的最大值,使得ai和aj共质,否则−1,如果不存在这样的i&#…

5.9.Webrtc线程事件处理

在前面的课程中呢,我已经向你介绍了事件处理的一些基础知识,那今天呢,我们再来看一下外边儿rtc下事件处理的基本逻辑是什么? 那首先呢,我们来看一下事件是如何协调线程工作的,那就如果这张图所展示的有两个…

直播进入新风口:XR虚拟直播市场火爆,未来发展势不可挡

 近年来,直播行业随着技术的不断发展,呈现出了蓬勃的发展态势。在这个竞争日益激烈的直播行业中,XR虚拟直播成为了最新的风口。XR虚拟直播是一种新型的直播形式,通过虚拟现实技术,让用户置身于直播现场&a…

【HELLO NEW WORLD】一封来自开放自动化时代的邀请函

​ 施耐德电气开放自动化平台,迈向开放、高效与韧性、可持续、以人为本的未来工业。 HELLO WORLD 是人类在信息世界开启的第一行 也是我们走进自动化领域迎来的第一句问候 如今 面临向数字化与自动化加速转型的新变局 工业领域迫切地需要一场变革 走向更加高效…

ClickHouse进阶(十三):Clickhouse数据字典-3-文件数据源及Mysql数据源

进入正文前,感谢宝子们订阅专题、点赞、评论、收藏!关注IT贫道,获取高质量博客内容! 🏡个人主页:含各种IT体系技术,IT贫道_大数据OLAP体系技术栈,Apache Doris,Kerberos安全认证-CSDN博客 📌订阅…

【洛谷算法题】P5704-字母转换【入门1顺序结构】

👨‍💻博客主页:花无缺 欢迎 点赞👍 收藏⭐ 留言📝 加关注✅! 本文由 花无缺 原创 收录于专栏 【洛谷算法题】 文章目录 【洛谷算法题】P5704-字母转换【入门1顺序结构】🌏题目描述🌏输入格式&a…