前言 时空预测,真的需要 RNN 吗?真的需要 CNN 吗?是否能够设计一个模型,可以自动地学习数据中的时空依赖,而不需要依赖于归纳偏置呢?
欢迎关注公众号CV技术指南,专注于计算机视觉的技术总结、最新技术跟踪、经典论文解读、CV招聘信息。
本文转载自PaperWeekly
仅用于学术分享,若侵权请联系删除
CV方向的准研究生们,未来三年如何度过?
招聘高光谱图像、语义分割、diffusion等方向论文指导老师
时空预测学习是一个拥有广泛应用场景的领域,比如天气预测,交通流预测,降水预测,自动驾驶,人体运动预测等。
提起时空预测,不得不提到经典模型 ConvLSTM 和最经典的 benchmark moving mnist,在 ConvLSTM 时代,对于 Moving MNIST 的预测存在肉眼可见的伪影和预测误差。而在最新模型 PredFormer 中,对 Moving MNIST 的误差达到肉眼难以分辨的近乎完美的预测结果。
▲ ConvLSTM
▲ PredFormer
在以前的时空预测工作中,主要分为两个流派,基于循环(自回归)的模型,以 ConvLSTM/PredRNN//E3DLSTM/SwinLSTM/VMRNN 等工作为代表;更近年来,研究者提出无需循环的 SimVP 框架,由 CNN Encoder-Decoder 结构和一个时间转换器组成,以 SimVP/TAU/OpenSTL 等工作为代表。
RNN 系列模型的缺陷在于,无法并行化,自回归速度慢,显存占用高,效率低;CNN 系列模型无需循环提高了效率,得益于归纳偏置,但往往以牺牲泛化性和可扩展性为代价,模型上限低。
于是作者提出了问题,时空预测,真的需要 RNN 吗?真的需要 CNN 吗?是否能够设计一个模型,可以自动地学习数据中的时空依赖,而不需要依赖于归纳偏置呢?
一个直觉的想法是利用 Transformer,因为它在各种视觉任务中的广泛成功,并且是 RNN 和 CNN 的有力替代者。
在此前的时空预测工作中,已有研究者把 Transformer 嵌入到上述两种框架中,比如 SwinLSTM (ICCV23) 融合了 Swin Transformer 和 LSTM,比如 OpenSTL (NeurIPS23) 把各种 MetaFormer 结构(比如 ViT,Swin Transformer 等)作为 SimVP 框架中的时间转换器。但是,纯 Transformer 结构的网络鲜有探索。
设计纯 Transformer 模型的挑战在于,如何在一个框架中同时处理时间和空间信息。一个简单的想法是合并空间序列和时间序列,计算时空全注意力,由于 Transformer 的计算复杂度是序列长度的二次复杂度,这样的做法会导致计算复杂度较大。
在这篇文章中,作者提出了用于时空预测学习的新框架 PredFormer,这是一个纯 ViT 模型,既没有自回归也没有任何卷积。作者利用精心设计的基于门控 Transfomer 模块,对 3D Attention 进行了全面的分析,包括时空全注意力,时空分解的注意力,和时空交错的注意力。
论文标题:
PredFormer: Transformers Are Effective Spatial-Temporal Predictive Learners
论文链接:
https://arxiv.org/abs/2410.04733
PredFormer 采用非循环、基于 Transformer 的设计,既简单又高效,更少参数量,Flops,更快推理速度,性能显著优于以前的方法。
在合成和真实数据集上进行的大量实验表明,PredFormer 实现了最先进的性能。在 Moving MNIST 上,PredFormer 相对于 SimVP 实现了 51.3% 的 MSE 降低,突破性地达到11.6。对于 TaxiBJ,该模型将 MSE 降低了 33.1%,并将 FPS 从 533 提高到 2364。
此外,在 WeatherBench 上,它将 MSE 降低了 11.1%,同时将 FPS 从 196 提高到 404。这些准确度和效率方面的性能提升证明了 PredFormer 在实际应用中的潜力。
PredFormer 模型遵循标准 ViT 的设计,先对输入进行 Patch Embedding,把输入为 [B, T, C, H, W] 的时空序列转换为 [B, T, N, D] 的张量。在位置编码环节,作者采用了不同于一般 ViT 设计的可学习的位置编码,而是采用了基于 sin 函数的绝对位置编码,作者在消融实验中进一步阐述了绝对位置编码在时空任务中的优越性。
PredFormer 的编码器部分,由门控 Transfomer 模块以不同的方式堆叠而成。由于编码器部分是纯 Transformer 结构,没有任何卷积,也没有分辨率的下降,每一个门控 Transformer 模块都建模了全局信息,这允许模型只需使用一个简单的解码器就可以构成一个性能强大的预测模型。作者采用了一个线性层作为解码器来进行 Patch Recovery,这让模型的输出从 [B, T, N, D] 恢复到 [B, T, C, H, W]。不同于标准 Transformer 模型采用 MLP 作为 FFN,PredFormer 采用了 Gated Linear Unit (GLU) 作为 FFN,这是受 GLU 在 NLP 任务中优于 MLP 启发的改进。作者在消融实验中进一步阐述了 GLU 相比于 MLP 在时空任务上的优越性。 作者对 3D Attention 进行了全面的分析,并提出了 9 种 PredFormer 变体。在以前用于视频分类的 Video ViT 设计中,TimesFormer (ICML21), ViviT (ICCV21), TSViT (CVPR23) 等工作也对时空分解进行了分析,但是 TimesFormer 是在 self-attention 层面进行分解,也就是 spatial attention 和 temporal attention 共用一个 MLP。ViviT 则是提出了在 Encoder 层面(先空间后时间),self-attention 层面和 head 层面进行时空分解。而 TSViT 发现先时间后空间的 Encoder 对卫星序列图像分类更有效。
不同于以上工作,PredFormer 是在 Gated Transformer Block (GTB) 层面(多了基于 Gated Linear Unit)进行时空分解。对时间和空间的 self-attention 都加 GLU 是至关重要的,因为它可以让学习到的时空特征相互作用并且增强非线性。
PredFormer 提出了时空全注意力 Encoder,时间在前和空间在前的 2 种分解 Encoder 和 6 种新颖的时空交错的 Encoder,一共 9 种模型。PredFormer 提出了 PredFormer Layer 的概念,即一个既能建模空间信息,又能建模时间信息的最小单元。
基于这种想法,作者提出了三种基本范式,二元组(由一个 Temporal GTB 和一个 Spatial GTB 组成,有 T-S 和 S-T 两种方式),三元组(T-S-T 和 S-T-S),四元组(两个二元组以相反的方向重组)。
这一设计源于不同的时空预测任务往往有着不同的空间分辨率和时间分辨率(时间间隔以及变化程度),这意味着不同的数据集上对时间信息和空间信息的依赖程度不同,作者设计了这些模型以提高 PredFormer 模型在不同任务上的适应性。
在实验部分,作者控制了提出的每种变体使用相同的 GTB 数目,这可以保证模型的参数量基本一致,从而对比不同模型的性能。对于三元组 GTB 个数无法整除的情况,选择最为接近的三元组个数。
实验发现了一些规律:
1. 时间在前的分解 Encoder 模型优于时空全注意力模型,由于空间在前的分解 Encoder 模型;
2. 时空交错的 6 种模型在大多数任务上表现都很好,都能达到 sota,但最优模型因为数据集本身的不同时空依赖特性而不同,这体现了 PredFormer 这种框架和时空交错设计的优势;
3. 作者在讨论环节提出了建议,在其他的时空预测任务上,从四元组-TSST 开始尝试,因为这个模型在三个数据集上都表现 sota,先调整 M 个 TSST(即 4M 个门控 Transformer)的 M 参数,然后尝试 M 个 TST 和 M 个 STS 以确定数据集是时间依赖更强或空间依赖更强的模型。
得益于 Transformer 构架的可扩展性,不同于 SimVP 框架的 CNN Encoder-Decoder 模型,对 spatial 和 temporal 的 hidden dim 以及 block 数都设置了不同的值,PredFormer 对 spatial 和 temporal GTB 采用相同的固定的参数,因此只需要调整 M 的值,在比较少次数的调整后就可以达到最优性能。
ViT 模型的训练通常要求较大的数据集,在时空预测任务上,大多数据集在几千到几万的量级,数据量少,因此很容易过拟合。作者还探索了不同的正则化策略,包括 dropout 和 drop path,通过广泛的消融实验,作者发现同时使用 dropout 和 uniform 的 drop path(不同于一般 ViT 使用线性增加的 drop path rate)会产生最优的模型效果。
作者还进行了可视化比较,可以看到,在 PredFormer 相对于 TAU 明显减少了预测误差。作者还给出了一个特殊例子来证明 PredFormer 模型相比于 CNN 模型在泛化性上的优越性。
在交通流预测任务上,当第四帧相比前三帧明显减少流量时,TAU 受限于归纳偏置仍然预测了较高的流量,而 PredFormer 却能捕捉到这里的变化。PredFormer 预测剧烈变化的能力在交通流和天气预测中可能有非常宝贵的应用价值。
这篇文章提出了一个无循环,无卷积的纯 Transformer 模型,并对时空注意力分解进行了全面的分析。PredFormer 不仅提供了一个鲁棒的基线模型,还为以后的基于纯 Transformer 构架的创新工作铺平了道路。
欢迎关注公众号CV技术指南,专注于计算机视觉的技术总结、最新技术跟踪、经典论文解读、CV招聘信息。
计算机视觉入门1v3辅导班
【技术文档】《从零搭建pytorch模型教程》122页PDF下载
QQ交流群:470899183。群内有大佬负责解答大家的日常学习、科研、代码问题。
其它文章
分享一个CV知识库,上千篇文章、专栏,CV所有资料都在这了
明年毕业,还不知道怎么做毕设的请抓紧机会了
LSKA注意力 | 重新思考和设计大卷积核注意力,性能优于ConvNeXt、SWin、RepLKNet以及VAN
CVPR 2023 | TinyMIM:微软亚洲研究院用知识蒸馏改进小型ViT
ICCV2023|涨点神器!目标检测蒸馏学习新方法,浙大、海康威视等提出
ICCV 2023 Oral | 突破性图像融合与分割研究:全时多模态基准与多交互特征学习
听我说,Transformer它就是个支持向量机
HDRUNet | 深圳先进院董超团队提出带降噪与反量化功能的单帧HDR重建算法
南科大提出ORCTrack | 解决DeepSORT等跟踪方法的遮挡问题,即插即用真的很香
1800亿参数,世界顶级开源大模型Falcon官宣!碾压LLaMA 2,性能直逼GPT-4
SAM-Med2D:打破自然图像与医学图像的领域鸿沟,医疗版 SAM 开源了!
GhostSR|针对图像超分的特征冗余,华为诺亚&北大联合提出GhostSR
Meta推出像素级动作追踪模型,简易版在线可玩 | GitHub 1.4K星
CSUNet | 完美缝合Transformer和CNN,性能达到UNet家族的巅峰!
AI最全资料汇总 | 基础入门、技术前沿、工业应用、部署框架、实战教程学习
计算机视觉入门1v3辅导班
计算机视觉交流群