进行生成简单数字图片

1.之前只能做一些图像预测,我有个大胆的想法,如果神经网络正向就是预测图片的类别,如果我只有一个类别那就可以进行生成图片,专业术语叫做gan对抗网络
在这里插入图片描述
2.训练代码

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.datasets as dset
import matplotlib.pyplot as plt
import os# 设置环境变量
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'# 定义生成器模型
class Generator(nn.Module):def __init__(self, input_dim=100, output_dim=784):super(Generator, self).__init__()self.fc1 = nn.Linear(input_dim, 256)self.fc2 = nn.Linear(256, 512)self.fc3 = nn.Linear(512, 1024)self.fc4 = nn.Linear(1024, output_dim)self.relu = nn.ReLU()self.tanh = nn.Tanh()def forward(self, x):x = self.relu(self.fc1(x))x = self.relu(self.fc2(x))x = self.relu(self.fc3(x))x = self.tanh(self.fc4(x))return x# 定义判别器模型
class Discriminator(nn.Module):def __init__(self, input_dim=784, output_dim=1):super(Discriminator, self).__init__()self.fc1 = nn.Linear(input_dim, 1024)self.fc2 = nn.Linear(1024, 512)self.fc3 = nn.Linear(512, 256)self.fc4 = nn.Linear(256, output_dim)self.relu = nn.ReLU()self.sigmoid = nn.Sigmoid()def forward(self, x):x = self.relu(self.fc1(x))x = self.relu(self.fc2(x))x = self.relu(self.fc3(x))x = self.sigmoid(self.fc4(x))return x# 加载 MNIST 手写数字图片数据集
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,), (0.5,))
])
dataroot = "path_to_your_mnist_dataset"  # 替换为 MNIST 数据集的路径
dataset = dset.MNIST(root=dataroot, train=True, transform=transform, download=True)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=128, shuffle=True)# 创建生成器和判别器实例
input_dim = 100
output_dim = 784
generator = Generator(input_dim, output_dim)
discriminator = Discriminator(output_dim)# 定义优化器和损失函数
lr = 0.0002
beta1 = 0.5
optimizer_g = optim.Adam(generator.parameters(), lr=lr, betas=(beta1, 0.999))
optimizer_d = optim.Adam(discriminator.parameters(), lr=lr, betas=(beta1, 0.999))
criterion = nn.BCELoss()# 训练 GAN 模型
num_epochs = 50
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("Device:", device)
generator.to(device)
discriminator.to(device)
for epoch in range(num_epochs):for i, data in enumerate(dataloader, 0):real_images, _ = datareal_images = real_images.to(device)batch_size = real_images.size(0)  # 获取批次样本数量# 训练判别器optimizer_d.zero_grad()real_labels = torch.full((batch_size, 1), 1.0, device=device)fake_labels = torch.full((batch_size, 1), 0.0, device=device)noise = torch.randn(batch_size, input_dim, device=device)fake_images = generator(noise)real_outputs = discriminator(real_images.view(batch_size, -1))fake_outputs = discriminator(fake_images.detach())d_loss_real = criterion(real_outputs, real_labels)d_loss_fake = criterion(fake_outputs, fake_labels)d_loss = d_loss_real + d_loss_faked_loss.backward()optimizer_d.step()# 训练生成器optimizer_g.zero_grad()noise = torch.randn(batch_size, input_dim, device=device)fake_images = generator(noise)fake_outputs = discriminator(fake_images)g_loss = criterion(fake_outputs, real_labels)g_loss.backward()optimizer_g.step()# 输出训练信息if i % 100 == 0:print("[Epoch %d/%d] [Batch %d/%d] [D loss: %.4f] [G loss: %.4f]"% (epoch, num_epochs, i, len(dataloader), d_loss.item(), g_loss.item()))# 保存生成器的权重和图片示例if epoch % 10 == 0:with torch.no_grad():noise = torch.randn(64, input_dim, device=device)fake_images = generator(noise).view(64, 1, 28, 28).cpu().numpy()fig, axes = plt.subplots(nrows=8, ncols=8, figsize=(12, 12), sharex=True, sharey=True)for i, ax in enumerate(axes.flatten()):ax.imshow(fake_images[i][0], cmap='gray')ax.axis('off')plt.subplots_adjust(wspace=0.05, hspace=0.05)plt.savefig("epoch_%d.png" % epoch)plt.close()torch.save(generator.state_dict(), "generator_epoch_%d.pth" % epoch)

3.测试模型的代码

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.utils import save_image# 定义生成器模型
class Generator(nn.Module):def __init__(self, input_dim, output_dim):super(Generator, self).__init__()self.fc1 = nn.Linear(input_dim, 256)self.fc2 = nn.Linear(256, 512)self.fc3 = nn.Linear(512, 1024)self.fc4 = nn.Linear(1024, output_dim)def forward(self, x):x = F.leaky_relu(self.fc1(x), 0.2)x = F.leaky_relu(self.fc2(x), 0.2)x = F.leaky_relu(self.fc3(x), 0.2)x = torch.tanh(self.fc4(x))return x# 创建生成器模型
generator = Generator(input_dim=100, output_dim=784)# 加载预训练权重
generator_weights = torch.load("generator_epoch_40.pth", map_location=torch.device('cpu'))# 将权重加载到生成器模型
generator.load_state_dict(generator_weights)# 生成随机噪声
noise = torch.randn(1, 100)# 生成图像
fake_image = generator(noise).view(1, 1, 28, 28)# 保存生成的图片
save_image(fake_image, "generated_image.png", normalize=False)

