一种简单的自编码器PyTorch代码实现

1. 引言

对于许多新接触深度学习爱好者来说,玩AutoEncoder总是很有趣的,因为它具有简单的处理逻辑、简易的网络架构,方便可视化潜在的特征空间。在本文中,我将从头开始介绍一个简单的AutoEncoder模型,以及一些可视化潜在特征空间的一些的方法,以便使本文变得生动有趣。

闲话少说,我们直接开始吧!

2. 数据集介绍

在本文中,我们使用FashionMNIST数据集来完成此任务。
在这里插入图片描述

以下是Kaggle上数据集的链接:戳我。
该数据集已在torchvision库中集成;我们可以通过几行代码直接导入和处理该数据集。

为此,首先需要是编写一个collate_fn函数,将数据集从PIL图像转换为torch张量,并进行相应的pad操作:

# This function convert the PIL images to tensors then pad them
def collate_fn(batch):process = transforms.Compose([transforms.ToTensor(),transforms.Pad([2])])# x - images; we process each image in the batchx = [process(data[0]) for data in batch]x = torch.concat(x).unsqueeze(1)# y - labels, note that we should convert the labels to LongTensory = torch.LongTensor([data[1] for data in batch])return x, y

3. 实现DataLoader

接着,我们就可以使用以下代码来完成相应的DataLoader的实现:

labels = ["T-shirt/top", "Trouser", "Pullover", "Dress","Coat", "Sandla", "Shirt", "Sneaker", "Bag", "Ankle boot"]# download/load dataset
train_data = FashionMNIST("./MNIST_DATA", train=True, download=True)
valid_data = FashionMNIST("./MNIST_DATA", train=False, download=True)# put datasets into dataloaders
train_loader = DataLoader(train_data, batch_size=config["batch_size"], shuffle=True, collate_fn=collate_fn)
valid_loader = DataLoader(valid_data, batch_size=config["batch_size"], shuffle=False, collate_fn=collate_fn)

接着我们可以使用以下代码来检验上述代码是否符合我们的预期,测试代码如下:

print("Inspecting train data: ")
for _, data in enumerate(train_loader):print("Batch shape: ", data[0].shape)fig, ax = plt.subplots(1, 4, figsize=(10, 4))for i in range(4):# Ture 3D tensor to 2D tensor due to image's single channelax[i].imshow(data[0][i].squeeze(), cmap="gray")ax[i].axis("off")ax[i].set_title(labels[data[1][i]])plt.show()# And don't forget to breakbreak

运行结果如下:
在这里插入图片描述
观察上图,图像和标签一一对应关系正常,接着我们就可以进入我们的网络设计部分。

4. 实现encoder

我们知道自编码器是由编码器encoder和解码器decoder实现的,其中编码器的作用为将输入的图像编码为特征空间的特征向量,解码器的作用相反,尽可能的将上述特征向量结果恢复为原图。基于此,我们首先来一步步实现编码器。首先,我们来定义模型的基本超参数如下:

# Model parameters:
LAYERS = 3
KERNELS = [3, 3, 3]
CHANNELS = [32, 64, 128]
STRIDES = [2, 2, 2]
LINEAR_DIM = 2048

同时相应的编码器的网络结构设计如下:

class Encoder(nn.Module):def __init__(self, output_dim=2, use_batchnorm=False, use_dropout=False):super(Encoder, self).__init__()# bottleneck dimentionalityself.output_dim = output_dim# variables deciding if using dropout and batchnorm in modelself.use_dropout = use_dropoutself.use_batchnorm = use_batchnorm# convolutional layer hyper parametersself.layers = LAYERSself.kernels = KERNELSself.channels = CHANNELSself.strides = STRIDESself.conv = self.get_convs()# layers for latent space projectionself.fc_dim = LINEAR_DIMself.flatten = nn.Flatten()self.linear = nn.Linear(self.fc_dim, self.output_dim)def get_convs(self):"""generating convolutional layers based on model's hyper parameters"""conv_layers = nn.Sequential()for i in range(self.layers):# The input channel of the first layer is 1if i == 0: conv_layers.append(nn.Conv2d(1, self.channels[i], kernel_size=self.kernels[i],stride=self.strides[i],padding=1))else: conv_layers.append(nn.Conv2d(self.channels[i-1], self.channels[i],kernel_size=self.kernels[i],stride=self.strides[i],padding=1))if self.use_batchnorm:conv_layers.append(nn.BatchNorm2d(self.channels[i]))# Here we use GELU as activation functionconv_layers.append(nn.GELU()) if self.use_dropout:conv_layers.append(nn.Dropout2d(0.15))return conv_layersdef forward(self, x):x = self.conv(x)x = self.flatten(x)return self.linear(x)

在Pytorch中torchsummary是一个非常方便的工具,用于检查和调试模型的网络结构;我们可以检查层、每层中的张量形状以及模型的参数。代码如下:

