变分自编码器(VAE,Variational Auto-Encoder)是一种生成模型,它通过学习数据的潜在表示来生成新的样本。
在学习潜空间时,需要保持生成样本与真实数据的相似性,并尽量让潜变量的分布接近标准正态分布。
VAE的基本结构:
1. 编码器(Encoder):将输入数据转换为潜在空间的分布,输出潜在变量的均值和方差。
2. 重参数化层(Reparameterization Layer):从编码器输出的均值和方差中进行重参数化采样,生成潜在变量。
3. 解码器(Decoder):接收潜在变量并将其转换回原始数据的分布。
为了让生成样本接近原始数据,最终loss是样本与真实数据相似度和潜变量与标准高斯分布相似度之和。
生成样本和真实数据相似度可以通过mse计算。
潜变量与标准高斯分布相似度可以通过KL散度计算。
下面是两个高斯分布计算KL散度的推导:
设其中一个为标准高斯函数:
下面代码是用FashionMNIST作为数据集,生成样本的示例:
import os import torch import torch.nn as nn import torch.nn.functional as F from torchvision import transforms,datasets from torchvision.utils import save_imagedevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')#dataset = datasets.MNIST(root='./data',train=True,transform=transforms.ToTensor(),download=True) dataset = datasets.FashionMNIST(root='./fasion_data',train=True,transform=transforms.ToTensor(),download=True)data_loader = torch.utils.data.DataLoader(dataset=dataset,batch_size=128, shuffle=True)class VAE(nn.Module):def __init__(self, image_size=784, h=400, z=20):super(VAE, self).__init__()self.fc1 = nn.Linear(image_size, h)self.fc2 = nn.Linear(h, z)self.fc3 = nn.Linear(h, z)self.fc4 = nn.Linear(z, h)self.fc5 = nn.Linear(h, image_size)def encode(self, x):h = F.relu(self.fc1(x))mu = self.fc2(h)log_var = self.fc3(h)return mu,log_var def reparameterize(self, mu, log_var):std = torch.exp(log_var/2)eps = torch.randn_like(std)return mu + eps * stddef decode(self, z):h = F.relu(self.fc4(z))reconst_x = F.sigmoid(self.fc5(h))return reconst_xdef forward(self, x):mu, log_var = self.encode(x)z = self.reparameterize(mu, log_var)reconst_x = self.decode(z)return reconst_x, mu, log_vardef loss_function(reconst_x, x, mu, log_var): mse = F.binary_cross_entropy(reconst_x, x, size_average=False)kld = - 0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())return mse+kldmodel = VAE().to(device) optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)for epoch in range(10):for i, (x, _) in enumerate(data_loader):x = x.to(device).view(-1, 784)reconst_x, mu, log_var = model(x)loss = loss_function(reconst_x,x,mu,log_var) optimizer.zero_grad()loss.backward()optimizer.step()if (i+1) % 10 == 0:print("epoch : ",epoch, "loss:", loss.item())with torch.no_grad():out, _, _ = model(x)x_concat = torch.cat([x.view(-1, 1, 28, 28), out.view(-1, 1, 28, 28)], dim=3)save_image(x_concat, os.path.join('./', '{}.png'.format(epoch)))
结果如下: