Paddle 实现DCGAN

传统GAN

传统的GAN可以看我的这篇文章:Paddle 基于ANN(全连接神经网络)的GAN(生成对抗网络)实现-CSDN博客

DCGAN

DCGAN是适用于图像生成的GAN,它的特点是:

  • 只采用卷积层和转置卷积层,而不采用全连接层
  • 在每个卷积层或转置卷积层之间,插入一个批归一化层和ReLU激活函数

转置卷积层

转置卷积层执行的是转置卷积或反卷积的操作,即它是常规卷积层的反向操作。它接收一个低分辨率的输入,然后将其通过转置滤波器升采样到更高的分辨率。

对于一个卷积层,它的输出大小公式是:

o = \frac{i + 2p - k}{s} + 1

其中,o表示输出大小,i表示输入大小,p表示填充(padding),k表示卷积核大小(kernel_size),s表示步长(stride)。也就是说:输出大小 = (输入大小 - 卷积核大小 + 2 × 填充数) ÷ 步长 + 1

而对于一个转置卷积层,它的输出大小公式是:

o = s(i-1)-2p+k+u

 其中,o表示输出大小,i表示输入大小,p表示填充(padding),k表示反卷积核大小(kernel_size),s表示步长(stride),u表示输出填充(output padding)。也就是说:输出大小 = (输入大小 - 1) * 步长 - 2*填充 + 反卷积大小 + 输出填充

在paddle中,转置卷积层可以这么定义:

paddle.nn.Conv2DTranspose(in_channels, out_channels, kernel_size, stride, padding)

像卷积层一样,反卷积层的in_channels表示输入通道数(如形如(3, 32, 32)的图片张量的通道数就是3),out_channels表示输出通道数(如把(64, 32, 32)变成3通道的彩色图像(3, 32, 32))。 

代码实现

这里我们采用NWPU-RESISC45数据集,从中选择“freeway”(高速公路)作为训练数据,让机器生成高速公路的图片。这个训练数据内有700张256x256的图片,但由于我的电脑显存不足,因此将图片大小设置为64x64.

先写dataset.py:

import paddle
import numpy as np
from PIL import Image
import osdef getAllPath(path):return [os.path.join(path, f) for f in os.listdir(path)]class FreewayDataset(paddle.io.Dataset):def __init__(self, transform=None):super().__init__()self.data = []for path in getAllPath('./freeway'):img = Image.open(path)img = img.resize((64, 64))img = np.array(img, dtype=np.float32).transpose((2, 1, 0))if transform is not None:img = transform(img)self.data.append(img)self.data = np.array(self.data, dtype=np.float32)def __getitem__(self, idx):return self.data[idx]def __len__(self):return len(self.data)

然后写训练脚本:

from dataset import FreewayDataset
import paddle
from models import Generator, Discriminator
import numpy as npdataset = FreewayDataset()
dataloader = paddle.io.DataLoader(dataset, batch_size=32, shuffle=True)netG = Generator()
netD = Discriminator()if 1:try:mydict = paddle.load('generator.params')netG.set_dict(mydict)mydict = paddle.load('discriminator.params')netD.set_dict(mydict)except:print('fail to load model')loss = paddle.nn.BCELoss()optimizerD = paddle.optimizer.Adam(parameters=netD.parameters(), learning_rate=0.0002, beta1=0.5, beta2=0.999)
optimizerG = paddle.optimizer.Adam(parameters=netG.parameters(), learning_rate=0.0002, beta1=0.5, beta2=0.999)# 最大迭代epoch
max_epoch = 1000for epoch in range(max_epoch):now_step = 0for step, data in enumerate(dataloader):############################# (1) 更新鉴别器############################ 清除D的梯度optimizerD.clear_grad()# 传入正样本,并更新梯度pos_img = datalabel = paddle.full([pos_img.shape[0], 1, 1, 1], 1, dtype='float32')pre = netD(pos_img)loss_D_1 = loss(pre, label)loss_D_1.backward()# 通过randn构造随机数,制造负样本,并传入D,更新梯度noise = paddle.randn([pos_img.shape[0], 100, 1, 1], 'float32')neg_img = netG(noise)label = paddle.full([pos_img.shape[0], 1, 1, 1], 0, dtype='float32')pre = netD(neg_img.detach())  # 通过detach阻断网络梯度传播,不影响G的梯度计算loss_D_2 = loss(pre, label)loss_D_2.backward()# 更新D网络参数optimizerD.step()optimizerD.clear_grad()loss_D = loss_D_1 + loss_D_2############################# (2) 更新生成器############################ 清除D的梯度optimizerG.clear_grad()noise = paddle.randn([pos_img.shape[0], 100, 1, 1], 'float32')fake = netG(noise)label = paddle.full((pos_img.shape[0], 1, 1, 1), 1, dtype=np.float32, )output = netD(fake)# 这个写法没有问题,因为这个loss既会影响到netG(output=netD(netG(noise)))的梯度,也会影响到netD的梯度,但是之后的代码并没有更新netD的参数,而循环开头就清除了netD的梯度loss_G = loss(output, label)loss_G.backward()# 更新G网络参数optimizerG.step()optimizerG.clear_grad()now_step += 1############################ 输出日志###########################if now_step % 10 == 0:print(f'Epoch ID={epoch} Batch ID={now_step} \n\n D-Loss={float(loss_D)} G-Loss={float(loss_G)}')paddle.save(netG.state_dict(), "generator.params")
paddle.save(netD.state_dict(), "discriminator.params")

 最后编写图片生成脚本:

