ViT: transformer在图像领域的应用

文章目录

  • 1. 概要
  • 2. 方法
  • 3. 实验
    • 3.1 Compare with SOTA
    • 3.2 PRE-TRAINING DATA REQUIREMENTS
    • 3.3 SCALING STUDY
    • 3.4 自监督学习
  • 4. 总结
  • 参考

论文: An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale
代码:https://github.com/google-research/vision_transformer
代码2:https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py

我们在Transformer详解(1)—原理部分详细介绍了transformer在NLP领域应用的原理,transformer架构自发布以来已经在自然语言处理任务上广泛应用,今天我们将介绍如何将transformer架构应用在图像领域。

1. 概要

基于self-attention的网络架构在NLP领域中取得了很大的成功,但是在CV领域卷积网络架构仍然占据主导地位。受到transformer在NLP中应用成功的启发,也有很多工作尝试将self-attention与CNN网络结合,甚至有些工作直接替换CNN网络,理论上这些模型是高效的,由于这些特殊的注意力机制未与硬件加速器有效适配,因此在大规模的图像检测中,经典的ResNet网络架构仍然是SOTA

受到Transformer网络在NLP领域中成功适配的启发,作者提出对transformer尽可能少的修改,直接在图片上应用标准的transformer。为了实现这个目标,首先需要将图片分割成多个patch,并将这些patch转换成embedding作为transformer的输入。图片的patch就相当于NLP中的token

最后作者得到结论:在数据量不足的情况下进行训练时,ViT不能很好地泛化,效果不如CNN,不过在训练大规模数据时,vit的效果会反超CNN

2. 方法

在模型设计方面,version transformer尽量与原始transformer结构保持一致,因为NLP中的transformer具有高效的实现方式,这样可以开箱即用。模型的整体结构如下所示:
在这里插入图片描述
标准的 transformer 输入是一维向量序列,为了处理二维图像,将输入图片 x ∈ R H × W × C \mathbf{x}\in\mathbb{R}^{H\times W\times C} xRH×W×C 分割成一系列的patch,并将这些patch平整成一维向量,最终得到 x p ∈ R N × ( P 2 ⋅ C ) \mathbf{x}_p\in\mathbb{R}^{N\times(P^2\cdot C)} xpRN×(P2C),其中 ( H , W ) (H,W) (H,W)是原始图片分辨率, C C C 是图片的通道数, ( P , P ) (P,P) (P,P)是每个patch的分辨率, N = H W P 2 N=\frac{HW}{P^2} N=P2HW 是patch的个数,也可以看作是输入序列的长度。由于transformer每一层的输入向量维度都是固定的 D D D,因此需要通过一个可训练的线性层将 flatten patch 的维度从 P 2 C P^2C P2C转换成 D D D,这个线性层的输出称为patch的embedding.

和BERT的 [class] token 类似,在path embedding序列的首位增加了一个可学习向量 z 0 0 = x c l a s s z_0^0=x_{class} z00=xclass,该向量在transformer encoder的输出部分看做是图片的表征,在预训练和微调阶段,该表征后都会接一个分类层。

为了保持位置信息,位置embedding会加到patch embedding上,这里作者使用了一个一维可学习的位置向量,因为通过实验发现使用二维位置向量并没有获得很大的性能提升,通过以上流程处理后的embedding就是transformer的输入embedding。从输入图片到transformer encoder输出可由以下式子表示:

z 0 = [ x class ; x p 1 E ; x p 2 E ; ⋯ ; x p N E ] + E p o s ; E ∈ R ( P 2 ⋅ C ) × D , E p o s ∈ R ( N + 1 ) × D z ′ ℓ = M S A ( L N ( z ℓ − 1 ) ) + z ℓ − 1 ; ℓ = 1 … L z ℓ = M L P ( L N ( z ′ ℓ ) ) + z ′ ℓ ; ℓ = 1 … L y = L N ( z L 0 ) \begin{align} z_0 =&[\mathbf{x}_\text{class};\mathbf{x}_p^1\mathbf{E};\mathbf{x}_p^2\mathbf{E};\cdots;\mathbf{x}_p^N\mathbf{E}]+\mathbf{E}_{pos}; \ \ \ \mathbf{E}\in\mathbb{R}^{(P^{2}\cdot C)\times D}, \mathbf{E}_{pos}\in\mathbb{R}^{(N+1)\times D}\\ \mathbf{z}^{\prime}{}_{\ell} =& \mathrm{MSA(LN(z_{\ell-1}))+z_{\ell-1}};\ \ \ell=1\ldots L \\ \mathbf{z}_{\ell} = &\mathrm{MLP}(\mathrm{LN}(\mathbf{z^{\prime}}_\ell))+\mathbf{z^{\prime}}_\ell; \ \ \ \ell=1\ldots L \\ y =& \mathrm{LN}(\mathbf{z}_{L}^{0}) \end{align} z0=z=z=y=[xclass;xp1E;xp2E;;xpNE]+Epos;   ER(P2C)×D,EposR(N+1)×DMSA(LN(z1))+z1;  =1LMLP(LN(z))+z;   =1LLN(zL0)

