G8-ACGAN理论

 本文为🔗365天深度学习训练营 中的学习记录博客
 原作者:K同学啊|接辅导、项目定制

我的环境:

1.语言:python3.7

2.编译器:pycharm

3.深度学习框架Pytorch 1.8.0+cu111


 一、对比分析

前面的文章介绍了CGAN(条件生成对抗网络),本文的ACGAN,是在CGAN与SGAN基础上的扩展,通过对判别器进行改进实现了图像分类的功能。

原始GAN网络的功能比较简单:输入噪声数据,输出伪造图片。而后CGAN发现可以通过给GAN的生成器添加辅助信息(比如类别标签),来实现生成图片类别的精确控制。。

  SGAN鉴别器与原始GAN实现有很大不同。它接收3种输入:生成器生成的伪样本X*、训练数据集中无标签的真实样本X和有标签的真实样本X,y。 

  ACGAN是在CGAN基础上更近一步的改进,将判别器的功能扩展为判别真假以及类别区分,可以认为ACGAN的判别器多出一个分类的功能 。

 ACGAN的损失函数也分为了判别损失和分类损失两个部分,其中判别损失和CGAN并没有区别,形式如下:

比较新的损失函数如下:

上面的分类损失就是ACGAN的核心贡献了,对于真实图片Xreal和生成器伪造的图片Xfake,判别器(或者说判别器中的分类器)应该能够预测它所属的类别。 

