人工智能(pytorch)搭建模型23-pytorch搭建生成对抗网络(GAN):手写数字生成的项目应用

大家好,我是微学AI,今天给大家介绍一下人工智能(pytorch)搭建模型23-pytorch搭建生成对抗网络(GAN):手写数字生成的项目应用。生成对抗网络(GAN)是一种强大的生成模型,在手写数字生成方面具有广泛的应用前景。通过生成逼真的手写数字图像,GAN可以用于数据增强、图像修复、风格迁移等任务,提高模型的性能和泛化能力。生成对抗网络在手写数字生成领域具有广泛的应用前景。主要应用场景包括数据增强、图像修复、风格迁移和跨领域生成。数据增强可以通过生成逼真的手写数字图像,为训练数据集提供更多的样本,提高模型的泛化能力。

一、项目背景

随着深度学习技术的不断发展,生成模型在计算机视觉、自然语言处理等领域取得了显著的成果。生成对抗网络(GAN)作为一种新兴的生成模型,近年来备受关注。在手写数字生成方面,GAN可以生成逼真的手写数字图像,为数据增强、图像修复等任务提供有力支持。

二、生成对抗网络原理

生成对抗网络(GAN)由Goodfellow等人于2014年提出,它由两个神经网络——生成器(Generator)和判别器(Discriminator)——组成。生成器的目标是生成逼真的假样本,而判别器的目标是区分真实样本和生成器生成的假样本。在训练过程中,生成器和判别器相互竞争,不断调整参数,以达到纳什均衡。
GAN的目标是最小化以下价值函数:
min ⁡ G max ⁡ D V ( D , G ) = E x ∼ p data ( x ) [ log ⁡ D ( x ) ] + E z ∼ p z ( z ) [ log ⁡ ( 1 − D ( G ( z ) ) ) ] \min_G \max_D V(D, G) = \mathbb{E}_{x \sim p_{\text{data}}(x)}[\log D(x)] + \mathbb{E}_{z \sim p_z(z)}[\log (1 - D(G(z)))] GminDmaxV(D,G)=Expdata(x)[logD(x)]+Ezpz(z)[log(1D(G(z)))]
其中, G G G表示生成器, D D D表示判别器, x x x表示真实样本, z z z表示生成器的输入噪声, p data p_{\text{data}} pdata表示真实数据分布, p z p_z pz表示噪声分布。
在这里插入图片描述

三、生成对抗网络应用场景

生成对抗网络(GAN)在手写数字生成领域的应用具有广泛的前景。以下是几个主要的应用场景:
1.数据增强:通过生成逼真的手写数字图像,GAN可以为训练数据集提供更多的样本,提高模型的泛化能力。
2. 图像修复:GAN可以用于修复损坏或缺失的手写数字图像,提高图像的质量和可读性。
3. 风格迁移:GAN可以将一种手写风格转换为另一种风格,为个性化手写数字生成提供可能。
4. 跨领域生成:GAN可以实现不同手写数字数据集之间的转换,为多任务学习提供支持。

四、生成对抗网络实现手写数字生成

下面我将利用pytorch深度学习框架构建生成对抗网络的生成器模型Generator、判别器模型Discriminator。

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchvision.utils import save_image# 超参数设置
batch_size = 128
learning_rate = 0.0002
num_epochs = 80# 数据预处理
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,), (0.5,))
])# 下载并加载训练数据
train_data = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
train_loader = DataLoader(dataset=train_data, batch_size=batch_size, shuffle=True)# 定义生成器模型
class Generator(nn.Module):def __init__(self):super(Generator, self).__init__()self.model = nn.Sequential(nn.Linear(100, 256),nn.LeakyReLU(0.2),nn.Linear(256, 512),nn.LeakyReLU(0.2),nn.Linear(512, 1024),nn.LeakyReLU(0.2),nn.Linear(1024, 28*28),nn.Tanh())def forward(self, x):return self.model(x).view(x.size(0), 1, 28, 28)# 定义判别器模型
class Discriminator(nn.Module):def __init__(self):super(Discriminator, self).__init__()self.model = nn.Sequential(nn.Linear(28*28, 1024),nn.LeakyReLU(0.2),nn.Dropout(0.3),nn.Linear(1024, 512),nn.LeakyReLU(0.2),nn.Dropout(0.3),nn.Linear(512, 256),nn.LeakyReLU(0.2),nn.Dropout(0.3),nn.Linear(256, 1),nn.Sigmoid())def forward(self, x):x = x.view(x.size(0), -1)return self.model(x)# 初始化模型
generator = Generator()
discriminator = Discriminator()# 损失函数和优化器
criterion = nn.BCELoss()
optimizerG = optim.Adam(generator.parameters(), lr=learning_rate)
optimizerD = optim.Adam(discriminator.parameters(), lr=learning_rate)# 训练模型
for epoch in range(num_epochs):for i, (images, _) in enumerate(train_loader):# 确保标签的大小与当前批次的数据大小一致real_labels = torch.ones(images.size(0), 1)fake_labels = torch.zeros(images.size(0), 1)# 训练判别器optimizerD.zero_grad()real_outputs = discriminator(images)d_loss_real = criterion(real_outputs, real_labels)z = torch.randn(images.size(0), 100)fake_images = generator(z)fake_outputs = discriminator(fake_images.detach())d_loss_fake = criterion(fake_outputs, fake_labels)d_loss = d_loss_real + d_loss_faked_loss.backward()optimizerD.step()# 训练生成器optimizerG.zero_grad()fake_images = generator(z)fake_outputs = discriminator(fake_images)g_loss = criterion(fake_outputs, real_labels)g_loss.backward()optimizerG.step()if (i+1) % 100 == 0:print(f'Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(train_loader)}], d_loss: {d_loss.item()}, g_loss: {g_loss.item()}')# 保存生成器生成的图片save_image(fake_images.data[:25], './fake_images/fake_images-{}.png'.format(epoch+1), nrow=5, normalize=True)# 保存模型
torch.save(generator.state_dict(), 'generator.pth')
torch.save(discriminator.state_dict(), 'discriminator.pth')

