无循环无卷积!上海交大提出时空预测学习新里程碑PredFormer

news/2024/11/27 11:21:08/文章来源:https://www.cnblogs.com/wxkang/p/18572025
前言 时空预测,真的需要 RNN 吗?真的需要 CNN 吗?是否能够设计一个模型,可以自动地学习数据中的时空依赖,而不需要依赖于归纳偏置呢?

欢迎关注公众号CV技术指南,专注于计算机视觉的技术总结、最新技术跟踪、经典论文解读、CV招聘信息。

本文转载自PaperWeekly

仅用于学术分享,若侵权请联系删除

CV方向的准研究生们,未来三年如何度过?

招聘高光谱图像、语义分割、diffusion等方向论文指导老师

时空预测学习是一个拥有广泛应用场景的领域,比如天气预测,交通流预测,降水预测,自动驾驶,人体运动预测等。

提起时空预测,不得不提到经典模型 ConvLSTM 和最经典的 benchmark moving mnist,在 ConvLSTM 时代,对于 Moving MNIST 的预测存在肉眼可见的伪影和预测误差。而在最新模型 PredFormer 中,对 Moving MNIST 的误差达到肉眼难以分辨的近乎完美的预测结果。

▲ ConvLSTM

 

▲ PredFormer

 

在以前的时空预测工作中,主要分为两个流派,基于循环(自回归)的模型,以 ConvLSTM/PredRNN//E3DLSTM/SwinLSTM/VMRNN 等工作为代表;更近年来,研究者提出无需循环的 SimVP 框架,由 CNN Encoder-Decoder 结构和一个时间转换器组成,以 SimVP/TAU/OpenSTL 等工作为代表。

RNN 系列模型的缺陷在于,无法并行化,自回归速度慢,显存占用高,效率低;CNN 系列模型无需循环提高了效率,得益于归纳偏置,但往往以牺牲泛化性和可扩展性为代价,模型上限低。

于是作者提出了问题,时空预测,真的需要 RNN 吗?真的需要 CNN 吗?是否能够设计一个模型,可以自动地学习数据中的时空依赖,而不需要依赖于归纳偏置呢?

一个直觉的想法是利用 Transformer,因为它在各种视觉任务中的广泛成功,并且是 RNN 和 CNN 的有力替代者。

在此前的时空预测工作中,已有研究者把 Transformer 嵌入到上述两种框架中,比如 SwinLSTM (ICCV23) 融合了 Swin Transformer 和 LSTM,比如 OpenSTL (NeurIPS23) 把各种 MetaFormer 结构(比如 ViT,Swin Transformer 等)作为 SimVP 框架中的时间转换器。但是,纯 Transformer 结构的网络鲜有探索。

设计纯 Transformer 模型的挑战在于,如何在一个框架中同时处理时间和空间信息。一个简单的想法是合并空间序列和时间序列,计算时空全注意力,由于 Transformer 的计算复杂度是序列长度的二次复杂度,这样的做法会导致计算复杂度较大。

在这篇文章中,作者提出了用于时空预测学习的新框架 PredFormer,这是一个纯 ViT 模型,既没有自回归也没有任何卷积。作者利用精心设计的基于门控 Transfomer 模块,对 3D Attention 进行了全面的分析,包括时空全注意力,时空分解的注意力,和时空交错的注意力。

论文标题:

PredFormer: Transformers Are Effective Spatial-Temporal Predictive Learners

论文链接:

PredFormer 采用非循环、基于 Transformer 的设计,既简单又高效,更少参数量,Flops,更快推理速度,性能显著优于以前的方法。

在合成和真实数据集上进行的大量实验表明,PredFormer 实现了最先进的性能。在 Moving MNIST 上,PredFormer 相对于 SimVP 实现了 51.3% 的 MSE 降低,突破性地达到11.6。对于 TaxiBJ,该模型将 MSE 降低了 33.1%,并将 FPS 从 533 提高到 2364。

此外,在 WeatherBench 上,它将 MSE 降低了 11.1%,同时将 FPS 从 196 提高到 404。这些准确度和效率方面的性能提升证明了 PredFormer 在实际应用中的潜力。

