Mamba详细介绍和RNN、Transformer的架构可视化对比

Transformer体系结构已经成为大型语言模型(llm)成功的主要组成部分。为了进一步改进llm,人们正在研发可能优于Transformer体系结构的新体系结构。其中一种方法是Mamba(一种状态空间模型)。

Mamba: Linear-Time Sequence Modeling with Selective State Spaces一文中提出了Mamba,我们在之前的文章中也有详细的介绍。

在本篇文章中,通过将绘制RNN,transformer,和Mamba的架构图,并进行详细的对比,这样我们可以更详细的了解它们之间的区别。

为了说明为什么Mamba是这样一个有趣的架构,让我们先介绍Transformer。

Transformer

Transformer将任何文本输入视为由令牌组成的序列。

transformer的一个主要优点是,无论它接收到多长的输入,它都使用序列中的任何令牌信息(无论序列有多长)来对输入数据进行处理。

这就是我们在论文中看到的注意力机制的作用,但是为了获得全局信息,注意力机制在长序列上非常耗费显存,这个我们后面说。

Transformer由两个结构组成,一组用于表示文本的编码器块和一组用于生成文本的解码器块。这些结构可以用于多种任务,包括翻译。

我们可以采用这种结构来创建仅使用解码器的生成模型。比如基于Transformer的GPT,使用解码器块来完成一些输入文本。

单个解码器块由两个主要部分组成,一个是自注意力模块,另一个是前馈神经网络。

注意力创建一个矩阵,将每个令牌与之前的每个令牌进行比较。矩阵中的权重由令牌对之间的相关性决定。

它支持并行化,所以可以极大地加快训练速度!

但是当生成下一个令牌时,我们需要重新计算整个序列的注意力,即使我们已经生成了一些新的令牌。

为长度为L的序列生成令牌大约需要L²的计算量,如果序列长度增加,计算量可能会很大。并且在这里需要计算所有令牌的注意力,所以如果序列很长,那么内存占用也会很大。所以需要重新计算整个序列是Transformer体系结构的主要瓶颈。当然也有很多技巧来提升注意力机制的效率,这里我们暂时不提,只看最经典的原始论文。

RNN

下面我们介绍更早的序列模型RNN。循环神经网络(RNN)是一种基于序列的网络。它在序列的每个时间步长取两个输入,即时间步长t的输入和前一个时间步长t-1的隐藏状态,以生成下一个隐藏状态并预测输出。

RNN有一个循环机制,允许它们将信息从上一步传递到下一步。我们可以“展开”这个可视化,使它更明确。

在生成输出时,RNN只需要考虑之前的隐藏状态和当前的输入。这样不会重新计算以前的隐藏状态,这正Transformer没有的。

这种流程可以让RNN进行快速推理,因为的时间与序列长度线性扩展!并且可以有无限的上下文长度(理论上),因为每次推理他只取一个隐藏状态和当前输入,内存的占用是非常稳定的。

我们将RNN应用于之前使用过的输入文本。

每个隐藏状态都是以前所有隐藏状态的聚合。但是这里就出现了问题,在生成名称“Maarten”时,最后一个隐藏状态不再包含关于单词“Hello”的信息(或者说最早的信息会被坐进的信息覆盖)。这会导致随着时间的推移,rnn会忘记信息,因为它们只考虑前一个状态。

并且rnn的这种顺序性产生了另一个问题。训练不能并行进行,因为它需要按顺序完成每一步。

与Transformer相比,rnn的问题完全相反!它的推理速度非常快,但不能并行化导致训练很慢。

人们一直在寻找一种既能像Transformer那样并行化训练,能够记住先前的信息,并且在推理时间还是随序列长度线性增长的模型,Mamba就是这样宣传的。

在介绍Mamba之前,让我们还需要介绍以下状态空间模型

The State Space Model (SSM)

状态空间模型(SSM),像Transformer和RNN一样,可以处理序列信息,比如文本,也包括信号。

状态空间是包含能够完全描述一个系统的最少数量变量的概念。它是一种通过定义系统可能的状态来数学表示问题的方式。

比如说我们正在通过一个迷宫。“状态空间” 就是所有可能位置(状态)的地图。每个点代表迷宫中的一个独特位置,具有特定的细节,比如你离出口有多远。

