Vision Transformer(VIT)论文解读及实现

1 论文解读

paper:VIT

1.1 VIT模型架构如下图所示:

  • 图片原始输入维度 H * W * C
  • 在H和W按像素P切分,则H 、W可分割为 NPP, N=HW/(PP),N为输入transform序列的长度。
  • x ∈ R H ∗ W ∗ C = > x ∈ R N ∗ P 2 ∗ C x \in R^{H*W*C} => x\in R^{N*P^2*C} xRHWC=>xRNP2C
  • 固定每层的维度D不变,The Transformer uses constant latent vector size D through all of its layers, so we flatten the patches and map to D dimensions with a trainable linear projection
  • 在N序列长度的基础上,增加一个Class token,类似bert用于分类任务学习
  • 增加位置信息,使用拉长后的一维数据作为位置编码信息。(使用图片的二维坐标位置,模型效果没有明显改善)
    VIT模型架构

VIT模型公式

输入 x ∈ N ∗ p 2 ∗ C 输入 x \in N*p^2*C 输入xNp2C
x p 1 ∈ P 2 ∗ C x_p^1 \in P^2*C xp1P2C
E ∈ ( P 2 ∗ C ) ∗ D E \in (P^2*C) *D E(P2C)D
其中E对序列N中的每一个xi都是一样的,z0的维度为(N+1)* D
公式(2)MSA(多头注意力)不改变z0的维度
公式(3)经过MLP层后与原始z相加,类似残差网络
公式(4)只取z的第一个值(之前在第一个位置手动添加了一个class标识)用于分类任务,进行模型学习
在这里插入图片描述

2 代码实现

2.1 embedding 层

  • 模型输入x.shape=[16,3,224,224] #16为batch_size
  • x输入patch_embedding 后,shape =[16,768,14,14]
  • 将上面的patch_embedding最后两位(H,W)拉平后,与channel调换位置,shape=[16,196,768]
  • 然后与手动的cls_token拼接 shape=[16,197,768]
  • 加入位置信息后,即可得到embdeeing的输出,shape=[16,197,768]
self.patch_embeddings = Conv2d(in_channels=in_channels,out_channels=config.hidden_size,kernel_size=16,stride=16)
  • cls_token shape=[1,1,768]
self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
#备注:n_patches=14*14   ,config.hidden_size=768
self.position_embeddings = nn.Parameter(torch.zeros(1, n_patches+1, config.hidden_size))

2.2 block层

  • 输入为Embedding层输入的x ;shape=[16,197,768]
  • 通过layer_norm层,,shape不变
  • 通过attn层,构建多头注意力,query,key,value的shape都为shape=[16,12,197,64]
  • 加上原始的x,纪委multi-head的输出,shape=[16,197,768]
  • 再经过layer_norm和全连接层,加上上层x,即为block的输出,shape=[16,197,768]

layer_norm层

  self.attention_norm = LayerNorm(config.hidden_size, eps=1e-6)

2.3Encoder层

经过L个Block层,输出结果即为encoder层,shape=[16,197,768]

2.4 模型输出

  • transform最后的输出层为 shape=[16,197,768]
  • 取序列197的第一个作为输出x,x shape=[16,768]
  • 输出x,经过全连接层,shape=[16,num_class]
  • 模型loss为交叉熵损失

3 transformer 结构

  (embeddings): Embeddings((patch_embeddings): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))(dropout): Dropout(p=0.1, inplace=False))(encoder): Encoder((layer): ModuleList((0): Block((attention_norm): LayerNorm((768,), eps=1e-06, elementwise_affine=True)(ffn_norm): LayerNorm((768,), eps=1e-06, elementwise_affine=True)(ffn): Mlp((fc1): Linear(in_features=768, out_features=3072, bias=True)(fc2): Linear(in_features=3072, out_features=768, bias=True)(dropout): Dropout(p=0.1, inplace=False))(attn): Attention((query): Linear(in_features=768, out_features=768, bias=True)(key): Linear(in_features=768, out_features=768, bias=True)(value): Linear(in_features=768, out_features=768, bias=True)(out): Linear(in_features=768, out_features=768, bias=True)(attn_dropout): Dropout(p=0.0, inplace=False)(proj_dropout): Dropout(p=0.0, inplace=False)(softmax): Softmax(dim=-1)))
... 省略10层Block(11): Block((attention_norm): LayerNorm((768,), eps=1e-06, elementwise_affine=True)(ffn_norm): LayerNorm((768,), eps=1e-06, elementwise_affine=True)(ffn): Mlp((fc1): Linear(in_features=768, out_features=3072, bias=True)(fc2): Linear(in_features=3072, out_features=768, bias=True)(dropout): Dropout(p=0.1, inplace=False))(attn): Attention((query): Linear(in_features=768, out_features=768, bias=True)(key): Linear(in_features=768, out_features=768, bias=True)(value): Linear(in_features=768, out_features=768, bias=True)(out): Linear(in_features=768, out_features=768, bias=True)(attn_dropout): Dropout(p=0.0, inplace=False)(proj_dropout): Dropout(p=0.0, inplace=False)(softmax): Softmax(dim=-1))))(encoder_norm): LayerNorm((768,), eps=1e-06, elementwise_affine=True))
)

