Video classification with UniFormer基于统一分类器的视频分类

本文主要介绍了UniFormer: Unified Transformer for Efficient Spatial-Temporal Representation Learning
代码:https://github.com/Sense-X/UniFormer/tree/main/video_classification

UNIFormer

动机

由于视频具有大量的局部冗余和复杂的全局依赖关系,因此从视频中学习丰富的、多尺度的时空语义是一项具有挑战性的任务。

最近的研究主要是由三维卷积神经网络和Vision Transformer驱动的。虽然三维卷积可以有效地聚集局部上下文来抑制来自小三维邻域的局部冗余,但由于感受域有限,它缺乏捕获全局依赖的能力。另外,vision Transformer通过自注意机制可以有效地捕获长时间依赖,但由于各层tokens之间存在盲目的相似性比较,限制了减少局部冗余。
在这里插入图片描述
视频transformer对浅层的局部特征编码比较低效,在时空上都只能学习到临近的信息

方法

提出了一种新型的统一Transformer(UniFormer),它以一种简洁的形式,将三维卷积和时空自注意的优点集成在一起,并在计算和精度之间取得了较好的平衡。与传统的Transformer不同的是,关系聚合器通过在浅层和深层中分别局部和全局tokens相关性来处理时空冗余和依赖关系。
在这里插入图片描述
由上图可知,UniFormer模型其中的特色组件是:动态位置嵌入(DPE)、多头关系聚合器(MHRA)和前馈网络(FFN)
在这里插入图片描述

动态位置嵌入(DPE)

之前的方法主要采用图像任务的绝对或相对位置嵌入。然而,当测试较长的输入帧时,绝对位置嵌入应该通过微调插值到目标输入大小。相对位置嵌入由于缺乏绝对位置信息而修改了自注意,表现较差。为了克服上述问题,扩展了条件位置编码(CPE)来设计DPE。
在这里插入图片描述

其中DWConv表示简单的三维深度卷积与零填充。由于卷积的共享参数和局部性,DPE可以克服置换不变性,并且对任意输入长度都很友好。此外,在CPE中已经证明,零填充可以帮助边界上的token意识到自己的绝对位置,因此所有token都可以通过查询其邻居来逐步编码自己的绝对时空位置信息

class SpeicalPatchEmbed(nn.Module):""" Image to Patch Embedding"""def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):super().__init__()img_size = to_2tuple(img_size)patch_size = to_2tuple(patch_size)num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])self.img_size = img_sizeself.patch_size = patch_sizeself.num_patches = num_patchesself.norm = nn.LayerNorm(embed_dim)self.proj = conv_3xnxn(in_chans, embed_dim, kernel_size=patch_size[0], stride=patch_size[0])def forward(self, x):B, C, T, H, W = x.shape# FIXME look at relaxing size constraints# assert H == self.img_size[0] and W == self.img_size[1], \#     f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."x = self.proj(x)B, C, T, H, W = x.shapex = x.flatten(2).transpose(1, 2)x = self.norm(x)x = x.reshape(B, T, H, W, -1).permute(0, 4, 1, 2, 3).contiguous()return x

多头关系聚合器(MHRA)

设计了一种替代的关系聚合器(RA),它可以将三维卷积和时空自注意灵活地统一在一个简洁的Transformer中,分别解决了浅层和深层的视频冗余和依赖问题。具体来说,MHRA通过多头融合进行tokens关系学习:
在这里插入图片描述

  1. 输入张量为 X ∈ R C × T × H × W , r e s h a p e 为 X ∈ R L × C X \in \mathbb{R}^{C \times T \times H \times W} , reshape为 \mathbf{X} \in \mathbb{R}^{L \times C} XRC×T×H×W,reshapeXRL×C L = T × H × W L=T \times H \times W L=T×H×W
  2. 通过线性转换,可以将 X \mathbf{X} X 转换为上下文信息 V n ( X ) ∈ R L × C N , n \mathrm{V}_{n}(\mathbf{X}) \in \mathbb{R}^{L \times \frac{C}{N}} , \mathrm{n} Vn(X)RL×NCn 表示第几个head。
  3. 然后关系聚合器 RA通过token affinity A n ∈ R L × L \mathrm{A}_{n} \in \mathbb{R}^{L \times L} AnRL×L 来融合上下文信息得到 R n ( X ) ∈ R L × C N \mathbf{R}_{n}(\mathbf{X}) \in \mathbb{R}^{L \times \frac{C}{N}} Rn(X)RL×NC
  4. 最后concat所有的head信息,并通过 U ∈ C L × C \mathbf{U} \in \mathbb{C}^{L \times C} UCL×C聚合所有head的信息。