最后我们打开fake_images/文件夹,可以看到生成手写图片的过程:
在这里插入图片描述

五、总结

本项目利用生成对抗网络(GAN)实现了手写数字的生成。通过训练生成器和判别器,我们成功生成了逼真的手写数字图像。这些生成的图像可以应用于数据增强、图像修复、风格迁移等领域,为手写数字识别等相关任务提供有力支持。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/440138.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Visual Studio 2022 C++ 生成dll或so文件在windows或linux下用C#调用

背景 开发中我们基本使用windows系统比较快捷,但是部署的时候我们又希望使用linux比较便宜,硬件产商还仅提供了c sdk!苦了我们做二次开发的码农。 方案 需要确认一件事,目前c这门语言不是跨平台的 第一个问题【C生成dll在window…

【Linux】—— 信号的产生

本期,我们今天要将的是信号的第二个知识,即信号的产生。 目录 (一)通过终端按键产生信号 (二)调用系统函数向进程发信号 (三)由软件条件产生信号 (四)硬件…

【MySQL】学习如何通过DQL进行数据库数据的基本查询

🌈个人主页: Aileen_0v0 🔥热门专栏: 华为鸿蒙系统学习|计算机网络|数据结构与算法 ​💫个人格言:“没有罗马,那就自己创造罗马~” #mermaid-svg-KvH5jXnPNsRtMkOC {font-family:"trebuchet ms",verdana,arial,sans-serif;font-siz…

【EI会议征稿通知】2024年第四届激光,光学和光电子技术国际学术会议(LOPET 2024)

2024年第四届激光,光学和光电子技术国际学术会议(LOPET 2024) 2024 4th International Conference on Laser, Optics and Optoelectronic Technology 2024年第四届激光,光学和光电子技术国际学术会议(LOPET 2024)将于2024年5月17日-19日在中国重庆举行。…

Mysql 删除数据

从数据表中删除数据使用DELETE语句&#xff0c;DELETE语句允许WHERE子句指定删除条件。DELETE语句基本语法格式如下&#xff1a; DELETE FROM table_name [WHERE <condition>]; table_name指定要执行删除操作的表&#xff1b;“[WHERE <condition>]”为可选参数&a…

利用OpenCV检测物流过程中的暴力拆箱和暴力拿放行为

背景介绍&#xff1a; 随着电子商务的快速发展&#xff0c;物流行业面临着越来越多的挑战。其中&#xff0c;暴力拆箱和暴力拿放成为最突出的问题之一。这些行为不仅会导致货物损坏&#xff0c;还会给物流公司和消费者带来巨大的经济损失。传统的解决方法依赖于人工监控&#x…

C语言应用实例——贪吃蛇

&#xff08;图片由AI生成&#xff09; 0.贪吃蛇游戏背景 贪吃蛇游戏&#xff0c;最早可以追溯到1976年的“Blockade”游戏&#xff0c;是电子游戏历史上的一个经典。在这款游戏中&#xff0c;玩家操作一个不断增长的蛇&#xff0c;目标是吃掉出现在屏幕上的食物&#xff0c…

什么是数据API接口,数据API有哪些应用?

​自2020年4月“数据”正式被纳入生产要素范围以来&#xff0c;已经和其它生产要素一起融入经济价值创造过程&#xff0c;近年来我国数据交易市场规模迅速增长&#xff0c;数据需求逐年扩增&#xff0c;“数据”日益成为推动数字中国建设和加快数字经济发展的重要战略资源。 作…

C# Socket 允许控制台应用通过防火墙

需求&#xff1a; 在代码中将exe添加到防火墙规则中&#xff0c;允许Socket通过 添加库引用 效果&#xff1a; 一键三联 若可用记得点赞评论收藏哦&#xff0c;你的支持就是写作的动力。 源地址: https://gist.github.com/cstrahan/513804 调用代码: private static void …

idea报错:Cannot resolve symbol ‘springframework‘

说明maven没有配置好或者加载好 解决&#xff1a; 1&#xff09;File–>Invalidate Caches… 清理缓存&#xff0c;重启idea客户端 然后我这里只进行了第一步就不报错了&#xff01;&#xff01;&#xff01; 如果你依然报错&#xff0c;就继续第二步&#xff1a; 2&…

3、css设置样式总结、节点、节点之间关系、创建元素的方式、BOM

一、css设置样式的方式总结&#xff1a; 对象.style.css属性 对象.className ‘’ 会覆盖原来的类 对象.setAttribut(‘style’,‘css样式’) 对象.setAttribute(‘class’,‘类名’) 对象.style.setProperty(css属性名,css属性值) 对象.style.cssText “css样式表” …

Linux:重定向

Linux&#xff1a;重定向 输出重定向追加重定向输出重定向与追加重定向的本质输入重定向 输出重定向 在Linux中&#xff0c;输出重定向是一种将命令的输出发送到不同位置的方法。通常&#xff0c;执行命令时&#xff0c;输出会显示在终端上。然而&#xff0c;使用输出重定向&a…