5、MAE:探索视觉预训练模型

目录

1、论文

2、背景与动机

3、回答的问题

4、创新与卖点

5、实现细节

模型框架

具体步骤

简单代码示例

6、一些资料


1、论文

Masked Autoencoders Are Scalable Vision Learnersicon-default.png?t=N7T8https://arxiv.org/pdf/2111.06377.pdf

2、背景与动机

        在深度学习和计算机视觉的领域中,预训练模型已经成为了提高下游任务性能的重要手段。传统上,许多预训练模型如ResNet、VGG等都是在大规模数据集(如ImageNet)上通过监督学习训练得到的。然而,监督学习需要大量的标记数据,这在成本和可扩展性上都是一个不小的挑战。

        最近,自监督学习作为一个新兴研究领域,提供了一种无需手工标注数据的解决方案。自监督学习的一个关键点是设计预测任务,通过这些任务模型可以从输入数据本身学习到有用的表示。在自然语言处理(NLP)领域,BERT通过掩码语言模型(MLM)任务表现出色,这激发了计算机视觉领域对类似方法的探索。

        MAE (Masked Autoencoder) 正是从这样的背景和动机出发,它将自监督学习中的掩码预测任务引入到视觉领域,致力于从图像数据中以无监督的方式学习高效的特征表示。

3、回答的问题

        论文中回答了一个问题。为什么自监督在CV领域的发展要滞后于NLP呢?论文中给了两个解释:

(1)NLP主流方法是Transformer,视觉里CNN是主流方法,结构差异让视觉很难构造类似于“masked autoencoding”的任务。但是ViT的提出解决了这个问题;

(2)语言和视觉的信息密度(information density)差异巨大,前者是强语义的,高信息密度的(highly semantic and information-dense),在NLP中即使只mask一个token,对模型来说可能都是很难的任务,因此模型可以通过学习获得复杂的语言理解能力(sophisticated language understanding),但是对视觉图像来说,信息是高度冗余的,缺失一个patch,可能并不会让模型产生多少困惑,模型可以通过周围的像素信息进行推断

        所以MAE做的一件事就是mask很高比例的patches,制造高难度的学习任务,方法简单但是极其有效

4、创新与卖点

MAE 的核心创新在于其独特的自监督预训练方法。不同于之前的自监督视觉模型通常需要对比学习或复杂的数据增强,MAE 提出了一种简洁高效的方法:

  1. Masking 策略:MAE 对输入图像进行随机遮蔽,只露出一小部分像素,模型的任务是预测被遮蔽部分的原始像素。这种策略减少了模型需要处理的数据量,同时迫使模型学习丰富的上下文信息来重建图像。

  2. 编码器-解码器架构:MAE 采用了一个不对称的编码器-解码器架构,其中编码器只对未被遮蔽的部分进行处理,大幅减少了计算量。解码器则负责图像的重建工作,它的结构相对简单,因为其主要任务是理解编码器提供的特征。

  3. 预训练与微调:MAE 的预训练阶段不依赖于标签,这使得模型可以在非常大的数据集上进行训练。一旦预训练完成,MAE 可以通过微调在各种下游任务上实现优异的性能,包括分类、检测和分割等。

5、实现细节

模型框架

具体步骤

  1. 数据遮掩:首先,在输入图像或序列数据中随机选择一定比例的区域进行遮掩,将其替换为特定的遮掩标记(如0或[MASK])。

  2. 编码阶段:仅将未遮掩的数据部分输入到一个轻量级的Transformer编码器中,以提取局部上下文特征。

  3. 解码阶段:将编码后的向量传递给一个解码器,该解码器通常也是一个Transformer,但会对所有像素或位置进行解码预测,恢复出被遮掩部分的信息。

  4. 损失函数:使用L1或L2距离作为损失函数,衡量预测的像素值或词向量与原始未遮掩数据之间的差异。

  5. 预训练与微调:经过大规模无标签数据上的预训练后,可以将模型参数迁移到特定的下游任务中进行微调,进一步提升任务性能。

简单代码示例