根据上下文的域大小,可以将MHRA分为 local MHRA 和global MHRA
在网络浅层中,目标是学习小三维时空中局部时空背景下的详细视频表示:值仅依赖于token之间的相对3D位置
在这里插入图片描述

class CBlock(nn.Module):def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):super().__init__()self.pos_embed = conv_3x3x3(dim, dim, groups=dim)self.norm1 = bn_3d(dim)self.conv1 = conv_1x1x1(dim, dim, 1)self.conv2 = conv_1x1x1(dim, dim, 1)self.attn = conv_5x5x5(dim, dim, groups=dim)# NOTE: drop path for stochastic depth, we shall see if this is better than dropout hereself.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()self.norm2 = bn_3d(dim)mlp_hidden_dim = int(dim * mlp_ratio)self.mlp = CMlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)def forward(self, x):x = x + self.pos_embed(x)x = x + self.drop_path(self.conv2(self.attn(self.conv1(self.norm1(x)))))x = x + self.drop_path(self.mlp(self.norm2(x)))return x   

在网络深层中,关注于在全局视频帧中捕获长远token依赖关系:通过比较全局视图中所有token的内容相似性
在这里插入图片描述

class SABlock(nn.Module):def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):super().__init__()self.pos_embed = conv_3x3x3(dim, dim, groups=dim)self.norm1 = norm_layer(dim)self.attn = Attention(dim,num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,attn_drop=attn_drop, proj_drop=drop)# NOTE: drop path for stochastic depth, we shall see if this is better than dropout hereself.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()self.norm2 = norm_layer(dim)mlp_hidden_dim = int(dim * mlp_ratio)self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)def forward(self, x):x = x + self.pos_embed(x)B, C, T, H, W = x.shapex = x.flatten(2).transpose(1, 2)x = x + self.drop_path(self.attn(self.norm1(x)))x = x + self.drop_path(self.mlp(self.norm2(x)))x = x.transpose(1, 2).reshape(B, C, T, H, W)return x

模型代码