PredFormer 模型遵循标准 ViT 的设计,先对输入进行 Patch Embedding,把输入为 [B, T, C, H, W] 的时空序列转换为 [B, T, N, D] 的张量。在位置编码环节,作者采用了不同于一般 ViT 设计的可学习的位置编码,而是采用了基于 sin 函数的绝对位置编码,作者在消融实验中进一步阐述了绝对位置编码在时空任务中的优越性。

PredFormer 的编码器部分,由门控 Transfomer 模块以不同的方式堆叠而成。由于编码器部分是纯 Transformer 结构,没有任何卷积,也没有分辨率的下降,每一个门控 Transformer 模块都建模了全局信息,这允许模型只需使用一个简单的解码器就可以构成一个性能强大的预测模型。作者采用了一个线性层作为解码器来进行 Patch Recovery,这让模型的输出从 [B, T, N, D] 恢复到 [B, T, C, H, W]。不同于标准 Transformer 模型采用 MLP 作为 FFN,PredFormer 采用了 Gated Linear Unit (GLU) 作为 FFN,这是受 GLU 在 NLP 任务中优于 MLP 启发的改进。作者在消融实验中进一步阐述了 GLU 相比于 MLP 在时空任务上的优越性。 作者对 3D Attention 进行了全面的分析,并提出了 9 种 PredFormer 变体。在以前用于视频分类的 Video ViT 设计中,TimesFormer (ICML21), ViviT (ICCV21), TSViT (CVPR23) 等工作也对时空分解进行了分析,但是 TimesFormer 是在 self-attention 层面进行分解,也就是 spatial attention 和 temporal attention 共用一个 MLP。ViviT 则是提出了在 Encoder 层面(先空间后时间),self-attention 层面和 head 层面进行时空分解。而 TSViT 发现先时间后空间的 Encoder 对卫星序列图像分类更有效。

不同于以上工作,PredFormer 是在 Gated Transformer Block (GTB) 层面(多了基于 Gated Linear Unit)进行时空分解。对时间和空间的 self-attention 都加 GLU 是至关重要的,因为它可以让学习到的时空特征相互作用并且增强非线性。

PredFormer 提出了时空全注意力 Encoder,时间在前和空间在前的 2 种分解 Encoder 和 6 种新颖的时空交错的 Encoder,一共 9 种模型。PredFormer 提出了 PredFormer Layer 的概念,即一个既能建模空间信息,又能建模时间信息的最小单元。

基于这种想法,作者提出了三种基本范式,二元组(由一个 Temporal GTB 和一个 Spatial GTB 组成,有 T-S 和 S-T 两种方式),三元组(T-S-T 和 S-T-S),四元组(两个二元组以相反的方向重组)。

这一设计源于不同的时空预测任务往往有着不同的空间分辨率和时间分辨率(时间间隔以及变化程度),这意味着不同的数据集上对时间信息和空间信息的依赖程度不同,作者设计了这些模型以提高 PredFormer 模型在不同任务上的适应性。

在实验部分,作者控制了提出的每种变体使用相同的 GTB 数目,这可以保证模型的参数量基本一致,从而对比不同模型的性能。对于三元组 GTB 个数无法整除的情况,选择最为接近的三元组个数。

实验发现了一些规律:

1. 时间在前的分解 Encoder 模型优于时空全注意力模型,由于空间在前的分解 Encoder 模型;

2. 时空交错的 6 种模型在大多数任务上表现都很好,都能达到 sota,但最优模型因为数据集本身的不同时空依赖特性而不同,这体现了 PredFormer 这种框架和时空交错设计的优势;

3. 作者在讨论环节提出了建议,在其他的时空预测任务上,从四元组-TSST 开始尝试,因为这个模型在三个数据集上都表现 sota,先调整 M 个 TSST(即 4M 个门控 Transformer)的 M 参数,然后尝试 M 个 TST 和 M 个 STS 以确定数据集是时间依赖更强或空间依赖更强的模型。