二、网络结构方面(原文链接:https://blog.csdn.net/qq_35692819/article/details/106684339)

相同的是ACGAN和CGAN在生成器输入时候,噪音z都拼接了采集的labels。
不同的是,ACGAN在判别器输入时,真假数据集都没有拼接labels,labels只是用来在辅助分类器中作为target_labels。而CGAN的判别器输入,真假数据集都拼接了labels。
网络结构上,生成网络和鉴别网络的网络层不再是CGAN的全连接,而是ACGAN的深层卷积网络(这是在DCGAN开始引入的改变),卷积能够更好的提取图片的特征值,所有ACGAN生成的图片边缘更具有连续性,感觉更真实。

代码部分:
 

import argparse
import os
import numpy as npimport torchvision.transforms as transforms
from torchvision.utils import save_imagefrom torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variableimport torch.nn as nn
import torch# 创建用于存储生成图像的目录
os.makedirs("images", exist_ok=True)# 解析命令行参数
parser = argparse.ArgumentParser()
parser.add_argument("--n_epochs", type=int, default=200, help="训练的总轮数")
parser.add_argument("--batch_size", type=int, default=64, help="每个批次的大小")
parser.add_argument("--lr", type=float, default=0.0002, help="Adam优化器的学习率")
parser.add_argument("--b1", type=float, default=0.5, help="Adam优化器的一阶动量衰减")
parser.add_argument("--b2", type=float, default=0.999, help="Adam优化器的二阶动量衰减")
parser.add_argument("--n_cpu", type=int, default=8, help="用于批次生成的CPU线程数")
parser.add_argument("--latent_dim", type=int, default=100, help="潜在空间的维度")
parser.add_argument("--n_classes", type=int, default=10, help="数据集的类别数")
parser.add_argument("--img_size", type=int, default=32, help="每个图像的尺寸")
parser.add_argument("--channels", type=int, default=1, help="图像通道数")
parser.add_argument("--sample_interval", type=int, default=400, help="图像采样间隔")
opt = parser.parse_args()
print(opt)# 检查是否支持GPU加速
cuda = True if torch.cuda.is_available() else False# 初始化神经网络权重的函数
def weights_init_normal(m):classname = m.__class__.__name__if classname.find("Conv") != -1:torch.nn.init.normal_(m.weight.data, 0.0, 0.02)elif classname.find("BatchNorm2d") != -1:torch.nn.init.normal_(m.weight.data, 1.0, 0.02)torch.nn.init.constant_(m.bias.data, 0.0)# 生成器网络类
class Generator(nn.Module):def __init__(self):super(Generator, self).__init__()# 为类别标签创建嵌入层self.label_emb = nn.Embedding(opt.n_classes, opt.latent_dim)# 计算上采样前的初始大小self.init_size = opt.img_size // 4  # Initial size before upsampling# 第一层线性层self.l1 = nn.Sequential(nn.Linear(opt.latent_dim, 128 * self.init_size ** 2))# 卷积层块self.conv_blocks = nn.Sequential(nn.BatchNorm2d(128),nn.Upsample(scale_factor=2),nn.Conv2d(128, 128, 3, stride=1, padding=1),nn.BatchNorm2d(128, 0.8),nn.LeakyReLU(0.2, inplace=True),nn.Upsample(scale_factor=2),nn.Conv2d(128, 64, 3, stride=1, padding=1),nn.BatchNorm2d(64, 0.8),nn.LeakyReLU(0.2, inplace=True),nn.Conv2d(64, opt.channels, 3, stride=1, padding=1),nn.Tanh(),)def forward(self, noise, labels):# 将标签嵌入到噪声中gen_input = torch.mul(self.label_emb(labels), noise)# 通过第一层线性层out = self.l1(gen_input)# 重新整形为合适的形状out = out.view(out.shape[0], 128, self.init_size, self.init_size)# 通过卷积层块生成图像img = self.conv_blocks(out)return img# 判别器网络类
class Discriminator(nn.Module):def __init__(self):super(Discriminator, self).__init__()# 定义判别器块的函数def discriminator_block(in_filters, out_filters, bn=True):"""返回每个判别器块的层"""block = [nn.Conv2d(in_filters, out_filters, 3, 2, 1), nn.LeakyReLU(0.2, inplace=True), nn.Dropout2d(0.25)]if bn:block.append(nn.BatchNorm2d(out_filters, 0.8))return block# 判别器的卷积层块self.conv_blocks = nn.Sequential(*discriminator_block(opt.channels, 16, bn=False),*discriminator_block(16, 32),*discriminator_block(32, 64),*discriminator_block(64, 128),)# 下采样后图像的高度和宽度ds_size = opt.img_size // 2 ** 4# 输出层self.adv_layer = nn.Sequential(nn.Linear(128 * ds_size ** 2, 1), nn.Sigmoid())self.aux_layer = nn.Sequential(nn.Linear(128 * ds_size ** 2, opt.n_classes), nn.Softmax())def forward(self, img):out = self.conv_blocks(img)out = out.view(out.shape[0], -1)validity = self.adv_layer(out)label = self.aux_layer(out)return validity, label# 损失函数
adversarial_loss = torch.nn.BCELoss()
auxiliary_loss = torch.nn.CrossEntropyLoss()# 初始化生成器和判别器
generator = Generator()
discriminator = Discriminator()if cuda:generator.cuda()discriminator.cuda()adversarial_loss.cuda()auxiliary_loss.cuda()# 初始化权重
generator.apply(weights_init_normal)
discriminator.apply(weights_init_normal)# 配置数据加载器
os.makedirs("../../data/mnist", exist_ok=True)
dataloader = torch.utils.data.DataLoader(datasets.MNIST("../../data/mnist",train=True,download=True,transform=transforms.Compose([transforms.Resize(opt.img_size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]),),batch_size=opt.batch_size,shuffle=True,
)# 优化器
optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))FloatTensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor
LongTensor = torch.cuda.LongTensor if cuda else torch.LongTensor# 保存生成图像的函数
def sample_image(n_row, batches_done):"""保存从0到n_classes的生成数字的图像网格"""# 采样噪声z = Variable(FloatTensor(np.random.normal(0, 1, (n_row ** 2, opt.latent_dim))))# 为n行生成标签从0到n_classeslabels = np.array([num for _ in range(n_row) for num in range(n_row)])labels = Variable(LongTensor(labels))gen_imgs = generator(z, labels)save_image(gen_imgs.data, "images/%d.png" % batches_done, nrow=n_row, normalize=True)# ----------
# 训练
# ----------for epoch in range(opt.n_epochs):for i, (imgs, labels) in enumerate(dataloader):batch_size = imgs.shape[0]# 真实数据的标签valid = Variable(FloatTensor(batch_size, 1).fill_(1.0), requires_grad=False)# 生成数据的标签fake = Variable(FloatTensor(batch_size, 1).fill_(0.0), requires_grad=False)# 配置输入real_imgs = Variable(imgs.type(FloatTensor))labels = Variable(labels.type(LongTensor))# -----------------# 训练生成器# -----------------optimizer_G.zero_grad()# 采样噪声和标签作为生成器的输入z = Variable(FloatTensor(np.random.normal(0, 1, (batch_size, opt.latent_dim))))gen_labels = Variable(LongTensor(np.random.randint(0, opt.n_classes, batch_size)))# 生成一批图像gen_imgs = generator(z, gen_labels)# 损失度量生成器的欺骗判别器的能力validity, pred_label = discriminator(gen_imgs)g_loss = 0.5 * (adversarial_loss(validity, valid) + auxiliary_loss(pred_label, gen_labels))g_loss.backward()optimizer_G.step()# ---------------------# 训练判别器# ---------------------optimizer_D.zero_grad()# 真实图像的损失real_pred, real_aux = discriminator(real_imgs)d_real_loss = (adversarial_loss(real_pred, valid) + auxiliary_loss(real_aux, labels)) / 2# 生成图像的损失fake_pred, fake_aux = discriminator(gen_imgs.detach())d_fake_loss = (adversarial_loss(fake_pred, fake) + auxiliary_loss(fake_aux, gen_labels)) / 2# 判别器的总损失d_loss = (d_real_loss + d_fake_loss) / 2# 计算判别器的准确率pred = np.concatenate([real_aux.data.cpu().numpy(), fake_aux.data.cpu().numpy()], axis=0)gt = np.concatenate([labels.data.cpu().numpy(), gen_labels.data.cpu().numpy()], axis=0)d_acc = np.mean(np.argmax(pred, axis=1) == gt)d_loss.backward()optimizer_D.step()print("[Epoch %d/%d] [Batch %d/%d] [D loss: %f, acc: %d%%] [G loss: %f]"% (epoch, opt.n_epochs, i, len(dataloader), d_loss.item(), 100 * d_acc, g_loss.item()))batches_done = epoch * len(dataloader) + iif batches_done % opt.sample_interval == 0:sample_image(n_row=10, batches_done=batches_done)