“状态空间表示” 是对这个地图的简化描述。它展示了你当前所处的位置(当前状态),以及下一步可以去哪里(可能的未来)。

虽然状态空间模型使用方程和矩阵来跟踪这种行为,描述状态的变量,在我们的例子中是X和Y坐标以及到出口的距离,可以表示为“状态向量”。

听起来熟悉吗?这不就是强化学习中的状态吗,我个人认为是可以这么理解的,那么怎么和序列有关呢?

因为语言模型中的嵌入或向量也经常用于描述输入序列的“状态”。例如,你当前位置的向量(状态向量)可能看起来像这样:

在神经网络中,“状态”通常是指其隐藏状态,在大型语言模型的背景下,这是生成新标记的一个最重要的方面之一。

状态空间模型(SSMs)是用于描述这些状态表示并根据某些输入进行下一个状态预测的模型。

在时间t,状态空间模型(SSMs):

  • 将输入序列x(t)(例如,在迷宫中向左和向下移动)映射到潜在状态表示h(t)(例如,到出口的距离和x/y坐标),
  • 并推导出预测的输出序列y(t)(例如,再次向左移动以更快地到达出口)。

这里就与强化学习中使用离散序列(如仅向左移动一次)不同,它将连续序列作为输入并预测输出序列。

ssm假设动态系统,例如在三维空间中移动的物体,可以通过两个方程从时间t的状态预测。

通过求解这些方程,假设可以揭示基于观测数据(输入序列和先前状态)预测系统状态的统计原理。

它的目标是找到这个状态表示h(t)这样我们就可以从一个输入序列到一个输出序列。

这两个方程就是是状态空间模型的核心。状态方程描述了基于输入如何影响状态(通过矩阵B)的状态变化(通过矩阵A)。

h(t)表示任意时刻t的潜在状态表示,而x(t)表示某个输入。

输出方程描述了状态如何转化为输出(通过矩阵C),以及输入如何影响输出(通过矩阵D)。

矩阵A、B、C和D通常被称为参数,因为它们是可学习的。将这两个方程可视化,我们可以得到如下架构:

下面我们看看这些矩阵如何影响学习过程。

假设我们有一个输入信号x(t)这个信号首先乘以矩阵B它描述了输入如何影响系统。

更新状态(h)是包含环境核心“知识”的潜在空间。我们将状态与矩阵A相乘,矩阵A描述了所有内部状态是如何连接的,因为它们代表了系统的潜在表示。

这里可以看到,在创建状态表示之前应用矩阵A,并在状态表示更新之后更新矩阵A。

然后使用矩阵C来描述如何将状态转换为输出。

最后利用矩阵D提供从输入到输出的直接信号。这通常也被称为跳过(残差)连接。

由于矩阵D类似于跳过连接,所以SSM通常被视为为不进行跳过连接的部分

回到我们的简化视图,现在可以将重点放在矩阵A、B和C上,它们是SSM的核心。

更新原始方程并添加一些颜色来表示每个矩阵的目的

这两个方程根据观测数据预测系统的状态。由于期望输入是连续的,SSM是连续时间表示。

但是因为文字都是离散的输入,我们还需要将模型离散化。这里就要使用* Zero-order hold * 技术

每次我们接收到一个离散信号,都会保证他的值不变,直到接收到一个新的离散信号再改变。这个过程创建了一个SSM可以使用的连续信号:

我们保持该值的时间由一个新的可学习参数表示,称为步长∆。这样就得到了一个连续的信号并且可以只根据输入的时间步长对值进行采样。

这些采样值就是我们的离散输出!在数学上,我们可以应用Zero-order hold如下:

因为我们SSM处理的是离散信号,所以这里不是一个函数到函数,x(t)→y(t),而是一个序列到序列,xₖ→yₖ,我们用公式表示如下:

矩阵A和B现在表示模型的离散参数,用k代替t来表示离散的时间步长。

离散化的SSM允许在特定的时间步中处理信息。就像我们之前在循环神经网络(RNNs)中看到的那样,循环方法在这里也非常有用,可以将问题重新表述为时间步骤:

在每个时间步长,我们计算当前输入(Bxₖ)如何影响前一个状态(Ahₖ₁),然后计算预测输出(Chₖ)。

