深度学习之GAN网络

目录

关于GAN网络

关于生成模型和判别模型

GAN网路的特性和搭建步骤(以手写字体识别数据集为例)

搭建步骤

特性

GAN的目标函数(损失函数)

目标函数原理

torch.nn.BCELoss(实际应用的损失函数)

代码

生成器

判别器

后向传播

完整代码

运行结果

图像显示


关于GAN网络

GAN包含有两个模型,一个是生成模型(generative model),一个是判别模型(discriminative model)。生成模型的任务是生成看起来自然真实的、和原始数据相似的实例。判别模型的任务是判断给定的实例看起来是自然真实的还是人为伪造的

关于生成模型和判别模型

通俗理解:生成模型像“一个造假币的机器”,而判别模型像“检测假币的机器”。

在判别模型这里只有“是假币”,“不是假币”。

而生成器的目的就是让自己的假币越来越逼真以至于能够“骗过”判断模型,判别器则努力不被生成器欺骗。模型经过交替优化训练,两种模型都能得到提升,但最终我们要得到的是效果提升到很高很好的生成模型(造假币的机器),这个生成模型(造假币的机器)所生成的产品能达到真假难分的地步。

(虽然最后判别器也能得到提升,可是随着我们不断地训练,判别器的准确率会停留在0.5左右,这是由于两个模型的特性决定的,也是之所以叫生成对抗神经网络的原因。)

GAN网路的特性和搭建步骤(以手写字体识别数据集为例)

搭建步骤

#GAN生成对抗网络,步骤:
#首先编写生成器和判别器
#然后固定生成器,用我们的数据优化判别器,试得我们最开始生成器生成的图片判断为0,真实图片判断为1
#接着固定判别器,利用我们的判别器判断生成器生成的图片,以判断的尽可能接近一为目的优化我们的生成器
#生成器的代码(针对手写字体识别)

特性