判别器

  1. def discriminator_block(in_filters, out_filters, bn=True):: 这是一个内部函数,用于定义判别器的卷积块。它接受输入的通道数 in_filters 和输出的通道数 out_filters,并返回一个卷积块的列表。

  2. self.conv_blocks = nn.Sequential(...):定义了判别器的卷积层块,它使用了 nn.Sequential 来组合多个卷积块。通过调用 discriminator_block 函数定义了四个卷积块,每个卷积块由一个卷积层、一个 LeakyReLU 激活函数和一个 Dropout2d 层组成。

  3. ds_size = opt.img_size // 2 ** 4:计算下采样后图像的高度和宽度。在这段代码中,每个卷积块都将输入图像的尺寸减半,共执行了 4 次这样的操作。

  4. self.adv_layer = nn.Sequential(...):定义了判别器的输出层。adv_layer 是用于判断图像真假的部分,它是一个全连接层,将卷积层块输出的特征展平后输入到一个 Sigmoid 激活函数中,以输出一个范围在 0 到 1 之间的值,表示图像的真实度。

  5. self.aux_layer = nn.Sequential(...):定义了判别器的辅助输出层。aux_layer 是用于对图像进行分类的部分,它也是一个全连接层,将卷积层块输出的特征展平后输入到一个 Softmax 激活函数中,以输出类别概率分布,其中 opt.n_classes 是类别的数量。

  6. def forward(self, img)::定义了前向传播函数。接收一个输入图像 img,将其输入到卷积层块中进行特征提取,然后将特征展平后分别输入到判别器的输出层 adv_layeraux_layer 中,得到判别器的输出:真假判别结果 validity 和图像类别预测结果 label

生成器 

  1. self.label_emb = nn.Embedding(opt.n_classes, opt.latent_dim): 创建了一个嵌入层 label_emb,用于将类别标签转换为一个与噪声相同维度的向量。这里假设 opt.n_classes 是类别的数量,opt.latent_dim 是噪声的维度。

  2. self.init_size = opt.img_size // 4: 计算了上采样前的初始大小。在这段代码中,初始大小是图像大小的 1/4。

  3. self.l1 = nn.Sequential(nn.Linear(opt.latent_dim, 128 * self.init_size ** 2)): 定义了一个线性层 l1,将噪声输入映射到一个特定大小的张量,以供后续卷积层块使用。

  4. self.conv_blocks = nn.Sequential(...):定义了生成器的卷积层块。通过 nn.Sequential 组合了多个层,包括批归一化层、上采样层、卷积层、LeakyReLU 激活函数和 Tanh 激活函数。这些层组合在一起,用于从输入的特征张量生成图像。

  5. def forward(self, noise, labels):: 定义了前向传播函数。接收噪声 noise 和类别标签 labels 作为输入,并经过一系列操作生成图像。首先,通过将标签嵌入到噪声中,将标签信息融合到生成的噪声中。然后,将融合后的输入通过线性层 l1,将其映射到适当的大小。接着,将线性层输出重塑为合适的形状,以适应后续的卷积层块。最后,通过卷积层块生成图像,并将生成的图像作为输出返回。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/502732.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

(定时器/计数器)中断系统(详解与使用)