得益于 Transformer 构架的可扩展性,不同于 SimVP 框架的 CNN Encoder-Decoder 模型,对 spatial 和 temporal 的 hidden dim 以及 block 数都设置了不同的值,PredFormer 对 spatial 和 temporal GTB 采用相同的固定的参数,因此只需要调整 M 的值,在比较少次数的调整后就可以达到最优性能。

ViT 模型的训练通常要求较大的数据集,在时空预测任务上,大多数据集在几千到几万的量级,数据量少,因此很容易过拟合。作者还探索了不同的正则化策略,包括 dropout 和 drop path,通过广泛的消融实验,作者发现同时使用 dropout 和 uniform 的 drop path(不同于一般 ViT 使用线性增加的 drop path rate)会产生最优的模型效果。

作者还进行了可视化比较,可以看到,在 PredFormer 相对于 TAU 明显减少了预测误差。作者还给出了一个特殊例子来证明 PredFormer 模型相比于 CNN 模型在泛化性上的优越性。

在交通流预测任务上,当第四帧相比前三帧明显减少流量时,TAU 受限于归纳偏置仍然预测了较高的流量,而 PredFormer 却能捕捉到这里的变化。PredFormer 预测剧烈变化的能力在交通流和天气预测中可能有非常宝贵的应用价值。

这篇文章提出了一个无循环,无卷积的纯 Transformer 模型,并对时空注意力分解进行了全面的分析。PredFormer 不仅提供了一个鲁棒的基线模型,还为以后的基于纯 Transformer 构架的创新工作铺平了道路。

欢迎关注公众号CV技术指南,专注于计算机视觉的技术总结、最新技术跟踪、经典论文解读、CV招聘信息。

计算机视觉入门1v3辅导班

【技术文档】《从零搭建pytorch模型教程》122页PDF下载

QQ交流群:470899183。群内有大佬负责解答大家的日常学习、科研、代码问题。

其它文章

分享一个CV知识库,上千篇文章、专栏,CV所有资料都在这了

明年毕业,还不知道怎么做毕设的请抓紧机会了

LSKA注意力 | 重新思考和设计大卷积核注意力,性能优于ConvNeXt、SWin、RepLKNet以及VAN

CVPR 2023 | TinyMIM:微软亚洲研究院用知识蒸馏改进小型ViT

ICCV2023|涨点神器!目标检测蒸馏学习新方法,浙大、海康威视等提出

ICCV 2023 Oral | 突破性图像融合与分割研究:全时多模态基准与多交互特征学习

听我说,Transformer它就是个支持向量机

HDRUNet | 深圳先进院董超团队提出带降噪与反量化功能的单帧HDR重建算法

南科大提出ORCTrack | 解决DeepSORT等跟踪方法的遮挡问题,即插即用真的很香

1800亿参数,世界顶级开源大模型Falcon官宣!碾压LLaMA 2,性能直逼GPT-4

SAM-Med2D:打破自然图像与医学图像的领域鸿沟,医疗版 SAM 开源了!

GhostSR|针对图像超分的特征冗余,华为诺亚&北大联合提出GhostSR

Meta推出像素级动作追踪模型,简易版在线可玩 | GitHub 1.4K星

CSUNet | 完美缝合Transformer和CNN,性能达到UNet家族的巅峰!​

AI最全资料汇总 | 基础入门、技术前沿、工业应用、部署框架、实战教程学习

计算机视觉入门1v3辅导班

计算机视觉交流群

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/842187.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

NeurIPS2024 | 提高专业生产力,让你的AI画作布局可控,360 AI Research开源新模型HiCo

前言 为了解决这一问题,360人工智能研究院在人工智能顶会NeurIPS2024上提出了布局可控AI绘画模型HiCo,并将于近期开源。基于HiCo模型,使用者可以对生成画面中的不同主体的布局进行自由控制和调整,实现“指哪打哪”的生成效果。 欢迎关注公众号CV技术指南,专注于计算机视觉…

【开源系列】Faraday : 渗透测试 IDE 和漏洞管理平台