我们第一开始后向传播的时候,我们先更新判断器参数,然后更新生成器的,然后接着这样重复,后面随着循环的进行,我们生成器生成的的图片大体上肯定是越来越接近真实的图片的。
假设我们已经训练了50轮,我们假设这个时候已经很接近正确的真实图片。
当他开始下一轮的时候,此时我们更新判别器在判0方面的参数更新的依据是:把这个很接近真实图片的伪造图片放进判断模型判断,通过损失函数计算和判断结果和0的误差来进行参数的更新
此时,就会出现一个矛盾点,这个矛盾点有两个主要方面:
第一点,伪造图片数据用来判别器判为0的优化(过损失函数计算和判断结果和0的误差),而我们的真实图片用来判别器判为1的优化(过损失函数计算和判断结果和1的误差
第二点,因为第一点的原因也会影响到我们生成器,因为我们生成一是将生成的伪造图片进行判别器判断后,通过损失函数计算和判断结果和1的误差来进行参数的更新

这样就有了个矛盾点:生成器是让伪造数据判断结果趋向于1,判别器是让伪造数据的判断结果趋向于0,这样就形成了‘对抗’
我们生成器一开始肯定是损失下降的。
可慢慢的不论我们怎么优化生成器的参数,我们生成的图片经过判断器判断后和1的区别似乎小到某个临界点之后越来越大了,以至于后面这个区别似乎在某个区间反复的“振荡”,并不是愈来愈小。
这表示如果把每一轮的损失都分为生成器损失和判别器损失,这两个损失都应该是波动型的。

按我的理解,这个模型的特点是体现在一开始我们生成器生成的图片是愈来愈“真”,但是如果次数很多的话,就会导致判别器很难判断0或者1,以至于影响到后面生成器生成的图片

GAN的目标函数(损失函数)

目标函数原理

这个是判别器的目标函数,D(x)表示生成器判1的概率,D(G(x))表示生成器判0的概率,光从这里看这个应当是愈来愈大的。

这是我们生成器 目标函数,可以看出来是愈来愈小的。

可我们实际应用的话也有一些区别。

torch.nn.BCELoss(实际应用的损失函数)

(以下内容来自网络)

 可以看出来在计算我们生成模型的损失值得时候,我们对我们伪造的图片进行判0得时候,这一块得损失值应当是增大的,因为实际上预测他的概率或者说最后输出层那一个神经元里面的数是应该不断接近一的。

代码

生成器

class Generator(torch.nn.Module):def __init__(self):super(Generator,self).__init__()self.main=torch.nn.Sequential(torch.nn.Linear(100,256),torch.nn.ReLU(),torch.nn.Linear(256,512),torch.nn.ReLU(),torch.nn.Linear(512,28*28),torch.nn.Tanh())def forward(self,x):img=self.main(x)img=img.reshape(-1,28,28)return img

判别器

class Discraiminator(torch.nn.Module):def __init__(self):super(Discraiminator,self).__init__()self.mainf = torch.nn.Sequential(torch.nn.Linear(28*28, 512),torch.nn.LeakyReLU(),torch.nn.Linear(512, 256),torch.nn.LeakyReLU(),torch.nn.Linear(256,1),torch.nn.Sigmoid())def forward(self,x):x=x.view(-1,28*28)x=self.mainf(x)return x

后向传播

#定义损失函数和优化函数
device='cuda' if torch.cuda.is_available() else 'cpu'
gen=Generator().to(device)
dis=Discraiminator().to(device)
#定义优化器
gen_opt=torch.optim.Adam(gen.parameters(),lr=0.0001)
dis_opt=torch.optim.Adam(dis.parameters(),lr=0.0001)
loss_fn=torch.nn.BCELoss()#损失函数
#图像显示
def gen_img_plot(model,testdata):pre=np.squeeze(model(testdata).detach().cpu().numpy())
#tensor.detach()
#返回一个新的tensor,从当前计算图中分离下来的,但是仍指向原变量的存放位置,不同之处只是requires_grad为false,得到的这个tensor永远不需要计算其梯度,不具有grad。
# 即使之后重新将它的requires_grad置为true,它也不会具有梯度grad
# 这样我们就会继续使用这个新的tensor进行计算,后面当我们进行反向传播时,到该调用detach()的tensor就会停止,不能再继续向前进行传播plt.figure()for i in range(16):plt.subplot(4,4,i+1)plt.imshow(pre[i])plt.show()#后向传播
dis_loss=[]#判别器损失值记录
gen_loss=[]#生成器损失值记录
lun=[]#轮数
for epoch in range(60):d_epoch_loss=0g_epoch_loss=0cout=len(trainload)#938批次for step, (img, _) in enumerate(trainload):img=img.to(device) #图像数据#print('img.size:',img.shape)#img.size: torch.Size([64, 1, 28, 28])size=img.size(0)#一批次的图片数量64#随机生成一批次的100维向量样本,或者说100个像素点random_noise=torch.randn(size,100,device=device)#先进性判断器的后向传播dis_opt.zero_grad()real_output=dis(img)d_real_loss=loss_fn(real_output,torch.ones_like(real_output))#真实数据的损失函数值d_real_loss.backward()gen_img=gen(random_noise)fake_output=dis(gen_img.detach())d_fake_loss = loss_fn(fake_output, torch.zeros_like(fake_output))#人造的数据的损失函数值d_fake_loss.backward()d_loss=d_real_loss+d_fake_lossdis_opt.step()#生成器的后向传播gen_opt.zero_grad()fake_output=dis(gen_img)g_loss=loss_fn(fake_output, torch.ones_like(fake_output))g_loss.backward()gen_opt.step()d_epoch_loss += d_lossg_epoch_loss += g_loss

完整代码

import matplotlib.pyplot as plt
from matplotlib import font_manager
import torch
from torch.utils.data import DataLoader
import torchvision
from torchvision import transforms
import numpy as np
#导入数据集并且进行数据处理
transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize(0.5,0.5)])
traindata=torchvision.datasets.MNIST(root='D:\learn_pytorch\数据集',train=True,download=True,transform=transform)#训练集60,000张用于训练
#利用DataLoader加载数据集
trainload=DataLoader(dataset=traindata,shuffle=True,batch_size=64)
#GAN生成对抗网络,步骤:
#首先编写生成器和判别器
#然后固定生成器,用我们的数据优化判别器,试得我们最开始生成器生成的图片判断为0,真实图片判断为1
#接着固定判别器,利用我们的判别器判断生成器生成的图片,以判断的尽可能接近一为目的优化我们的生成器
#生成器的代码(针对手写字体识别)
class Generator(torch.nn.Module):def __init__(self):super(Generator,self).__init__()self.main=torch.nn.Sequential(torch.nn.Linear(100,256),torch.nn.ReLU(),torch.nn.Linear(256,512),torch.nn.ReLU(),torch.nn.Linear(512,28*28),torch.nn.Tanh())def forward(self,x):img=self.main(x)img=img.reshape(-1,28,28)return img
#判别器,最后判断0,1,这意味着最后可以是一个神经元或者两个神经元
class Discraiminator(torch.nn.Module):def __init__(self):super(Discraiminator,self).__init__()self.mainf = torch.nn.Sequential(torch.nn.Linear(28*28, 512),torch.nn.LeakyReLU(),torch.nn.Linear(512, 256),torch.nn.LeakyReLU(),torch.nn.Linear(256,1),torch.nn.Sigmoid())def forward(self,x):x=x.view(-1,28*28)x=self.mainf(x)return x
#定义损失函数和优化函数
device='cuda' if torch.cuda.is_available() else 'cpu'
gen=Generator().to(device)
dis=Discraiminator().to(device)
#定义优化器
gen_opt=torch.optim.Adam(gen.parameters(),lr=0.0001)
dis_opt=torch.optim.Adam(dis.parameters(),lr=0.0001)
loss_fn=torch.nn.BCELoss()#损失函数
#图像显示
def gen_img_plot(model,testdata):pre=np.squeeze(model(testdata).detach().cpu().numpy())
#tensor.detach()
#返回一个新的tensor,从当前计算图中分离下来的,但是仍指向原变量的存放位置,不同之处只是requires_grad为false,得到的这个tensor永远不需要计算其梯度,不具有grad。
# 即使之后重新将它的requires_grad置为true,它也不会具有梯度grad
# 这样我们就会继续使用这个新的tensor进行计算,后面当我们进行反向传播时,到该调用detach()的tensor就会停止,不能再继续向前进行传播plt.figure()for i in range(16):plt.subplot(4,4,i+1)plt.imshow(pre[i])plt.show()#后向传播
dis_loss=[]#判别器损失值记录
gen_loss=[]#生成器损失值记录
lun=[]#轮数
for epoch in range(60):d_epoch_loss=0g_epoch_loss=0cout=len(trainload)#938批次for step, (img, _) in enumerate(trainload):img=img.to(device) #图像数据#print('img.size:',img.shape)#img.size: torch.Size([64, 1, 28, 28])size=img.size(0)#一批次的图片数量64#随机生成一批次的100维向量样本,或者说100个像素点random_noise=torch.randn(size,100,device=device)#先进性判断器的后向传播dis_opt.zero_grad()real_output=dis(img)d_real_loss=loss_fn(real_output,torch.ones_like(real_output))#真实数据的损失函数值d_real_loss.backward()gen_img=gen(random_noise)fake_output=dis(gen_img.detach())d_fake_loss = loss_fn(fake_output, torch.zeros_like(fake_output))#人造的数据的损失函数值d_fake_loss.backward()d_loss=d_real_loss+d_fake_lossdis_opt.step()#生成器的后向传播gen_opt.zero_grad()fake_output=dis(gen_img)g_loss=loss_fn(fake_output, torch.ones_like(fake_output))g_loss.backward()gen_opt.step()d_epoch_loss += d_lossg_epoch_loss += g_lossdis_loss.append(float(d_epoch_loss))gen_loss.append(float(g_epoch_loss))print(f'第{epoch+1}轮的生成器损失值:{g_epoch_loss},判别器损失值{d_epoch_loss}')lun.append(epoch+1)
font = font_manager.FontProperties(fname="C:\\Users\\ASUS\\Desktop\\Fonts\\STZHONGS.TTF")
plt.figure()
plt.plot(lun,dis_loss,'r',label='判别器损失值')
plt.plot(lun,gen_loss,'b',label='生成器损失值')
plt.xlabel('训练轮数', fontproperties=font, fontsize=12)
plt.ylabel('损失值', fontproperties=font, fontsize=12)
plt.title('损失值随着训练轮数得变化情况:',fontproperties=font, fontsize=18)
plt.legend(prop=font)
plt.show()
random_noise=torch.randn(16,100,device=device)
gen_img_plot(gen,random_noise)

