在kaggle中用GPU使用CGAN生成指定mnist手写数字

文章目录

  • 1项目介绍
  • 2参考文章
  • 3代码的实现过程及对代码的详细解析
    • 独热编码
    • 定义生成器
    • 定义判别器
    • 打印我们的引导信息
    • 模型训练
    • 迭代过程中生成的图片
    • 损失函数的变化
  • 4总结
  • 5 模型相关的文件

1项目介绍

在GAN的基础上进行有条件的引导生成图片cgan

2参考文章

GAN实战之Pytorch 使用CGAN生成指定MNIST手写数字
GANs系列:CGAN(条件GAN)原理简介以及项目代码实现

3代码的实现过程及对代码的详细解析


import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)import os
for dirname, _, filenames in os.walk('/kaggle/input'):for filename in filenames:print(os.path.join(dirname, filename))
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
import torchvision
from torchvision import transforms
from torch.utils import data
import os
import glob
from PIL import Image

独热编码


# 输入x代表默认的torchvision返回的类比值,class_count类别值为10
def one_hot(x, class_count=10):return torch.eye(class_count)[x, :]  # 切片选取,第一维选取第x个,第二维全要

torch.eye(10)函数的作用是生成一个10*10的对角矩阵
该函数的作用是得到第x个位置为1的独热编码,如果传入为列表,则得到一个矩阵
在这里插入图片描述

 
transform =transforms.Compose([transforms.ToTensor(),transforms.Normalize(0.5, 0.5)])#minist数据集中的图片数据的维度是[batch_size, 1, 28, 28],其中batch_size是每个批次的图像数量。这个数据集中的每个图像都是28x28像素的灰度图像,因此它们只有一个通道
dataset = torchvision.datasets.MNIST('data',train=True,transform=transform,target_transform=one_hot,download=True)
#这里target_transform参数的作用是对标签进行转换。在这个例子中,它的作用是将标签转换为one-hot编码。
dataloader = data.DataLoader(dataset, batch_size=64, shuffle=True)

定义生成器


class Generator(nn.Module):def __init__(self):super(Generator, self).__init__()#因此,这个函数的输入张量维度为[batch_size, 10]和[batch_size, 100],输出张量维度为[batch_size, 1, 1, 1]。self.linear1 = nn.Linear(10, 128 * 7 * 7)self.bn1 = nn.BatchNorm1d(128 * 7 * 7)self.linear2 = nn.Linear(100, 128 * 7 * 7)self.bn2 = nn.BatchNorm1d(128 * 7 * 7)#这个函数的作用是将一个输入张量进行反卷积操作,得到一个输出张量。#nn.ConvTranspose2d函数的作用是将一个256通道的输入张量转换为一个128通道的输出张量,使用3x3的卷积核进行卷积操作,并在卷积操作后进行1像素的paddingself.deconv1 = nn.ConvTranspose2d(256, 128,kernel_size=(3, 3),padding=1)self.bn3 = nn.BatchNorm2d(128)self.deconv2 = nn.ConvTranspose2d(128, 64,kernel_size=(4, 4),stride=2,padding=1)self.bn4 = nn.BatchNorm2d(64)self.deconv3 = nn.ConvTranspose2d(64, 1,kernel_size=(4, 4),stride=2,padding=1)def forward(self, x1, x2):x1 = F.relu(self.linear1(x1))x1 = self.bn1(x1)x1 = x1.view(-1, 128, 7, 7)x2 = F.relu(self.linear2(x2))x2 = self.bn2(x2)x2 = x2.view(-1, 128, 7, 7)#将两个处理后的结果拼接在一起,得到形状为[64, 256, 7, 7]的张量x = torch.cat([x1, x2], axis=1)x = F.relu(self.deconv1(x))#形状变为为[64, 128, 7, 7]的张量x = self.bn3(x)x = F.relu(self.deconv2(x))#形状变为为[64, 64, 14, 14]的张量x = self.bn4(x)# 形状变为为[64, 1, 28, 28]的张量x = torch.tanh(self.deconv3(x))return x

生成器对数据的处理过程:
这个函数对于输入张量[64, 1, 28, 28]的维度变化过程如下:
输入张量维度为[64, 1, 28, 28]
经过线性变换和ReLU激活函数处理后,得到两个形状为[64, 128 * 7 * 7]的张量
将两个张量分别通过BatchNorm1d进行归一化处理
将两个处理后的结果reshape成形状为[64, 128, 7, 7]的张量
将两个处理后的结果拼接在一起,得到形状为[64, 256, 7, 7]的张量
经过反卷积操作得到输出张量,维度为[64, 1, 28, 28]

定义判别器


