Vision Transformer模型架构详解

🎀个人主页: https://zhangxiaoshu.blog.csdn.net
📢欢迎大家:关注🔍+点赞👍+评论📝+收藏⭐️,如有错误敬请指正!
💕未来很长,值得我们全力奔赴更美好的生活!

前言

2019年开始,自然语言处理(NLP)领域抛弃了循环神经网络(RNN)序列依赖的问题,开始采用Attention is All you need的Transformer结构[1],其中的Attention是一种可以让模型专注于重要的信息并能够充分学习和吸收的技术。在NLP领域中,伴随着各种语言Transformer模型的提出使得多项语言处理任务的精度和模型深度开始飞速提升。由于基于Transformer的预训练语言模型非常成功,研究者开始探索其在视觉领域的应用。2020年10月,Google创新性的设计了用于分类的Vision Transformer模型[2]—ViT。此后视觉Transformer模型的研究进入了快车道,本文主要对Vision Transformer模型架构进行详细介绍以及在pytorch中的使用方法进行介绍。


文章目录

  • 前言
  • 一、Vision Transformer模型架构
    • 1. Embedding层结构
    • 2. Transformer Encoder结构
      • (1)层归一化(Layer Norm)
      • (2)多头注意力机制(Multi-Head Attention)
      • (3)Dropout/DropPath
      • (4)MLP Block
    • 3. MLP Head结构
  • 二、PyTorch实现
    • 1. 首先安装vit-pytorch库:
    • 2.导入进行调用:
  • 总结


一、Vision Transformer模型架构

下图是原论文中作者给出的关于Vision Transformer的模型总体框架图:

在这里插入图片描述
从图中可以看出,Vision Transformer模型主要由三部分组成:第一部分为Linear Projection of Flattened Patches,也被称为Embedding层,主要用于将输入的图片数据转化为适合Transformer结构处理的形式。第二部分为Transformer Encoder部分,它是整个ViT模型的核心板块,在图右侧给出了更加详细的结构,它主要由层归一化(Layer Norm)、多头注意力机制(Multi-Head Attention)、Dropout/DropPath、MLP Block四部分组成用于学习输入图像数据的特征。第三部分为MLP Head,它是最终用于分类的层结构。下面本设计将对每一个组成部分进行一个详细介绍。

1. Embedding层结构

在视觉Transform模型中,其Transformer Encoder模块的输入形式是一个向量(token)序列,即一个二维矩阵[num_token, token_dim]的形式,如上图所示,输入的粉色小块token0-9对应的都是向量序列。

但是,图像处理和语言处理不一样,它的数据格式和Transformer Encoder输入格式是不一样,而是一个三维矩阵[H, W, C]的形式。所以在视觉Transform模型中首先加入了一个Embedding层结构用于将数据变化为向量序列。其主要过程为:首先将输入的图片形式数据按照模型定义的切割大小切割成多个小块(Patches),然后将切割的小块通过维度变化映射成向量形式。以常见的ViT-B/16为例,它首先将输入图片( 224 × 224 224\times224 224×224)按照 16 × 16 16\times16 16×16的大小进行切分得到196个Patches,接着通过线性映射将每一个Patches(16, 16, 3)映射成一个长度为768的向量。

在具体实现代码时,可以通过一个卷积层和Flatten层来直接实现。以ViT-B/16为例,如图所示其卷积层的参数为:卷积核大小是16x16、步距是16、卷积核的个数是768。数据通过卷积层后维度从(224, 224, 3)变化为(14, 14, 768),接着,将H和W两个维度展平即Flatten操作即可变化为(196, 768)这样的二维矩阵形式,这正是Transformer Encoder的输入格式。
在这里插入图片描述
除了将输入数据的形式变化为Transformer Encoder的输入格式,模型还在输入Transformer Encoder之前加入了[class]token以及Position Embedding,如下图所示。[class]token是参考了BERT所设计的,它是一个可以学习的参数,用于拼接到tokens中专门用于图像数据的分类。以ViT-B/16为例,就是让一个768长度的向量,与从Flatten层输出的数据拼接在一起,即,Cat((1, 768),(196, 768))—>(197, 768)。Position Embedding也是一个可以学习的参数。它是直接叠加在tokens上的(Add),因为对于图像数据而言,每一块和每一块在都有一定的位置依赖关系,所以Position Embedding主要用于表达Patches之间的位置关系。以ViT-B/16为例,就是让一个(197, 768)的向量与之前得到的(197, 768)向量相加。
在这里插入图片描述