运行结果

D:\Anaconda3\envs\pytorch\python.exe D:\learn_pytorch\学习过程\第七周的代码\GAN.py 
第1轮的生成器损失值:2528.59814453125,判别器损失值353.325927734375
第2轮的生成器损失值:2921.286376953125,判别器损失值248.41244506835938
第3轮的生成器损失值:3441.375732421875,判别器损失值159.46466064453125
第4轮的生成器损失值:3337.6513671875,判别器损失值215.8180694580078
第5轮的生成器损失值:3971.089111328125,判别器损失值125.399658203125
第6轮的生成器损失值:4233.75341796875,判别器损失值117.00439453125
第7轮的生成器损失值:4556.33935546875,判别器损失值124.78838348388672
第8轮的生成器损失值:4610.49658203125,判别器损失值130.80592346191406
第9轮的生成器损失值:4639.509765625,判别器损失值141.80921936035156
第10轮的生成器损失值:4648.3037109375,判别器损失值148.4178466796875
第11轮的生成器损失值:4764.65380859375,判别器损失值158.76829528808594
第12轮的生成器损失值:5018.9208984375,判别器损失值138.4495391845703
第13轮的生成器损失值:5489.87939453125,判别器损失值142.05093383789062
第14轮的生成器损失值:5771.021484375,判别器损失值122.76760864257812
第15轮的生成器损失值:5327.2373046875,判别器损失值142.30992126464844
第16轮的生成器损失值:5096.5966796875,判别器损失值131.3242950439453
第17轮的生成器损失值:5312.2607421875,判别器损失值143.69143676757812
第18轮的生成器损失值:5721.6220703125,判别器损失值135.46119689941406
第19轮的生成器损失值:4894.25341796875,判别器损失值178.1021728515625
第20轮的生成器损失值:4638.419921875,判别器损失值182.18136596679688
第21轮的生成器损失值:4944.3798828125,判别器损失值167.8157958984375
第22轮的生成器损失值:4811.68798828125,判别器损失值185.08151245117188
第23轮的生成器损失值:4419.439453125,判别器损失值213.2805633544922
第24轮的生成器损失值:4265.69873046875,判别器损失值224.0750732421875
第25轮的生成器损失值:3908.226318359375,判别器损失值263.6612243652344
第26轮的生成器损失值:3829.155517578125,判别器损失值275.2219543457031
第27轮的生成器损失值:3941.054443359375,判别器损失值260.74066162109375
第28轮的生成器损失值:4242.72216796875,判别器损失值236.7626953125
第29轮的生成器损失值:3621.1337890625,判别器损失值297.2765808105469
第30轮的生成器损失值:3010.93310546875,判别器损失值355.7358703613281
第31轮的生成器损失值:2710.3935546875,判别器损失值389.6776428222656
第32轮的生成器损失值:2636.367919921875,判别器损失值399.9930725097656
第33轮的生成器损失值:2637.282470703125,判别器损失值402.30859375
第34轮的生成器损失值:2988.23974609375,判别器损失值357.68804931640625
第35轮的生成器损失值:5289.29541015625,判别器损失值199.2281951904297
第36轮的生成器损失值:5623.095703125,判别器损失值194.9387969970703
第37轮的生成器损失值:3149.01416015625,判别器损失值348.37677001953125
第38轮的生成器损失值:2634.451904296875,判别器损失值405.61566162109375
第39轮的生成器损失值:2547.442138671875,判别器损失值418.43597412109375
第40轮的生成器损失值:2448.359130859375,判别器损失值432.43096923828125
第41轮的生成器损失值:2429.5537109375,判别器损失值443.5609130859375
第42轮的生成器损失值:2452.239013671875,判别器损失值449.2001953125
第43轮的生成器损失值:2652.614990234375,判别器损失值427.6022033691406
第44轮的生成器损失值:3612.08935546875,判别器损失值330.52484130859375
第45轮的生成器损失值:5931.68701171875,判别器损失值155.98318481445312
第46轮的生成器损失值:2881.45751953125,判别器损失值420.7775573730469
第47轮的生成器损失值:2272.14111328125,判别器损失值491.7653503417969
第48轮的生成器损失值:2222.366943359375,判别器损失值504.3684387207031
第49轮的生成器损失值:2263.885498046875,判别器损失值493.1470031738281
第50轮的生成器损失值:2237.902587890625,判别器损失值507.2246398925781
第51轮的生成器损失值:2207.36962890625,判别器损失值512.4918823242188
第52轮的生成器损失值:2241.222900390625,判别器损失值509.4316711425781
第53轮的生成器损失值:2264.532958984375,判别器损失值501.825927734375
第54轮的生成器损失值:2370.49365234375,判别器损失值487.0404052734375
第55轮的生成器损失值:2869.4560546875,判别器损失值425.4640197753906
第56轮的生成器损失值:3764.148681640625,判别器损失值317.8245849609375
第57轮的生成器损失值:3606.946533203125,判别器损失值340.7588195800781
第58轮的生成器损失值:2236.921875,判别器损失值521.7313842773438
第59轮的生成器损失值:2077.1865234375,判别器损失值551.8837890625
第60轮的生成器损失值:2065.66943359375,判别器损失值555.025634765625进程已结束,退出代码0