# input:1,28,28的图片以及长度为10的condition
class Discriminator(nn.Module):def __init__(self):super(Discriminator, self).__init__()self.linear = nn.Linear(10, 1*28*28)self.conv1 = nn.Conv2d(2, 64, kernel_size=3, stride=2)self.conv2 = nn.Conv2d(64, 128, kernel_size=3, stride=2)self.bn = nn.BatchNorm2d(128)self.fc = nn.Linear(128*6*6, 1) # 输出一个概率值def forward(self, x1, x2):#leak_relu激活函数:它在输入小于0时返回一个小的斜率,而在输入大于等于0时返回输入本身x1 =F.leaky_relu(self.linear(x1))x1 = x1.view(-1, 1, 28, 28)#torch.cat([x1, x2], axis=1)函数将张量x1和张量x2沿着第二个维度(即列)拼接起来x = torch.cat([x1, x2], axis=1)#处理过后变为(64,2,28,28)x = F.dropout2d(F.leaky_relu(self.conv1(x)))#维度变为(64,64,13,13)x = F.dropout2d(F.leaky_relu(self.conv2(x)))#维度变为(64,128,6,6)x = self.bn(x)x = x.view(-1, 128*6*6)#最后键位了64*1(同时把值映射到0~1之间)x = torch.sigmoid(self.fc(x))return x
# 初始化模型
device = 'cuda' if torch.cuda.is_available() else 'cpu'
gen = Generator().to(device)
dis = Discriminator().to(device)# 损失计算函数
loss_function = torch.nn.BCELoss()# 定义优化器
d_optim = torch.optim.Adam(dis.parameters(), lr=1e-5)
g_optim = torch.optim.Adam(gen.parameters(), lr=1e-4)
# 定义可视化函数
def generate_and_save_images(model, epoch, label_input, noise_input):#生成器生成取片,label_input为输入的引导信息,noise_input为随机的噪声点predictions = np.squeeze(model(label_input, noise_input).cpu().numpy())#numpy.squeeze()函数的作用是去掉矩阵里维度为1的维度。fig = plt.figure(figsize=(4, 4))for i in range(predictions.shape[0]):plt.subplot(4, 4, i + 1)plt.imshow((predictions[i] + 1) / 2, cmap='gray')plt.axis("off")from IPython.display import FileLinkplt.savefig('data/img/image_at_epoch_{:04d}.png'.format(epoch))plt.show()
import os 
os.makedirs("data/img")

打印我们的引导信息

noise_seed = torch.randn(16, 100, device=device)label_seed = torch.randint(0, 10, size=(16,))
label_seed_onehot = one_hot(label_seed).to(device)print(label_seed)
tensor([1, 3, 5, 4, 9, 3, 0, 0, 1, 3, 4, 5, 9, 2, 3, 7])

模型训练

D_loss = []
G_loss = []


# 训练循环
for epoch in range(150):d_epoch_loss = 0g_epoch_loss = 0count = len(dataloader.dataset)# 对全部的数据集做一次迭代#dataloader中的图像是四维的。在for循环中,每次迭代会返回一个batch_size大小的数据#其中每个数据都是一个四维张量,形状为[batch_size, channels, height, width]for step, (img, label) in enumerate(dataloader):img = img.to(device)label = label.to(device)size = img.shape[0]random_noise = torch.randn(size, 100, device=device)d_optim.zero_grad()real_output = dis(label, img)d_real_loss = loss_function(real_output,torch.ones_like(real_output, device=device))#torch.ones_like(real_output, device=device)函数的作用是生成一个与real_output形状相同的张量,其中所有元素都为1。                         d_real_loss.backward() #求解梯度# 得到判别器在生成图像上的损失gen_img = gen(label,random_noise)fake_output = dis(label, gen_img.detach())  # 判别器输入生成的图片,f_o是对生成图片的预测结果d_fake_loss = loss_function(fake_output,torch.zeros_like(fake_output, device=device))d_fake_loss.backward()d_loss = d_real_loss + d_fake_lossd_optim.step()  # 优化# 得到生成器的损失g_optim.zero_grad()fake_output = dis(label, gen_img)g_loss = loss_function(fake_output,torch.ones_like(fake_output, device=device))g_loss.backward()g_optim.step()with torch.no_grad():d_epoch_loss += d_loss.item()g_epoch_loss += g_loss.item()with torch.no_grad():d_epoch_loss /= countg_epoch_loss /= countD_loss.append(d_epoch_loss)G_loss.append(g_epoch_loss)if epoch % 10 == 0:print('Epoch:', epoch)generate_and_save_images(gen, epoch, label_seed_onehot, noise_seed)print("epoch:{}/150".format(epoch))plt.plot(D_loss, label='D_loss')
plt.plot(G_loss, label='G_loss')
plt.legend()
plt.show()

迭代过程中生成的图片

迭代1次
在这里插入图片描述
迭代10次
在这里插入图片描述
迭代20次
在这里插入图片描述

迭代30次
在这里插入图片描述

迭代40次
在这里插入图片描述

迭代150次
在这里插入图片描述