class Uniformer(nn.Module):"""Vision Transformer一个PyTorch实现:`一个图像值16x16词:大规模图像识别的Transformer` - https://arxiv.org/abs/2010.11929"""def __init__(self, cfg):super().__init__()# 从配置中提取各种参数depth = cfg.UNIFORMER.DEPTH  # 模型深度num_classes = cfg.MODEL.NUM_CLASSES  # 类别数量img_size = cfg.DATA.TRAIN_CROP_SIZE  # 图像尺寸in_chans = cfg.DATA.INPUT_CHANNEL_NUM[0]  # 输入通道数embed_dim = cfg.UNIFORMER.EMBED_DIM  # 嵌入维度head_dim = cfg.UNIFORMER.HEAD_DIM  # 头部维度mlp_ratio = cfg.UNIFORMER.MLP_RATIO  # MLP比例qkv_bias = cfg.UNIFORMER.QKV_BIAS  # QKV偏置qk_scale = cfg.UNIFORMER.QKV_SCALE  # QKV缩放representation_size = cfg.UNIFORMER.REPRESENTATION_SIZE  # 表示维度drop_rate = cfg.UNIFORMER.DROPOUT_RATE  # Dropout率attn_drop_rate = cfg.UNIFORMER.ATTENTION_DROPOUT_RATE  # 注意力Dropout率drop_path_rate = cfg.UNIFORMER.DROP_DEPTH_RATE  # 随机深度衰减率split = cfg.UNIFORMER.SPLIT  # 是否分裂std = cfg.UNIFORMER.STD  # 是否标准化self.use_checkpoint = cfg.MODEL.USE_CHECKPOINT  # 使用检查点self.checkpoint_num = cfg.MODEL.CHECKPOINT_NUM  # 检查点数量logger.info(f'Use checkpoint: {self.use_checkpoint}')  # 日志:使用检查点logger.info(f'Checkpoint number: {self.checkpoint_num}')  # 日志:检查点数量self.num_classes = num_classesself.num_features = self.embed_dim = embed_dim  # 为了与其他模型保持一致,设置特征数量和嵌入维度norm_layer = partial(nn.LayerNorm, eps=1e-6)  # 层标准化函数# 创建不同尺寸的Patch嵌入层self.patch_embed1 = SpeicalPatchEmbed(img_size=img_size, patch_size=4, in_chans=in_chans, embed_dim=embed_dim[0])  # Patch嵌入层1self.patch_embed2 = PatchEmbed(img_size=img_size // 4, patch_size=2, in_chans=embed_dim[0], embed_dim=embed_dim[1], std=std)  # Patch嵌入层2self.patch_embed3 = PatchEmbed(img_size=img_size // 8, patch_size=2, in_chans=embed_dim[1], embed_dim=embed_dim[2], std=std)  # Patch嵌入层3self.patch_embed4 = PatchEmbed(img_size=img_size // 16, patch_size=2, in_chans=embed_dim[2], embed_dim=embed_dim[3], std=std)  # Patch嵌入层4self.pos_drop = nn.Dropout(p=drop_rate)  # 位置Dropout层dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depth))]  # 随机深度衰减规则num_heads = [dim // head_dim for dim in embed_dim]  # 头部数量# 创建Transformer块并组成模型的不同部分self.blocks1 = nn.ModuleList([CBlock(dim=embed_dim[0], num_heads=num_heads[0], mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer) for i in range(depth[0])])  # 第一个部分的Transformer块self.blocks2 = nn.ModuleList([CBlock(dim=embed_dim[1], num_heads=num_heads[1], mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i+depth[0]], norm_layer=norm_layer) for i in range(depth[1])])  # 第二个部分的Transformer块if split:self.blocks3 = nn.ModuleList([SplitSABlock(dim=embed_dim[2], num_heads=num_heads[2], mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i+depth[0]+depth[1]], norm_layer=norm_layer)for i in range(depth[2])])  # 如果拆分,创建第三个部分的Split Self-Attention块self.blocks4 = nn.ModuleList([SplitSABlock(dim=embed_dim[3], num_heads=num_heads[3], mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i+depth[0]+depth[1]+depth[2]], norm_layer=norm_layer)for i in range(depth[3])])  # 如果拆分,创建第四个部分的Split Self-Attention块else:self.blocks3 = nn.ModuleList([SABlock(dim=embed_dim[2], num_heads=num_heads[2], mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i+depth[0]+depth[1]], norm_layer=norm_layer)for i in range(depth[2])])  # 创建第三个部分的Self-Attention块self.blocks4 = nn.ModuleList([SABlock(dim=embed_dim[3], num_heads=num_heads[3], mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i+depth[0]+depth[1]+depth[2]], norm_layer=norm_layer)for i in range(depth[3])])  # 创建第四个部分的Self-Attention块self.norm = bn_3d(embed_dim[-1])  # 3D批标准化层# 表示层if representation_size:self.num_features = representation_sizeself.pre_logits = nn.Sequential(OrderedDict([('fc', nn.Linear(embed_dim, representation_size)),  # 全连接层('act', nn.Tanh())  # Tanh激活函数]))else:self.pre_logits = nn.Identity()  # 如果没有设置表示维度,则为恒等映射# 分类器头部self.head = nn.Linear(embed_dim[-1], num_classes) if num_classes > 0 else nn.Identity()  # 分类器线性层或恒等映射self.apply(self._init_weights)  # 初始化权重# 初始化某些参数的权重for name, p in self.named_parameters():if 't_attn.qkv.weight' in name:nn.init.constant_(p, 0)  # 初始化t_attn.qkv.weight为常数0if 't_attn.qkv.bias' in name:nn.init.constant_(p, 0)  # 初始化t_attn.qkv.bias为常数0if 't_attn.proj.weight' in name:nn.init.constant_(p, 1)  # 初始化t_attn.proj.weight为常数1if 't_attn.proj.bias' in name:nn.init.constant_(p, 0)  # 初始化t_attn.proj.bias为常数0def _init_weights(self, m):"""初始化权重函数"""if isinstance(m, nn.Linear):trunc_normal_(m.weight, std=.02)  # 使用截断正态分布初始化权重if isinstance(m, nn.Linear) and m.bias is not None:nn.init.constant_(m.bias, 0)  # 初始化偏置为常数0elif isinstance(m, nn.LayerNorm):nn.init.constant_(m.bias, 0)  # 初始化偏置为常数0nn.init.constant_(m.weight, 1.0)  # 初始化权重为常数1.0@torch.jit.ignoredef no_weight_decay(self):"""指定不进行权重衰减的参数"""return {'pos_embed', 'cls_token'}def get_classifier(self):"""获取分类器头部"""return self.headdef reset_classifier(self, num_classes, global_pool=''):"""重置分类器Args:num_classes (int): 新的类别数量global_pool (str): 全局池化方式"""self.num_classes = num_classesself.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()  # 重新设置分类器头部def inflate_weight(self, weight_2d, time_dim, center=False):"""权重膨胀Args:weight_2d: 二维权重张量time_dim: 时间维度center (bool): 是否中心化Returns:Tensor: 膨胀后的三维权重张量"""if center:weight_3d = torch.zeros(*weight_2d.shape)weight_3d = weight_3d.unsqueeze(2).repeat(1, 1, time_dim, 1, 1)middle_idx = time_dim // 2weight_3d[:, :, middle_idx, :, :] = weight_2delse:weight_3d = weight_2d.unsqueeze(2).repeat(1, 1, time_dim, 1, 1)weight_3d = weight_3d / time_dimreturn weight_3ddef get_pretrained_model(self, cfg):"""获取预训练模型Args:cfg: 配置文件Returns:dict: 预训练模型参数字典"""if cfg.UNIFORMER.PRETRAIN_NAME:checkpoint = torch.load(model_path[cfg.UNIFORMER.PRETRAIN_NAME], map_location='cpu')if 'model' in checkpoint:checkpoint = checkpoint['model']elif 'model_state' in checkpoint:checkpoint = checkpoint['model_state']state_dict_3d = self.state_dict()for k in checkpoint.keys():if checkpoint[k].shape != state_dict_3d[k].shape:if len(state_dict_3d[k].shape) <= 2:logger.info(f'Ignore: {k}')  # 忽略不匹配的参数continuelogger.info(f'Inflate: {k}, {checkpoint[k].shape} => {state_dict_3d[k].shape}')  # 膨胀参数形状time_dim = state_dict_3d[k].shape[2]checkpoint[k] = self.inflate_weight(checkpoint[k], time_dim)if self.num_classes != checkpoint['head.weight'].shape[0]:del checkpoint['head.weight'] del checkpoint['head.bias'] return checkpointelse:return Nonedef forward_features(self, x):"""前向传播特征提取Args:x (tensor): 输入张量Returns:tensor: 特征提取结果"""x = self.patch_embed1(x)x = self.pos_drop(x)for i, blk in enumerate(self.blocks1):if self.use_checkpoint and i < self.checkpoint_num[0]:x = checkpoint.checkpoint(blk, x)else:x = blk(x)x = self.patch_embed2(x)for i, blk in enumerate(self.blocks2):if self.use_checkpoint and i < self.checkpoint_num[1]:x = checkpoint.checkpoint(blk, x)else:x = blk(x)x = self.patch_embed3(x)for i, blk in enumerate(self.blocks3):if self.use_checkpoint and i < self.checkpoint_num[2]:x = checkpoint.checkpoint(blk, x)else:x = blk(x)x = self.patch_embed4(x)for i, blk in enumerate(self.blocks4):if self.use_checkpoint and i < self.checkpoint_num[3]:x = checkpoint.checkpoint(blk, x)else:x = blk(x)x = self.norm(x)x = self.pre_logits(x)return xdef forward(self, x):"""前向传播Args:x (tensor): 输入张量Returns:tensor: 输出结果"""x = x[0]x = self.forward_features(x)x = x.flatten(2).mean(-1)x = self.head(x)return x