#测试结果,由于我的训练集是数字的,所以会生成各种各样的数字,下面明显的是1
在这里插入图片描述
#应该也是1
在这里插入图片描述

#再次运行,我也看不出来,不过只要我训练只有一个种类的问题就可以生成这个种类的图像
在这里插入图片描述
#搞定黑白图,那彩色图应该距离不远了,我需要改进的是把对抗网络的代码改为训练一个种类的图形,不过我感觉这种图形具有随机性,虽然通过训练我们得到了所有图像他们的规律,但是如果需要正常点的图片还是挺难的,就像是上面这张人都不一定知道他是什么东西(在没有颜色的情况下)总结就是精度不够,而且随机性太强了,现在普遍图片AI生成工具具有这个缺点(生成的物体可能会扭曲,挺阴间的),而且生成的图片速度慢,如果谁比较受益那一定是老黄(英伟达)哈哈哈
//比如下面这个图片生成视频的网站
https://app.runwayml.com/login

#每一帧看起来都没有问题,就是连起来变成视频不自然,如果有改进方法的话那可能需要引入重力/加速度/光处理 等等物理公式,来让图片更自然…
在这里插入图片描述
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/256544.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

ArcGIS Pro中怎么设置标注换行

在ArcGIS Pro中进行文字标注的时候,如果标注的字段内容太长,直接标注的话会不美观,而且还会影响旁边的标注显示,这里为大家介绍一下在ArcGIS Pro中设置文字换行的方法,希望能对你有所帮助。 数据来源 本教程所使用的…

【EI会议征稿】第三届密码学、网络安全和通信技术国际会议(CNSCT 2024)

第三届密码学、网络安全和通信技术国际会议(CNSCT 2024) 2024 3rd International Conference on Cryptography, Network Security and Communication Technology 随着互联网和网络应用的不断发展,网络安全在计算机科学中的地位越来越重要&…

【深度学习】迁移学习中的领域转移及迁移学习的分类

领域转移 根据分布移位发生的具体部分,域移位可分为三种类型,包括协变量移位、先验移位和概念移位 协变量移位: 在协变量移位的情况下,源域和目标域的边际分布是不同的,即ps(x)∕ pt(x),而给定x的y的后验分布在域之间…

ardupilot开发 --- MAVSDK 篇

一些概念 MAVSDK用于与MAVLink系统(如无人机、相机或地面系统)接口。 这些库提供了一些简单的API,用于管理一个或多个vehicles,提供对vehicles信息和遥测的程序访问,以及对任务、移动和其他操作的控制。 这些库可以在…

TCP单聊和UDP群聊

TCP协议单聊 服务端: import java.awt.BorderLayout; import java.io.BufferedReader; import java.io.IOException; import java.io.InputStreamReader; import java.io.PrintWriter; import java.net.ServerSocket; import java.net.Socket; import java.util.V…

ChatGPT可能即将发布新版本,带有debug功能:支持下载原始对话、可视化对话分支等

本文原文来自DataLearnerAI官方网站:ChatGPT内置隐藏debug功能:支持下载原始对话、可视化对话分支等 | 数据学习者官方网站(Datalearner) AIPRM的工作人员最近发现ChatGPT的客户端隐藏内置了一个新的debug特性,可以提高ChatGPT对话的问题调试…

Navicat 连接 GaussDB分布式的快速入门

Navicat Premium(16.3.3 Windows版或以上)正式支持 GaussDB 分布式数据库。GaussDB分布式模式更适合对系统可用性和数据处理能力要求较高的场景。Navicat 工具不仅提供可视化数据查看和编辑功能,还提供强大的高阶功能(如模型、结构…

AZURE==SQL managed instances

创建资源 创建DB 创建完成后,拿着刚才的账号密码依然连接不上 远程连接 需要开启公网访问和开放相关端口 参考Configure public endpoint - Azure SQL Managed Instance | Microsoft Learn 连接成功

Django模板,Django中间件,ORM操作(pymysql + SQL语句),连接池,session和cookie, 缓存

day04 django进阶-知识点 今日概要: 模板中间件ORM操作(pymysql SQL语句)session和cookie缓存(很多种方式) 内容回顾 请求周期 路由系统 最基本路由关系动态路由(含正则)路由分发不同的app中…

如何在Spring Boot中集成RabbitMQ

如何在Spring Boot中集成RabbitMQ 在现代微服务架构中,消息队列(如RabbitMQ)扮演了关键的角色,它不仅能够提供高效的消息传递机制,还能解耦服务间的通信。本文将介绍如何在Spring Boot项目中集成RabbitMQ,…

漏洞复现--Apache Ofbiz XML-RPC RCE(CVE-2023-49070)

免责声明: 文章中涉及的漏洞均已修复,敏感信息均已做打码处理,文章仅做经验分享用途,切勿当真,未授权的攻击属于非法行为!文章中敏感信息均已做多层打马处理。传播、利用本文章所提供的信息而造成的任何直…

短视频账号矩阵系统源码搭建步骤包括以下几个方面:

短视频账号矩阵系统源码搭建步骤包括以下几个方面: 1. 确定账号类型和目标受众:确定要运营的短视频账号类型,如搞笑、美食、美妆等,并明确目标受众和定位。 2. 准备账号资料:准备相关资质和资料,如营业执照…