Pytorch Advanced(二) Variational Auto-Encoder

自编码说白了就是一个特征提取器,也可以看作是一个降维器。下面找了一张很丑的图来说明自编码的过程。

自编码分为压缩和解码两个过程。从图中可以看出来,压缩过程就是将一组数据特征进行提取, 得到更深层次的特征。解码的过程就是利用之前的深层次特征再还原成为原来的数据特征。那么如何保证从压缩到解码两部分,原数据和解码数据保持一致呢?这就是要训练的过程。

如何理解降维?如果压缩的过程是卷积,维度可以根据核的个数变化,特征维度因此而改变。


import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import transforms
from torchvision.utils import save_imagedevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')sample_dir = 'samples'
if not os.path.exists(sample_dir):os.makedirs(sample_dir)
image_size = 784
h_dim = 400
z_dim = 20
num_epochs = 15
batch_size = 128
learning_rate = 1e-3dataset = torchvision.datasets.MNIST(root='../../data',train=True,transform=transforms.ToTensor(),download=True)# Data loader
data_loader = torch.utils.data.DataLoader(dataset=dataset,batch_size=batch_size, shuffle=True)

模型搭建:这里搭建的是一个变分自编码,Variational Autoencoder

那么变分自编码是为了解决什么问题呢? ——- 其主要思想还是希望学习隐层变量,并将其用来表示原始数据,但是它加另一个条件, 即隐层变量能学习原始数据的分布, 并反过来生产一些和原始数据相似的数据(这有啥用?—-可用于图片修复,让图片按训练集的数据分布变化)。

变分自编码 (Variational Autoencoder) 为了让隐层抓住输入数据特性, 而不是简单的输出数据=输入数据,他在隐层中加入随机噪声(单位高斯噪声)(这个过程也叫reparametrize),以确保隐层能较好抽象输入数据特点。

代码中怎么做的呢?

1、编码过程中我们保存了第二层线性层的输出。其中第二层包含有fc2与fc3两部分,他们是并联的。

2、给隐藏层加入随机噪声,作为解码的输入

class VAE(nn.Module):def __init__(self, image_size=784, h_dim=400, z_dim=20):super(VAE, self).__init__()self.fc1 = nn.Linear(image_size, h_dim)self.fc2 = nn.Linear(h_dim, z_dim)self.fc3 = nn.Linear(h_dim, z_dim)self.fc4 = nn.Linear(z_dim, h_dim)self.fc5 = nn.Linear(h_dim, image_size)def encode(self, x):h = F.relu(self.fc1(x))return self.fc2(h), self.fc3(h)def reparameterize(self, mu, log_var):std = torch.exp(log_var/2)eps = torch.randn_like(std)return mu + eps * stddef decode(self, z):h = F.relu(self.fc4(z))return F.sigmoid(self.fc5(h))def forward(self, x):mu, log_var = self.encode(x)z = self.reparameterize(mu, log_var)x_reconst = self.decode(z)return x_reconst, mu, log_var

训练:由于训练中加入了噪声,所以损失值的结构也因此改变。一部分来源于解码内容核原内容的相似度,另一部分是kl_div,具体是什么意义需查看论文。

model = VAE().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)# Start training
for epoch in range(num_epochs):for i, (x, _) in enumerate(data_loader):# Forward passx = x.to(device).view(-1, image_size)x_reconst, mu, log_var = model(x)# Compute reconstruction loss and kl divergence# For KL divergence, see Appendix B in VAE paper or http://yunjey47.tistory.com/43reconst_loss = F.binary_cross_entropy(x_reconst, x, size_average=False)kl_div = - 0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())# Backprop and optimizeloss = reconst_loss + kl_divoptimizer.zero_grad()loss.backward()optimizer.step()if (i+1) % 10 == 0:print ("Epoch[{}/{}], Step [{}/{}], Reconst Loss: {:.4f}, KL Div: {:.4f}" .format(epoch+1, num_epochs, i+1, len(data_loader), reconst_loss.item(), kl_div.item()))with torch.no_grad():# Save the sampled imagesz = torch.randn(batch_size, z_dim).to(device)out = model.decode(z).view(-1, 1, 28, 28)save_image(out, os.path.join(sample_dir, 'sampled-{}.png'.format(epoch+1)))# Save the reconstructed imagesout, _, _ = model(x)x_concat = torch.cat([x.view(-1, 1, 28, 28), out.view(-1, 1, 28, 28)], dim=3)save_image(x_concat, os.path.join(sample_dir, 'reconst-{}.png'.format(epoch+1)))