什么是 Faraday ? Faraday 是一个开源的漏洞管理平台,它旨在帮助安全团队有效地管理和协作处理漏洞。Faraday 提供了一个集中的平台,用于收集、分析和报告漏洞信息。它支持多种集成,可以与各种安全工具和扫描器无缝对接,从而提高漏洞管理的效率和准确性。 Faraday 的功能特…

【开源系列】OpenEMR:开源免费的医院管理系统

今天给大家分享一款完全开源的电子病历和医疗管理系统【OpenEMR】 什么是 OpenEMR ? OpenEMR 是一款免费开源的电子健康记录(EHR)和医疗实践管理系统。它提供了全面的医疗信息系统功能,包括患者信息管理、日程安排、处方开具、账单处理、报告生成等。OpenEMR 支持多种平台,…

【windows环境搭建】Windows下安装使用JMETER

一、插件驱动安装1.1 安装JDK环境1.2 安装插件1.3 添加驱动包二、JMeter压测2.1 创建压测线程组2.2 创建JDBC request2.3 创建JDBC Connection Configuration2.4 创建汇总报告2.5 创建查看结果树2.6 创建jp@gc - Transactions per Second(TPS)一、插件驱动安装 1.1 安装JDK环境…

【windows安装教程】Windows下安装使用JMETER

一、插件驱动安装1.1 安装JDK环境1.2 安装插件1.3 添加驱动包二、JMeter压测2.1 创建压测线程组2.2 创建JDBC request2.3 创建JDBC Connection Configuration2.4 创建汇总报告2.5 创建查看结果树2.6 创建jp@gc - Transactions per Second(TPS)一、插件驱动安装 1.1 安装JDK环境…

如何设计好分布式数据库,这个策略很重要(GaussDB)

​ 数据库是应用和计算机的核心组成,试想,如果没有数据库,就像人的大脑没有了记忆一样,信息也得不到共享,那么,对开发者来说,如何设计一款高效易用的数据库至关重要。 GaussDB是企业级分布式数据库,具备分布式强一致、有效降低容灾成本、支持PB级海量数据、智能诊断等优…

[Linux]缓冲区的理解

缓冲区的理解 先来看这段代码 #include <stdio.h> #include <unistd.h> #include <string.h>int main() {//C接口printf("hello printf\n");fprintf(stdout, "hello fprintf\n");fputs("hello fputs\n", stdout);//系统接口co…

Ollama本地部署Qwen2.5 14B(使用docker实现Nvidia GPU支持)

通过docker部署支持Nvidia GPU加速的本地大模型前提条件:已经本地安装好了Ollama。 如果没有安装Ollama或者想部署其他的模型或者不想使用docker,,可以参考之前的这篇文章: https://www.cnblogs.com/Chenlead/p/18571005 安装过程参考:https://docs.openwebui.com/getting…

Jmeter 临界部分控制器 Critical Section Controller

Jmeter必知利器-临界部分控制器-腾讯云开发者社区-腾讯云 Jmeter之临界部分控制器使用-CSDN博客 使用前,线程执行顺序随机 使用后,线程执行顺序从上到下

dedecms提示500错误解决方法

查看网站程序版本:打开 /data/admin/ver.txt 文件查看 查看主机PHP版本:在主机面板查看或创建一个 p.php 文件,内容为 <?php phpinfo(); ?>,上传到网站根目录,访问 http://域名/p.php 查看PHP版本,完成后删除 p.php 低版本织梦(2014、2015、2016、2017开头)无法…

自动检测工作人员工服穿戴规范行为

自动检测工作人员工服穿戴规范行为利用现场安装的高清摄像机,自动检测工作人员工服穿戴规范行为对采集到的视频进行预处理,识别出图像中的员工,并检测其工服穿戴情况,一旦系统判断出工服穿戴异常,将立即发出警报,通知管理人员或自动启动相应的安全措施。通过实时监测,及…

Docker Logs清理

查看docker日志路径 docker inspect --format={{.LogPath}} <container_name_or_id>清理docker日志 echo |sudo tee $(docker inspect --format={{.LogPath}} <container_name_or_id>).zstitle { width: 280px; text-align: center; font-size: 26px } .zsimgwei…