深度学习(VAE)

news/2024/11/14 19:29:01/文章来源:https://www.cnblogs.com/tiandsp/p/18453132

变分自编码器(VAE,Variational Auto-Encoder)是一种生成模型,它通过学习数据的潜在表示来生成新的样本。

在学习潜空间时,需要保持生成样本与真实数据的相似性,并尽量让潜变量的分布接近标准正态分布。

VAE的基本结构:

1. 编码器(Encoder):将输入数据转换为潜在空间的分布,输出潜在变量的均值和方差。

2. 重参数化层(Reparameterization Layer):从编码器输出的均值和方差中进行重参数化采样,生成潜在变量。

3. 解码器(Decoder):接收潜在变量并将其转换回原始数据的分布。

为了让生成样本接近原始数据,最终loss是样本与真实数据相似度和潜变量与标准高斯分布相似度之和。

生成样本和真实数据相似度可以通过mse计算。

潜变量与标准高斯分布相似度可以通过KL散度计算。

下面是两个高斯分布计算KL散度的推导:

设其中一个为标准高斯函数:

下面代码是用FashionMNIST作为数据集,生成样本的示例:

import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms,datasets
from torchvision.utils import save_imagedevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')#dataset = datasets.MNIST(root='./data',train=True,transform=transforms.ToTensor(),download=True)
dataset = datasets.FashionMNIST(root='./fasion_data',train=True,transform=transforms.ToTensor(),download=True)data_loader = torch.utils.data.DataLoader(dataset=dataset,batch_size=128, shuffle=True)class VAE(nn.Module):def __init__(self, image_size=784, h=400, z=20):super(VAE, self).__init__()self.fc1 = nn.Linear(image_size, h)self.fc2 = nn.Linear(h, z)self.fc3 = nn.Linear(h, z)self.fc4 = nn.Linear(z, h)self.fc5 = nn.Linear(h, image_size)def encode(self, x):h = F.relu(self.fc1(x))mu = self.fc2(h)log_var = self.fc3(h)return mu,log_var def reparameterize(self, mu, log_var):std = torch.exp(log_var/2)eps = torch.randn_like(std)return mu + eps * stddef decode(self, z):h = F.relu(self.fc4(z))reconst_x = F.sigmoid(self.fc5(h))return reconst_xdef forward(self, x):mu, log_var = self.encode(x)z = self.reparameterize(mu, log_var)reconst_x = self.decode(z)return reconst_x, mu, log_vardef loss_function(reconst_x, x, mu, log_var): mse = F.binary_cross_entropy(reconst_x, x, size_average=False)kld = - 0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())return mse+kldmodel = VAE().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)for epoch in range(10):for i, (x, _) in enumerate(data_loader):x = x.to(device).view(-1, 784)reconst_x, mu, log_var = model(x)loss = loss_function(reconst_x,x,mu,log_var) optimizer.zero_grad()loss.backward()optimizer.step()if (i+1) % 10 == 0:print("epoch : ",epoch, "loss:", loss.item())with torch.no_grad():out, _, _ = model(x)x_concat = torch.cat([x.view(-1, 1, 28, 28), out.view(-1, 1, 28, 28)], dim=3)save_image(x_concat, os.path.join('./', '{}.png'.format(epoch)))

 结果如下:

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/832044.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

基于 dp 凸性的优化策略(待修缮)

斜率优化 \(y=kx+b\) 形式维护队列,询问不单调则二分决策点。 Slope Trick 如果决策函数满足以下条件:连续 凸包,每一段斜率为整数 凸包上断点之间的一次函数斜率总和为 \(\mathcal O(n)\) 级别则称这个函数满足性质 \(T\),且如果 \(f,h\) 都满足性质 \(T\),则 \(f+h\) 也…

warmup_csaw_2016

