Visual Transformer (ViT)模型详解

1 Vit简介

1.1 Vit的由来

ViT是2020年Google团队提出的将Transformer应用在图像分类的模型,虽然不是第一篇将transformer应用在视觉任务的论文,但是因为其模型“简单”且效果好,可扩展性强(scalable,模型越大效果越好),成为了transformer在CV领域应用的里程碑著作,也引爆了后续相关研究。

论文地址:https://arxiv.org/pdf/2010.11929.pdf

Visual Transformer (ViT) 出自于论文《AN IMAGE IS WORTH 16X16 WORDS: TRANSFORMERS FOR IMAGE RECOGNITION AT SCALE》,是基于Transformer的模型在视觉领域的开篇之作。ViT模型是基于Transformer Encoder模型的。

1.2 Vit如何工作

我们知道Transformer模型最开始是用于自然语言处理(NLP)领域的,NLP主要处理的是文本、句子、段落等,即序列数据。但是视觉领域处理的是图像数据,因此将Transformer模型应用到图像数据上面临着诸多挑战,理由如下:

  • 与单词、句子、段落等文本数据不同,图像中包含更多的信息,并且是以像素值的形式呈现。

  • 如果按照处理文本的方式来处理图像,即逐像素处理的话,即使是目前的硬件条件也很难。

  • Transformer缺少CNNs的归纳偏差,比如平移不变性和局部受限感受野。

  • CNNs是通过相似的卷积操作来提取特征,随着模型层数的加深,感受野也会逐步增加。但是由于Transformer的本质,其在计算量上会比CNNs更大。

  • Transformer无法直接用于处理基于网格的数据,比如图像数据。

为了解决上述问题,Google的研究团队提出了ViT模型,它的本质其实也很简单,既然Transformer只能处理序列数据,那么我们就把图像数据转换成序列数据就可以了呗。下面来看下ViT是如何做的。

1.3 ViT模型架构

我们先结合下面的动图来粗略地分析一下ViT的工作流程,如下:

  • 将一张图片分成patches

  • 将patches铺平

  • 将铺平后的patches的线性映射到更低维的空间

  • 添加位置embedding编码信息

  • 将图像序列数据送入标准Transformer encoder中去

  • 在较大的数据集上预训练

  • 在下游数据集上微调用于图像分类

 模型由三个模块组成:

  • Linear Projection of Flattened Patches(Embedding层)
  • Transformer Encoder
  • MLP Head(最终用于分类的层结构)

Embedding层

对于标准的Transformer模块,要求输入的是token(向量)序列,即二维矩阵[num_token, token_dim],如下图,token0-9对应的都是向量,以ViT-B/16为例,每个token向量长度为768。

对于图像数据而言,其数据格式为[H, W, C]是三维矩阵明显不是Transformer想要的。所以需要先通过一个Embedding层来对数据做个变换。如下图所示,首先将一张图片按给定大小分成一堆Patches。以ViT-B/16为例,将输入图片(224x224)按照16x16大小的Patch进行划分,划分后会得到196个Patches。接着通过线性映射将每个Patch映射到一维向量中,以ViT-B/16为例,每个Patche数据shape为[16, 16, 3]通过映射得到一个长度为768的向量(后面都直接称为token)。[16, 16, 3] -> [768]

在代码实现中,直接通过一个卷积层来实现。 以ViT-B/16为例,直接使用一个卷积核大小为16x16,步距为16,卷积核个数为768的卷积来实现。通过卷积[224, 224, 3] -> [14, 14, 768],然后把H以及W两个维度展平即可[14, 14, 768] -> [196, 768],此时正好变成了一个二维矩阵,正是Transformer想要的。

Transformer Encoder

Transformer Encoder其实就是重复堆叠Encoder Block L次,主要由Layer Norm、Multi-Head Attention、Dropout和MLP Block几部分组成。

​​
MLP Head

