Paddle 基于ANN(全连接神经网络)的GAN(生成对抗网络)实现

什么是GAN

GAN是生成对抗网络,将会根据一个随机向量,实现数据的生成(如生成手写数字、生成文本等)。

GAN的训练过程中,需要有一个生成器G和一个鉴别器D.

生成器用于生成数据,鉴定器用于鉴定数据的准确性,其实就是在鉴别数据是人生成的还是机器生成的,因为生成器需要以假乱真。

鉴别器将会与生成器一起训练。鉴别器将会先训练,这样才有适当的能力去鉴定生成器生成数据的准确性。

鉴别器的训练过程中,需要先给它准确的数据,和通过随机向量传入生成器产生的数据(一律视为负样本),并通过损失函数对其进行训练;生成器训练过程中,会先给它一个随机向量进行前向传播,然后让鉴别器判断其正确性,并通过损失函数(不正确的数据意味着有损失)进行训练:

生成器训练过程中,需要先通过随机向量获取其结果,然后让鉴别器进行鉴别,在通过鉴别器的鉴别结果计算损失(如果鉴别器认为这是生成器生成的,则产生损失),最后更新梯度和参数:

训练过程直到生成器拟合训练集(收敛),判别器的输出总是0.5(均方误差损失函数应为0.25)为止.

形象的GAN的例子

想象一场由一位“名画伪造者”和一位“艺术鉴定家”参与的猫捉老鼠游戏。

在这个场景中,名画伪造者(即GAN中的生成器)的目标是创造出一幅足以欺骗艺术鉴定家(即GAN中的判别器)的假画。开始时,伪造者的技艺并不精湛,他制作的假画充满了破绽,很容易被鉴定家一眼识破。

然而,随着伪造者不断尝试和失败,他逐渐从每一次的失败中学习,逐渐提升了自己的技艺。他开始注意到真画的每一个细节,从笔触、色彩到构图,都尽量模仿得惟妙惟肖。每一次的失败都让他更接近成功,他制作的假画也越来越难以辨别真伪。

而艺术鉴定家也不甘示弱。他开始时能够轻易地识别出伪造者的假画,但随着伪造者技艺的提升,他也需要不断提升自己的鉴定能力。他开始深入研究真画的每一个特点,以便更准确地识别出伪造者的假画。

这个过程就像GAN中的训练过程一样。生成器不断尝试生成新的数据(在这里是假画),而判别器则不断尝试区分这些数据是真实的还是生成的。两者在相互竞争的过程中不断提升自己的能力,最终达到了一个平衡状态。

在这个例子中,名画伪造者就是GAN中的生成器,他负责生成新的数据;而艺术鉴定家则是GAN中的判别器,他负责区分数据的真伪。两者在相互竞争的过程中共同进步,使得生成的数据越来越接近真实的数据。

代码实现

本文将以基于MNIST(手写数据集)为数据集,实现一个生成手写数字的GAN模型:

首先创建models.py,用于定义判别器和生成器:

import paddle# Generator Code
class Generator(paddle.nn.Layer):def __init__(self, ):super(Generator, self).__init__()self.gen = paddle.nn.Sequential(paddle.nn.Linear(in_features=100, out_features=256),paddle.nn.ReLU(True),paddle.nn.Linear(in_features=256, out_features=512),paddle.nn.ReLU(True),paddle.nn.Linear(in_features=512, out_features=1024),paddle.nn.Tanh(),)def forward(self, x):x = self.gen(x)out = paddle.reshape(x,[-1,1,32,32])return out# Discriminator Code
class Discriminator(paddle.nn.Layer):def __init__(self, ):super(Discriminator, self).__init__()self.dis = paddle.nn.Sequential(paddle.nn.Linear(in_features=1024, out_features=512),paddle.nn.LeakyReLU(0.2),paddle.nn.Linear(in_features=512, out_features=256),paddle.nn.LeakyReLU(0.2),paddle.nn.Linear(in_features=256, out_features=1),paddle.nn.Sigmoid())def forward(self, x):x = paddle.reshape(x, [-1, 1024])out = self.dis(x)return out

其中,生成器将接收一个长度为100的张量(随机向量),输出一个长度为1024的张量(生成的图片);鉴别器将接收一个长度为1024的张量(图片) ,输出长度为1的张量(鉴别结果)

然后创建main.py,用于训练:

