深度学习——自编码器AutoEncoder

基本概念

概述

自编码器(Autoencoder)是一种无监督学习的神经网络模型,用于学习数据的低维表示。它由编码器(Encoder)和解码器(Decoder)两部分组成,通过将输入数据压缩到低维编码空间,再从编码空间中重构输入数据。

基本结构

自编码器的基本结构如下:
1.编码器(Encoder):接收输入数据,将其映射到低维编码空间。编码器由一系列隐藏层组成,通常逐渐减小维度以进行特征提取和数据压缩。
2.解码器(Decoder):接收编码器的输出,将编码后的数据映射回原始输入空间。解码器的结构与编码器相反,逐渐增加维度并尝试重构原始数据。
3.重构损失(Reconstruction Loss):自编码器的目标是尽可能准确地重构输入数据。因此,使用重构损失函数来衡量原始数据与重构数据之间的差异,如均方误差(MSE)或交叉熵损失。

训练过程

1.将输入数据提供给编码器,获得低维编码。
2.将编码结果传递给解码器,尝试重构输入数据。
3.计算重构损失,并通过反向传播优化网络参数,使重构误差最小化。
重复上述步骤,直到自编码器能够准确地重构输入数据。

应用

1.数据降维:自编码器可以学习数据的低维表示,有助于数据的压缩和降维。
2.特征学习:通过训练自编码器,可以学习到数据的有意义的特征表示,用于后续的监督学习任务。
3.异常检测:自编码器可以学习数据的正常分布,从而用于检测异常或异常数据的重构错误。

详细代码与注释

import torch
import torch.nn as nn
import torch.utils.data as Data
import torchvision
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from matplotlib import cm
import numpy as np# torch.manual_seed(1)    # reproducible# Hyper Parameters
EPOCH = 10
BATCH_SIZE = 64
LR = 0.005         # learning rate
DOWNLOAD_MNIST = TrueN_TEST_IMG = 5
# Mnist digits dataset
train_data = torchvision.datasets.MNIST(root='./mnist/',train=True,                                     # this is training datatransform=torchvision.transforms.ToTensor(),    # Converts a PIL.Image or numpy.ndarray to# torch.FloatTensor of shape (C x H x W) and normalize in the range [0.0, 1.0]download=DOWNLOAD_MNIST,                        # download it if you don't have it
)# plot one example
# 训练数据
print(train_data.train_data.size())     # (60000, 28, 28)
# 训练标签
print(train_data.train_labels.size())   # (60000)
plt.imshow(train_data.train_data[2].numpy(), cmap='gray')
plt.title('%i' % train_data.train_labels[2])
plt.show()# Data Loader for easy mini-batch return in training, the image batch shape will be (50, 1, 28, 28)
train_loader = Data.DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)class AutoEncoder(nn.Module):def __init__(self):super(AutoEncoder, self).__init__()# 编码器self.encoder = nn.Sequential(nn.Linear(28*28, 128),nn.Tanh(),nn.Linear(128, 64),nn.Tanh(),nn.Linear(64, 12),nn.Tanh(),nn.Linear(12, 3),   # compress to 3 features which can be visualized in plt)# 解码器self.decoder = nn.Sequential(nn.Linear(3, 12),nn.Tanh(),nn.Linear(12, 64),nn.Tanh(),nn.Linear(64, 128),nn.Tanh(),nn.Linear(128, 28*28),nn.Sigmoid(),       # compress to a range (0, 1))def forward(self, x):encoded = self.encoder(x)decoded = self.decoder(encoded)return encoded, decodedautoencoder = AutoEncoder()optimizer = torch.optim.Adam(autoencoder.parameters(), lr=LR)
loss_func = nn.MSELoss()# initialize figure
f, a = plt.subplots(2, N_TEST_IMG, figsize=(5, 2))
plt.ion()   # continuously plot# original data (first row) for viewing
view_data = train_data.train_data[:N_TEST_IMG].view(-1, 28*28).type(torch.FloatTensor)/255.
for i in range(N_TEST_IMG):a[0][i].imshow(np.reshape(view_data.data.numpy()[i], (28, 28)), cmap='gray'); a[0][i].set_xticks(()); a[0][i].set_yticks(())# 训练
for epoch in range(EPOCH):for step, (x, b_label) in enumerate(train_loader):b_x = x.view(-1, 28*28)   # batch x, shape (batch, 28*28)b_y = x.view(-1, 28*28)   # batch y, shape (batch, 28*28)encoded, decoded = autoencoder(b_x)# 比对解码出来的数据和原始数据,计算lossloss = loss_func(decoded, b_y)      # mean square erroroptimizer.zero_grad()               # clear gradients for this training steploss.backward()                     # backpropagation, compute gradientsoptimizer.step()                    # apply gradientsif step % 100 == 0:print('Epoch: ', epoch, '| train loss: %.4f' % loss.data.numpy())# plotting decoded image (second row)_, decoded_data = autoencoder(view_data)for i in range(N_TEST_IMG):a[1][i].clear()a[1][i].imshow(np.reshape(decoded_data.data.numpy()[i], (28, 28)), cmap='gray')a[1][i].set_xticks(())a[1][i].set_yticks(())plt.draw()plt.pause(0.05)plt.ioff()
plt.show()# visualize in 3D plot
view_data = train_data.train_data[:200].view(-1, 28*28).type(torch.FloatTensor)/255.
encoded_data, _ = autoencoder(view_data)
fig = plt.figure(2)
ax = Axes3D(fig)
X, Y, Z = encoded_data.data[:, 0].numpy(), encoded_data.data[:, 1].numpy(), encoded_data.data[:, 2].numpy()
values = train_data.train_labels[:200].numpy()
for x, y, z, s in zip(X, Y, Z, values):c = cm.rainbow(int(255*s/9)); ax.text(x, y, z, s, backgroundcolor=c)
ax.set_xlim(X.min(), X.max()); ax.set_ylim(Y.min(), Y.max()); ax.set_zlim(Z.min(), Z.max())
plt.show()