这种表示看起来是不是有点熟悉?其实他的处理方法和RNN一样

也可以这样展开:

这种技术与RNN类似,快速推理和慢速训练。

另一种ssm的表示是卷积的表示。我们应用过滤器(核)来获得聚合特征:

因为我们处理的是文本而不是图像,所以我只要一维的视角:

我们用来表示这个“过滤器”的核是由SSM公式推导出来的:

可以使用SSM核遍历每一组令牌并计算输出:

上图也说明了padding 可能对输出产生的影响,所以我们一般都会在末尾padding而不是在前面。第二步核被移动一次来执行下一步计算:

在最后一步,我们可以看到核的完整效果:

卷积的一个主要好处是它可以并行训练。但是由于核大小是固定,它们的推理不如rnn快速并且对序列长度有限制。

上面的三种SMM都有各自的优缺点

这里可以使用一个简单的技巧,即根据任务选择表示。在训练过程中使用可以并行化的卷积表示,在推理过程中,我们使用高效的循环表示:

听起来有点奇幻,但是有人就是实现出来了,这个模型叫做Linear State-Space Layer (LSSL)

https://proceedings.neurips.cc/paper_files/paper/2021/hash/05546b0e38ab9175cd905eebcc6ebb76-Abstract.html

它结合了线性动态系统理论和神经网络的概念,可以有效地捕获数据中的时序信息和动态特征。LSSL 基于线性动态系统理论,这种系统可以用状态空间模型表示。在这个模型中,系统的行为由状态变量的演化和外部控制信号的影响决定。状态变量是系统的内部表示,可以捕获系统的动态特性。

这些表示都有一个重要的特性,即线性时不变性(LTI)。LTI表示ssm参数A、B和C对于所有时间步长都是固定的。这意味着对于SSM生成的每个令牌,矩阵A、B和C都是相同的。

也就是说无论给SSM的序列是什么,A、B和C的值都保持不变。这样就得到了一个不感知内容的静态表示。但是静态表示没有任何意义对吧,所以Mamba解决的就是这个问题,但是在介绍Mamba之前,我们还有一个知识点需要强调,那就是矩阵A

因为SSM公式中最重要的就是矩阵a。正如我们之前在循环表示中看到的那样,它捕获了关于前一个状态的信息来构建新状态,如果矩阵a如果跟RNN一样会遗忘掉非常靠前的信息那么SMM将没有任何的意义,对吧。

矩阵A产生隐藏状态:

如何保留大上下文大小的方式创建矩阵A呢?

HiPPO 的模型结合了递归记忆(Recurrent Memory)和最优多项式投影(Optimal Polynomial Projections)的概念,这种投影技术可以显著改善递归记忆的性能,特别是在处理长序列和长期依赖关系时。

https://proceedings.neurips.cc/paper/2020/hash/102f0bb6efb3a6128a3c750dd16729be-Abstract.html

使用矩阵A来构建一个状态表示,该状态表示可以很好地捕获最近的令牌并衰减较旧的令牌。其公式可表示为:

具体的详细内容我们就不介绍了,有兴趣的查看原论文。

这样我们就基本上解决了所有的问题:1、状态空间模型;2、处理远程依赖关系;3、离散化和并行计算

如果想深入了解有关如何计算HiPPO矩阵和自己构建S4模型建议您阅读注释的S4。

https://srush.github.io/annotated-s4/

Mamba

上面介绍完所有必要的基础知识,最后就是我们的重点了

Mamba 有两个主要贡献:

1、选择性扫描算法,模型可以过滤有关和无关的信息

2、硬件感知算法,通过并行扫描、核融合和重计算有效地存储(中间)结果。

在探讨这两个主要贡献之前,我们先看看一下为什么它们是必要的。

状态空间模型,S4(Structured State Space Model),在语言建模和生成中的某些任务上表现不佳

比如在选择性复制任务中,SSM的目标是按顺序复制输入和输出的部分:

(循环/卷积)SSM在这个任务中表现不佳,因为它是线性时不变的。对于SSM生成的每个令牌,矩阵A、B和C都是相同的。

因为它将每个令牌平等地视为固定的a、B和C矩阵的结果,所以SSM不能执行内容感知推理