from torchsummary import summary
# Get the summary of autoencoder architecture
encoder = Encoder(use_batchnorm=True, use_dropout=True).to(DEVICE)
summary(encoder, (1, 32, 32))
pass

得到输出如下:
在这里插入图片描述

5. 实现decoder

在我们的例子中,解码器层decoder是编码器的反向操作;确保每一层的输入和输出形状是很重要的。此外,我们应该调整转置卷积层中的paddingoutput_pading参数,以确保输出图像和输入图像的维度相同。代码实现如下:

class Decoder(nn.Module):def __init__(self, input_dim=2, use_batchnorm=False, use_dropout=False):super(Decoder, self).__init__()# variables deciding if using dropout and batchnorm in modelself.use_dropout = use_dropoutself.use_batchnorm = use_batchnormself.fc_dim = LINEAR_DIMself.input_dim = input_dim# Conv layer hypyer parametersself.layers = LAYERSself.kernels = KERNELSself.channels = CHANNELS[::-1] # flip the channel dimensionsself.strides = STRIDES# In decoder, we first do fc project, then conv layersself.linear = nn.Linear(self.input_dim, self.fc_dim)self.conv =  self.get_convs()self.output = nn.Conv2d(self.channels[-1], 1, kernel_size=1, stride=1)def get_convs(self):conv_layers = nn.Sequential()for i in range(self.layers):if i == 0: conv_layers.append(nn.ConvTranspose2d(self.channels[i],self.channels[i],kernel_size=self.kernels[i],stride=self.strides[i],padding=1,output_padding=1))else: conv_layers.append(nn.ConvTranspose2d(self.channels[i-1], self.channels[i],kernel_size=self.kernels[i],stride=self.strides[i],padding=1,output_padding=1))if self.use_batchnorm and i != self.layers - 1:conv_layers.append(nn.BatchNorm2d(self.channels[i]))conv_layers.append(nn.GELU())if self.use_dropout:conv_layers.append(nn.Dropout2d(0.15))return conv_layersdef forward(self, x):x = self.linear(x)# reshape 3D tensor to 4D tensorx = x.reshape(x.shape[0], 128, 4, 4)x = self.conv(x)return self.output(x)

相应的解码器实现如下:

decoder = Decoder(use_batchnorm=True, use_dropout=True).to(DEVICE)
summary(decoder, (1, 2))
pass

运行后,结果如下:
在这里插入图片描述

6. 实现自编码器

接着,我们将上述编码器和解码器串联起来,代码实现如下:

class AutoEncoder(nn.Module):def __init__(self):super(AutoEncoder, self).__init__()self.encoder = Encoder(output_dim=2, use_batchnorm=True, use_dropout=False)self.decoder = Decoder(input_dim=2, use_batchnorm=True, use_dropout=False)def forward(self, x):return self.decoder(self.encoder(x))model = AutoEncoder().to(DEVICE)
summary(model, (1, 32, 32))
pass

得到结果如下:
在这里插入图片描述

7. 可视化函数

在进入训练部分之前,让我们花一些时间编写一个函数来可视化我们模型的潜在特征空间,即编码后二维特征向量的可视化表示。

def plotting(step:int=0, show=False):model.eval() # Switch the model to evaluation modepoints = []label_idcs = []path = "./ScatterPlots"if not os.path.exists(path): os.mkdir(path)for i, data in enumerate(valid_loader):img, label = [d.to(DEVICE) for d in data]# We only need to encode the validation imagesproj = model.encoder(img)points.extend(proj.detach().cpu().numpy())label_idcs.extend(label.detach().cpu().numpy())del img, labelpoints = np.array(points)# Creating a scatter plotfig, ax = plt.subplots(figsize=(10, 10) if not show else (8, 8))scatter = ax.scatter(x=points[:, 0], y=points[:, 1], s=2.0, c=label_idcs, cmap='tab10', alpha=0.9, zorder=2)ax.spines["right"].set_visible(False)ax.spines["top"].set_visible(False)if show: ax.grid(True, color="lightgray", alpha=1.0, zorder=0)plt.show()else: # Do not show but only save the plot in trainingplt.savefig(f"{path}/Step_{step:03d}.png", bbox_inches="tight")plt.close() # don't forget to close the plot, or it is always in memorymodel.train()

以下是训练过程中生成的图;该过程显示了模型的潜在空间随时间的分布,可以看出尽管有个别离群点,整体不同类别的数据在特征空间呈现出聚类趋势:
在这里插入图片描述

8. 损失函数

在编写训练和验证函数之前,还有一个步骤是定义目标函数和优化方法。由于自动编码器是一个自监督模型,输入也是网络输出重建图像逼近的对象,因此我们可以使用MSE(均方误差)损失来评估输入和重建图像之间的逐像素损失。当然有很多优化器可供选择,这里我选择的是AdamW,因为我在过去几个月里经常使用它。

