【论文复现】Conditional Generative Adversarial Nets(CGAN)

文章目录

  • GAN基础理论
  • 2.1 算法来源
  • 2.2 算法介绍
  • 2.3 基于CGAN的手写数字生成实验
    • 2.3.1 网络结构
    • 2.3.2 训练过程
      • 一、 D的loss (discriminator_train_step)
      • 二、 G的loss (generator_train_step)
  • 2.4 实验分析
    • 2.4.1 超参数调整
      • 一、batch size
      • 二、 epochs
      • 三、 Adam:learning rate
      • 四、 Adam:weight_decay
      • 五、 n_critic
    • 2.4.2 模型改进
      • 一、 超参数优化
      • 二、 逐层归一化
      • 三、 损失函数改进
      • 四、 激活函数选择
      • 五、 优化器改进
      • 六、 噪声z的分布
      • 七、 其余设想
    • 2.4.3 模型测试

GAN基础理论

具体内容详见:【论文复现】Generative Adversarial Nets(GAN基础理论)

2.1 算法来源

作者:Mehdi Mirza, Simon Osindero
摘要
  Generative Adversarial Nets were recently introduced as a novel way to train generative models. In this work we introduce the conditional version of generative adversarial nets, which can be constructed by simply feeding the data, y, we wish to condition on to both the generator and discriminator. We show that this model can generate MNIST digits conditioned on class labels. We also illustrate how this model could be used to learn a multi-modal model, and provide preliminary examples of an application to image tagging in which we demonstrate how this approach can generate descriptive tags which are not part of training labels.
日期:6 Nov 2014
论文链接
https://arxiv.org/pdf/1411.1784.pdf
实验数据
https://github.com/MrHeadbang/machineLearning/blob/main/mnist.zip
代码链接
https://github.com/eriklindernoren/PyTorch-GAN/blob/master/implementations/cgan/cgan.py

2.2 算法介绍

  CGAN(条件式生成对抗网络)是对原始GAN的一种变形,其生成器和判别器都增加额外信息C作为条件条件可以是类别信息、或其他模态数据。通过将额外信息C输送给判别模型和生成模型,作为输入层的一部分,其架构图如下:
在这里插入图片描述

  和原始GAN一样,CGAN还是基于多层感知器。在原始GAN中,判别器的输入是训练样本x,生成器的的输入是噪声z,而在CGAN中,生成器和判别器的输入都多了一个y,这个y就是那个额外条件信息。

  • 把噪声z和条件y作为输入同时送进生成器生成跨域向量,再通过非线性函数映射到数据空间。
  • 把数据x和条件y作为输入同时送进判别器生成跨域向量,并进一步判断x是真实训练数据的概率。
  • 二元极小极大博弈转变为:
    在这里插入图片描述

2.3 基于CGAN的手写数字生成实验

2.3.1 网络结构

  原始的GAN是无监督的,包括之前实验课上的DCGAN,其输出是完全随机的,在人脸上训练好的网络,最后生成什么样的人脸是完全没办法控制的。而CGAN则是有监督的GAN,在MNIST上以数字类别标签为约束条件,最终根据类别标签信息,生成对应的数字图像
在这里插入图片描述
  本实验使用MNIST(手写数字体)数据集,生成器的输入是100维服从均匀分布的噪声向量,以类别标签(one-hot编码)为条件来训练CGAN,生成器经过sigmoid生成784维(28x28)的单通道图像(每张图片的shape是[1, 28, 28]),判别器的输入为784维的图像和类别标签(one-hot编码),输出是该样本来自训练集的的概率。

2.3.2 训练过程

  CGAN的损失函数即BCELoss:及Adam优化器

# Define the Binary Cross Entropy Loss criterion for the GAN
criterion = nn.BCELoss()# Set up optimizers for the discriminator and generator models
# Use Adam optimizer for updating discriminator's parameters
d_optimizer = torch.optim.Adam(discriminator.parameters(), lr=1e-4)
# Use Adam optimizer for updating generator's parameters
g_optimizer = torch.optim.Adam(generator.parameters(), lr=1e-4)

  CGAN中有三个loss,一个是D的real概率的loss,一个是D的fake概率的loss(二者相加得到d_loss),最后是G的real的loss。