损失函数的变化

在这里插入图片描述

4总结

cGAN相比于GAN而言,将label的信息通过一系列的卷积操作和图像的信息融合在一起,然后放进模型进行训练,让我们的模型能和label相匹配的图像,从而在我们给出制定的数字label时能够生成对应的数字图片,实现了引导的过程。

5 模型相关的文件

模型的相关文件:提取码(ujki)

本模型是放在kaggle中运行的,kaggle的部署流程请参考:在kaggle中用GPU训练模型

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/87283.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

基于mysql5.7制作自定义的docker镜像,适用于xxl-job依赖的数据库,自动执行初始化脚本(ddl语句和dml语句)

一、背景 xxl-job-admin依赖mysql数据库,且需执行初始化脚本,包括ddl和dml语句。 具体的步骤总结如下: 1、新建数据库xxl_job2、创建mysql表table3、执行dml语句,包括新建admin用户及密码,创建执行器和任务。 毫无疑…

Android App的设计规范

Android App 设计规范是为开发者和设计师提供的一系列准则和建议,以确保应用在 Android 设备上的外观、交互和用户体验保持一致。以下是一些常见的 Android App 设计规范要点,希望对大家有所帮助。北京木奇移动技术有限公司,专业的软件外包开…

基于SSM+vue框架的个人博客网站源码和论文

基于SSMvue框架的个人博客网站源码和论文061 开发工具:idea 数据库mysql5.7 数据库链接工具:navcat,小海豚等 技术:ssm (设计)研究背景与意义 关于博客的未来:在创办了博客中国(blogchina)、被誉为“…

Python爬虫实战案例——第三例

文章中所有内容仅供学习交流使用,不用于其他任何目的!严禁将文中内容用于任何商业与非法用途,由此产生的一切后果与作者无关。若有侵权,请联系删除。 起点中文网月票榜加密字体处理 字体加密的原理:就是将一种特定的…

通俗理解DDPM到Stable Diffusion原理

代码1:stabel diffusion 代码库代码2:diffusers 代码库论文:High-Resolution Image Synthesis with Latent Diffusion Models模型权重:runwayml/stable-diffusion-v1-5 文章目录 1. DDPM的通俗理解1.1 DDPM的目的1.2 扩散过程1.3 …

PHPEXCEL 导出excel

$styleArray [alignment > [horizontal > Alignment::HORIZONTAL_CENTER,vertical > Alignment::VERTICAL_CENTER],];$border_style [borders > [allborders > [style > \PHPExcel_Style_Border::BORDER_THIN ,//细边框]]];$begin_date $request->beg…

接口经典题目

​ White graces:个人主页 🙉专栏推荐:《Java入门知识》🙉 🙉 内容推荐:继承与组合:代码复用的两种策略🙉 🐹今日诗词:人似秋鸿来有信,事如春梦了无痕。🐹 目录 &…

Go 第三方库引起的线上问题、如何在线线上环境进行调试定位问题以及golang开发中各种问题精华整理总结

Go 第三方库引起的线上问题、如何在线线上环境进行调试定位问题以及golang开发中各种问题精华整理总结。 01 前言 在使用 Go 语言进行 Web 开发时,我们往往会选择一些优秀的库来简化 HTTP 请求的处理。其中,go-resty 是一个被广泛使用的 HTTP 客户端。…

win10家庭版远程桌面补丁_rdp wrapper

RDP Wrapper Library 就是可以帮你在 Windows 7、Windows 8、Windows 10 家庭版中打开远程桌面的工具。 1、把电脑上打开的安全软件与杀毒软件都关掉,因为这个远程桌面补丁会修改系统文件,所以安全软件可能会拦截。 2、下载RDP Wrapper Library补丁压缩…

windows安装mysql8.0.34的压缩包

文章目录 目录 文章目录 前言 一、下载安装包zip格式 二、使用步骤 总结 前言 一、下载安装包zip格式 MySQL :: Begin Your Download 二、使用步骤 解压缩之后在解压之后的目录里创建data和my.ini my.ini内容 # 设置mysql客户端连接服务端时默认使用的端口 port3306#默认…

Wlan——锐捷零漫游网络解决方案以及相关配置

目录 零漫游介绍 一代零漫游 二代单频率零漫游 二代双频率零漫游 锐捷零漫游方案总结 锐捷零漫游方案的配置 配置无线信号的信道 开启关闭5G零漫游 查看配置 零漫游介绍 普通的漫游和零漫游的区别 普通漫游 漫游是由一个AP到另一个AP或者一个射频卡到另一个射频卡的漫…

14-redis

一 Redis概述 1 为什么要用NoSQL 单机Mysql的美好年代 在90年代,一个网站的访问量一般都不大,用单个数据库完全可以 轻松应付。在那个时候,更多的都是静态网页,动态交互类型的网站不多。 遇到问题: 随着用户数的增长…