模型训练完成了之后该如何使用这个模型呢?

model.decode()是一个解码的过程,我们给他一个随机的中间特征z就可以输出一个数字图片了。

z = torch.randn(1,z_dim).to(device)
out = model.decode(z)
plt.imshow(out.cpu().data.numpy().reshape(28,28),cmap='gray')
plt.show()

有了随机的一张图片之后,我们把他完整的放入模型中,生成了和输入相似的一张图片,也没看出来是修复了图像......

out,_,_ = model(out) 
plt.imshow(out.cpu().data.numpy().reshape(28,28),cmap='gray')
plt.show()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/111981.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

详细介绍下路由器中的WAN口

路由器的 WAN 口(Wide Area Network port)是指用于连接广域网(WAN)的接口。它是路由器与外部网络(如互联网)之间的物理连接点,允许路由器与互联网服务提供商(ISP)或其他广…

风车时间锁管理 - 构建IPA文件加锁+签名+管理一站式解决方案

时间锁管理:是一种用于控制对某些资源、功能或操作的访问权限的机制,它通过设定时间限制来限制对特定内容、系统或功能的访问或执行,以提高安全性和控制性,时间锁管理常见于以下场景: 1. 文件或文档的保密性&#xff…

浏览器代理解决方案

当谈到网络浏览器, 浏览器 无疑是最受欢迎和广泛使用的选项之一。然而,你可能已经注意到, 浏览器并不原生支持 SOCKS5 代理协议。不过,别担心!在本文中,我将与你分享一些解决方案,让你能够在 浏…

【C++初阶】动态内存管理

​👻内容专栏: C/C编程 🐨本文概括: C/C内存分布、C语言动态内存管理、C动态内存管理、operator new与operator delete函数、new和delete的实现原理、定位new表达式、常见面试问题等。 🐼本文作者: 阿四啊 …

Django系列:Django的项目结构与配置解析

Django系列 Django的项目结构与配置解析 作者:李俊才 (jcLee95):https://blog.csdn.net/qq_28550263 邮箱 :291148484163.com 本文地址:https://blog.csdn.net/qq_28550263/article/details/132893616 【介…

简单介绍神经网络中不同优化器的数学原理及使用特性【含规律总结】

当涉及到优化器时,我们通常是在解决一个参数优化问题,也就是寻找能够使损失函数最小化的一组参数。当我们在无脑用adam时,有没有斟酌过用这个是否合适,或者说凭经验能够有目的性换用不同的优化器?是否用其他的优化器可…

【分布式】分布式事务:2PC

分布式事务的问题可以分为两部分: 并发控制 concurrency control原子提交 atomic commit 分布式事务问题的产生场景:一份数据被分片存在多台服务器上,那么每次事务处理都涉及到了多台机器。 可序列化(并发控制)&…

SQL优化--分页优化(limit)

在数据量比较大时,如果进行limit分页查询,在查询时,越往后,分页查询效率越低。 通过测试我们会看到,越往后,分页查询效率越低,这就是分页查询的问题所在。 因为,当在进行分页查询时&…

帆软BI开发-Day2-趋势图的多种变形

前言: 在BI数据展示中,条形图、趋势图无疑是使用场景非常多的两种图形。与条形图不同的是,趋势图更能反馈出一定的客观规律和未来的趋势走向,因此用于作为预警和判异的业务场景,但实际业务场景的趋势图可没你想的那么简…

视频汇聚/视频云存储/视频监控管理平台EasyCVR分发rtsp流起播慢优化步骤详解

安防视频监控/视频集中存储/云存储/磁盘阵列EasyCVR平台可拓展性强、视频能力灵活、部署轻快,可支持的主流标准协议有国标GB28181、RTSP/Onvif、RTMP等,以及支持厂家私有协议与SDK接入,包括海康Ehome、海大宇等设备的SDK等。平台既具备传统安…

基于Java的大学生心理健康答题小程序设计与实现(亮点:选题新颖、可以发布试卷设置题目、自动判卷、上传答案、答案解析)

校园点餐小程序 一、前言二、我的优势2.1 自己的网站2.2 自己的小程序(小蔡coding)2.3 有保障的售后2.4 福利 三、开发环境与技术3.1 MySQL数据库3.2 Vue前端技术3.3 Spring Boot框架3.4 微信小程序 四、功能设计4.1 主要功能描述 五、系统主要功能5.1 登…