文献阅读:LONGNET: Scaling Transformers to 1,000,000,000 Tokens

  • 文献阅读:LONGNET: Scaling Transformers to 1,000,000,000 Tokens
    • 1. 文章简介
    • 2. 方法原理
      • 1. 方法思路
      • 2. Dilated Attention
        • 1. 具体原理
        • 2. 多头实现
        • 3. 复杂度分析
      • 3. 训练方法
    • 3. 实验结果
    • 4. 结论 & 思考
    • 5. 参考链接
  • 文献链接:https://arxiv.org/abs/2307.02486

1. 文章简介

这篇文章算是我司最近的一篇力作吧,即DeepNet, Foundation Transformer之后,大佬们终于还是盯上了attention layer,毕竟attention层 O ( N 2 ) O(N^2) O(N2)的计算复杂度一直是制约Transformer往长文本发展的主要原因。

想当年,像是线性化Attention的Linformer,或者以更直观的稀疏化attention的Reformer,亦或者结合局部与全局attention的Longformer,或者类似金字塔型的将长文本拆分为短文本然后各自做attention然后逐层往上的方式(不过这篇具体文章给忘了),总之当年零零碎碎有不少关于优化attention层计算量,使之可以拓展到长文本上的工作。

不过可惜的是,虽然当时大家都觉得这个方向很重要,结果以GPT3还有PALM等为代表的大模型反而从工程上发力,直接强行扩展文本长度,从头上干掉了这个问题……

这两年,感觉这方面的工作已经比较少听到了,不过我司的大佬们似乎还是重新抓出了这个方向,然后像是DeepNet那样直接干出了一个量级上碾压的工作,也是真的厉害……

在这里插入图片描述

2. 方法原理

1. 方法思路

LongNet的整体的一个思路其实和之前的Reformer,Linformer等一致,还是在attention层方面做文章,希望将attention layer的计算复杂度从原始的 O ( N 2 d ) O(N^2d) O(N2d)进行优化,使得其与句长 N N N呈线性关系而非平方关系,从而使得模型整体的计算复杂度得到缩减。

对于,文中提出了dilated attention的结构,成功地将attention layer的计算复杂度从 O ( N 2 d ) O(N^2d) O(N2d)降维至 O ( N d ) O(Nd) O(Nd)复杂度。

在这里插入图片描述

需要注意的是,这里的比较没有包含linear transformer,它虽然很早之前已经实现了 O ( N d ) O(Nd) O(Nd)复杂度的attention实现,不过貌似效果不佳,不算是主流的attention方法,因此文中弃用了linear transformer作为对照。

下面,我们就需要具体看一下Dilated Attention层的具体实现方法。

2. Dilated Attention

1. 具体原理

首先,我们给出Dilated Attention层的整体原理图如下:

在这里插入图片描述

具体来说,就是首先给出一个局部窗口长度 w w w和间隔距离 r r r,那么,就可以将总长为 N N N的序列拆分为 N / w N/w N/w个子序列,然后在每一个子序列当中按照间隔 r r r取出token,一共就能够取出 w / r w/r w/r个token,然后用着 w / r w/r w/r个token作为新的序列计算attention,然后把这 N / w N/w N/w个attention矩阵concat起来,就能得到一个 N × N N \times N N×N的稀疏attention矩阵。

考察对于固定的 w , r w,r w,r下的第 i i i个attention矩阵,有:

{ Q i = [ Q i w Q i w + r ⋯ Q ( i + 1 ) w − r ] K i = [ K i w K i w + r ⋯ K ( i + 1 ) w − r ] V i = [ V i w V i w + r ⋯ V ( i + 1 ) w − r ] \left\{ \begin{aligned} Q_i &= [Q_{iw} & Q_{iw+r} & \cdots & Q_{(i+1)w-r}] \\ K_i &= [K_{iw} & K_{iw+r} & \cdots & K_{(i+1)w-r}] \\ V_i &= [V_{iw} & V_{iw+r} & \cdots & V_{(i+1)w-r}] \end{aligned} \right. QiKiVi=[Qiw=[Kiw=[ViwQiw+rKiw+rViw+rQ(i+1)wr]K(i+1)wr]V(i+1)wr]

此时有:

O i = s o f t m a x ( Q i ⋅ K i T d ) V i O_i = \mathop{softmax}(\frac{Q_i \cdot K_i^T}{\sqrt{d}})V_i Oi=softmax(d QiKiT)Vi

当然,这样的一个attention矩阵事实上只包含了局部的attention信息,因此无法兼顾长距离和短距离的attention信息。因此,如果要令总的attention兼顾长距离和短距离的attention信息,就需要取出多组 w , r w,r w,r,分别计算attention然后进行矩阵加和。也就是上图中的合并部分,从而才能获得包含全局attention信息的矩阵。

具体实现上来说,文中采用的是等比数列的方式进行实现,比如如下的方式:

{ w = w , α w , α 2 w , ⋯ , α n w r = r , α r , α 2 r , ⋯ , α n r \left\{ \begin{aligned} w &= {w, \alpha w, \alpha^2 w, \cdots, \alpha^n w} \\ r &= {r, \alpha r, \alpha^2 r, \cdots, \alpha^n r} \end{aligned} \right. {wr=w,αw,α2w,,αnw=r,αr,α2r,,αnr

在上图的demo中,取用的 w , r w,r w,r就是 4 4 4 1 1 1 α \alpha α的取值为 2 2 2

当然,考虑到由于 w , r w,r w,r取值不同导致的attention的密度不同,因此加和的时候需要对权重进行调整,具体而言:

O = ∑ i = 1 k s i ∑ j s j O r i , w i O = \sum\limits_{i=1}^{k}\frac{s_i}{\sum_j s_j}O_{r_i, w_i} O=i=1kjsjsiOri,wi

其中, s i s_i si ( w i , r i ) (w_i, r_i) (wi,ri)这组参数下计算得到的attention矩阵( Q i ⋅ K i T d \frac{Q_i \cdot K_i^T}{\sqrt{d}} d QiKiT)在计算softmax时的分母部分,也就是:

∑ j e Q i ⋅ K i T d \sum\limits_{j} e^{\frac{Q_i \cdot K_i^T}{\sqrt{d}}} jed QiKiT

这样也就得到了一组 n n n维的系数向量,作为我们这里的 s s s

2. 多头实现

关于Dilated Attention的多头实现,整体来说和vanilla transformer的实现方式是一致的,还是在input的向量当中进行split,然后分别过一个上述介绍的Dilated Attention层,最后将output的结果concat起来即可。

不过,感谢作者Shuming大佬的解释,这里和vanilla transformer存在一定的区别,具体就在于对于每一个context window,我们事实上都是等间隔的sample了其中的几个token进行attention的计算,某种意义上来说总是会丢失掉一些信息的。

因此,在设计多头attention的时候,文中进行了一定的优化,即对于input的token位置在不同的head上面给了不同的位置偏移量,从而使得尽可能地覆盖更多的token之间的attention。

具体来说就是,对于第 j j j个head,选取的token为:

{ Q i = [ Q i w + j ( ≡ r ) Q i w + r + j ( ≡ r ) ⋯ Q ( i + 1 ) w − r + j ( ≡ r ) ] K i = [ K i w + j ( ≡ r ) K i w + r + j ( ≡ r ) ⋯ K ( i + 1 ) w − r + j ( ≡ r ) ] V i = [ V i w + j ( ≡ r ) V i w + r + j ( ≡ r ) ⋯ V ( i + 1 ) w − r + j ( ≡ r ) ] \left\{ \begin{aligned} Q_i &= [Q_{iw + j(\equiv r)} & Q_{iw+r + j(\equiv r)} & \cdots & Q_{(i+1)w-r + j(\equiv r)}] \\ K_i &= [K_{iw + j(\equiv r)} & K_{iw+r + j(\equiv r)} & \cdots & K_{(i+1)w-r + j(\equiv r)}] \\ V_i &= [V_{iw + j(\equiv r)} & V_{iw+r + j(\equiv r)} & \cdots & V_{(i+1)w-r + j(\equiv r)}] \end{aligned} \right. QiKiVi=[Qiw+j(r)=[Kiw+j(r)=[Viw+j(r)Qiw+r+j(r)Kiw+r+j(r)Viw+r+j(r)Q(i+1)wr+j(r)]K(i+1)wr+j(r)]V(i+1)wr+j(r)]

可以用文中的图3来对上述不同头的attention进行更为形象化的展示如下:

在这里插入图片描述

3. 复杂度分析

下面,我们来考察一下Dilated Attention层的算法复杂度。

我们首先来考察对于一组确定的 w , r w,r w,r对应的Dilated Attention层的算法复杂度,其对应的结果如下:

F L O P s = 2 N w ⋅ ( w r ) 2 d = 2 N w d r 2 FLOPs = \frac{2N}{w} \cdot (\frac{w}{r})^2d = \frac{2Nwd}{r^2} FLOPs=w2N(rw)2d=r22Nwd

因此,遍历 w , r w,r w,r,我们即可得到完整的Dilated Attention层的算法复杂度如下:

F L O P s = ∑ i = 0 k − 1 2 N w i d r i 2 = 2 N w 0 d r 0 2 ∑ i = 0 k − 1 1 α i < 2 N w 0 d r 0 2 ⋅ α α − 1 ∼ O ( N d ) FLOPs = \sum\limits_{i=0}^{k-1}\frac{2Nw_id}{r_i^2} = \frac{2Nw_0d}{r_0^2} \sum\limits_{i=0}^{k-1} \frac{1}{\alpha^i} < \frac{2Nw_0d}{r_0^2} \cdot \frac{\alpha}{\alpha-1} \sim O(Nd) FLOPs=i=0k1ri22Nwid=r022Nw0di=0k1αi1<r022Nw0dα1αO(Nd)

3. 训练方法

最后,我们看一下文中实际的训练过程。

注意到,这里由于极限的扩展了输入的context的序列长度,因此事实上如何将文本塞入GPU也就成了一个大问题,因此,这方面也需要有一些工程上的实现细节考察。

具体来说,文中给出的方法还是说先对sequence进行一下split,然后由不同的GPU分别计算,最后进行加总实现。

其原理图可以参考文中的图4:

在这里插入图片描述

不过需要注意的是,这里在不同的gpu当中计算完了不同的部分的input seq之后,在计算dilated attention的时候会有一个slice的过程,然后slice之后的得到的dilated attention会在不同的GPU之间进行聚合,从而确保不同的gpu上的token之间的attention能够相互计算和聚合。

由于这里只是slice之后的attention,因此可以避免掉由于过长的文本长度(比如文中给出的1B)导致的内存爆炸的问题。

3. 实验结果

文中使用torchscale作为基准库,然后替换attention layer之后train了一个768维,12层的模型进行实验考察。

得到结果如下:

在这里插入图片描述

而除了最终的ppl之外,文中还比较了transformer与LongNet在处理不同文本长度的文本时所需的计算量。

在这里插入图片描述

可以看到:

  • LongNet可以在更少的计算量下获得相较于原始的transformer更好的ppl。

此外,文中还对LongNet在不同的参数量以及不同的context window进行了一下考察,得到结果如下:

在这里插入图片描述

可以看到:

  • 随着参数量的增长,模型的ppl是在不断减小的,说明LongNet具有很好的扩展能力;
  • context window越大,模型的效果也能够不断地提升,说明LongNet对于长文本有较好的理解能力。

最后,文中还非常直观的给出了将输入文本长度扩展到1B之后vanilla transformer与LongNet的infer时间变化的比较:

在这里插入图片描述

其结果直观地证明了LongNet对于长文本处理能力的能力,较之Vanilla Transformer耗时的快速增长,Dilated Attention基本没有发生什么太大的变化。

4. 结论 & 思考

综上,整体而言这篇文章还是很惊艳的,至少从context length的角度来说这种突破性的震撼确实厉害,结合他之前的foundation transformer等工作,我觉得他们在transformer的基础架构上面确实花了不少的功夫来做优化,这一点确实是厉害。

不过考虑到工程上,这篇文章的主要贡献可能还是在于长文本的关联attention上面,也就意味着其优势必然还是需要长上下文+大语料的前提下才能充分发挥出它的效果,就目前我的工作而言,可能还是有点用不太到……

所以,就只能膜拜一下大佬了,后面有机会的话可以考虑一下在业余时间复现一下看看了,在工作上倒是觉得ROI应该是不会很大了……

5. 参考链接

  1. Longformer: 局部Attention和全局attention的混搭

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/160992.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Android studio新版本多渠道打包配置

最近公司套壳app比较多 功能也都一样只有地址&#xff0c;和app名字还有icon不一样 签名文件也是一样的,所以就研究了多渠道打包 配置如下&#xff1a; 在app下build.gradle配置 因为最新版as中禁用了BuildConfig 所以我们需要手动配置一下 android { //TODO 其他省略buildFe…

CATIA环境编辑器用不了时创建项目快捷方式

CATIA环境编辑器用不了时创建项目快捷方式 一、参考适用情况示例二、 解决步骤(一) 先正确放置winb_64部署包(二) 添加环境文件(三) 修改加入的环境文件(四) 复制本机CATIA快捷方式后重命名(五) 修改快捷方式目标的值 一、参考适用情况示例 二、 解决步骤 (一) 先正确放置winb…

CoT: 思路链提示促进大语言模型的多步推理

CoT 总览摘要1 引言2 Chain-of-Thought Prompting3 算术推理 &#xff08;Arithmetic Reasoning&#xff09;3.1 实验设置3.2 结果3.3 消融实验3.4 CoT的鲁棒性 4 常识推理 &#xff08;Commonsense Reasoning&#xff09;5 符号推理 &#xff08;Symbolic Reasoning&#xff0…

基于8086的出租车计价器系统设计

**单片机设计介绍&#xff0c;1665基于8051单片机与1601LCD的计算器设计 文章目录 一 概要二、功能设计设计思路 三、 软件设计原理图 五、 程序六、 文章目录 一 概要 一个基于8086的出租车计价器系统可以分为硬件和软件两部分。 硬件部分包括输入设备&#xff08;例如计价器…

Mac苹果电脑分辨率修改管理 安装SwitchResX 完美解决

SwitchResX for Mac是一款Mac应用程序&#xff0c;可帮助您更好地管理和控制显示器分辨率和其他显示设置。使用SwitchResX&#xff0c;您可以创建自定义分辨率、旋转屏幕、调整显示器色彩配置等。 1. 自定义分辨率&#xff1a;SwitchResX允许用户创建自定义的屏幕分辨率&#…

Oracle安全基线检查

一、账户安全 1、禁止SYSDBA用户远程连接 用户具备数据库超级管理员(SYSDBA)权限的用户远程管理登录SYSDBA用户只能本地登录,不能远程。REMOTE_LOGIN_PASSWORDFILE函数的Value值为NONE。这意味着禁止共享口令文件,只能通过操作系统认证登录Oracle数据库。 1)检查REMOTE…

mac装不了python3.7.6

今天发现一个很奇怪的问题 但是我一换成 conda create -n DCA python3.8.12就是成功的 这个就很奇怪

【Head First 设计模式】-- 策略模式

一、背景 Head First 设计模式第一章设计模式入门–策略模式 二、工具箱的工具&#xff08;本章&#xff09; 1、OO基础 封装 继承 多态 抽象 2、OO原则 封装变化 面向接口编程&#xff0c;而非面向实现编程 组合优于继承 3、OO模式 策略模式&#xff0c;所谓策略模式就是定义…

全链路压力测试的目的在于哪儿?

全链路压力测试(End-to-End Load Testing)是一种关键的性能测试方法&#xff0c;旨在评估一个应用程序或系统在真实使用情况下的性能表现。这种类型的测试模拟了用户在应用程序的各个组成部分之间执行各种操作的情景&#xff0c;以便了解系统在高负载下的表现如何。本文将介绍全…

SpringCloud Alibaba Demo(Nacos,OpenFeign,Gatway,Sentinel)

开源地址&#xff1a; ma/springcloud-alibaba-demo 简介 参考&#xff1a;https://www.cnblogs.com/zys2019/p/12682628.html SpringBoot、SpringCloud 、SpringCloud Alibaba 以及各种组件存在版本对应关系。可参考下面 版本对应 项目前期准备 启动nacos. ./startup.c…

wagtail的使用

文章目录 安装虚拟环境新建项目时指定虚拟环境打开已有项目添加虚拟环境 安装wagtail查看安装后的包 创建wagtail项目安装依赖迁移创建超级用户运行项目 管理工作台内容扩展首页的数据模型更新数据库修改模板页创建一个页面的过程 models中的基本字段templates字符型文本字段富…

【广州华锐互动】军用飞机VR实战训练系统

随着科技的飞速发展&#xff0c;虚拟现实(VR)技术为军事训练带来了前所未有的机遇。军用飞机VR实战训练系统&#xff0c;正是在这一背景下应运而生的一种创新的训练方法。该系统利用先进的虚拟现实技术&#xff0c;为飞行员提供真实且逼真的模拟飞行环境&#xff0c;使之能够在…