论文阅读 Vision Transformer - VIT

文章目录

  • 1 摘要
    • 1.1 核心
  • 2 模型架构
    • 2.1 概览
    • 2.2 对应CV的特定修改和相关理解
  • 3 代码
  • 4 总结

1 摘要

1.1 核心

通过将图像切成patch线形层编码成token特征编码的方法,用transformer的encoder来做图像分类

2 模型架构

2.1 概览

在这里插入图片描述

2.2 对应CV的特定修改和相关理解

解决问题:

  1. transformer输入限制: 由于自注意力+backbone,算法复杂度为o(n²),token长度一般要<512才足够运算
    解决:a) 将图片转为token输入 b) 将特征图转为token输入 c)√ 切patch转为token输入
  2. transformer无先验知识:卷积存在平移不变性(同特征同卷积核同结果)和局部相似性(相邻特征相似结果),
    而transformer无卷积核概念,只有整个编解码器,需要从头学
    解决:大量数据训练
  3. cv的各种自注意力机制需要复杂工程实现:
    解决:直接用整个transformer模块
  4. 分类head:
    解决:直接沿用transformer cls token
  5. position编码:
    解决:1D编码

pipeline:
224x224输入切成16x16patch进行位置编码和线性编码后增加cls token 一起输入的encoder encoder中有L个selfattention模块
输出的cls token为目标类别

3 代码

如果理解了transformer,看完这个结构感觉真的很简单,这篇论文也只是开山之作,没有特别复杂的结构,所以想到代码里看看。

import torch
from torch import nnfrom einops import rearrange, repeat
from einops.layers.torch import Rearrange# helpersdef pair(t):return t if isinstance(t, tuple) else (t, t)# classesclass FeedForward(nn.Module):def __init__(self, dim, hidden_dim, dropout = 0.):super().__init__()self.net = nn.Sequential(nn.LayerNorm(dim),nn.Linear(dim, hidden_dim),nn.GELU(),nn.Dropout(dropout),nn.Linear(hidden_dim, dim),nn.Dropout(dropout))def forward(self, x):return self.net(x)class Attention(nn.Module):def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):super().__init__()inner_dim = dim_head *  headsproject_out = not (heads == 1 and dim_head == dim)self.heads = headsself.scale = dim_head ** -0.5self.norm = nn.LayerNorm(dim)self.attend = nn.Softmax(dim = -1)self.dropout = nn.Dropout(dropout)# linear(1024 , 3072)self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)self.to_out = nn.Sequential(nn.Linear(inner_dim, dim),nn.Dropout(dropout)) if project_out else nn.Identity()def forward(self, x):# [1, 65, 1024]x = self.norm(x)# [1, 65, 1024]qkv = self.to_qkv(x).chunk(3, dim = -1)# self.to_qkv(x)                [1, 65, 3072]# self.to_qkv(x).chunk(3,-1)    [3, 1, 65, 1024]q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)# q,k,v                         [1, 65, 1024] -> [1, 16, 65, 64]# 把 65个1024的特征分为 heads个65个d维的特征 然后每个heads去分别有自己要处理的隐藏层,对不同的特征建立不同学习能力dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale# [1, 16, 65, 64] * [1, 16, 64, 65] -> [1, 16, 65, 65]# scale 保证在softmax前所有的值都不太大attn = self.attend(dots)# softmax [1, 16, 65, 65]attn = self.dropout(attn)# dropout [1, 16, 65, 65]out = torch.matmul(attn, v)# out [1, 16, 65, 64]out = rearrange(out, 'b h n d -> b n (h d)')# out [1, 65, 1024]return self.to_out(out)# out [1, 65, 1024]class Transformer(nn.Module):def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):super().__init__()self.norm = nn.LayerNorm(dim)self.layers = nn.ModuleList([])for _ in range(depth):self.layers.append(nn.ModuleList([Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),FeedForward(dim, mlp_dim, dropout = dropout)]))def forward(self, x):# [1, 65, 1024]for attn, ff in self.layers:# [1, 65, 1024]x = attn(x) + x# [1, 65, 1024]x = ff(x) + x# [1, 65, 1024]return self.norm(x)# shape不会改变class ViT(nn.Module):def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):super().__init__()image_height, image_width = pair(image_size)patch_height, patch_width = pair(patch_size)assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'num_patches = (image_height // patch_height) * (image_width // patch_width)patch_dim = channels * patch_height * patch_widthassert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'# num_patches   64# patch_dim     3072# dim           1024self.to_patch_embedding = nn.Sequential(#Rearrange是einops中的一个方法# einops:灵活和强大的张量操作,可读性强和可靠性好的代码。支持numpy、pytorch、tensorflow等。# 代码中Rearrage的意思是将传入的image(3,224,224),按照(3,(h,p1),(w,p2))也就是224=hp1,224 = wp2,接着把shape变成b (h w) (p1 p2 c)格式的,这样把图片分成了每个patch并且将patch拉长,方便下一步的全连接层# 还有一种方法是采用窗口为16*16,stride 16的卷积核提取每个patch,然后再flatten送入全连接层。Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),nn.LayerNorm(patch_dim),nn.Linear(patch_dim, dim),nn.LayerNorm(dim),)self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))self.cls_token = nn.Parameter(torch.randn(1, 1, dim))self.dropout = nn.Dropout(emb_dropout)self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)self.pool = poolself.to_latent = nn.Identity()self.mlp_head = nn.Linear(dim, num_classes)def forward(self, img):# 1. [1, 3, 256, 256]       输入imgx = self.to_patch_embedding(img)# 2. [1, 64, 1024]          patch embdb, n, _ = x.shape# 3. [1, 1, 1024]           cls_tokenscls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b = b)# 4. [1, 65, 1024]          cat [cls_tokens, x]x = torch.cat((cls_tokens, x), dim=1)# 5. [1, 65, 1024]          add [x] [pos_embedding]x += self.pos_embedding[:, :(n + 1)]# 6. [1, 65, 1024]          dropoutx = self.dropout(x)# 7. [1, 65, 1024]          N * transformerx = self.transformer(x)# 8. [1,1024]               cls_x outputx = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]# 9. [1,1024]               cls_x output meanx = self.to_latent(x)# 10.[1,1024]               nn.Identity()不改变输入和输出 占位层return self.mlp_head(x)# 11.[1,cls]                mlp_cls_head