图像显示

 果然,从本次实验来看,我们得生成器损失值是波动的,判别器损失值也是,很难说他们的趋势走向(当然估计和我的训练轮数有关)

这是我们生成器生成的“伪造的图片”,从这里可以看出来已经很不错了。 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/691477.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【电子学会】2023年12月图形化一级 -- 魔法变变变

魔法变变变 1. 准备工作 (1)删除默认角色小猫,添加角色Wizard、Wizard Hat、Bananas; (2)添加背景Theater; (3)为角色Bananas添加任意五个造型,如下图所示。 2. 功能实现 (1)设置角色的位置、大小和方向,如下图所示; (2)按下空格键,魔法师换成“wizard-b…

WHAT - CSS Animationtion 动画系列(一)

目录 一、介绍二、animation-name三、animation-duration四、animation-timing-function4.1 介绍4.2 具体实践:抛物线 五、animation-delay六、animation-iteration-count七、animation-direction八、animation-fill-mode九、animation-play-state 今天我们主要学习…

基于鹈鹕优化算法POA的复杂城市地形下无人机避障三维航迹规划,可以修改障碍物及起始点(Matlab代码)

复杂城市地形下无人机避障三维航迹规划是指在城市等高密度区域内,通过无人机的传感器和导航系统来实现飞行路径的规划和调整,从而避免无人机与建筑物、其他无人机、地面障碍物等发生碰撞和冲突。具体来说,无人机需要实时感知周围环境&#xf…

