AE——重构数字(Pytorch+mnist)

1、简介

  • AE(自编码器)由编码器和解码器组成,编码器将输入数据映射到潜在空间,解码器将潜在表示映射回原始输入空间。
  • AE的训练目标通常是最小化重构误差,即尽可能地重构输入数据,使得解码器输出与原始输入尽可能接近。
  • AE通常用于数据压缩、去噪、特征提取等任务。
  • 本文利用AE,输入数字图像。训练后,输入测试数字图像,重构生成新的数字图像。
    • 【注】本文案例需要输入才能生成输出,目标是重构,而不是生成。
  • 可以看出,重构图片和原始图片差别不大。 
  • 【注】输出的10张数字图像是输入的测试图像的第一批次。

2、代码

  • import matplotlib.pyplot as plt
    import torch
    import torch.nn as nn
    import torch.optim as optim
    import torchvision# 在一个类中编写编码器和解码器层。为编码器和解码器层的组件都定义了全连接层
    class AE(nn.Module):def __init__(self, **kwargs):super().__init__()self.encoder_hidden_layer = nn.Linear(in_features=kwargs["input_shape"], out_features=128)  # 编码器隐藏层self.encoder_output_layer = nn.Linear(in_features=128, out_features=128)  # 编码器输出层self.decoder_hidden_layer = nn.Linear(in_features=128, out_features=128)  # 解码器隐藏层self.decoder_output_layer = nn.Linear(in_features=128, out_features=kwargs["input_shape"])  # 解码器输出层# 定义了模型的前向传播过程,包括激活函数的应用和重构图像的生成def forward(self, features):activation = self.encoder_hidden_layer(features)activation = torch.relu(activation)  # ReLU 激活函数,得到编码器的激活值code = self.encoder_output_layer(activation)code = torch.sigmoid(code)  # Sigmoid 激活函数,以确保编码后的表示在 [0, 1] 范围内activation = self.decoder_hidden_layer(code)activation = torch.relu(activation)activation = self.decoder_output_layer(activation)reconstructed = torch.sigmoid(activation)return reconstructedif __name__ == '__main__':# 设置批大小、学习周期和学习率batch_size = 512epochs = 30learning_rate = 1e-3# 载入 MNIST 数据集中的图片进行训练transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor()])  # 将图像转换为张量train_dataset = torchvision.datasets.MNIST(root="~/torch_datasets", train=True, transform=transform, download=True)  # 加载 MNIST 数据集的训练集,设置路径、转换和下载为 Truetrain_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)  # 创建一个数据加载器,用于加载训练数据,设置批处理大小和是否随机打乱数据# 在使用定义的 AE 类之前,有以下事情要做:# 配置要在哪个设备上运行device = torch.device("cuda" if torch.cuda.is_available() else "cpu")# 建立 AE 模型并载入到 CPU 设备model = AE(input_shape=784).to(device)# Adam 优化器,学习率 10e-3optimizer = optim.Adam(model.parameters(), lr=learning_rate)# 使用均方误差(MSE)损失函数criterion = nn.MSELoss()# 在GPU设备上运行,实例化一个输入大小为784的AE自编码器,并用Adam作为训练优化器用MSELoss作为损失函数# 训练:for epoch in range(epochs):loss = 0for batch_features, _ in train_loader:# 将小批数据变形为 [N, 784] 矩阵,并加载到 CPU 设备batch_features = batch_features.view(-1, 784).to(device)# 梯度设置为 0,因为 torch 会累加梯度optimizer.zero_grad()# 计算重构outputs = model(batch_features)# 计算训练重建损失train_loss = criterion(outputs, batch_features)# 计算累积梯度train_loss.backward()# 根据当前梯度更新参数optimizer.step()# 将小批量训练损失加到周期损失中loss += train_loss.item()# 计算每个周期的训练损失loss = loss / len(train_loader)# 显示每个周期的训练损失print("epoch : {}/{}, recon loss = {:.8f}".format(epoch + 1, epochs, loss))# 用训练过的自编码器提取一些测试用例来重构test_dataset = torchvision.datasets.MNIST(root="~/torch_datasets", train=False, transform=transform, download=True)  # 加载 MNIST 测试数据集test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=10, shuffle=False)  # 创建一个测试数据加载器test_examples = None# 通过循环遍历测试数据加载器,获取一个批次的图像数据with torch.no_grad():  # 使用 torch.no_grad() 上下文管理器,确保在该上下文中不会进行梯度计算for batch_features in test_loader:  # 历测试数据加载器中的每个批次的图像数据batch_features = batch_features[0]  # 获取当前批次的图像数据test_examples = batch_features.view(-1, 784).to(device)  # 将当前批次的图像数据转换为大小为 (批大小, 784) 的张量,并加载到指定的设备(CPU 或 GPU)上reconstruction = model(test_examples)  # 使用训练好的自编码器模型对测试数据进行重构,即生成重构的图像break# 试着用训练过的自编码器重建一些测试图像with torch.no_grad():number = 10  # 设置要显示的图像数量plt.figure(figsize=(20, 4))  # 创建一个新的 Matplotlib 图形,设置图形大小为 (20, 4)for index in range(number):  # 遍历要显示的图像数量# 显示原始图ax = plt.subplot(2, number, index + 1)plt.imshow(test_examples[index].cpu().numpy().reshape(28, 28))plt.gray()ax.get_xaxis().set_visible(False)ax.get_yaxis().set_visible(False)# 显示重构图ax = plt.subplot(2, number, index + 1 + number)plt.imshow(reconstruction[index].cpu().numpy().reshape(28, 28))plt.gray()ax.get_xaxis().set_visible(False)ax.get_yaxis().set_visible(False)plt.savefig('reconstruction_results.png')  # 保存图像plt.show()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/586618.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

