目录
1、论文
2、背景与动机
3、回答的问题
4、创新与卖点
5、实现细节
模型框架
具体步骤
简单代码示例
6、一些资料
1、论文
Masked Autoencoders Are Scalable Vision Learnershttps://arxiv.org/pdf/2111.06377.pdf
2、背景与动机
在深度学习和计算机视觉的领域中,预训练模型已经成为了提高下游任务性能的重要手段。传统上,许多预训练模型如ResNet、VGG等都是在大规模数据集(如ImageNet)上通过监督学习训练得到的。然而,监督学习需要大量的标记数据,这在成本和可扩展性上都是一个不小的挑战。
最近,自监督学习作为一个新兴研究领域,提供了一种无需手工标注数据的解决方案。自监督学习的一个关键点是设计预测任务,通过这些任务模型可以从输入数据本身学习到有用的表示。在自然语言处理(NLP)领域,BERT通过掩码语言模型(MLM)任务表现出色,这激发了计算机视觉领域对类似方法的探索。
MAE (Masked Autoencoder) 正是从这样的背景和动机出发,它将自监督学习中的掩码预测任务引入到视觉领域,致力于从图像数据中以无监督的方式学习高效的特征表示。
3、回答的问题
论文中回答了一个问题。为什么自监督在CV领域的发展要滞后于NLP呢?论文中给了两个解释:
(1)NLP主流方法是Transformer,视觉里CNN是主流方法,结构差异让视觉很难构造类似于“masked autoencoding”的任务。但是ViT的提出解决了这个问题;
(2)语言和视觉的信息密度(information density)差异巨大,前者是强语义的,高信息密度的(highly semantic and information-dense),在NLP中即使只mask一个token,对模型来说可能都是很难的任务,因此模型可以通过学习获得复杂的语言理解能力(sophisticated language understanding),但是对视觉图像来说,信息是高度冗余的,缺失一个patch,可能并不会让模型产生多少困惑,模型可以通过周围的像素信息进行推断
所以MAE做的一件事就是mask很高比例的patches,制造高难度的学习任务,方法简单但是极其有效
4、创新与卖点
MAE 的核心创新在于其独特的自监督预训练方法。不同于之前的自监督视觉模型通常需要对比学习或复杂的数据增强,MAE 提出了一种简洁高效的方法:
-
Masking 策略:MAE 对输入图像进行随机遮蔽,只露出一小部分像素,模型的任务是预测被遮蔽部分的原始像素。这种策略减少了模型需要处理的数据量,同时迫使模型学习丰富的上下文信息来重建图像。
-
编码器-解码器架构:MAE 采用了一个不对称的编码器-解码器架构,其中编码器只对未被遮蔽的部分进行处理,大幅减少了计算量。解码器则负责图像的重建工作,它的结构相对简单,因为其主要任务是理解编码器提供的特征。
-
预训练与微调:MAE 的预训练阶段不依赖于标签,这使得模型可以在非常大的数据集上进行训练。一旦预训练完成,MAE 可以通过微调在各种下游任务上实现优异的性能,包括分类、检测和分割等。
5、实现细节
模型框架
具体步骤
数据遮掩:首先,在输入图像或序列数据中随机选择一定比例的区域进行遮掩,将其替换为特定的遮掩标记(如0或[MASK])。
编码阶段:仅将未遮掩的数据部分输入到一个轻量级的Transformer编码器中,以提取局部上下文特征。
解码阶段:将编码后的向量传递给一个解码器,该解码器通常也是一个Transformer,但会对所有像素或位置进行解码预测,恢复出被遮掩部分的信息。
损失函数:使用L1或L2距离作为损失函数,衡量预测的像素值或词向量与原始未遮掩数据之间的差异。
预训练与微调:经过大规模无标签数据上的预训练后,可以将模型参数迁移到特定的下游任务中进行微调,进一步提升任务性能。
简单代码示例
import torch
import torch.nn as nn
import torch.nn.functional as Fclass PositionalEncoding(nn.Module):# 用于添加位置信息的模块,通常在Transformer结构中使用def __init__(self, d_model, dropout=0.1, max_len=5000):super(PositionalEncoding, self).__init__()self.dropout = nn.Dropout(p=dropout)pe = torch.zeros(max_len, d_model)position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-torch.log(torch.tensor(10000.0)) / d_model))pe[:, 0::2] = torch.sin(position * div_term)pe[:, 1::2] = torch.cos(position * div_term)pe = pe.unsqueeze(0).transpose(0, 1)self.register_buffer('pe', pe)def forward(self, x):x = x + self.pe[:x.size(0), :]return self.dropout(x)class Encoder(nn.Module):def __init__(self, embed_dim, num_layers, num_heads, mlp_ratio=4., qkv_bias=False, drop_rate=0., attn_drop_rate=0.):super(Encoder, self).__init__()self.layers = nn.ModuleList([nn.TransformerEncoderLayer(d_model=embed_dim, nhead=num_heads, dim_feedforward=int(embed_dim * mlp_ratio),dropout=drop_rate, attention_dropout=attn_drop_rate, bias_qkv=qkv_bias)for _ in range(num_layers)])def forward(self, src, mask=None):output = srcfor layer in self.layers:output = layer(output, src_key_padding_mask=mask)return outputclass MaskedAutoencoder(nn.Module):def __init__(self, image_size, patch_size, num_channels, embed_dim, num_layers, num_heads, mlp_ratio, num_classes):super(MaskedAutoencoder, self).__init__()self.patch_size = patch_sizeself.embed_dim = embed_dimself.num_patches = (image_size // patch_size) ** 2self.encoder = nn.Sequential(nn.Conv2d(num_channels, embed_dim, kernel_size=patch_size, stride=patch_size),nn.LayerNorm(embed_dim),)self.pos_embed = PositionalEncoding(embed_dim)self.transformer_encoder = Encoder(embed_dim, num_layers, num_heads, mlp_ratio)self.decoder = nn.Sequential(nn.Linear(embed_dim, embed_dim),nn.GELU(),nn.Linear(embed_dim, num_channels * patch_size ** 2),nn.PixelShuffle(patch_size),)self.to_patch_embedding = nn.Sequential(nn.Unflatten(dim=1, unflattened_size=(num_patches, embed_dim)),nn.Dropout(p=0.1),)def forward(self, x, mask_ratio=0.75):B, C, H, W = x.shapeassert H == W, "Input image must be square"x = self.encoder(x)x = self.pos_embed(x)# 随机掩码rand_mask = torch.rand(B, self.num_patches, 1, 1, device=x.device) < mask_ratiomasked_x = x.clone()masked_x[rand_mask] = 0.# 编码encoded_patches = self.transformer_encoder(self.to_patch_embedding(masked_x))# 解码reconstructed_image = self.decoder(encoded_patches)return reconstructed_image# 初始化模型
model = MaskedAutoencoder(image_size=224, patch_size=16, num_channels=3, embed_dim=768, num_layers=12, num_heads=12, mlp_ratio=4., num_classes=0)# 假设我们有输入数据x
x = torch.randn((10, 3, 224, 224))# 计算重构后的图像
reconstruction = model(x)
6、一些资料
MAE(Masked Autoencoders) - 知乎简介MAE(Masked Autoencoders)是用于CV的自监督学习方法,优点是扩展性强的(scalable),方法简单。在MAE方法中会随机mask输入图片的部分patches,然后重构这些缺失的像素。MAE基于两个核心设计:(1)不对称的(…https://zhuanlan.zhihu.com/p/446761025