运行结果

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/27035.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

使用shell监控应用运行状态通过企业微信接收监控通知

目的:编写shell脚本来监控应用服务运行状态,若是应用异常则自动重启应用通过企业微信接收监控告警通知 知识要点: 使用shell脚本监控应用服务使用shell脚本自动恢复异常服务通过企业微信通知接收监控结果shell脚本使用数组知识,…

Word 常用操作总结

文章目录 【公式篇】编号右对齐自动编号多行公式对齐编号右靠下编号右居中 公式引用更新编号 【公式篇】 简述:通过“#换行”的方式使编号右对齐,通过插入题注的方式使其自动编号,通过交叉引用的方式引用公式编号。 编号右对齐自动编号 在公…

盛元广通科研院所实验室安全管理系统LIMS

实验室的管理与安全直接影响着教学与科研质量,从科研角度出发,实验室安全风险特点与生产现场安全风险特点存在较大差异,危险源种类复杂实验内容变更频繁,缺乏有效监管,实验室安全运行及管理长期游离于重点监管领域外&a…

Django实现接口自动化平台(十二)自定义函数模块DebugTalks 序列化器及视图【持续更新中】

上一章: Django实现接口自动化平台(十一)项目模块Projects序列化器及视图【持续更新中】_做测试的喵酱的博客-CSDN博客 本章是项目的一个分解,查看本章内容时,要结合整体项目代码来看: python django vue…

【Java面试丨并发编程】线程中并发安全

一、Synchronized关键字的底层原理 1. Synchronized的作用 Synchronized【对象锁】采用互斥的方式让同一时刻至多只有一个线程能持有【对象锁】,其他线程再想获取这个【对象锁】时就会阻塞住 2. Monitor Synchronized【对象锁】底层是由Monitor实现,…

计算机网络——VLan介绍

学习视频: 网工必会,十分钟搞明白,最常用的VLAN技术_哔哩哔哩_bilibili 技术总结:VLAN,网络中最常用的技术,没有之一_哔哩哔哩_bilibili 全国也没几个比我讲得好的:VLAN虚拟局域网 本来补充了…

巧妙使用 CSS 渐变来实现波浪动画

目录 一、波浪的原理 二、曲面的绘制 三、波浪动画 四、文字波浪动画 五、总结一下 参考资料 之前看到coco[1]的这样一篇文章:纯 CSS 实现波浪效果![2],非常巧妙,通过改变border-radius和不断旋转实现的波浪效果&#xff0c…

【unity细节】分不清楚__世界坐标,自身坐标,Vector3,transform和translate?

👨‍💻个人主页:元宇宙-秩沅 hallo 欢迎 点赞👍 收藏⭐ 留言📝 加关注✅! 本文由 秩沅 原创 收录于专栏:unity细节和bug ⭐世界坐标系transform和自身坐标Trasform.local和Vector3⭐ 文章目录 ⭐世界坐标…

excel常用操作备忘

excel操作: 1、快速填充多列公式:选中多列后,按ctrlD填 充。 2、快速删除空行:全选行,按ctrlG,空值项前边打上钩,点确定,针对选中的空行,鼠标右击,点删除&…

APACHE KAFKA本机Hello World教程

目标 最近想要简单了解一下Apache Kafka,故需要在本机简单打个Kafka弄一弄Hello World级别的步骤。 高手Kafka大佬们,请忽略这里的内容。 步骤 Apacha Kafka要求按照Javak8以上版本的环境。从官网下载kafka并解压。 启动 # 生产kafka集群随机ID KA…

【机密计算标准】GB/T 41388-2022 可信执行环境基础安全规范

1 范围 本文件确立了可信执行环境系统整体技术架构,描述了可信执行环境基础要求、可信虚拟化系统、可信操作系统、可信应用与服务管理、跨平台应用中间件等主要内容及其测试评价方法。 2 规范性引用文件 下列文件中的内容通过文中的规范性引用面构成本文件必不…

❤️创意网页:如何创建一个漂亮的3D正六边形

✨博主:命运之光 🌸专栏:Python星辰秘典 🐳专栏:web开发(简单好用又好看) ❤️专栏:Java经典程序设计 ☀️博主的其他文章:点击进入博主的主页 前言:欢迎踏入…