加速大模型落地!使用4-bit训练Transformer,比FP16快2.2倍,提速35.1%

点击蓝字 关注我们

关注并星标

从此不迷路

计算机视觉研究院

ca1a3476938ba826e96cb1863070dc8d.gif

7fe384227e2c6f85c982fbc102034ef9.gif

公众号ID计算机视觉研究院

学习群扫码在主页获取加入方式

62aa8167f2a5538862ec54580780b285.png

论文地址:https://arxiv.org/pdf/2306.11987.pdf

项目地址:https://github.com/xijiu9/Train_Transformers_with_INT4

计算机视觉研究院专栏

Column of Computer Vision Institute

将激活、权重和梯度量化为4-bit有望加速神经网络训练。然而,现有的4-bit训练方法需要定制的数字格式,这是当代硬件所不支持的。

8f783ea897db40e4fb93d17babf8f428.gif

01

概要简介

在这项工作中,研究者提出了一种用INT4算法实现所有矩阵乘法的transformers的训练方法。超低INT4精度的训练极具挑战性。为了实现这一点,我们仔细分析了transformer中激活和梯度的具体结构,为它们提出了专用的量化器。对于前向传播,我们识别了异常值的挑战,并提出了一种Hadamard量化器来抑制异常值。对于反向传播,我们通过提出比特分割和利用分数采样技术来精确量化梯度,从而利用梯度的结构稀疏性。我们的算法在包括自然语言理解、机器翻译和图像分类在内的广泛任务中实现了具有竞争力的准确性。

645d67baea776cc39948d817a4e1c826.png

【QLoRA本身讲的是模型本身用4bit加载,训练时把数值反量化到bf16后进行训练,利用LoRA[2]可以锁定原模型参数不参与训练,只训练少量LoRA参数的特性使得训练所需的显存大大减少。例如33B的LLaMA模型经过这种方式可以在24 GB的显卡上训练,也就是说单卡4090、3090都可以实现,大大降低了微调的门槛】QLORA: Efficient Finetuning of Quantized LLMs

与以前的4-bit训练方法不同,我们的算法可以在当前一代的GPU上实现。我们的原型线性算子实现速度是FP16的2.2倍,训练速度提高了35.1%。

d8b82bd9d5752d3334df05a9c73d9211.gif

02

背景介绍

训练神经网络在计算上要求很高。低精度算术训练(也称为全量化训练或FQT)有望提高计算和记忆效率。FQT方法在原来的全精度计算图中添加了一些量化器和反量化器,并用廉价的低精度运算取代了昂贵的浮点运算。FQT的研究旨在降低训练的数值精度,而不牺牲太多的收敛速度或精度。所需的数值精度已从FP16降低到FP8、INT32+INT8和INT8+INT5。FP8训练是在英伟达的H100 GPU和变压器引擎中实现的,为大型变压器的训练实现了令人印象深刻的加速。

最近,训练数值精度已被降低到4位。Sun等人成功地用INT4激活/权重和FP4梯度训练了几个现代网络;和Chmiel等人提出了一种自定义的4位对数数字格式,以进一步提高精度。然而,这些4位训练方法不能直接用于加速,因为它们需要现代硬件不支持的自定义数字格式。在极低的4位水平上训练神经网络存在重大的优化挑战。首先,前向传播中的不可微量化器使损失景观变得崎岖不平,其中基于梯度的优化器很容易陷入局部最优。其次,梯度仅以低精度近似计算。这种不精确的梯度减缓了训练过程,甚至导致训练不稳定或偏离。

Fully Quantized Training

全量化训练(FQT)方法通过将激活、权重和梯度量化到低精度来加速训练,因此训练过程中的线性和非线性算子可以用低精度算法实现。FQT的研究设计了新的数值格式和量化算法,可以更好地逼近全精度张量。目前的研究前沿是4位FQT。由于梯度的巨大数值范围和从头开始训练量化网络的优化问题,FQT具有挑战性。由于这些挑战,现有的4位FQT算法在某些任务上的精度仍有1-2.5%的下降,并且它们无法支持当代硬件。

Other Efficient Training Methods

Mixture-of-experts【Outrageously large neural networks: The sparsely-gated mixture-of-experts layer】在不增加训练预算的情况下提高了模型的能力。结构丢弃利用计算上有效的方法来正则化模型。有效的注意力减少了计算注意力的二次时间复杂度。分布式训练系统通过利用更多的计算资源来减少训练时间。我们降低数值精度的工作与这些方向正交。