import paddle
import matplotlib.pyplot as plt
from models import Generator, Discriminator
import numpy as npdataset = paddle.vision.datasets.MNIST(mode='train',transform=paddle.vision.transforms.Compose([paddle.vision.transforms.Resize((32, 32)),paddle.vision.transforms.Normalize([0], [255])]))dataloader = paddle.io.DataLoader(dataset, batch_size=32, shuffle=True)netG = Generator()
netD = Discriminator()if 1:try:mydict = paddle.load('generator.params')netG.set_dict(mydict)mydict = paddle.load('discriminator.params')netD.set_dict(mydict)except:print('fail to load model')optimizerD = paddle.optimizer.Adam(parameters=netD.parameters(), learning_rate=0.0002, beta1=0.5, beta2=0.999)
optimizerG = paddle.optimizer.Adam(parameters=netG.parameters(), learning_rate=0.0002, beta1=0.5, beta2=0.999)# 最大迭代epoch
max_epoch = 10for epoch in range(max_epoch):now_step = 0for step, (data, label) in enumerate(dataloader):############################# (1) 更新鉴别器############################ 清除D的梯度optimizerD.clear_grad()# 传入正样本,并更新梯度pos_img = datalabel = paddle.full([pos_img.shape[0], 1], 1, dtype='float32')pre = netD(pos_img)loss_D_1 = paddle.nn.functional.mse_loss(pre, label)loss_D_1.backward()# 通过randn构造随机数,制造负样本,并传入D,更新梯度noise = paddle.randn([pos_img.shape[0], 100], 'float32')neg_img = netG(noise)label = paddle.full([pos_img.shape[0], 1], 0, dtype='float32')pre = netD(neg_img.detach())  # 通过detach阻断网络梯度传播,不影响G的梯度计算loss_D_2 = paddle.nn.functional.mse_loss(pre, label)loss_D_2.backward()# 更新D网络参数optimizerD.step()optimizerD.clear_grad()loss_D = loss_D_1 + loss_D_2############################# (2) 更新生成器############################ 清除D的梯度optimizerG.clear_grad()noise = paddle.randn([pos_img.shape[0], 100], 'float32')fake = netG(noise)label = paddle.full((pos_img.shape[0], 1), 1, dtype=np.float32, )output = netD(fake)# 这个写法没有问题,因为这个mse_loss既会影响到netG(output=netD(netG(noise)))的梯度,也会影响到netD的梯度,但是之后的代码并没有更新netD的参数,而循环开头就清除了netD的梯度loss_G = paddle.nn.functional.mse_loss(output, label)loss_G.backward()# 更新G网络参数optimizerG.step()optimizerG.clear_grad()now_step += 1############################ 输出日志###########################if now_step % 100 == 0:print(f'Epoch ID={epoch} Batch ID={now_step} \n\n D-Loss={float(loss_D)} G-Loss={float(loss_G)}')paddle.save(netG.state_dict(), "generator.params")
paddle.save(netD.state_dict(), "discriminator.params")

如果是第一次训练或不使用原有训练参数,可以将if 1改成if 0.

接下来创建use.py,用于生成图片:

import paddle
from models import Generator
import matplotlib.pyplot as pltimport paddle
from models import Generator
import matplotlib.pyplot as plt
import numpy as np# 加载模型
netG = Generator()
mydict = paddle.load('generator.params')
netG.set_dict(mydict)# 设置matplotlib的显示环境
fig, axs = plt.subplots(nrows=2, ncols=5, figsize=(15, 6))  # 创建一个2x5的子图网格# 生成10个噪声向量
for i, ax in enumerate(axs.flatten()):noise = paddle.randn([1, 100], 'float32')img = netG(noise)img = img.numpy()[0][0]  # img.numpy():张量转np数组img[img < 0] = 0  # 将img中所有小于0的元素赋值为0img = np.clip(img, 0, 1)  # 将img中所有小于0的元素设为0,大于1的设为1(如果需要)# 显示图片ax.imshow(img)ax.axis('off')  # 不显示坐标轴# 显示图像
plt.show()

进行多轮训练后,生成结果:

 

可以看到,它很好的生成了我们想要的图片。

GANs

但是,我们这个模型只能随机产生数字,还不能生成指定的数字(如让机器生成一个1).为了解决这个问题,我们可以针对每一个数字生成一个对应的GAN,所有这样的GAN组合起来,就是GANs. 这里不展开讲解。

参考

MNIST数据集下用Paddle框架的动态图模式玩耍经典对抗生成网络(GAN)-使用文档-PaddlePaddle深度学习平台

【飞桨PaddlePaddle】四天搞懂生成对抗网络(一)——通俗理解经典GAN_四天搞懂生成对抗网络(一)-CSDN博客

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/685896.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

国产开源物联网操作系统