题目链接:warmup_csaw_2016。 下载附件后,使用 IDA 反编译,定位到 main 函数,如下。 __int64 __fastcall main(int a1, char **a2, char **a3) {char s[64]; // [rsp+0h] [rbp-80h] BYREFchar v5[64]; // [rsp+40h] [rbp-40h] BYREFwrite(1, "-Warm Up-\n", 0xAu…

System

System 类常见的成员方法:图1System 是一个工具类, 提供了一些与系统相关的方法. public static void exit(int status) // 终止当前运行的 Java 虚拟机status 是一个状态码, 有两种情况, 第一种情况是等于 0, 表示当前虚拟机是正常停止的. 第二种情况是非零, 一般是写 1, 表示…

Java中的 Exception 和 Error 有什么区别

Java中的 Exception 和 Error 有什么区别Exception 和 Error 都是 Throwable 类的子类(在Java代码中只有继承了 Throwable 类的实例才可以被 throw 或者被 catch)它们表示在程序运行时发生的异常或错误情况。 总结来看: Exception 表示可以被处理的程序异常,Error 表示系统…

数据采集与融合技术实验课程作业三

数据采集与融合技术实验课程作业三作业所属课程 https://edu.cnblogs.com/campus/fzu/2024DataCollectionandFusiontechnology作业链接 https://edu.cnblogs.com/campus/fzu/2024DataCollectionandFusiontechnology/homework/13287gitee码云代码位置 https://gitee.com/wang-qi…

HTTPS ppt素材

本来的主题是介绍一下我之前做的搜索与推荐的业务,但9月份开始我主要开始承担一些医那块的业务测试,就想做点别的分享,但换成医的业务介绍,想了想我目前对医的了解程度,实在没勇气拿出来分享,所以就换了这个主题。 这个主题其实也是早有预谋,一个初衷是想对某一个通用性…

jvm 堆内存

堆、方法区、直接内存,多个线程之间是共享的。 ------------ 堆内存是会溢出的。 堆内存默认最大是7G

CdnCheck工具

前言:CdnCheck工具实现,记录下我这边实现的几个点 参考文章:https://github.com/projectdiscovery/cdncheck 参考文章:https://github.com/YouChenJun/CheckCdn 参考文章:https://github.com/zu1k/nali 参考文章:https://github.com/u9sky/cdn-cname-domain/blob/main/cd…

考研打卡(14)

开局(14) 开始时间 2024-11-11 20:21:43 结束时间 2024-11-11 22:00:55今天考研数学的资料到了数据结构设一组初始记录关键字序列为(50,40,95,20,15,70,60,45), 则以增量d=4的一趟希尔排序结束后前4条记录关键字为_____(中国地质大学2017年) A 40,50,20,95 B 15,40,6…

说明与笔记导航(咕咕咕)

对使用这些笔记的同学想说的话,以及更新进度。为什么写这么多B东西? 其一呢是帮助我自己,边写笔记边梳理知识;其二呢是帮助各位义父义母考试成功。 更新进度与内容说明 11.11:本周工作日需突击学习python,有限体积N-S方程推导已写完。 目前进度:3009 建模:数值方法写完…

运用Windows API进行编程

目录运用Windows API进行编程实验环境窗口创建基本流程基本代码流程1、头文件和库2、全局变量和函数3、入口主函数4、注册窗口类函数5、创建和显示窗口函数6、窗口过程函数运行结果实验小结 运用Windows API进行编程 运行Windows应用程序在桌面显示Windows窗口。窗口内背景色为…

Windows API窗口绘图程序设计

目录Windows API窗口绘图程序设计1、窗口过程函数2、WM_LBUTTONDOWN:处理鼠标左键按下的消息鼠标消息相关知识点基本鼠标消息双击消息附带信息滚轮消息附带信息:3、WM_PAINT:处理窗口重绘的消息窗口绘图相关知识点窗口绘图基本流程开始绘画绘制封闭图形(能使用画刷填充的图形…