讲解 简介 定时器/计数器 定时器实际上也是计数器,只是计数的是固定周期的脉冲 定时和计数只是触发来源不同(时钟信号和外部脉冲)其他方面是一样的。 定时器在单片机内部就像一个小闹钟一样,根据时钟的输出信号,每隔“一秒”,计数单元的数值就增加一,当计数单元数值…

多输入多输出 | Matlab实现RIME-BP霜冰算法优化BP神经网络多输入多输出预测

多输入多输出 | Matlab实现RIME-BP霜冰算法优化BP神经网络多输入多输出预测 目录 多输入多输出 | Matlab实现RIME-BP霜冰算法优化BP神经网络多输入多输出预测预测效果基本介绍程序设计往期精彩参考资料 预测效果 基本介绍 多输入多输出 | Matlab实现RIME-BP霜冰算法优化BP神经网…

Java SPI:Service Provider Interface

SPI机制简介 SPI(Service Provider Interface),是从JDK6开始引入的,一种基于ClassLoader来发现并加载服务的机制。 一个标准的SPI,由3个组件构成,分别是: Service:是一个公开的接口…

线性规划在多种问题形式下的应用

线性规划的用处非常的广泛,这主要是因为很多类型的问题是可以通过转化的方式转化为线性规划的问题。例如需要再图论中寻找起始点到给定的点的最短路径问题: 添加图片注释,不超过 140 字(可选) 假设要计算从节点0到节点…

Java毕业设计 基于SpringBoot jsp 交友系统

Java毕业设计 基于SpringBoot jsp 交友系统 SpringBoot jsp 交友系统 功能介绍 登录 验证码 注册 忘记密码 完善资料 后台首页 个人信息 修改密码 寻找好友 随机 添加好友 我的好友 留言板 申请列表 我的申请 我要充值 充值消费记录 管理员用户功能 登录 验证码 忘记密码 完…

如何在群晖Docker运行本地聊天机器人并结合内网穿透发布到公网访问

文章目录 1. 拉取相关的Docker镜像2. 运行Ollama 镜像3. 运行Chatbot Ollama镜像4. 本地访问5. 群晖安装Cpolar6. 配置公网地址7. 公网访问8. 固定公网地址 随着ChatGPT 和open Sora 的热度剧增,大语言模型时代,开启了AI新篇章,大语言模型的应用非常广泛,包括聊天机…

【改进算法】【IHAOAVOA】天鹰优化算法和非洲秃鹫混合优化算法

目录 1 主要内容 IHAOAVOA流程图 主要创新点 2 部分代码 3 程序结果 4 下载链接 1 主要内容 该程序复现《IHAOAVOA: An improved hybrid aquila optimizer and African vultures optimization algorithm for global optimization problems》,天鹰优化算法&am…

c++函数指针 回调函数

目录 函数指针 ​编辑 实例 函数指针作为某个函数的参数 实例 std::function轻松实现回调函数 绑定一个函数 作为回调函数 作为函数入参 函数指针 函数指针是指向函数的指针变量。 通常我们说的指针变量是指向一个整型、字符型或数组等变量,而函数指针是指向…

第六节:Vben Admin权限-后端控制方式

系列文章目录 第一节:Vben Admin介绍和初次运行 第二节:Vben Admin 登录逻辑梳理和对接后端准备 第三节:Vben Admin登录对接后端login接口 第四节:Vben Admin登录对接后端getUserInfo接口 第五节:Vben Admin权限-前端控制方式 文章目录 系列文章目录前言一、角色权限(后端…

网安播报|开源Xeno RAT特洛伊木马在GitHub上成为潜在威胁

1、开源Xeno RAT特洛伊木马在GitHub上成为潜在威胁 一种“设计复杂”的远程访问特洛伊木马(RAT),称为Xeno RAT已在GitHub上提供,使其他参与者可以轻松访问,无需额外费用。开源RAT是用C#编写的,与Windows 10…

react-router 源码之matchPath方法

1. 基础依赖path-to-regexp react-router提供了专门的路由匹配方法matchPath(位于packages/react-router/modules/matchPath.js),该方法背后依赖的其实是path-to-regexp包。 path-to-regexp输入是路径字符串(也就是Route中定义的path的值)&…

Ansible的playbook的编写和解析

目录 什么是playbook Ansible 的脚本 --- playbook 剧本 实例部署(使用playbook安装启动httpd服务) 1.编写一个.yaml文件 在主机下载安装http,将配置文件复制到opt目录下 运行playbook 在192.168.17.77主机上查看httpd服务是否成功开启…