import paddle
from models import Generator
import matplotlib.pyplot as plt# 加载模型
netG = Generator()
mydict = paddle.load('generator.params')
netG.set_dict(mydict)# 设置matplotlib的显示环境
fig, axs = plt.subplots(nrows=2, ncols=5, figsize=(15, 6))  # 创建一个2x5的子图网格# 生成10个噪声向量
for i, ax in enumerate(axs.flatten()):noise = paddle.randn([1, 100, 1, 1], 'float32')img = netG(noise)img = img.numpy()[0].transpose((2, 1, 0))  # img.numpy():张量转np数组img[img < 0] = 0  # 将img中所有小于0的元素赋值为0# 显示图片ax.imshow(img)ax.axis('off')  # 不显示坐标轴# 显示图像
plt.show()

经过数次训练,最终的效果如下:

这样看来,至少有点高速公路的感觉了。 

参考

通过DCGAN实现人脸图像生成-使用文档-PaddlePaddle深度学习平台

卷积层和反卷积层输出特征图大小计算_输出特征图大小的计算方法-CSDN博客 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/686069.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

预约咨询小程序源码搭建/部署/上线/运营/售后/更新

包含在线咨询、视频咨询、电话咨询、面询多种咨询方式&#xff0c;适用于心理、法律、宠物等预约咨询问诊场景 分类预览&#xff1a;小程序提供清晰的分类选项&#xff0c;使用户能够迅速找到所需的咨询服务类型&#xff0c;如法律咨询、心理咨询、医疗咨询等。预约时间选择&a…

多目标跟踪入门介绍

多目标跟踪算法 我们也可以称之为 Multi-Target-Tracking &#xff08;MTT&#xff09;。 那么多目标跟踪是什么&#xff1f; 不难看出&#xff0c;跟踪算法同时会为每个目标分配一个特定的 id 。 由此得出了目标跟踪与目标检测的区别&#xff08;似乎都是用方框来框出目标捏…

java编程中,实现分页对象的类型转换

一、背景 当数据库分页查询返回的对象与接口要返回的对象类型不一致时&#xff0c;不可避免需要进行类型转换。 示例&#xff1a;数据库分页查询返回的对象是PageDTO&#xff0c;而接口返回的对象类型是PageVO。 PageDTO Data public class PageDTO<T> {/*** Current…

同一局域网内互传文件

1. 打开要共享的文件夹&#xff0c;然后在地址框内输入cmd 2. 弹出的命令框内输入python -m http.server &#xff08;这么就创建好了共享服务器&#xff09; 3.win R输入cmd运行 4.输入ipconfig找到IP地址 5.另一台同一局域网内的机子就可以在网页浏览器输入ip和端口号…

五金建材微信小程序商城系统开发搭建指南

如今&#xff0c;随着移动互联网的发展&#xff0c;小程序成为了商家们开拓新市场、增加收益的重要途径。特别是对于五金店这类实体店铺来说&#xff0c;通过小程序开设线上商城&#xff0c;不仅可以提升品牌影响力&#xff0c;还能够实现线上线下的无缝对接&#xff0c;为店家…

SpringBoot 实现 RAS+AES 自动接口解密

接口安全老生常谈了 目前常用的加密方式就对称性加密和非对称性加密&#xff0c;加密解密的操作的肯定是大家知道的&#xff0c;最重要的使用什么加密解密方式&#xff0c;制定什么样的加密策略&#xff1b;考虑到我技术水平和接口的速度&#xff0c;采用的是RAS非对称加密和AE…

FilterListener详解

文章目录 MVC模式和三层架构MVC模式三层架构MVC和三层架构 JavaWeb的三大组件Filter概述快速入门过滤器API介绍过滤器开发步骤配置过滤器俩种方式修改idea的过滤器模板 使用细节生命周期拦截路径过滤器链 案例统一解决全站乱码问题登录权限校验验 ServletContextServletContext…

回溯算法—组合问题

文章目录 介绍应用问题基本流程算法模版例题&#xff08;1&#xff09;组合&#xff08;2&#xff09;电话号码的字母组合 介绍 回溯算法实际上是 一个类似枚举的搜索尝试过程&#xff0c;主要是在搜索尝试过程中寻找问题的解&#xff0c;当发现已不满足求解条件时&#xff0c;…

Spring添加注解读取和存储对象

5大注解 Controller 控制器 Service 服务 Repository 仓库 Componet 组件 Configuration 配置 五大类注解的使用 //他们都是放在同一个目录下&#xff0c;不同的类中 只不过这里粘贴到一起//控制器 Controller public class UserController {public void SayHello(){System.ou…

C++进阶 | [3] 搜索二叉树

摘要&#xff1a;什么是搜索二叉树&#xff0c;实现搜索二叉树&#xff08;及递归版本&#xff09; 什么是搜索二叉树 搜索二叉树/二叉排序树/二叉查找树BST&#xff08;Binary Search Tree&#xff09;&#xff1a;特征——左小右大&#xff08;不允许重复值&#xff09;。即…

pydev debugger: process **** is connecting

目录 解决方案一解决方案二 1、调试时出现pydev debugger: process **** is connecting 解决方案一 File->settings->build,execution,deployment->python debugger 下面的attach to subprocess automatically while debugging取消前面的勾选&#xff08;默认状态为勾…

rbac权限和多级请假设计的流程演示和前端页面实现

登录账号&#xff1a;t6普通用户 t7部门经理 m8总经理 密码都为&#xff1a;test 多级请假&#xff1a;7级及以下申请请假需要部门经理审核&#xff0c;若是请假时长超过72小时&#xff0c;则需要总经理审核&#xff0c;7级申请请将需要总经理审核&#xff0c;总经理请假自动审…