软件介绍 RT-Thread是一个开源、中立、社区化发展的物联网操作系统&#xff0c;采用C语言编写&#xff0c;具有易移植的特性。该项目提供完整版和Nano版以满足不同设备的资源需求。 功能特点 1.内核层 RT-Thread内核包括多线程调度、信号量、邮箱、消息队列、内存管理、定时器…

[c++]多态的分析

多态详细解读 多态的概念多态的构成条件 接口继承和实现继承: 多态的原理:动态绑定和静态绑定 多继承中的虚函数表 多态的概念 -通俗的来说&#xff1a;当不同的对象去完成某同一行为时&#xff0c;会产生不同的状态。 多态的构成条件 必须通过基类的指针或者引用调用虚函数1虚…

合并两个有序链表(C语言)———链表经典算法题

题目描述​​​​​​21. 合并两个有序链表 - 力扣&#xff08;LeetCode&#xff09;&#xff1a; 答案展示: 迭代&#xff1a; /*** Definition for singly-linked list.* struct ListNode {* int val;* struct ListNode *next;* };*/ struct ListNode* mergeTwoLis…

【docker 】push 镜像到私服

查看镜像 docker images把这个hello-world 推送到私服 docker push hello-world:latest 报错了。不能推送。需要标记镜像 标记Docker镜像 docker tag hello-world:latest 192.168.2.1:5000/hello-world:latest 将Docker镜像推送到私服 docker push 192.168.2.1:5000/hello…

【吃透Java手写】2-Spring(下)-AOP-事务及传播原理

【吃透Java手写】Spring&#xff08;下&#xff09;AOP-事务及传播原理 6 AOP模拟实现6.1 AOP工作流程6.2 定义dao接口与实现类6.3 初始化后逻辑6.4 原生Spring的方法6.4.1 实现类6.4.2 定义通知类&#xff0c;定义切入点表达式、配置切面6.4.3 在配置类中进行Spring注解包扫描…

python中如何把list变成字符串

python中如何把list变成字符串&#xff1f;方法如下&#xff1a; python中list可以直接转字符串&#xff0c;例如&#xff1a; data ["hello", "world"] print(data1:,str(data)) 得到结果&#xff1a; (data1:, "[hello, world]") 这里将整个…

芋道----工作流中添加邮件通知

1、配置邮件发送的账号 2、编辑邮件的内容模板 如何新建邮箱&#xff0c;直接查看芋道官网即可&#xff0c;已经讲解的很详细了&#xff0c;可以直接点击下方链接 邮件配置 | ruoyi-vue-pro 开发指南 (iocoder.cn)https://doc.iocoder.cn/mail/#_3-1-%E6%96%B0%E5%BB%BA%E9%82…

Android项目转为鸿蒙,真就这么简单?

最近做了一个有关Android转换成鸿蒙的项目。经不少开发者的反馈&#xff1b;许多公司的业务都增加了鸿蒙板块。 对此想分享一下这个项目转换的流程结构&#xff0c;希望能够给大家在工作中带来一些帮助。转换流程示意图如下&#xff1a; 下面我就给大家介绍&#xff0c;Android…

26、Qt使用QFontDatabase类加载ttf文件更改图标颜色

一、图标下载 iconfont-阿里巴巴矢量图标库 点击上面的链接&#xff0c;在打开的网页中搜索自己要使用的图标&#xff0c;如&#xff1a;最大化 找到一个自己想用图标&#xff0c;选择“添加入库” 点击“购物车”图标 能看到刚才添加的图标&#xff0c;点击“下载代码”(需要…

手撕C语言题典——移除链表元素(单链表)

目录 前言 一.思路 1&#xff09;遍历原链表&#xff0c;找到值为 val 的节点并释放 2&#xff09;创建新链表 二.代码实现 1)大胆去try一下思路 2&#xff09;竟然报错了&#xff1f;&#xff01; 3&#xff09;完善之后的成品代码 搭配食用更佳哦~~ 数据结构之单…

双向链表(详解)

在单链表专题中我们提到链表的分类&#xff0c;其中提到了带头双向循环链表&#xff0c;今天小编将详细讲下双向链表。 话不多说&#xff0c;直接上货。 1.双向链表的结构 带头双向循环链表 注意 这几的“带头”跟前面我们说的“头节点”是两个概念&#xff0c;实际前面的在…

【机器学习与实现】线性回归分析

目录 一、相关和回归的概念&#xff08;一&#xff09;变量间的关系&#xff08;二&#xff09;Pearson&#xff08;皮尔逊&#xff09;相关系数 二、线性回归的概念和方程&#xff08;一&#xff09;回归分析概述&#xff08;二&#xff09;线性回归方程 三、线性回归模型的损…