2. Transformer Encoder结构

Transformer Encoder其实就是将Encoder Block 重复堆叠L次, Encoder Block结构图如下图2.4所示,主要由层归一化(Layer Norm)多头注意力机制(Multi-Head Attention)Dropout/DropPathMLP Block四部分组成。

(1)层归一化(Layer Norm)

层归一化(Layer Norm):这是一种主要针对NLP领域提出的归一化方法,这里是对每个token进行归一化处理。目前的归一化层主要有BN、LN、IN、GN和SN五种方法,它解决了深度神经网络内部协方差偏移问题,是一种将深度神经网络之间的数据进行归一化的算法,使得深度学习的训练过程中梯度变化趋于稳定,从而使网络在训练时达到快速收敛的目的。将输入的图像shape记为[N, C, H, W],这些方法的主要不同之处是,BatchNorm是在Batch上进行的,对NHW做归一化,对于较小的Batch Size没有太大的作用;LayerNorm是在通道方向上进行的,对CHW归一化,对RNN有很大的作用;InstanceNorm是在图像的像素上进行的,对HW做归一化,主要用在风格化迁移等方面;GroupNorm首先将Channel进行分组,然后再做归一化;SwitchableNorm是将BN、LN、IN结合并给予权重,让网络自己去学习归一化层应当使用的方法。

*有关BN、LN、IN、GN归一化方法的详细介绍可以看我这篇文章:神经网络常用归一化和正则化方法解析(一);

在这里插入图片描述
Layer Norm即层归一化针对神经网络的某一层的所有输入按照以下公式进行归一化操作:

H H H是某一层中隐藏结点的数量, l l l表示层数,可以计算得到Layer Norm的归一化统计量 μ l \mu^l μl σ l \sigma^l σl,如下式:

μ l = 1 H ∑ i = 1 H a i l \mu^l=\frac{1}{H}\sum_{i=1}^{H}a_i^l μl=H1i=1Hail

σ l = 1 H ∑ i = 1 H ( a l − μ l ) 2 \sigma^l=\sqrt{\frac{1}{H}\sum_{i=1}^{H}\left(a^l-\mu^l\right)^2} σl=H1i=1H(alμl)2

其中 a l a^l al表示一个中间输出结果的总和。上面的统计量和样本数没有关系,而是和隐藏层的结点数有关,我们甚至可以使 Batch Size = 1。于是,我们可以根据约定的统计量进行归一化处理,

a ^ l = a l − μ l ( σ l ) 2 + ε {\hat{a}}^l=\frac{a^l-\mu^l}{\sqrt{\left(\sigma^l\right)^2+\varepsilon}} a^l=(σl)2+ε alμl

同样,在Layer Norm中常使用参数增益(gain)和偏置(bias)这两个参数来保障归一化操作不会破坏之前的信息,同BatchNorm中的 γ \gamma γ β \beta β

y i = γ a ^ l + β y_i=\gamma{\hat{a}}^l+\beta yi=γa^l+β

从以上公式可以看到, LN中同层神经元输入拥有相同的均值和方差,不同的输入样本有不同的均值和方差。所以,LN与Batch的大小无关,也不取决于输入Sequence的深度,所以可以在batchsize为1和RNN中对边长的输入Sequence进行Normalize操作。

(2)多头注意力机制(Multi-Head Attention)

多头注意力机制(Multi-Head Attention):通过多个注意力机制的并行组合,将独立的注意力输出串联起来,预期维度得到线性地转化。直观看来,多个注意头允许对序列的不同部分进行注意力运算