QuickBooks 2024 for Mac 激活版:智慧管理,财务无忧

想要轻松掌控财务,实现高效管理吗?QuickBooks 2024 for Mac,您的智慧财务管理专家,为您带来前所未有的便利和体验。无论是账务、工资还是销售和库存,它都能一手搞定。直观易用的界面,让您轻松上手&#xff…

ctfshow web274

web274 thinkphp框架序列化漏洞 EXP <?php namespace think; abstract class Model{protected $append[];private $data[];function __construct(){$this->append["lin">["ctf","show"]];$this->data["lin">new Req…

5.10.4 Vision Transformer的条件位置编码(CPE)

用于视觉 Transformer 的条件位置编码&#xff08;CPE&#xff09;方案与之前预定义且独立于输入标记的固定或可学习位置编码不同&#xff0c;CPE 是动态生成的&#xff0c;并以输入标记的局部邻域为条件。 CPE 可以轻松泛化到比模型在训练期间见过的输入序列更长的输入序列。…

使用Ollama本地部署 Llama3大模型!最简单的方法,无需GPU也能使用

文章目录 前言一、安装Ollama客户端二、安装webUI1、安装Docker Desktop2、安装webUI3、设置语言4、下载模型 总结 前言 开源大模型领域当前最强的无疑是 LLaMA 3&#xff01;Meta 这次不仅免费公布了两个性能强悍的大模型&#xff08;8B 和 70B&#xff09;&#xff0c;还计划…

《QT实用小工具·六十三》QT实现微动背景,界面看似静态实则动态

1、概述 源码放在文章末尾 该项目实现了微动背景&#xff0c;界面看似静态实则动态&#xff0c;风动&#xff0c;幡动&#xff0c;仁者心动&#xff0c;所以到底是什么在动&#xff1f;哈哈~ 界面会偷偷一点一点改动文字颜色的颜色填充。 虽然是动态&#xff0c;但是慢到难以…

【牛客】SQL201 查找薪水记录超过15条的员工号emp_no以及其对应的记录次数t

1、描述 有一个薪水表&#xff0c;salaries简况如下&#xff1a; 请你查找薪水记录超过15条的员工号emp_no以及其对应的记录次数t&#xff0c;以上例子输出如下&#xff1a; 2、题目建表 drop table if exists salaries ; CREATE TABLE salaries ( emp_no int(11) NOT N…

当代 Qt 正确的 安装方法 及 多版本切换

此文写于 20240511 首先去网站Index of /official_releases/online_installers下载一个安装器 安装器有什么用? 可以浏览安装版本 安装组件 安装器版本越能 能装的东西越多 现在只能选Qt5 和 Qt6 至于你公司用的Qt4 我也没招 见招时再拆招 安装器 默认国外源 可以换国内…

Go-Zero自定义goctl实战:定制化模板,加速你的微服务开发效率(四)

前言 上一篇文章带你实现了Go-Zero和goctl&#xff1a;解锁微服务开发的神器&#xff0c;快速上手指南&#xff0c;本文将继续深入探讨Go-Zero的强大之处&#xff0c;并介绍如何使用goctl工具实现模板定制化&#xff0c;并根据实际项目业务需求进行模板定制化实现。 通过本文…

SpringBoot实现图片验证码

引入依赖 <dependency><groupId>com.github.whvcse</groupId><artifactId>easy-captcha</artifactId><version>1.6.2</version> </dependency>代码实现 package com.qiangesoft.captcha.controller;import com.wf.captcha.*…