【论文复现】Conditional Generative Adversarial Nets(CGAN)


  • GAN基础理论
  • 2.1 算法来源
  • 2.2 算法介绍
  • 2.3 基于CGAN的手写数字生成实验
    • 2.3.1 网络结构
    • 2.3.2 训练过程
      • 一、 D的loss (discriminator_train_step)
      • 二、 G的loss (generator_train_step)
  • 2.4 实验分析
    • 2.4.1 超参数调整
      • 一、batch size
      • 二、 epochs
      • 三、 Adam:learning rate
      • 四、 Adam:weight_decay
      • 五、 n_critic
    • 2.4.2 模型改进
      • 一、 超参数优化
      • 二、 逐层归一化
      • 三、 损失函数改进
      • 四、 激活函数选择
      • 五、 优化器改进
      • 六、 噪声z的分布
      • 七、 其余设想
    • 2.4.3 模型测试


2.1 算法来源

作者:Mehdi Mirza, Simon Osindero
  Generative Adversarial Nets were recently introduced as a novel way to train generative models. In this work we introduce the conditional version of generative adversarial nets, which can be constructed by simply feeding the data, y, we wish to condition on to both the generator and discriminator. We show that this model can generate MNIST digits conditioned on class labels. We also illustrate how this model could be used to learn a multi-modal model, and provide preliminary examples of an application to image tagging in which we demonstrate how this approach can generate descriptive tags which are not part of training labels.
日期:6 Nov 2014

2.2 算法介绍



  • 把噪声z和条件y作为输入同时送进生成器生成跨域向量,再通过非线性函数映射到数据空间。
  • 把数据x和条件y作为输入同时送进判别器生成跨域向量,并进一步判断x是真实训练数据的概率。
  • 二元极小极大博弈转变为:

2.3 基于CGAN的手写数字生成实验

2.3.1 网络结构

  本实验使用MNIST(手写数字体)数据集,生成器的输入是100维服从均匀分布的噪声向量,以类别标签(one-hot编码)为条件来训练CGAN,生成器经过sigmoid生成784维(28x28)的单通道图像(每张图片的shape是[1, 28, 28]),判别器的输入为784维的图像和类别标签(one-hot编码),输出是该样本来自训练集的的概率。

2.3.2 训练过程


# Define the Binary Cross Entropy Loss criterion for the GAN
criterion = nn.BCELoss()# Set up optimizers for the discriminator and generator models
# Use Adam optimizer for updating discriminator's parameters
d_optimizer = torch.optim.Adam(discriminator.parameters(), lr=1e-4)
# Use Adam optimizer for updating generator's parameters
g_optimizer = torch.optim.Adam(generator.parameters(), lr=1e-4)


for epoch in range(num_epochs):print('Starting epoch {}...'.format(epoch), end=' ')# Iterate through the data loaderfor i, (images, labels) in enumerate(data_loader):step = epoch * len(data_loader) + i + 1real_images = Variable(images).to(device)labels = Variable(labels).to(device)generator.train()d_loss = 0# Perform multiple discriminator training stepsfor _ in range(n_critic):d_loss = discriminator_train_step(len(real_images), discriminator,generator, d_optimizer, criterion,real_images, labels,device)# Perform a single generator training stepg_loss = generator_train_step(batch_size, discriminator, generator, g_optimizer, criterion, device)# Write the losses to TensorBoardwriter.add_scalars('scalars', {'g_loss': g_loss, 'd_loss': (d_loss / n_critic)}, step)  

【深度学习实验】TensorBoard使用教程【SCALARS、IMAGES、TIME SERIES】

一、 D的loss (discriminator_train_step)

  • D的real概率的loss
  	# Train the discriminator with real imagesreal_validity = discriminator(real_images, labels)# Calculate loss on real images; discriminator's goal: classify real images as real (1)real_loss = criterion(real_validity, Variable(torch.ones(batch_size)).to(device))


  • D的fake概率的loss
	# Train the discriminator with fake imagesz = Variable(torch.randn(batch_size, 100)).to(device)fake_labels = Variable(torch.LongTensor(np.random.randint(0, 10, batch_size))).to(device)fake_images = generator(z, fake_labels)fake_validity = discriminator(fake_images, fake_labels)# Calculate loss on fake images; discriminator's goal: classify fake images as fake (0)fake_loss = criterion(fake_validity, Variable(torch.zeros(batch_size)).to(device))


  • D的total loss
    # Total discriminator loss is the sum of losses on real and fake imagesd_loss = real_loss + fake_loss# Backpropagation: Compute gradients and update discriminator's weightsd_loss.backward()


def discriminator_train_step(batch_size, discriminator, generator, d_optimizer, criterion, real_images, labels, device):# Zero out the gradients from the previous iterationd_optimizer.zero_grad()# Train the discriminator with real imagesreal_validity = discriminator(real_images, labels)# Calculate loss on real images; discriminator's goal: classify real images as real (1)real_loss = criterion(real_validity, Variable(torch.ones(batch_size)).to(device))# Train the discriminator with fake imagesz = Variable(torch.randn(batch_size, 100)).to(device)fake_labels = Variable(torch.LongTensor(np.random.randint(0, 10, batch_size))).to(device)fake_images = generator(z, fake_labels)fake_validity = discriminator(fake_images, fake_labels)# Calculate loss on fake images; discriminator's goal: classify fake images as fake (0)fake_loss = criterion(fake_validity, Variable(torch.zeros(batch_size)).to(device))# Total discriminator loss is the sum of losses on real and fake imagesd_loss = real_loss + fake_loss# Backpropagation: Compute gradients and update discriminator's weightsd_loss.backward()d_optimizer.step()# Return the discriminator's loss as a Python floatreturn d_loss.item()

二、 G的loss (generator_train_step)


def generator_train_step(batch_size, discriminator, generator, g_optimizer, criterion, device):# Zero out the gradients from the previous iterationg_optimizer.zero_grad()# Generate random noise vector zz = Variable(torch.randn(batch_size, 100)).to(device)# Generate random labels for the fake imagesfake_labels = Variable(torch.LongTensor(np.random.randint(0, 10, batch_size))).to(device)# Generate fake images using the generatorfake_images = generator(z, fake_labels)# Get the discriminator's prediction on the generated fake imagesvalidity = discriminator(fake_images, fake_labels)# Calculate the generator's loss using the discriminator's prediction# Generator's goal: Make the discriminator classify generated images as real (1)g_loss = criterion(validity, Variable(torch.ones(batch_size)).to(device))# Backpropagation: Compute gradients and update generator's weightsg_loss.backward()g_optimizer.step()# Return the generator's loss as a Python floatreturn g_loss.item()

2.4 实验分析

2.4.1 超参数调整


一、batch size

二、 epochs

三、 Adam:learning rate

四、 Adam:weight_decay

五、 n_critic

2.4.2 模型改进


一、 超参数优化

二、 逐层归一化

三、 损失函数改进

四、 激活函数选择

五、 优化器改进

六、 噪声z的分布

七、 其余设想

2.4.3 模型测试

Batch Normalization + PReLU激活函数+AdamW优化器