对于Self-Attention来说,假设输入的token长度为 L L L,则输入为 [ x 1 , x 2 . . . x L , ] [x_1,x_2...x_L,] [x1,x2...xL,],然后分别将 x 1 x 2 . . . x L x_1x_2...x_L x1x2...xL分别通过三个变化矩阵 W q , W k , W v W_q,W_k,W_v Wq,Wk,Wv(这三个参数是可训练的、共享的)得到对应的 q i , k i , v i , q^i,k^i,v^i, qi,ki,vi, 并将 q , k , v q,k,v q,k,v向量序列记为 Q , K , V Q,K,V Q,K,V。计算过程如下式所示,具体实现时使用全连接层实现。

( Q , K , V ) = ( q i , k i , v i ) = x i ∙ ( W q , W k , W v ) (Q,K,V)=\left(q^i,k^i,v^i\right)=x_i\bullet\left(W_q,W_k,W_v\right) (Q,K,V)=(qi,ki,vi)=xi(Wq,Wk,Wv)

其中 i = 1 , 2... L i=1,2...L i=1,2...L q q q表示query,后续会去和每一个k进行匹配, k k k代表key,后续会被每个 q q q匹配, v v v代表从 x x x中提取得到的信息value,后续 q q q k k k匹配的过程可以理解成计算两者的相关性,相关性越大对应 v v v的权重也就越大。

接着将 Q Q Q中的每一个 q i q^i qi去和 K K K中的每一个 k j k^j kj进行匹配,即点积操作。然后再除以 L \sqrt L L 得到对应的 α i , j \alpha_{i,j} αi,j,这样做的目的是进行点乘后的数值很大,导致通过Softmax后梯度变的很小,所以通过除以 L \sqrt L L 来进行缩放。具体计算过程如下式所示。

α i , j = q i ( k j ) T L \alpha_{i,j}=\frac{q^i\left(k^j\right)^T}{\sqrt L} αi,j=L qi(kj)T

α i , j \alpha_{i,j} αi,j表示 x i x_i xi x j x_j xj注意程度,然后对每一行分别进行Softmax处理得到 a ^ \hat{a} a^,相当于 x j x_j xj x i x_i xi权重,即对于 v v v的权重。具体计算过程如下式所示。

a ^ i , j = S o f t m a x ( α i , j ) {\hat{a}}_{i,j}=Softmax(α_{i,j}) a^i,j=Softmax(αij)

上面已经计算得到 a ^ i , j {\hat{a}}_{i,j} a^i,j,即针对每个 v v v的权重,接着进行加权得到最终结果,如下式所示。

b i = ∑ j = 1 L a ^ i , j × v j b^i=\sum_{j=1}^{L}{{\hat{a}}_{i,j}\times v^j} bi=j=1La^i,j×vj

其中 b i b^i bi表示 x i x_i xi经过Self-Attention后的结果。以上四式的过程习惯上用以下式来统一表示。

A t t e n t i o n ( Q , K , V ) = S o f t m a x ( Q ( K ) T L ) V Attention(Q,K,V)=Softmax\left(\frac{Q\left(K\right)^T}{\sqrt L}\right)V Attention(Q,K,V)=Softmax(L Q(K)T)V

对于Multi-Head Attention来说, 使用多头注意力机制能够联合来自不同head部分学习到的信息。首先根据使用的head的数目 h h h W q , W k , W v W_q,W_k,W_v Wq,Wk,Wv权值矩阵均分成 h h h份,即 W i Q , W i K , W i V W_i^Q,W_i^K,W_i^V WiQ,WiK,WiV, 其中 i = 1 , 2... h i=1,2...h i=1,2...h,然后还是和Self-Attention模块一样将 x i x_i xi分别通过变化矩阵 W i Q , W i K , W i V W_i^Q,W_i^K,W_i^V WiQ,WiK,WiV得到对应head的 q i , k i , v i q^i,k^i,v^i qi,ki,vi, 接下来针对每个head使用和Self-Attention中相同的方法即可得到对应的结果。如下式所示。

h e a d i = A t t e n t i o n ( Q W i Q , K W i K , V W i V ) {head}_i=Attention\left(QW_i^Q,KW_i^K,VW_i^V\right) headi=Attention(QWiQ,KWiK,VWiV)