上面通过Transformer Encoder后输出的shape和输入的shape是保持不变的,以ViT-B/16为例,输入的是[197, 768]输出的还是[197, 768]。这里我们只是需要分类的信息,所以我们只需要提取出[class]token生成的对应结果就行,即[197, 768]中抽取出[class]token对应的[1, 768]。接着我们通过MLP Head得到我们最终的分类结果。MLP Head原论文中说在训练ImageNet21K时是由Linear+tanh激活函数+Linear组成。但是迁移到ImageNet1K上或者你自己的数据上时,只用一个Linear即可。


2 ViT工作原理

我们将上图展示的过程近一步分解为6步,接下来一步一步地来解析它的原理。如下图:

2.1 步骤1、将图片转换成patches序列

这一步很关键,为了让Transformer能够处理图像数据,第一步必须先将图像数据转换成序列数据,但是怎么做呢?假如我们有一张图片x\epsilon R^{H*W*C},patch大小为p,那么我们可以创建N个图像patches,可以表示为x_{p}\epsilon R^{N*(p^{2}c)},其中N=\frac{HW}{P^{2}}N就是序列的长度,类似一个句子中单词的个数。在上面的图中,可以看到图片被分为了9个patches。

2.2 步骤2、将Patches铺平

在原论文中,作者选用的patch大小为16,那么一个patch的shape为(3,16,16),维度为3,将它铺平之后大小为3x16x16=768。即一个patch变为长度为768的向量。不过这看起来还是有点大,此时可以使用加一个Linear transformation,即添加一个线性映射层,将patch的维度映射到我们指定的embedding的维度,这样就和NLP中的词向量类似了。

2.3 步骤3、添加Position embedding

与CNNs不同,此时模型并不知道序列数据中的patches的位置信息。所以这些patches必须先追加一个位置信息,也就是图中的带数字的向量。实验表明,不同的位置编码embedding对最终的结果影响不大,在Transformer原论文中使用的是固定位置编码,在ViT中使用的可学习的位置embedding 向量,将它们加到对应的输出patch embeddings上。

2.4 步骤4、添加class token

在输入到Transformer Encoder之前,还需要添加一个特殊的class token,这一点主要是借鉴了BERT模型。添加这个class token的目的是因为,ViT模型将这个class token在Transformer Encoder的输出当做是模型对输入图片的编码特征,用于后续输入MLP模块中与图片label进行loss计算。

2.5 步骤5、输入Transformer Encoder

将patch embedding和class token拼接起来输入标准的Transformer Encoder中。 Transformer Encoder其实就是重复堆叠Encoder Block L次,主要由Layer Norm、Multi-Head Attention、Dropout和MLP Block几部分组成。

2.6 步骤6、分类

注意Transformer Encoder的输出其实也是一个序列,但是在ViT模型中只使用了class token的输出,将其送入MLP模块中,去输出最终的分类结果。

3 模型搭建参数

在论文的Table1中有给出三个模型(Base/ Large/ Huge)的参数,在源码中除了有Patch Size为16x16的外还有32x32的。

其中:

Layers就是Transformer Encoder中重复堆叠Encoder Block的次数 L。
Hidden Size就是对应通过Embedding层(Patch Embedding + Class Embedding + Position Embedding)后每个token的dim(序列向量的长度)
MLP Size是Transformer Encoder中MLP Block第一个全连接的节点个数(是token长度的4倍)
Heads代表Transformer中Multi-Head Attention的heads数。
 

4 结果分析

上表是论文用来对比ViT,Resnet(和刚刚讲的一样,使用的卷积层和Norm层都进行了修改)以及Hybrid模型的效果。通过对比可得出结论:

  • 在训练epoch较少时Hybrid优于ViT -> Epoch小选Hybrid
  • 当epoch增大后ViT优于Hybrid -> Epoch大选ViT
     

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/313912.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Windows系统历史版本简介详细版