import torch
import torch.nn as nn
import torch.nn.functional as Fclass PositionalEncoding(nn.Module):# 用于添加位置信息的模块,通常在Transformer结构中使用def __init__(self, d_model, dropout=0.1, max_len=5000):super(PositionalEncoding, self).__init__()self.dropout = nn.Dropout(p=dropout)pe = torch.zeros(max_len, d_model)position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-torch.log(torch.tensor(10000.0)) / d_model))pe[:, 0::2] = torch.sin(position * div_term)pe[:, 1::2] = torch.cos(position * div_term)pe = pe.unsqueeze(0).transpose(0, 1)self.register_buffer('pe', pe)def forward(self, x):x = x + self.pe[:x.size(0), :]return self.dropout(x)class Encoder(nn.Module):def __init__(self, embed_dim, num_layers, num_heads, mlp_ratio=4., qkv_bias=False, drop_rate=0., attn_drop_rate=0.):super(Encoder, self).__init__()self.layers = nn.ModuleList([nn.TransformerEncoderLayer(d_model=embed_dim, nhead=num_heads, dim_feedforward=int(embed_dim * mlp_ratio),dropout=drop_rate, attention_dropout=attn_drop_rate, bias_qkv=qkv_bias)for _ in range(num_layers)])def forward(self, src, mask=None):output = srcfor layer in self.layers:output = layer(output, src_key_padding_mask=mask)return outputclass MaskedAutoencoder(nn.Module):def __init__(self, image_size, patch_size, num_channels, embed_dim, num_layers, num_heads, mlp_ratio, num_classes):super(MaskedAutoencoder, self).__init__()self.patch_size = patch_sizeself.embed_dim = embed_dimself.num_patches = (image_size // patch_size) ** 2self.encoder = nn.Sequential(nn.Conv2d(num_channels, embed_dim, kernel_size=patch_size, stride=patch_size),nn.LayerNorm(embed_dim),)self.pos_embed = PositionalEncoding(embed_dim)self.transformer_encoder = Encoder(embed_dim, num_layers, num_heads, mlp_ratio)self.decoder = nn.Sequential(nn.Linear(embed_dim, embed_dim),nn.GELU(),nn.Linear(embed_dim, num_channels * patch_size ** 2),nn.PixelShuffle(patch_size),)self.to_patch_embedding = nn.Sequential(nn.Unflatten(dim=1, unflattened_size=(num_patches, embed_dim)),nn.Dropout(p=0.1),)def forward(self, x, mask_ratio=0.75):B, C, H, W = x.shapeassert H == W, "Input image must be square"x = self.encoder(x)x = self.pos_embed(x)# 随机掩码rand_mask = torch.rand(B, self.num_patches, 1, 1, device=x.device) < mask_ratiomasked_x = x.clone()masked_x[rand_mask] = 0.# 编码encoded_patches = self.transformer_encoder(self.to_patch_embedding(masked_x))# 解码reconstructed_image = self.decoder(encoded_patches)return reconstructed_image# 初始化模型
model = MaskedAutoencoder(image_size=224, patch_size=16, num_channels=3, embed_dim=768, num_layers=12, num_heads=12, mlp_ratio=4., num_classes=0)# 假设我们有输入数据x
x = torch.randn((10, 3, 224, 224))# 计算重构后的图像
reconstruction = model(x)

6、一些资料