4 总结

multihead和我原有的理解偏差修正。
我以为的是QKV会有N块相同的copy(),每一份去做后续的linear等操作。
代码里是直接用linear将QKV分为一整个大块,用permute/rearrange的操作切成了N块,f(Q,K)之后再恢复成一整个大块,很强。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/416140.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【算法Hot100系列】跳跃游戏

💝💝💝欢迎来到我的博客,很高兴能够在这里和您见面!希望您在这里可以感受到一份轻松愉快的氛围,不仅可以获得有趣的内容和知识,也可以畅所欲言、分享您的想法和见解。 推荐:kwan 的首页,持续学习,不断总结,共同进步,活到老学到老导航 檀越剑指大厂系列:全面总结 jav…

为什么说在java中万物皆方法?

为什么说在java中万物皆方法&#xff1f; 在开始前我有一些资料&#xff0c;是我根据网友给的问题精心整理了一份「java的资料从专业入门到高级教程」&#xff0c; 点个关注在评论区回复“888”之后私信回复“888”&#xff0c;全部无偿共享给大家&#xff01;&#xff01;&am…

SwitchyOmega插件管理海外動態IP代理設置教程

SwitchyOmega插件很好解決了管理多個代理並在它們之間切換的問題&#xff0c;通過本文來全面瞭解SwitchyOmega&#xff0c;比如SwitchyOmega插件的用途、它的主要功能和應用、怎麼下載和使用&#xff0c;如何管理海外動態IP代理。 SwitchyOmega插件有什麼用途&#xff1f; Swit…

关于Windows 10的操作中心 ,看这篇文章就可以了

这篇文章介绍了Windows 10操作中心&#xff0c;也称为通知中心&#xff0c;以及如何使用它。操作中心会在需要你注意的事情时发送警报。 如何在操作中心中访问和解决通知 Windows操作中心显示为Windows任务栏右下角的发言气泡。图标下的数字表示你有未解析的通知。 通知会在…

取代房子,中国又一种资本在崛起(深度)

我一直有一个观点&#xff1a;经济形势好的时候&#xff0c;只要不是夕阳行业&#xff0c;做什么都能过得不错。经济形势差的时候&#xff0c;对于个人来说&#xff0c;拼的就是学习能力。 10年前&#xff0c;在市场上很吃香的是MBA&#xff0c;那时候企业需要高速发展&#x…

maven多个module打包

common是父组件&#xff0c;servicebase依赖于commonutils&#xff0c;如下图 1.打servicebase包时出现问题&#xff1a;找不到commonutils的jar包&#xff0c;但是commonutils包可以正常打&#xff0c;并且可以install到本地maven仓库。 解决方式&#xff1a; servicebase 的…

JAVA正则表达式第二个作用:爬取

目录 本地数据爬取&#xff1a; 本地爬取练习&#xff1a; 网络爬取&#xff1a; ----- 以下为均本地数据爬取&#xff1a; 带条件爬取 贪婪爬取和非贪婪爬取&#xff1a; 例题 1&#xff1a;使获取 1 为不贪婪 *例题 2&#xff1a;使获取 0、1 都为不贪婪 之前介绍了正…

Kafka 的 Consumer Group 解读

作为一份笔记&#xff0c;本文再次梳理一下 Kafka 的 Consumer Group。我们知道&#xff0c;一个 Topic 往往会有多个 Partition&#xff0c;一条消息只会被写到一个 Kafka 的 Partition 中&#xff0c;那 Consumer 是怎么消费 Message 的呢&#xff1f; Consumer Group 又从中…

MySQL复合查询解析

&#x1f388;行百里者半九十&#x1f388; &#x1f388;目录&#x1f388; 概念多表查询自连接子查询单行子查询多行子查询in关键字all关键字any关键字 多列子查询在from中使用子查询合并查询unionunion all 总结 概念 之前我们很多的查询都只是对于单表进行查询&#xff0c…

Sqoop故障排除指南:处理错误和问题

故障排除是每位数据工程师和分析师在使用Sqoop进行数据传输时都可能遇到的关键任务。Sqoop是一个功能强大的工具&#xff0c;但在实际使用中可能会出现各种错误和问题。本文将提供一个详尽的Sqoop故障排除指南&#xff0c;涵盖常见错误、问题和解决方法&#xff0c;并提供丰富的…

RabbitMQ的基本使用,进行实例案例的消息队列

目录 一、介绍 1. 概述 2. 作用 3. 工作原理 二、RabbitMQ安装部署 1. 安装 2. 部署 3. 增加用户 三、实现案例 1. 项目创建 2. 项目配置 3. 生产者代码 4. 消费者代码 四、测试 每篇一获 一、介绍 1. 概述 RabbitMQ 是一种开源的消息代理和队列服务器&#x…

Logistic回归实战

一、题目 假设你是一所大学的行政管理人员&#xff0c;你想根据两门考试的结果&#xff0c;来决定每个申请人是否被录取。你有以前申请人的历史数据&#xff0c;可以将其用作逻辑回归训练集。对于每一个训练样本&#xff0c;你有申请人两次测评的分数以及录取的结果。为了完成这…