学习目标: 目录 学习目标: 学习内容: 学习产出: Windows 11的全新用户界面设计:学习新的任务栏、开始菜单、窗口管理等界面元素的使用与操作。 Windows 11的新功能和特点:学习新的虚拟桌面、Microsoft Team…

07-2-接口文档管理工具-swagger注解使用__ev

swagger参考demo package com.example.swagger2.controller;import com.example.swagger2.exception.SwaggerException; import com.example.swagger2.model.User; import io.swagger.annotations.*; import org.springframework.web.bind.annotation.*;import java.util.Has…

多模态大模型的前世今生

1 引言 前段时间 ChatGPT 进行了一轮重大更新:多模态上线,能说话,会看图!微软发了一篇长达 166 页的 GPT-4V 测评论文,一时间又带起了一阵多模态的热议,随后像是 LLaVA-1.5、CogVLM、MiniGPT-5 等研究工作…

HTTP协议编程实战(二)实战二

使用析构函数主要是在里面关闭套接字(socket); waitForReadyRead()里面参数是毫秒,失败返回false; \r\n表示请求头部已经结束了,HTTP/1.1是版本号,200 ok表示请求响应成功 关闭的话就在前面加/

【阅读笔记】LoRAHub:Efficient Cross-Task Generalization via Dynamic LoRA Composition

一、论文信息 1 论文标题 LoRAHub:Efficient Cross-Task Generalization via Dynamic LoRA Composition 2 发表刊物 NIPS2023_WorkShop 3 作者团队 Sea AI Lab, Singapore 4 关键词 LLMs、LoRA 二、文章结构 #mermaid-svg-Gn81hPysu7z59nlv {font-family:&…

ARM CCA机密计算软件架构之内存加密上下文(MEC)

内存加密上下文(MEC) 内存加密上下文是与内存区域相关联的加密配置,由MMU分配。 MEC是Arm Realm Management Extension(RME)的扩展。RME系统架构要求对Realm、Secure和Root PAS进行加密。用于每个PAS的加密密钥、调整或加密上下文在该PAS内是全局的。例如,对于Realm PA…

LLM应用的分块策略

每日推荐一篇专注于解决实际问题的外文,精准翻译并深入解读其要点,助力读者培养实际问题解决和代码动手的能力。 欢迎关注公众号 原文标题:Chunking Strategies for LLM Applications 原文地址:https://www.pinecone.io/learn/c…

电子招标采购系统源码之从供应商管理到采购招投标、采购合同、采购执行的全过程数字化管理。

在数字化时代,采购管理也正经历着前所未有的变革。全过程数字化采购管理成为了企业追求高效、透明和规范的关键。该系统通过Spring Cloud、Spring Boot2、Mybatis等先进技术,打造了从供应商管理到采购招投标、采购合同、采购执行的全过程数字化管理。通过…

MySQL 数值函数,字符串函数与多表查询

MySQL像其他语言一样,也提供了很多库函数,分为单行函数和分组函数(聚合函数),我们这里先简易介绍一些函数,熟悉就行,知道怎么使用即可. 数值函数 三角函数 指数与对数函数 进制间的转换函数 字符串函数 注:LPAD函数是右对齐,RPAD函数是左对齐 多表查询 注:如果为表起了别名,就…

Origin 2021软件安装包下载及安装教程

Origin 2021下载链接:https://docs.qq.com/doc/DUnJNb3p4VWJtUUhP 1.选中下载的压缩包,然后鼠标右键选择解压到"Origin 2021"文件夹 2.双击打开“Setup”文件夹 3.选中“Setup.exe”鼠标右键点击“以管理员身份运行” 4.点击“下一步" 5…

Photoshop显示16位/32位像素值

打开“信息”窗口-单击“画笔”图标-子菜单中选择16位/32位

c++_09_继承

1 继承 C的继承是弱继承 继承的语法: class 子类 : 继承方式1 基类1, 继承方式2 基类2, ... { ... }; 继承方式: 共有继承 public 保护继承 protected 私有继承 private 2 继承的基本属性(3种继承方式均有) 继承所…