2100600ca47ad5f1ab2e6f45644bd101.gif

03

新框架

神经网络训练是一种迭代优化过程,通过前向和后向传播计算随机梯度。我们使用4位整数(INT4)算法加速正向和反向传播。首先描述我们的训练程序的正向传播。前向传播可以公式化为线性和非线性(GeLU、归一化、softmax等)算子的组合。在我们的训练过程中,我们使用INT4算法加速所有线性算子,并将所有计算密集度较低的非线性算子保留为16位浮点(FP16)格式。变压器中的所有线性运算都可以写成矩阵乘法(MM)形式。

d96d73fc95fd56cc93930bb04c09ccba.png

Histogram of activation of the linear-1-2 layer in a BERT-base-uncased model. (a) Original activation distribution; (b) Hadamard-transformed activation distribution.

e3958035a01bb32b0b861cca29133e93.png

(a) The distribution of gradient norm along the token dimension. (b) The cumulative sum of the top X values as a percentage of the sum of all norms along the token dimension.

Hadamard Quantization

我们提出了一种Hadamard量化器(HQ)来解决异常值问题。它的主要思想是在另一个具有较少异常值的线性空间中量化矩阵。激活矩阵中的异常值形成了一个特征结构。它们通常集中在几个维度上,即只有少数X列比其他列大得多。Hadamard变换是一种线性变换,可以将异常值摊销为其他条目。具体地说,Hadamard变换Hk是2k×2k矩阵,其中:

d53710254d01558115dc2e7b10770f25.png

Hadamard矩阵是正交对称的:

b91c7fa4054e9f8d2f856f1f839d1d40.png

所以HkHk = I, ∀k ≥ 0。考虑任何坐标行向量e⊤i ∈ R2k。这证明了当单个异常值支配所有其他维度时的极端情况。在这种情况下,Hadamard变换有效地将矢量转变为量化友好的全一矢量。Hadamard变换在抑制激活异常值方面的实际效果如上图b所示。

e4704f677483b2e6e8646b1f9e6a49c3.png

结合量化矩阵,我们得到:

10945dad69f8c9733af9d2c67a3949fe.png

43b59e01a4638435af4d17c8c0b71ea6.png

ab1820be51c555205d70e0eeaf997fc9.gif

04

实验

26d23806d961466cff7d690daefa89ce.png

1c0b39eeacbdbd54e39cd1e7f0556c0f.png

© THE END 

转载请联系本公众号获得授权

b79915d8823c75607e6f907b29e84332.gif

计算机视觉研究院学习群等你加入!

ABOUT

计算机视觉研究院

计算机视觉研究院主要涉及深度学习领域,主要致力于目标检测、目标跟踪、图像分割、OCR、模型量化、模型部署等研究方向。研究院每日分享最新的论文算法新框架,提供论文一键下载,并分享实战项目。研究院主要着重”技术研究“和“实践落地”。研究院会针对不同领域分享实践过程,让大家真正体会摆脱理论的真实场景,培养爱动手编程爱动脑思考的习惯!

VX:2311123606

83936321f14ee7478c36be36ad3f7a1e.png

 往期推荐 

🔗

  • 中国提出的分割天花板 | 精度相当,速度提升50倍!

  • All Things ViTs:在视觉中理解和解释注意力

  • 基于LangChain+GLM搭建知识本地库

  • OVO:在线蒸馏一次视觉Transformer搜索

  • 最近几篇较好论文实现代码(附源代码下载)

  • AI大模型落地不远了!首个全量化Vision Transformer的方法FQ-ViT(附源代码)

  • CVPR 2023|EfficientViT:让ViT更高效部署实现实时推理(附源码)

  • VS Code支持配置远程同步了

  • 基于文本驱动用于创建和编辑图像(附源代码)

  • 基于分层自监督学习将视觉Transformer扩展到千兆像素图像

  • 霸榜第一框架:工业检测,基于差异和共性的半监督方法用于图像表面缺陷检测

  • CLCNet:用分类置信网络重新思考集成建模(附源代码下载)

  • YOLOS:通过目标检测重新思考Transformer(附源代码)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/16087.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

裸机搭建k8s报错记录

安装教程参考 修复一、 cd /etc/kubernetes/manifests vim kube-scheduler.yaml注释掉 重启 systemctl restart kubelet.service问题二、 https://github.com/kubernetes/kubernetes/issues/70202 一直处于创建中状态 网络原因 cat << EOF > /run/flannel/subnet.…

对卷积和全连接之间关系的学习(1*1卷积与全连接层可以互换吗?)