MAE(Masked Autoencoders) - 知乎简介MAE(Masked Autoencoders)是用于CV的自监督学习方法,优点是扩展性强的(scalable),方法简单。在MAE方法中会随机mask输入图片的部分patches,然后重构这些缺失的像素。MAE基于两个核心设计:(1)不对称的(…icon-default.png?t=N7T8https://zhuanlan.zhihu.com/p/446761025

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/341552.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

SpringBoot知识02

1、快速生成mapper和service &#xff08;自动生成简单的单表sql&#xff09; 2、springboot配置swagger&#xff08;路径不用加/api&#xff09; &#xff08;1&#xff09;主pom导包&#xff08;子pom要引用&#xff0c;可选依赖&#xff09; <!-- swagger3…

Odrive 学习系列一:vscode 编译Odrive

搭建环境可参考Markerbase教程,很详细了。 简单说一两点: 解压ODrive-fw-v0.5.1.zip: 打开ODrive-fw-v0.5.1文件夹,找到Firmware文件夹,用vscode打开该文件夹: 按照以下内容操作: 编译工程: 打开 中断(terminal),输入 make -j4 回车 进行编译。编译…

Linux之Iptables简易应用

文档形成时期&#xff1a;2009-2024年 和iptables打交道有15年了&#xff0c;经过无数实践后&#xff0c;形成一个简易应用文档。 文档主题是简易应用&#xff0c;所以其原理不详述了。 因软件世界之复杂和个人能力之限&#xff0c;难免疏漏和错误&#xff0c;欢迎指正。 文章目…

互联网金融P2P主业务场景自动化测试

互联网金融P2P行业&#xff0c;近三年来发展迅速&#xff0c;如火如荼。 据不完全统计&#xff0c;全国有3000的企业。 “互联网”企业&#xff0c;几乎每天都会碰到一些奇奇怪怪的bug&#xff0c;作为在互联网企业工作的测试人员&#xff0c;风险和压力都巨大。那么我们如何降…

C#上位机与欧姆龙PLC的通信11----【爆肝】上位机应用开发(Winform版)

1、先上图 前面10讲&#xff0c;让你爽煹了肝&#xff0c;已经进入最后收尾阶段&#xff0c;这节来个常规应用&#xff0c;让前面的技能直接飞上天&#xff0c;我们要做的界面软件是这样的&#xff0c;虽然没有潘金莲漂亮&#xff0c;但也是爆抱&#xff1a; 2、如何爆&#x…

案例:Web组件抽奖案例

文章目录 介绍相关概念相关权限约束与限制完整示例 代码结构解读H5小程序Web组件总结 介绍 本篇Codelab是基于ArkTS的声明式开发范式的样例&#xff0c;主要介绍了Web组件如何加载本地和云端H5小程序。所加载的页面是由HTMLCSSJavaScript实现的完整小应用。样例主要包含以下功…

【Harmony OS - 消息通知】

应用可以通过接口发送通知消息&#xff0c;提醒用户关注应用中的变化。用户可以在通知栏查看和操作通知内容&#xff0c;通常用于当应用处于后台时&#xff0c;发送&#xff0c;本文主要来介绍在Harmony OS中的三种消息通知。 基础通知 总体流程有三步&#xff1a; 导入noti…

Mac M1 Parallels CentOS7.9 Install Jenkins

官网: https://www.jenkins.io/ 一、Install & Check Java Env Oracle官网下载Java: https://www.oracle.com/cn/ # 拷贝到Jenkins服务器 scp Downloads/jdk-8u391-linux-aarch64.tar.gz root10.211.55.34:~# 解压 mkdir -p /opt/java && tar -zxvf jdk-8u391-li…

类名.this:内部类引用外部类实例

类名.this是啥意思&#xff1f; 今天在看尚硅谷的课程时里面讲了这么一句话&#xff1a; 集合在遍历时需要先创建一个容器&#xff0c;存放集合的数据&#xff0c;这样做浪费内存 想去验证下&#xff0c;就翻了翻ArrayList的迭代过程源码 在ArrayList的迭代器类Itr&#xff08;…

Vue过滤器详解

聚沙成塔每天进步一点点 本文内容 ⭐ 专栏简介基本用法多个过滤器的串联过滤器在指令中的应用全局过滤器 ⭐ 本期推荐 ⭐ 专栏简介 Vue学习之旅的奇妙世界 欢迎大家来到 Vue 技能树参考资料专栏&#xff01;创建这个专栏的初衷是为了帮助大家更好地应对 Vue.js 技能树的学习。每…

螺纹钢负公差轧制中的测径仪应用

1、负公差轧制意义 为了满足生产使用要求&#xff0c;并根据轧制水平&#xff0c;在产品标准冲规定钢材尺寸的波动范围&#xff0c;允许钢材的实际尺寸与公称尺之间有一定的偏差&#xff0c;这个偏差一般称公差&#xff0c;公差分正、负公差&#xff0c;钢材按负公差轧制时&…

selenium不自动关闭chrome,selenium hello world

selenium不自动关闭chrome 用visual studio的话&#xff0c;右键&#xff0c;在终端运行。 from selenium import webdriveroptions webdriver.ChromeOptions() options.add_experimental_option("detach", True) driver webdriver.Chrome(optionsoptions) url …