其中 E E E 是patch维度转换矩阵, M S A MSA MSA是多头注意力层(multi-head self attention), L N LN LN是layer normalization 层, M L P MLP MLP是transformer中前馈网络层

另外,也可以使用CNN网络的特征图作为输入序列,在这种混合模型中,patch embeding 投影层将被用于改变CNN特征图的形状。

在微调阶段,将移除预训练的prediction layer,并新增一个零初始化的预测层,一般来说,在更高分辨率图像上微调是非常有益的。在喂入更高分辨率图像时,保持patch的尺寸不变,这样会造成输入序列长度增加,虽然ViT模型可以处理任意长的输入序列(直到内存不够),但是预训练的位置编码将无效,因此作者根据当前位置在原始图片中的位置,对预训练的位置编码采用2D插值的方法获取最新的位置编码

3. 实验

下文中将用一些简写来代表模型的尺寸和输入patch的尺寸,如ViT-L/16 代表模型为ViT-Large,输入patch的尺寸为 16 × 16 16 \times 16 16×16,下表展示了不同尺寸模型的配置及参数量
在这里插入图片描述
这里需要注意,由于输入序列长度与patch的尺寸成反比,所以,patch 尺寸越小,反而计算量越大

3.1 Compare with SOTA

在这里插入图片描述
TPU v3-core-days:代表计算量,All models were trained on TPUv3 hardware, and we
report the number of TPUv3-core-days taken to pre-train each of them, that is, the number of TPU
v3 cores (2 per chip) used for training multiplied by the training time in days

在这里插入图片描述
不同模型简介:

  • Big Transfer (BiT), which performs supervised transfer learning with large ResNets
  • VIVI – a ResNet co-trained on ImageNet and Youtube
  • S4L – supervised plus semi-supervised learning on ImageNet

3.2 PRE-TRAINING DATA REQUIREMENTS

作者经过实验得到如下结论:

  • 在小数据集上预训练,ViT-Large比ViT-Base要差,在大数据集上训练对ViT-Large比较有益
  • 在小数据集上预训练,ViT的效果比CNN还要差,在大数据集上预训练ViT的效果超过CNN
  • CNN网络的归纳有偏性在小数据集上是有用的,但是在大数据集上,直接从数据中学习相关的模式更有效
    在这里插入图片描述

3.3 SCALING STUDY

如下图所示,作者得到如下结论:

  • ViT在效果和计算量平衡之间相比ResNet占绝对优势,ResNet需要使用约3倍的算力来获得与ViT相似的结果
  • 混合模型在小计算量上相比ViT具有一定的优势,但是这种优势在大模型(大计算量)上逐渐消失
  • ViT在当前实验中貌似并没有饱和,这激励着未来的研究
    在这里插入图片描述

3.4 自监督学习

作者模仿BERT通过mask patch prediction任务进行自监督预训练,ViT-B/16在ImageNet上获得了79.9%的准确率,相比从随机初始化开始训练提升了2%,但是相比于监督学习仍然落后4%。

4. 总结

作者将图片看作是patch序列,并使用标准的Transformer对patch序列进行处理,最终在大数据集上预训练取得了很不错的效果,在图片分类任务上超过了很多SOTA模型。但也还存在一些挑战等待后期处理:

  • 将ViT应用在其他计算机视觉任务中,如目标检测、语义分割等
  • 还需进一步探索自监督预训练方法
  • 进一步扩大ViT模型的规模,可能会取得更好的效果

参考

如何理解Inductive bias?
Translation Equivariance
CNN中的Translation Equivariance【理解】
2D插值(2D interpolation)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/474241.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

读十堂极简人工智能课笔记05_无监督学习