其中 Q W i Q QW_i^Q QWiQ同前式相比多了一个 W i Q W_i^Q WiQ,表示这里是根据划分的变化矩阵去计算每一个head的结果。即通过 W i Q , W i K , W i V W_i^Q,W_i^K,W_i^V WiQ,WiK,WiV映射得到每个head的 q i , k i , v i q^i,k^i,v^i qi,ki,vi,然后计算结果。
最后将每个head得到的结果进行concat拼接,接着将拼接后的结果通 过 W o 过W^o Wo(可学习的参数)进行融合,融合后得到最终的结果 b i b^i bi。如式(2-11)所示。

M u l t i H e a d ( Q , K , V ) = C o n c a t ( h e a d 1 , h e a d 1 … h e a d h ) W o MultiHead\left(Q,K,V\right)=Concat\left({head}_1,{head}_1\ldots{head}_h\right)W^o MultiHead(Q,K,V)=Concat(head1,head1headh)Wo

(3)Dropout/DropPath

Dropout/DropPath:在学习深度学习模型时,由于模型的参数过多、样本数量过少,导致了模型的过度拟合。在神经网络的训练中,常常会碰到一些问题。该方法具有较低的训练数据损失,具有较高的训练准确率。但是,测试数据的损失函数比较大,导致预测的准确性不高。

Dropout能在一定程度上减轻过度拟合,并能在某种程度上实现正规化。其基本原理是:在前向传播前进的过程中,使一个神经元的激活值以 p的概率不能工作,这在下面的图中可以看到。停止工作的神经元用虚线表示,与该神经元相连的相应传播过程将不在存在。这使得模型更加一般化,因为它不会依赖于一些局部特征。

DropPath类似于Dropout,不同的是Dropout 是对神经元随机“失效”,而DropPath是随机“失效”模型中的多分支结构。例如如下图右图所示,若 x x x为输入的张量,其通道为[B,C,H,W],那么DropPath的含义为一个Batch_size中,在经过多分支结构时,随机有drop_prob的样本,不经过主干,而直接经过分支(图中虚线)进行恒等映射。这在一定程度上使模型泛化性更强。
在这里插入图片描述

(4)MLP Block

MLP Block:如前文中Transformer Encoder结构图右侧所示,MLP Block由全连接层、GELU激活函数、Dropout组成,以ViT-B/16为例,第一个全连接层会把输入节点个数翻4倍(197, 768)—> (197, 3072),第二个全连接层会还原回原节点个数(197, 3072)—> (197, 768)。

3. MLP Head结构

通过Transformer Encoder后输出的维度和输入的维度是保持不变的,以ViT-B/16为例,输入的是(197, 768)输出的还是(197, 768)。这里只需要从[class]token抽取生成的对应结果,即从(197, 768)中抽取出[class]token对应的(1, 768),即为需要的分类信息。然后就可以用 MLP Head进行最后的分类得到结果。原论文中提到,在训练ImageNet21K时MLP Head是由全连接层+tanh激活函数+全连接层组成。但是如果是在ImageNet1K或者自己的数据集上时,只需要使用一个全连接层(Linear)即可,其结构如下图所示。
在这里插入图片描述

二、PyTorch实现

ViT模型共有三个不同的规模,如下所示:
。

1. 首先安装vit-pytorch库:

$ pip install vit-pytorch

2.导入进行调用:

import torch
from vit_pytorch import ViTmodel = ViT(image_size = 224,patch_size = 32,num_classes = 1000,dim = 1024,depth = 6,heads = 16,mlp_dim = 2048,dropout = 0.1,emb_dropout = 0.1
)imgs = torch.randn(1, 3, 224, 224)preds = model(imgs) # (1, 1000)

总结

以上就是对Vision Transformer模型架构的详细介绍及其适用,Vision Transformer模型作为第一个将Transformer结构应用到计算机视觉上的模型,对近年来计算机视觉的研究具有很大的意义,其常常与swin Transformer(可以理解为FPN结构的ViT)用作其他任务如检测、分割的backbone以及视觉特征提取器。

参考:
Attention is all you need
An image is worth 16x16 words: Transformers for image recognition at scale

文中图片大多来自论文和网络,如有侵权,联系删除,文中有不对的地方欢迎指正。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/268325.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