3 代码总览

3.1 Embedding类

class Embeddings(nn.Module):"""Construct the embeddings from patch, position embeddings."""def __init__(self, config, img_size, in_channels=3):super(Embeddings, self).__init__()self.hybrid = Noneimg_size = _pair(img_size)if config.patches.get("grid") is not None:grid_size = config.patches["grid"]patch_size = (img_size[0] // 16 // grid_size[0], img_size[1] // 16 // grid_size[1])n_patches = (img_size[0] // 16) * (img_size[1] // 16)self.hybrid = Trueelse:patch_size = _pair(config.patches["size"])n_patches = (img_size[0] // patch_size[0]) * (img_size[1] // patch_size[1])self.hybrid = Falseif self.hybrid:self.hybrid_model = ResNetV2(block_units=config.resnet.num_layers,width_factor=config.resnet.width_factor)in_channels = self.hybrid_model.width * 16self.patch_embeddings = Conv2d(in_channels=in_channels,out_channels=config.hidden_size,kernel_size=patch_size,stride=patch_size)self.position_embeddings = nn.Parameter(torch.zeros(1, n_patches+1, config.hidden_size))self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))self.dropout = Dropout(config.transformer["dropout_rate"])def forward(self, x):print(x.shape)B = x.shape[0]cls_tokens = self.cls_token.expand(B, -1, -1)print(cls_tokens.shape)if self.hybrid:x = self.hybrid_model(x)x = self.patch_embeddings(x)print(x.shape)x = x.flatten(2)print(x.shape)x = x.transpose(-1, -2)print(x.shape)x = torch.cat((cls_tokens, x), dim=1)print(x.shape)embeddings = x + self.position_embeddingsprint(embeddings.shape)embeddings = self.dropout(embeddings)print(embeddings.shape)return embeddings

3.2 Block层

class Block(nn.Module):
def init(self, config, vis):
super(Block, self).init()
self.hidden_size = config.hidden_size
self.attention_norm = LayerNorm(config.hidden_size, eps=1e-6)
self.ffn_norm = LayerNorm(config.hidden_size, eps=1e-6)
self.ffn = Mlp(config)
self.attn = Attention(config, vis)

def forward(self, x):print(x.shape)h = xx = self.attention_norm(x)print(x.shape)x, weights = self.attn(x)x = x + hprint(x.shape)h = xx = self.ffn_norm(x)print(x.shape)x = self.ffn(x)print(x.shape)x = x + hprint(x.shape)return x, weights

3 encoder层

class Encoder(nn.Module):def __init__(self, config, vis):super(Encoder, self).__init__()self.vis = visself.layer = nn.ModuleList()self.encoder_norm = LayerNorm(config.hidden_size, eps=1e-6)for _ in range(config.transformer["num_layers"]):layer = Block(config, vis)self.layer.append(copy.deepcopy(layer))def forward(self, hidden_states):print(hidden_states.shape)attn_weights = []for layer_block in self.layer:hidden_states, weights = layer_block(hidden_states)if self.vis:attn_weights.append(weights)encoded = self.encoder_norm(hidden_states)return encoded, attn_weights

attention 层

class Attention(nn.Module):def __init__(self, config, vis):super(Attention, self).__init__()self.vis = visself.num_attention_heads = config.transformer["num_heads"]self.attention_head_size = int(config.hidden_size / self.num_attention_heads)self.all_head_size = self.num_attention_heads * self.attention_head_sizeself.query = Linear(config.hidden_size, self.all_head_size)self.key = Linear(config.hidden_size, self.all_head_size)self.value = Linear(config.hidden_size, self.all_head_size)self.out = Linear(config.hidden_size, config.hidden_size)self.attn_dropout = Dropout(config.transformer["attention_dropout_rate"])self.proj_dropout = Dropout(config.transformer["attention_dropout_rate"])self.softmax = Softmax(dim=-1)def transpose_for_scores(self, x):new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)print(new_x_shape)x = x.view(*new_x_shape)print(x.shape)print(x.permute(0, 2, 1, 3).shape)return x.permute(0, 2, 1, 3)def forward(self, hidden_states):print(hidden_states.shape)mixed_query_layer = self.query(hidden_states)print(mixed_query_layer.shape)mixed_key_layer = self.key(hidden_states)print(mixed_key_layer.shape)mixed_value_layer = self.value(hidden_states)print(mixed_value_layer.shape)query_layer = self.transpose_for_scores(mixed_query_layer)print(query_layer.shape)key_layer = self.transpose_for_scores(mixed_key_layer)print(key_layer.shape)value_layer = self.transpose_for_scores(mixed_value_layer)print(value_layer.shape)attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))print(attention_scores.shape)attention_scores = attention_scores / math.sqrt(self.attention_head_size)print(attention_scores.shape)attention_probs = self.softmax(attention_scores)print(attention_probs.shape)weights = attention_probs if self.vis else Noneattention_probs = self.attn_dropout(attention_probs)print(attention_probs.shape)context_layer = torch.matmul(attention_probs, value_layer)print(context_layer.shape)context_layer = context_layer.permute(0, 2, 1, 3).contiguous()print(context_layer.shape)new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)context_layer = context_layer.view(*new_context_layer_shape)print(context_layer.shape)attention_output = self.out(context_layer)print(attention_output.shape)attention_output = self.proj_dropout(attention_output)print(attention_output.shape)return attention_output, weights

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/20207.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