1. 自我改善 1.1. 只有学会了如何学习和改变的人,才称得上是受过教育的人 1.1.1. 卡尔罗杰斯 1.2. 人工智能如果只是学习纯理论的游戏(从国际象棋和围棋到电脑游戏),其结果已然可以令人惊叹 1.3. 让大多数机器人玩叠叠乐游戏&…

Python算法题集_将有序数组转换为二叉搜索树

Python算法题集_将有序数组转换为二叉搜索树 题108:将有序数组转换为二叉搜索树1. 示例说明2. 题目解析- 题意分解- 优化思路- 测量工具 3. 代码展开1) 标准求解【极简代码递归】2) 改进版一【多行代码递归】3) 改进版二【极简代码递归传递下标】 4. 最优算法 本文为…

【类与对象 -2】学习类的6个默认成员函数中的构造函数与析构函数

目录 1.类的6个默认成员函数 2.构造函数 2.1概念 2.2特性 3.析构函数 3.1析构函数的概念 3.2特性 1.类的6个默认成员函数 如果一个类中什么成员都没有,简称为空类。 空类中真的什么都没有吗?并不是,任何类在什么都不写时,…

解锁Spring Boot中的设计模式—05.策略模式:探索【策略模式】的奥秘与应用实践!

1.策略者工厂模式(Map版本) 1.需求背景 假设有一个销售系统,需要根据不同的促销活动对商品进行打折或者其他形式的优惠。这些促销活动可以是针对不同商品类别的,比如男装、女装等。 2.需求实现 活动策略接口:定义了…

后端学习:Maven模型与Springboot框架

Maven 初识Maven Maven:是Apache旗下的一个开源项目,是一款用于管理和构建java项目的工具。 Maven的作用1.依赖管理2.统一项目结构3.项目构建依赖管理:方便快捷的管理项目依赖的资源(jar包),避免版本冲突问题   当使用maven进行项目依赖…

【教学类-19-10】20240214《ABAB式-规律黏贴18格-手工纸15*15CM-一页3种图案,AB一组样板,纵向、有边框》(中班)

背景需求 利用15*15CM手工纸制作AB色块手环(手工纸自带色彩),一页3个图案,2条为一组,画图案,黏贴成一个手环。 素材准备 代码展示 # # 作者:阿夏 # 时间:2024年2月14日 # 名称&…

19-k8s的附加组件-coreDNS组件

一、概念 coreDNS组件:就是将svc资源的名称解析成ClusterIP; kubeadm部署的k8s集群自带coreDNS组件,二进制部署需要自己手动部署; [rootk8s231 ~]# kubectl get pods -o wide -A k8s系统中安装了coreDNS组件后,会有一个…

BuildAdmin - 免费开源可商用!基于 ThinkPHP8 和 Vue3 等流行技术栈打造的商业级后台管理系统

一款包含 PHP 服务端和 Vue 前端代码的 admin 管理系统,实用性很强,推荐给大家。 BuildAdmin 是一个成熟的后台管理系统,后端服务采用 ThinkPHP8 ,数据库使用 Mysql,前端部分则使用当前流行的 Vue3 / TypeScript / Vi…

K8s进阶之路-安装部署K8s

参考:(部署过程参考的下面红色字体文档链接就可以,步骤很详细,重点部分在下面做了标注) 安装部署K8S集群文档: 使用kubeadm方式搭建K8S集群 GitBook 本机: master:10.0.0.13 maste…

Vulnhub靶场 DC-6

目录 一、环境搭建 二、主机发现 三、漏洞复现 1、wpscan工具 2、后台识别 dirsearch 3、爆破密码 4、rce漏洞利用 activity monitor 5、rce写shell 6、新线索 账户 7、提权 8、拿取flag 四、总结 一、环境搭建 Vulnhub靶机下载: 官网地址&#xff1a…

stable diffusion webui学习总结(2):技巧汇总

一、脸部修复:解决在低分辨率下,脸部生成异常的问题 勾选ADetailer,会在生成图片后,用更高的分辨率,对于脸部重新生成一遍 二、高清放大:低分辨率照片提升到高分辨率,并丰富内容细节 1、先通过…

springboot登录校验

一、登录功能 二、登录校验 2.1 会话技术 2.2 JWT令牌 JWT令牌解析: 如何校验JWT令牌?Filter和Interceptor两种方式。 2.3 过滤器Filter 2.3.1 快速入门 修改上述代码: 2.3.2 详解 2.3.3 登录校验-Filter 2.4 Interceptor拦截器 2.4.1 …