总结

原文提出了一种新的UniFormer,它可以有效地统一3D卷积和时空自注意力在一个简洁的Transformer格式,以克服视频冗余和依赖。我们在浅层采用局部MHRA,大大减少了计算负担,在深层采用全局MHRA,学习全局令牌关系。大量的实验表明,我们的UniFormer在流行的视频基准测试Kinetics-400/600和Something-Something V1/V2上实现了准确性和效率之间的较好平衡。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/329275.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

强化学习的数学原理学习笔记 - 时序差分学习(Temporal Difference)

文章目录 概览&#xff1a;RL方法分类时序差分学习&#xff08;Temporal Difference&#xff0c;TD&#xff09;TD for state valuesBasic TD&#x1f7e1;TD vs. MC &#x1f7e6;Sarsa (TD for action values)Basic Sarsa变体1&#xff1a;Expected Sarsa变体2&#xff1a;n-…

HUAWEI WATCH 系列 eSIM 全新开通指南来了

HUAWEI WATCH 系列手表提供了eSIM硬件能力&#xff0c;致力为用户提供更便捷、高效的通信体验。但eSIM 业务是由运营商管理并提供服务的&#xff0c;当前运营商eSIM业务集中全面恢复&#xff0c;电信已经全面恢复&#xff0c;移动大部分省份已经全面放开和多号App开通方式&…