criterion = nn.MSELoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=config["lr"], weight_decay=1e-5)# For mixed precision training
scaler = torch.cuda.amp.GradScaler()
steps = 0 # tracking the training steps

9. 训练函数

接着我们来定义训练一个epoch的函数,代码实现如下:

def train(model, dataloader, criterion, optimizer, save_distrib=False):# steps is used to track training progress, purely for latent space plotsglobal steps model.train()train_loss = 0.0# Process tqdm bar, helpful for monitoring training processbatch_bar = tqdm(total=len(dataloader), dynamic_ncols=True, leave=False, position=0, desc="Train")for i, batch in enumerate(dataloader):optimizer.zero_grad()x = batch[0].to(DEVICE)# Here we implement the mixed precision trainingwith torch.cuda.amp.autocast():y_recons = model(x)loss = criterion(y_recons, x)train_loss += loss.item()scaler.scale(loss).backward()scaler.step(optimizer)scaler.update()batch_bar.set_postfix(loss=f"{train_loss/(i+1):.4f}",lr = f"{optimizer.param_groups[0]['lr']:.4f}")batch_bar.update()        # Saving latent space plotsif steps % 10 == 0 and save_distrib and steps <= 400: plotting(steps)steps += 1        # remove unnecessary cache in CUDA memorytorch.cuda.empty_cache()del x, y_reconsbatch_bar.close()train_loss /= len(dataloader)return train_loss

10 验证函数

相应的验证函数的实现稍微简单一点,代码如下:

def validate(model, dataloader, criterion):model.eval() # Don't forget to turn the model to eval modevalid_loss = 0.0# Progress tqdm barbatch_bar = tqdm(total=len(dataloader), dynamic_ncols=True,leave=False, position=0, desc="Validation")for i, batch in enumerate(dataloader):x = batch[0].to(DEVICE)with torch.no_grad(): # we don't need gradients in validationy_recons = model(x)loss = criterion(y_recons, x)valid_loss += loss.item()batch_bar.set_postfix(loss=f"{valid_loss/(i+1):.4f}",lr = f"{optimizer.param_groups[0]['lr']:.4f}")batch_bar.update()torch.cuda.empty_cache()del x, y_reconsbatch_bar.close()valid_loss /= len(dataloader)return valid_loss

11 训练过程

接着,我们将上述代码串起来,来实现我们模型的训练,由于FashionMNIST是一个很小的数据集,我们实际上不需要大量训练;初始训练和验证损失非常低,并且在三个epoch之后没有太大的改进空间。

for i in range(config["epochs"]):curr_lr = float(optimizer.param_groups[0]["lr"])train_loss = train(model, train_loader, criterion, optimizer, save_distrib=True)valid_loss = validate(model, valid_loader, criterion)print(f"Epoch {i+1}/{config['epochs']}\nTrain loss: {train_loss:.4f}\t Validation loss: {valid_loss:.4f}\tlr: {curr_lr:.4f}")

输出如下:
在这里插入图片描述

12 结果可视化

我们现在可以再次绘制和检查收敛后的特征空间,可视化输出如下:
在这里插入图片描述
观察上图可知,相应的聚类后的效果比训练过程中的要好,但有些个别类混合在同一集群中。这个问题可以通过增加编码器输出的特征向量的维度或使用其他损失函数函数来解决。

13 预测效果可视化

为了验证我们的解码器确实学到了东西,我们可以在随机绘制一些离散点来观察解码器重建图像的效果,代码如下:

# randomly sample x and y values
xs = [random.uniform(-6.0, 8.0) for i in range(8)]
ys = [random.uniform(-7.5, 10.0) for i in range(8)]points = list(zip(xs, ys))
coords = torch.tensor(points).unsqueeze(1).to(DEVICE)
nrows, ncols = 2, 4
fig, axes = plt.subplots(nrows, ncols, figsize=(10, 5))
model.eval()
with torch.no_grad():generates = [model.decoder(coord) for coord in coords]
# plot points
idx = 0
for row in range(0, nrows):for col in range(0, ncols):ax = axes[row, col]im = generates[idx].squeeze().detach().cpu()ax.imshow(im, cmap="gray")ax.axis("off")coord = coords[idx].detach().cpu().numpy()[0]ax.set_title(f"({coord[0]:.3f}, {coord[1]:.3f})")idx += 1plt.show()

代码输出如下:
在这里插入图片描述

14. 总结

本文重点介绍了如何利用Pytorch来实现自编码器,从数据集,到搭建网络结构,以及特征可视化和网络预测输出几个方面,分别进行了详细的阐述,并给出了相应的代码示例。

您学废了吗?

完整代码链接:戳我

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/295695.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Hadoop入门学习笔记——七、Hive语法