分享一个宝藏课程:近屿AIGC工程师和产品经理训练营

说起AIGC,大家都会自然地想到近两年火的一塌糊涂的ChatGPT,而开发出它的OpenAI,去年年底的年化收入已突破16亿美元,部分OpenAI的管理层认为,按目前进度,到2024年底,OpenAI的年化收入至少能达到50亿美元。而…

风险与收益

风险与收益 影响资产需求的主要因素财富总量预期收益率资产的流动性影响流动性的主要因素 风险 如何降低风险系统风险和非系统风险机会集合与有效集合资产组合理论 影响资产需求的主要因素 影响资产需求的主要因素包括:财富总量、预期收益率、资产的流动性和风险。…

移位运算与乘法

描述 题目描述: 已知d为一个8位数,请在每个时钟周期分别输出该数乘1/3/7/8,并输出一个信号通知此时刻输入的d有效(d给出的信号的上升沿表示写入有效) 信号示意图: 波形示意图: 输入描述&#…

如何同时安全高效管理多个谷歌账号?

您的业务活动需要多个 Gmail 帐户吗?出海畅游,Gmail账号是少不了的工具之一,可以关联到Twitter、Facebook、Youtube、Chatgpt等等平台,可以说是海外网络的“万能锁”。但是大家都知道,以上这些平台注册多账号如果产生关…

如何一键展示全平台信息?Python手把手教你搭建自己的自媒体展示平台

前言 灵感源于之前写过的Github中Readme.md中可以插入自己的js图片和动态api解析模块&#xff0c;在展示方面十分的美观&#xff1a; 这方面原理可以简化为&#xff0c;在Markdown中&#xff0c;你可以使用HTML标签来添加图像&#xff0c;就像这样&#xff1a; <tr><…

代码随想录Day25:回溯算法Part2

Leetcode 216. 组合总和III 讲解前&#xff1a; 这道题如果掌握了组合那道题的话就变得非常容易了&#xff0c;其实就是多加了一个参数的问题&#xff0c;我们可选的数字从可变的变成了1-9固定&#xff0c;然后呢要找的组合大小还是k&#xff0c;这次多加一个条件就是组合中的…

打断点调试代码的思路(找bug的思路)二分法

现象&#xff1a; 当断点运行到此处&#xff0c;卡死 二分法&#xff1a; 用断点把程序切段&#xff0c;前一段&#xff0c;后一段 **前一段&#xff1a;检查变量值&#xff0c;如无问题&#xff0c;则说明没有任何问题 问题必然出在后一段 后一段&#xff1a;人为检查&…

JVS智能BI数据分析:图表的数据联动配置详解

图表的数据联动 图表的数据联动是指在可视化图表中&#xff0c;当一个图表的数据发生变化时&#xff0c;另一个图表中的数据也会自动更新。这种功能通常用于展示相互关联的数据集&#xff0c;帮助用户更直观地了解数据之间的关系和趋势。 我们先看看实际的效果&#xff0c;如下…

安卓Android 架构模式及UI布局设计

文章目录 一、Android UI 简介1.1 在手机UI设计中&#xff0c;坚持的原则是什么1.2 安卓中的架构模式1.2.1 MVC (Model-View-Controller)设计模式优缺点 1.2.2 MVP(Model-View-Presenter)设计模式MVP与MVC关系&#xff1a; 1.2.3 MVVM(Model—View—ViewModel ) 设计模式1.2.4 …

<网络> 网络Socket 编程基于UDP协议模拟简易网络通信

目录 前言&#xff1a; 一、预备知识 &#xff08;一&#xff09;IP地址 &#xff08;二&#xff09;端口号 &#xff08;三&#xff09;端口号与进程PID &#xff08;四&#xff09;传输层协议 &#xff08;五&#xff09;网络字节序 二、socket 套接字 &#xff08;…

Redis (String 底层数据结构)

Redis是查询数据很快的no Sql数据库。其原因不只是因为它存储在内存中&#xff0c;还因为它存储各种类型的数据结构。比如这期说到的String类型。String类型在Redis中应用很广泛。Redis中存储数据是以键值对的方式。这个键都是String类型。所以下面我们看下String类型的底层数据…

mac+win10虚拟机+phpstudy便捷运行php+pgsql的方法

痛点&#xff1a;mac下要搭建nginxphp&#xff08;含pdo_pgsql&#xff09;pgsql比较麻烦 另类解决方法&#xff1a; 前提&#xff1a;mac下需要已安装win10虚拟机 方法&#xff1a; 1. win10虚拟机下安装phpstudy8.1 -> 开启php扩展&#xff08;pdo_pgsql&#xff09;&a…