解决Typora笔记上传到CSDN上图片无法显示的问题

解决Typora笔记上传到CSDN上图片无法显示的问题 一、发现问题二、分析问题三、解决问题图床介绍所需工具PicGo软件安装操作下载安装PicGo配置PicGo 设置Typora 四、总结 一、发现问题 当我们使用Typora这款强大的Markdown编辑器记录笔记时&#xff0c;经常会遇到一个让人困扰的…

JVM中虚拟机栈和本地方法栈等

jvm Java虚拟机栈本地方法栈 Java虚拟机栈 Java虚拟机栈&#xff08;VM Stack&#xff09; ​ 虚拟机栈是线程执行Java程序时&#xff0c;处理Java方法中内容的内存区域。虚拟机栈也是线程私有的区域&#xff0c;每个Java方法被调用的时候&#xff0c;都会在虚拟机栈中创建出…

c++学习第八讲---类和对象---继承

继承&#xff1a; 使子类&#xff08;派生类&#xff09;拥有与父类&#xff08;基类&#xff09;相同的成员&#xff0c;以节约代码量。 1.继承的基本语法&#xff1a; class 子类名&#xff1a;继承方式 父类名{} &#xff1b; 例&#xff1a; class father { public:in…

Hadolint:Lint Dockerfile 的完整指南

想学习如何使用 Hadolint 对 Dockerfile 进行 lint 处理吗&#xff1f;这篇博文将向您展示如何操作。这是关于 Dockerfile linting 的完整指南。 通过对 Dockerfile 进行 lint 检查&#xff0c;您可以及早发现错误和问题&#xff0c;并确保它们遵循最佳实践。 什么是Hadolint…

深入理解C指针

深入理解C指针 ​#C语言 #​ #C指针 #​ 1 认识指针 指针&#xff1a;一个存放内存地址的变量 1.1 指针和内存 ​​ ‍ 阅读指针声明时候&#xff0c;可以选择倒过来读&#xff0c;会更容易理解。 指针被赋值为NULL时候&#xff0c;会被解释为二进制0. void指针 具有和…

Java Swing手搓坦克大战遇到的问题和思考

1.游戏中的坐标系颇为复杂 像素坐标系还有行列坐标&#xff0c;都要使用&#xff0c;这之间的互相转化使用也要注意 2.游戏中坦克拐弯的处理&#xff0c;非常重要 由于坦克中心点是要严格对齐到一条网格线&#xff0c;并沿着这条线前进的&#xff0c;如果拐弯不做处理&#…

二刷Laravel 教程(构建页面)总结Ⅰ

L01 Laravel 教程 - Web 开发实战入门 ( Laravel 9.x ) 一、功能 1.会话控制&#xff08;登录、退出、记住我&#xff09; 2.用户功能&#xff08;注册、用户激活、密码重设、邮件发送、个人中心、用户列表、用户删除&#xff09; 3.静态页面&#xff08;首页、关于、帮助&am…

【二】使用create-vue创建vue3的helloworld项目(推荐)

create-vue 官网&#xff1a;快速上手 | Vue.js create-vue 是 Vue3 的专用脚手架&#xff0c;使用 vite 创建 Vue3 的项目&#xff0c;也可以选择安装需要的各种插件&#xff0c;使用更简单。 1、使用方式 npm create vuelatest这个命令会安装和执行 create-vue&#xff0…

【Project】TPC-Online Module (manuscript_2024-01-07)

PRD正文 一、概述 本模块实现隧道点云数据的线上汇总和可视化。用户可以通过注册和登录功能进行身份验证&#xff0c;然后上传原始隧道点云数据和经过处理的数据到后台服务器。该模块提供数据查询、筛选和可视化等操作&#xff0c;同时支持对指定里程的分段显示和点云颜色更改…

2022年多元统计分析期末试题

2023年多元统计分析期末试题 1.试论述系统聚类、动态聚类和有序聚类的异同之处。 2、设 X {X} X~ N 3 {N_3} N3​(μ&#xff0c;Σ)&#xff0c;其中 X {X} X ~ ( X 1 {X_1} X1​, X 2 {X_2} X2​, X 3 {X_3} X3​)&#xff0c;μ (1,-2,3)‘&#xff0c;Σ [ 1 1 1 1 3 2…