1.对于卷积和全连接 首先我们看一张图&#xff0c;它是一张关于卷积的操作&#xff1a; 然后在看关于全连接的操作&#xff1a; 从上面两张图中可以看出卷积的过程和全连接的过程&#xff0c;我们利用粉色的卷积核在image上进行卷积&#xff0c;进行内积计算得到输出值3&#…

uniapp zjy-calendar日历,uni-calendar日历增强版

一、zjy-calendar简介 zjy-calendar日历是对uniapp uni-calendar日历的增强&#xff0c;支持圆点和文字自定义颜色。 二、使用方法 源使用说明&#xff1a;https://uniapp.dcloud.net.cn/component/uniui/uni-calendar.html 1、下载导入 https://ext.dcloud.net.cn/plugin?…

django框架中使用ORM设计数据库的模型

ORM关联数据的逻辑是&#xff1a; Django 中常见的模型字段类型及其含义&#xff1a; AutoField&#xff1a;一个自动递增的整型字段&#xff0c;添加记录时它会自动增长。BigAutoField&#xff1a;一个自动递增的 biginteger字段&#xff0c;添加记录时它会自动增长。CharFie…

RPC 框架架构设计

RPC 框架架构设计 RPC 又称远程过程调用&#xff08;Remote Procedure Call&#xff09;&#xff0c;用于解决分布式系统中服务之间的调用问题。通俗地讲&#xff0c;就是开发者能够像调用本地方法一样调用远程的服务。下面我们通过一幅图来说说 RPC 框架的基本架构。 RPC 框架…

Nginx学习

文章目录 Nginx什么是NginxLinux安装与配置Nginx编译安装Nginxnignx使用nginx默认首页配置案例 localtion的匹配规则Nginx虚拟主机基于多IP的虚拟主机基于多端口的虚拟主机基于域名的虚拟机主机 反向代理案例①案例② 负载均衡案例①案例②分配策略 动静分离案例 配置Nginx网关…

分布式监控之Zabbix6.0监控系统一

分布式监控之Zabbix6.0监控系统 前言一、Zabbix1、介绍2、zabbix监控原理3、Zabbix6.0版本新特性4、Zabbix6.0功能组件5、Zabbix与Prometheus对比 二、Zabbix6.0部署1、部署zabbix服务端2、添加 zabbix 客户端主机3、自定义监控内容4、zabbix 自动发现5、zabbix 自动注册 前言 …

从零搭建一台基于ROS的自动驾驶车-----4.定位

系列文章目录 北科天绘 16线3维激光雷达开发教程 基于Rplidar二维雷达使用Hector_SLAM算法在ROS中建图 Nvidia Jetson Nano学习笔记–串口通信 Nvidia Jetson Nano学习笔记–使用C语言实现GPIO 输入输出 Autolabor ROS机器人教程 从零搭建一台基于ROS的自动驾驶车-----1.整体介…

C语言库函数strcpy学习

strcpy是C语言的一个标准库函数&#xff1b; strcpy把含有\0结束符的字符串复制到另一个地址空间&#xff0c;返回值的类型为char*。 原型声明&#xff1a;char *strcpy(char* dest, const char *src); 头文件&#xff1a;#include <string.h> 和 #include <stdio.h&g…

MySQL:我的从库竟是我自己!?

本文将通过复制场景下的异常分析&#xff0c;介绍手工搭建MySQL主从复制时需要注意的关键细节。 作者&#xff1a;秦福朗 爱可生 DBA 团队成员&#xff0c;负责项目日常问题处理及公司平台问题排查。热爱互联网&#xff0c;会摄影、懂厨艺&#xff0c;不会厨艺的 DBA 不是好司机…

当某个微服务重启后,GateWay网关访问服务出现503的问题

因为开发阶段可能需要经常重启微服务&#xff0c;但有时会莫名奇妙返回503 Service Unavailable 由于从springcloud2020版本开始&#xff0c;弃用了Ribbon&#xff0c;因此Alibaba在2021及之后版本的nacos中删除了Ribbon的jar包&#xff0c;因此无法通过loadbalancer路由到指定…

LSTD: A Low-Shot Transfer Detector for Object Detection论文阅读笔记

LSTD: A Low-Shot Transfer Detector for Object Detection论文阅读笔记 提出low-shot Transfer detector&#xff0c;来解决标注样本数据不足的情况。利用source domain知识&#xff0c;来构建高效的target-domain检测器&#xff0c;仅需要很少的训练样本。 提出了一个高效的…