WordPress插件你好多莉( Hello Dolly )可否卸載

什么是你好多莉Hello Dolly WordPress插件 你好多莉是 WordPress插件 简单的预装在 WordPress 上。 如果您激活后者,它将显示出名曲的歌词“ 你好,多莉 “要 路易斯阿姆斯特朗. 您可能已经注意到,在阅读插件说明时,还不够清楚。 …

5G CPE可代替宽带,解决断网问题

最近某运营商就玩起了套餐,断用户的网。 老百姓对宽带半知不解,网络断了没法上网,很着急。因为相信运营商,维修人员怎么说,老百姓就怎么办呗,直到最后才发现自己上当,但钱都给了。 截至2023年9月…

thinkphp 使用array_reduce 处理返回的数据格式

我想要的效果: 不使用array_reduce 的效果 : 代码: public function teamList($userId,$good_id){$nowbuyers $this->order->where(good_id,$good_id)->count();$data GroupTotalOrder::alias(t_order)->where(merchant_Id,$u…

JVM虚拟机系统性学习-运行时数据区(虚拟机栈、本地方法栈)

虚拟机栈 虚拟机栈为每个线程所私有的,如下图: 栈帧是什么? 栈帧存储了方法的局部变量表、操作数栈、动态链接和方法返回地址等信息 栈内存为线程私有的空间,每个方法在执行时都会创建一个栈帧,执行该方法时&…

Nginx首页修改及使用Nginx实现端口转发

按照我之前博客给的方法搭建好这样一个CTF靶场 但是呢它默认是在8000端口 如何直接访问IP地址或者域名就可以实现直接访问到靶场呢 我们需要将80端口的内容转发到8000,使用nginx实现端口转发功能 首先我们安装nginx: 安装工具和库 yum -y install gc…

【产品经理】需求池和版本树

在这个人人都是产品经理的时代,每位入行的产品人进阶速度与到达高度各有不同。本文作者结合自身三年产品行业的经历,根据案例拆解产品行业的极简研发过程、需求池、版本树、产品自我优化等相关具体方法论。 一、产品研发的极简过程 1. 产品概述 产品就…

一体化超声波气象站科普解说

随着科技的不断发展,气象监测设备也在逐步升级。一体化超声波气象站作为新型气象监测设备,以其优势和预报能力,成为了气象监测领域的新宠。 一、一体化超声波气象站的特点 WX-CSQX12 一体化超声波气象站是一种集成了多种气象监测设备的新型…

java反序列化数据过滤

前言: 反序列化漏洞的危害稍微了解一点的都知道,如果能找到前端某处存在反序列化漏洞,那基本上距离拿下服务器仅一步之遥,这个时候我们可以通过继承ObjectInputFilter添加tFilter实现对所有反序列化类的校验,当然这个需…

mysql——数据库基础

目录 一.什么是数据库 二.主流的数据库 三.服务器,数据库,表关系 四.数据逻辑存储 五.MySQL架构 六.SQL语句分类 七.存储引擎 一.什么是数据库 存储数据用文件就可以了,为什么还要弄个数据库? 文件保存数据有以下几个缺点&#xff1…

java--Collection的常用方法

1.集合体系结构 ①Collection代表单列集合,每个元素(数据)只包含一个值 ②Map代表双列集合,每个元素包含两个值(键值对) 2.Collection集合体系 3.Collection集合特点 1.List系列集合:添加的元素是有序、可重复、有索引 ①ArrayList、Line…

聊聊国内的汽车改装现状以及SUV车型中的CAN数据改装应用

随着汽车个性化、文旅需求旺盛以及运动赛事的兴起及线下活动的参与,改装人群在不断扩大,国内改装行业也在不断发展,出现很多不同风格的,包括成文化、成体系的汽车改装。并且,在这里面孵化出很多优秀的公司,…

银行数据分析进阶篇:银行外呼业务数据分析与客户精准营销优化研究

上次和大家分享了“信用卡全生命周期分析”的案例,不少朋友都有正向的反馈,今天继续和大家分享我之前看到的银行数据分析的案例,这个案例结构清晰,内容详细,相信朋友们能很快掌握! 01 需求痛点 我们先来了…