SSM表现不佳的第二个任务是重现输入中发现的模式:

我们的提示在“教”模型在每个“Q:”之后提供“A:”响应。但是由于ssm是时间不变的,它不能选择从其历史中获取先前的令牌。

以矩阵B为例不管输入x是什么,矩阵B保持完全相同,并且与x无关:

同理无论输入是什么,A和C也不变,这就是我们上面说的静态。

而Transformers 可以根据输入序列动态地改变注意力。可以选择性地“看”或“注意”序列的不同部分,再加上位置编码,这使得Transformers对于这种任务非常的简单。

ssm在这些任务上的糟糕性能说明了定常ssm的潜在问题,矩阵A、B和C的静态特性导致了内容感知问题。

选择性地保留信息

SSM的循环表示创建了一个非常有效的小状态,因为它压缩了整个历史信息,所以与不压缩历史(注意力矩阵)的Transformer模型相比,它的功能要弱得多。

Mamba 的目标是获得Transformer一样强大的“小”状态

通过有选择地将数据压缩到状态,当输入一个句子时,通常会有一些信息,比如停顿词,这些信息没有太多的意义。

我们先看看SSM在训练期时的输入和输出维度:

在结构化状态空间模型(S4)中,矩阵a、B和C独立于输入,因为它们的维度N和D是静态的,不会改变。

而Mamba通过结合输入的序列长度和批量大小,使矩阵B和C,甚至步长∆依赖于输入:

这意味着对于每个输入标记,有不同的B和C矩阵,这解决了内容感知问题!这里矩阵A保持不变,因为希望状态本身保持静态,但影响它的方式(通过B和C)是动态的。

也就是说它们一起选择性地选择将什么保留在隐藏状态中,什么需要忽略,这都是由输入确定的。

较小的步长∆导致忽略特定的单词,而是更多地使用之前的上下文,而较大的步长∆则更多地关注输入单词而不是上下文:

扫描操作

这些矩阵现在是动态的了所以它们不能使用卷积表示来计算,只能使用循环进行处理,这就使得无法进行并行化。

为了实现并行化,我们先看看循环的输出:

每个状态都是前一个状态(乘以A)加上当前输入(乘以B)的和。这被称为扫描操作,可以很容易地通过for循环计算出来。但是并行化似乎是不可能的,因为每个状态只有在我们有前一个状态时才能计算出来。

但是Mamba使用并行扫描算,通过关联属性假定执行操作的顺序无关紧要。这样就可以计算部分序列并迭代组合它们:

这样还有一个好处是因为顺序不重要,也可以省略掉Transformer的位置编码。

硬件感知的算法

最近gpu的一个缺点是它们在小但高效的SRAM和大但效率稍低的DRAM之间的传输(IO)速度有限。在SRAM和DRAM之间频繁地复制信息成为瓶颈。

Mamba的DRAM和SRAM分配的具体实例如下:

中间状态不被保存,但对于反向传播计算梯度是必要的。作者重新计算了反向传递过程中的中间状态。尽管这看起来效率很低,但它比从相对较慢的DRAM读取所有这些中间状态的成本要低得多。

这里我们就不详细说明了,因为这部分我也没太研究过

Mamba 块

选择性SSM可以作为一个块,就像在Transformer中的的注意力模块一样。我们可以堆叠多个块,并使用它们的输出作为下一个曼巴块的输入:

最后一个端到端(输入到输出)的例子包含了归一化层和选择输出标记softmax。

这样就得到了快速的推理和训练,而且是“无限”长度上下文的模型

总结

看完这篇文章,我希望你能对Mamba 和状态空间模型有一定的了解,最后我们以作者的发现为结尾:

作者发现模型与相同尺寸的Transformer模型的性能相当,有时甚至超过了它们!

https://avoid.overfit.cn/post/94105fed36de4cd981da0b916c0ced47

作者:Maarten Grootendorst

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/487290.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

华为HCIP Datacom H12-831 卷23

单选题 1、某园区部署IS-IS实现网络互通,在所有IS-IS路由器的进程中配置命令flash-flood 6 max-timer-interval 100 Leve1-2,则以下关于该场景的描述,正确的是哪—项? A、若某IS-IS路由器LSDB内更新的LSP数量为5,则在100毫秒内且路由计算完成前&#…