视频课程地址&#xff1a;https://www.bilibili.com/video/BV1WY4y197g7 课程资料链接&#xff1a;https://pan.baidu.com/s/15KpnWeKpvExpKmOC8xjmtQ?pwd5ay8 Hadoop入门学习笔记&#xff08;汇总&#xff09; 目录 七、Hive语法7.1. 数据库相关操作7.1.1. 创建数据库7.1.2…

w15初识php基础

一、计算100之内的偶数之和 实现思路 所有的偶数除2都为0 代码实现 <?php # 记录100以内的偶数和 $number1; $num0; while($number<100){if($number%20){ $num$number;}$number1; } echo $num; ?>输出的结果 二、计算100之内的奇数之和 实现思路 所有的奇数除…

【hcie-cloud】【9】华为云Stack_Deploy部署工具介绍

文章目录 前言华为云Stack Deploy简介华为云Stack Deploy工具简介华为云Stack Deploy工具部署范围华为云Stack Deploy工具节点网络要求华为云Stack Deploy工具部署流程 华为云Stack Deploy功能介绍部署工具工程场景部署流程介绍创建工程 - 基本信息填写创建工程 - 基本参数选择…

检漏继电器JJB-660/380B 柜内安装,可选新款数显型可选面板安装

JY82A检漏继电器 JY82B检漏继电器 JY82-380/660检漏继电器 JY82-IV检漏继电器 JY82-2P检漏继电器 JY82-2/3检漏继电器 JJKY检漏继电器 JD型检漏继电器 JY82-IV;JY82J JY82-II;JY82-III JY82-1P;JY82-2PA;JY82-2PB JJB-380;JJB-380/660 JD-127V/127V;JD-380V/380V; …

[Angular] 笔记 6:ngStyle

ngStyle 指令: 用于更新 HTML 元素的样式。设置一个或多个样式属性&#xff0c;用以冒号分隔的键值对指定。键是样式名称&#xff0c;带有可选的 .<unit> 后缀&#xff08;如 ‘top.px’、‘font-style.em’&#xff09;&#xff0c;值为待求值的表达式&#xff0c;得到…

从零构建tomcat环境

一、官网构建 1.1 下载 一般来说对于开源软件都有自己的官方网站&#xff0c;并且会附上使用文档以及一些特性和二次构建的方法&#xff0c;那么我们首先的话需要从官网或者tomcat上下载到我们需要的源码包。下载地址&#xff1a;官网、Github。 这里需要声明一下&#xff…

arm和x86架构服务器拉取arm64架构的docker镜像

dockerhub提供的镜像部分支持arm64架构 Docker arm架构服务器拉取docker镜像&#xff0c;默认是arm架构 # docker pull centos Using default tag: latest latest: Pulling from library/centos 52f9ef134af7: Pull complete Digest: sha256:a27fd8080b517143cbbbab9dfb7c8…

springboot实现发送邮件开箱即用

springboot实现发送邮件开箱即用 环境依赖包yml配置Service层Controller层测试 环境 jdk17 springboot版本3.2.1 依赖包 <dependency><groupId>org.springframework.boot</groupId><artifactId>spring-boot-starter-mail</artifactId><ver…

C++初阶——类和对象

呀哈喽&#xff0c;我是结衣 C入门之后&#xff0c;我们就进入了C的初阶的学习了&#xff0c;在了解类和对象之前&#xff0c;我们还是先了解&#xff0c;面向过程和面向对象的初步认识。 在本篇博客中&#xff0c;我们要讲的内容有 1.面向过程和面向对象初步认识 2.类的引入 3…

DALL-E:Zero-Shot Text-to-Image Generation

DALL-E 论文是一个文本生成图片模型。 训练分为两个阶段 第一阶段&#xff0c;训练一个dVAE&#xff08;discrete variational autoencoder离散变分自动编码器&#xff09;&#xff0c;其将256 x 256的RGB图片转换为32 x 32的图片token。目的&#xff1a;降低图片的分辨率。图…

【ubuntu 22.04】安装中文版系统、中文语言包和中文输入法

在系统安装中的键盘布局选择时&#xff0c;选择Chinese - Chinese&#xff0c;此时会自动安装所有的中文语言包和ibus中文输入法系统安装成功重启后&#xff0c;点击设置 - 区域和语言 - 管理已安装的语言 * 根据提示安装更新后&#xff0c;将汉语&#xff08;中国&#xff09;…

小白入门之安装Navicat

重生之我在大四学JAVA 第四章 安装Navicat (mysql可视化工具) 这里Navicat是15版本&#xff0c;不是最新版&#xff0c;有新版强迫症的自行百度 傻瓜式安装一直下一步就行 完成后切记不要打开&#xff0c;不要打开&#xff0c;不要打开 可以打开刚刚安装的navicat了 切…