for epoch in range(num_epochs):print('Starting epoch {}...'.format(epoch), end=' ')# Iterate through the data loaderfor i, (images, labels) in enumerate(data_loader):step = epoch * len(data_loader) + i + 1real_images = Variable(images).to(device)labels = Variable(labels).to(device)generator.train()d_loss = 0# Perform multiple discriminator training stepsfor _ in range(n_critic):d_loss = discriminator_train_step(len(real_images), discriminator,generator, d_optimizer, criterion,real_images, labels,device)# Perform a single generator training stepg_loss = generator_train_step(batch_size, discriminator, generator, g_optimizer, criterion, device)# Write the losses to TensorBoardwriter.add_scalars('scalars', {'g_loss': g_loss, 'd_loss': (d_loss / n_critic)}, step)  

【深度学习实验】TensorBoard使用教程【SCALARS、IMAGES、TIME SERIES】

一、 D的loss (discriminator_train_step)

  • D的real概率的loss
  	# Train the discriminator with real imagesreal_validity = discriminator(real_images, labels)# Calculate loss on real images; discriminator's goal: classify real images as real (1)real_loss = criterion(real_validity, Variable(torch.ones(batch_size)).to(device))

输入的是真实的MNIST数据集的图像,是real

  • D的fake概率的loss
	# Train the discriminator with fake imagesz = Variable(torch.randn(batch_size, 100)).to(device)fake_labels = Variable(torch.LongTensor(np.random.randint(0, 10, batch_size))).to(device)fake_images = generator(z, fake_labels)fake_validity = discriminator(fake_images, fake_labels)# Calculate loss on fake images; discriminator's goal: classify fake images as fake (0)fake_loss = criterion(fake_validity, Variable(torch.zeros(batch_size)).to(device))

输入的是G生成的假的图像,是fake,要让判别器知道

  • D的total loss
    # Total discriminator loss is the sum of losses on real and fake imagesd_loss = real_loss + fake_loss# Backpropagation: Compute gradients and update discriminator's weightsd_loss.backward()

整合:

def discriminator_train_step(batch_size, discriminator, generator, d_optimizer, criterion, real_images, labels, device):# Zero out the gradients from the previous iterationd_optimizer.zero_grad()# Train the discriminator with real imagesreal_validity = discriminator(real_images, labels)# Calculate loss on real images; discriminator's goal: classify real images as real (1)real_loss = criterion(real_validity, Variable(torch.ones(batch_size)).to(device))# Train the discriminator with fake imagesz = Variable(torch.randn(batch_size, 100)).to(device)fake_labels = Variable(torch.LongTensor(np.random.randint(0, 10, batch_size))).to(device)fake_images = generator(z, fake_labels)fake_validity = discriminator(fake_images, fake_labels)# Calculate loss on fake images; discriminator's goal: classify fake images as fake (0)fake_loss = criterion(fake_validity, Variable(torch.zeros(batch_size)).to(device))# Total discriminator loss is the sum of losses on real and fake imagesd_loss = real_loss + fake_loss# Backpropagation: Compute gradients and update discriminator's weightsd_loss.backward()d_optimizer.step()# Return the discriminator's loss as a Python floatreturn d_loss.item()

二、 G的loss (generator_train_step)

  生成器要骗过判别器,生成较为逼真的图像。怎么骗判别器?那就是在做一个real的loss,用的还是G生成的图像数据。

def generator_train_step(batch_size, discriminator, generator, g_optimizer, criterion, device):# Zero out the gradients from the previous iterationg_optimizer.zero_grad()# Generate random noise vector zz = Variable(torch.randn(batch_size, 100)).to(device)# Generate random labels for the fake imagesfake_labels = Variable(torch.LongTensor(np.random.randint(0, 10, batch_size))).to(device)# Generate fake images using the generatorfake_images = generator(z, fake_labels)# Get the discriminator's prediction on the generated fake imagesvalidity = discriminator(fake_images, fake_labels)# Calculate the generator's loss using the discriminator's prediction# Generator's goal: Make the discriminator classify generated images as real (1)g_loss = criterion(validity, Variable(torch.ones(batch_size)).to(device))# Backpropagation: Compute gradients and update generator's weightsg_loss.backward()g_optimizer.step()# Return the generator's loss as a Python floatreturn g_loss.item()

2.4 实验分析

2.4.1 超参数调整

具体内容详见:【论文复现】基于CGAN的手写数字生成实验——超参数调整

一、batch size

二、 epochs

三、 Adam:learning rate

四、 Adam:weight_decay

五、 n_critic

2.4.2 模型改进

具体内容详见:【论文复现】基于CGAN的手写数字生成实验——模型改进

一、 超参数优化

二、 逐层归一化

三、 损失函数改进

四、 激活函数选择

五、 优化器改进

六、 噪声z的分布

七、 其余设想

2.4.3 模型测试

Batch Normalization + PReLU激活函数+AdamW优化器
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/443133.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

java+springboot企业员工工作日志审批管理系统ssm+vue

企业OA管理系统具有管理员角色,用户角色,这两个操作权限。 ①管理员 管理员在企业OA管理系统里面查看并管理人事信息,工作审批信息,部门信息,通知公告信息以及内部邮件信息。 管理员功能结构图如下: ide工具…

isctf---web

圣杯战争 php反序列 ?payloadO:6:"summon":2:{s:5:"Saber";O:8:"artifact":2:{s:10:"excalibuer";O:7:"prepare":1:{s:7:"release";O:5:"saber":1:{s:6:"weapon";s:52:"php://filter…

1 月 30 日算法练习-思维和贪心

文章目录 重复字符串翻硬币乘积最大 重复字符串 思路&#xff1a;判断是否能整除&#xff0c;如果不能整除直接退出&#xff0c;能整除每次从每组对应位置中找出出现最多的字母将其他值修改为它&#xff0c;所有修改次数即为答案。 #include<iostream> using namespace …

【JS逆向实战-入门篇】某gov网站加密参数分析与Python算法还原

文章目录 1. 写在前面2. 请求分析3. 断点分析4. 算法还原 【作者主页】&#xff1a;吴秋霖 【作者介绍】&#xff1a;Python领域优质创作者、阿里云博客专家、华为云享专家。长期致力于Python与爬虫领域研究与开发工作&#xff01; 【作者推荐】&#xff1a;对JS逆向感兴趣的朋…

如何对Ajax请求进行封装操作,解决跨域问题的方法,如何使用core解决跨域

目录 1.Ajax原理 2.为什么要封装 3.如何进行封装 4.如何请求 5.如何解决Ajax跨域问题 6.使用CORS解决Ajax跨域问题 1.服务端 2.客户端 1.Ajax原理 Ajax&#xff08;Asynchronous JavaScript and XML&#xff09;是一种通过在后台与服务器进行少量数据交换&…

微信小程序(二十八)网络请求数据进行列表渲染

注释很详细&#xff0c;直接上代码 上一篇 新增内容&#xff1a; 1.GET请求的规范 2.数据赋值的方法 源码&#xff1a; index.wxml <!-- 列表渲染基础写法&#xff0c;不明白的看上一篇 --> <view class"students"><view class"item">&…

Shell脚本之 -------------免交互操作

一、Here Document 1.Here Document概述 Here Document 使用I/O重定向的方式将命令列表提供给交互式程序 Here Document 是标准输 入的一种替代品&#xff0c;可以帮助脚本开发人员不必使用临时文件来构建输入信息&#xff0c;而是直接就地 生产出一个文件并用作命令的标准…

排序链表---归并--链表OJ

https://leetcode.cn/problems/sort-list/submissions/499363940/?envTypestudy-plan-v2&envIdtop-100-liked 这里我们直接进阶&#xff0c;用时间复杂度O(nlogn)&#xff0c;空间复杂度O(1)&#xff0c;来解决。 对于归并&#xff0c;如果自上而下的话&#xff0c;空间复…

Netty源码二:服务端创建NioEventLoopGroup

示例 还是拿之前启动源码的示例&#xff0c;来分析NioEventLoopGroup源码 NioEventLoopGroup构造函数 这里能看到会调到父类的MultiThread EventLoopGroup的构造方法 MultiThreadEventLoopGroup 这里我们能看到&#xff0c;如果传入的线程数目为0&#xff0c;那么就会设置2倍…

代码随想录 Leetcode222.完全二叉树的节点个数

题目&#xff1a; 代码&#xff08;首刷自解 2024年1月30日&#xff09;&#xff1a; class Solution { public:int countNodes(TreeNode* root) {int res 0;if (root nullptr) return res;queue<TreeNode*> deque;TreeNode* cur root;deque.push(cur);int size 0;w…

网络隔离场景下访问 Pod 网络

接着上文 VPC网络架构下的网络上数据采集 介绍 考虑一个监控系统&#xff0c;它的数据采集 Agent 是以 daemonset 形式运行在物理机上的&#xff0c;它需要采集 Pod 的各种监控信息。现在很流行的一个监控信息是通过 Prometheus 提供指标信息。 一般来说&#xff0c;daemonset …

数据中心IP代理是什么?有何优缺点?海外代理IP全解

海外代理IP中&#xff0c;数据中心代理IP是很热门的选择。这些代理服务器为用户分配不属于 ISP&#xff08;互联网服务提供商&#xff09;且来自第三方云服务提供商的 IP 地址&#xff0c;是分配给位于数据中心的服务器的 IP 地址&#xff0c;通常由托管和云公司拥有。 这些 I…