电脑中了vbs病毒怎么恢复数据?无备份也无需担忧,数据恢复有招

在如今计算机技术高度发达的年代,人们越来越依赖电脑进行日常生活和工作。但与此同时,电脑病毒也变得越来越猖獗,其中有一种叫做VBS(Visual Basic Script)的病毒,它以一种看似无害的脚本语言形式存在&#…

Android Java代码与JNI交互基础数据类型转换(三)

🔥 Android Studio 版本 🔥 🔥 基础类型数据的转换 🔥 定义传递基础数据类型到Native的JNI接口函数 package com.cmake.ndk1.jni;public class JNIBasicType{static{System.loadLibrary("native-lib");}public native int callNativeInt(int num);public na…

Qchart学习

目录 Qchart简介 QChartView 简介 QAbstractAxis 简介 QAbstractSeries 简介 Qchart Public Types Properties属性 Public Functions QAbstractSeries Public Types Properties Public Functions Signals信号 QAbstractAxis Properties Public Functions 主题设…

elementui-drawer模板

1、效果图 2、上代码 <template><div><el-drawersize"100%":visible.sync"drawer"style"position: absolute;"class"details":modal-append-to-body"false":modal "false":before-close"ha…

Kubernetes(k8s)实战:使用k8s+jenkins实现CICD

文章目录 一、什么是CICD二、准备k8s环境三、jenkins环境准备&#xff08;选择一台服务器&#xff09;1、安装java&#xff08;最新版jenkins只支持jdk11以上&#xff09;&#xff08;1&#xff09;找到jdk资源上传到指定机器&#xff08;2&#xff09;配置环境变量 2、安装mav…

Python endswith()函数使用详解

「作者主页」&#xff1a;士别三日wyx 「作者简介」&#xff1a;CSDN top100、阿里云博客专家、华为云享专家、网络安全领域优质创作者 「推荐专栏」&#xff1a;小白零基础《Python入门到精通》 endswith函数使用详解 1、指定范围2、str可以传入元组3、空字符串为真4、大小写敏…

不如来试试看这个AI大模型 感觉速度飞快,真的还挺不错呢!

&#x1f680; 个人主页 极客小俊 ✍&#x1f3fb; 作者简介&#xff1a;web开发者、设计师、技术分享博主 &#x1f40b; 希望大家多多支持一下, 我们一起进步&#xff01;&#x1f604; &#x1f3c5; 如果文章对你有帮助的话&#xff0c;欢迎评论 &#x1f4ac;点赞&#x1…

Tomcat 8.5 源码分析

一、获取源码并启动程序 获取教程地址 总体架构 二、Tomcat的启动入口 Catalina类主要负责 具体的管理类&#xff0c;而Bootstrap类是启动的入口(main方法)。 /*** Main method and entry point when starting Tomcat via the provided* scripts.** param args Command lin…

用C语言写一个压缩文件的程序

本篇目录 数据在计算机中的表现形式huffman 编码将文件的二进制每4位划分&#xff0c;统计其值在文件中出现的次数构建二叉树搜索二叉树的叶子节点运行并输出新的编码文件写入部分写入文件首部写入数据部分压缩运行调试解压缩部分解压缩测试为可执行文件配置环境变量总结完整代…

【Nginx07】Nginx学习:HTTP核心模块(四)错误页面与跳转

Nginx学习&#xff1a;HTTP核心模块&#xff08;四&#xff09;错误页面与跳转 最最核心的部分学习完了&#xff0c;但其实还有更多的内容要等待着我们探索。今天我们先来看到的就是关于错误页面的设置以及 301、302 跳转相关的内容。这两块内容都有一个特点&#xff0c;那就是…

LeetCode·每日一题·931. 下降路径最小和·记忆化搜索

作者&#xff1a;小迅 链接&#xff1a;https://leetcode.cn/problems/minimum-falling-path-sum/solutions/2341965/ji-yi-hua-sou-suo-zhu-shi-chao-ji-xiang-3n58v/ 来源&#xff1a;力扣&#xff08;LeetCode&#xff09; 著作权归作者所有。商业转载请联系作者获得授权&am…

既好用还不贵!云服务器选型六大理由 ,最后两条直击用户内心

“预算不够&#xff0c;腾讯云、百度云、阿里云&#xff0c;到底购买哪个更划算?”这个问题&#xff0c;很多朋友都跟我提过&#xff0c;选择最适合的云服务提供商并不是一件轻松的任务&#xff0c;因为每家公司都有各自的优势和限制。 **拿我接触的一个例子说一说&#xff…