nginx-------- 高性能的 Web服务端 (三) 验证模块 页面配置

一、http设置 1.1 验证模块 需要输入用户名和密码 htpasswd 此命令来自于 httpd-tools 包,如果没有安装 安装一下即可 也可以安装httpd 直接yum install httpd -y 也一样 第一次生成文件htpasswd -c 文件路径 姓名 交互式生成密码 htpasswd -bc 文…

视频评论抓取软件|抖音数据抓取工具

最近我们推出了一款基于C#语言开发的工具。这款工具提供了丰富的功能,旨在帮助用户轻松获取抖音视频内容。让我们一起来详细介绍一下这款工具的主要功能模块: 1. 批量视频提取: 工具提供了便捷的批量视频提取功能,用户只需输入关…

算法打卡day1|数组篇|Leetcode 704.二分查找、27.移除元素

数组理论基础 数组是存放在连续内存空间上的相同类型数据的集合,可以方便的通过下标索引的方式获取到下标下对应的数据。 1.数组下标都是从0开始的。 2.数组内存空间的地址是连续的。 正是因为数组的在内存空间的地址是连续的,所以我们在删除或者增添…

基于DPU和HADOS-RACE加速Spark 3.x

背景简介 Apache Spark(下文简称Spark)是一种开源集群计算引擎,支持批/流计算、SQL分析、机器学习、图计算等计算范式,以其强大的容错能力、可扩展性、函数式API、多语言支持(SQL、Python、Java、Scala、R&#xff09…

sql注入 [极客大挑战 2019]FinalSQL1

打开题目 点击1到5号的结果 1号 2号 3号 4号 5号 这里直接令传入的id6 传入id1^1^1 逻辑符号|会被检测到,而&感觉成了注释符,&之后的内容都被替换掉了。 传入id1|1 直接盲注比较慢,还需要利用二分法来编写脚本 这里利用到大佬的脚…

idea如何在一个service窗口中显示多个服务教程

idea在service窗口中显示多个服务 展示效果如下: 找到.idea > workspace.xml 中找到 RunDashboard 替换成如下 <component name"RunDashboard"><option name"configurationTypes"><set><option value"SpringBootApplicatio…

2023最新盲盒交友脱单系统源码

源码获取方式 搜一搜&#xff1a;万能工具箱合集 点击资源库直接进去获取源码即可 如果没看到就是待更新&#xff0c;会陆续更新上 或 源码软件库 最新盲盒交友脱单系统源码&#xff0c;纸条广场&#xff0c;单独抽取/连抽/同城抽取/高质量盒子 新增功能包括心动推荐&#xff…

React18源码: reconcliler启动过程

Reconcliler启动过程 Reconcliler启动过程实际就是React的启动过程位于react-dom包&#xff0c;衔接reconciler运作流程中的输入步骤.在调用入口函数之前&#xff0c;reactElement(<App/>) 和 DOM对象 div#root 之间没有关联&#xff0c;用图片表示如下&#xff1a; 在启…

Go语言基础总结

一、Go语言结构 包声明 引入包 函数 变量 语句&表达式 注释 下面简单给出hello.go文件。 package src /*定义包名*/import "fmt" /*引入包*/func hello() { /*函数*/fmt.Println("Hello,World!") /*语句&表达式*/fmt.Println("菜鸟教…

跟着野火学FreeRTOS:第二段(事件组)

在小节里面介绍了二进制信号量&#xff0c;计数信号量&#xff0c;互斥量和递归互斥量等功能&#xff0c;其中二进制信号量和计数信号量&#xff08;也包括队列&#xff09;常用于任务和任务之间以及任务和中断之间的同步&#xff0c;她们具有以下属性&#xff1a; 当等待的事…

【QT】QTextEdit 常用方法汇总

目录 1.QTextEdit 限制文本输入数量 2.使用 QTextEdit&#xff0c;根据我希望一次可见的行数来设置高度 3.限制QTextEdit行数 4.判断QTextEdit当前行数 5.QTextEdit光标移至最后一行 6.QTextEdit删除光标的前一个字符 7.QTextEdit移动光标至上一行的起始位置 8.限制QTe…