低成本微调长文本LLM

低成本微调长文本LLM

最近有一个需求微调长文本的大模型LLM。通常情况下,数据长度扩大后,需要的显存更大。在有限的设备资源上微调长文本的LLM显得很重要了。中文Llama2-7b支持的最大长度为4k,Qwen1.5-7b支持的最大长度为32k,Qwen1.5-14b支持的最大长度为32k,Qwen-14b支持的最大长度为8k,Baichuan2-13b支持的最大长度为4k。在原有的预训练模型上微调比支持的最大长度的长的文本面临两个问题:如何修改旋转位置编码支持更长的文本,一个是有限的资源上LLM微调长文本。修改旋转位置编码支持长文本微调,可以修改配置文件中的rope_theta值,或者采用动态的ntk。一般是将rope_theta的值设置大一些,使得超过预训练模型的最大长度的token都有相应的位置编码表征。可以采用LongLoRA(来自于LongLoRA: Efficient Fine-tuning of Long-Context Large Language Models)和LoRA的方式对长文本进行微调,可以有限的资源情况下进行LLM微调长文本,不过LongLoRA的方式更加省显存。

LoRA

LoRA是对线性层的参数增加两个低维度的矩阵进行线性运算,从而降低了模型微调的参数量和显存消耗。通常低维度的参数rank 可以自行设置,相比于LLM的hidden_size的值要小很多。下面介绍LoRA的数学表达式:
h i d d e n _ s t a t e = ( w T + α l o r a _ A T l o r a _ B T ) x = w T x + α l o r a _ A T l o r a _ B T x hidden\_state = (w^{T} + \alpha lora\_{A}^{T}lora\_{B}^{T}) x\\ =w^{T} x+ \alpha lora\_{A}^{T}lora\_{B}^{T} x hidden_state=(wT+αlora_ATlora_BT)x=wTx+αlora_ATlora_BTx
其中 w ∈ R m × h w \in R^{m\times h} wRm×h l o r a _ A ∈ R r × h lora\_{A} \in R^{r\times h} lora_ARr×h l o r a _ B ∈ R m × r lora\_{B} \in R^{m\times r} lora_BRm×r r r r通常可以自行设置的值比较小,而且远远小于m, l o r a _ A ∈ R r × h , l o r a _ B ∈ R m × r lora\_{A} \in R^{r\times h},lora\_{B} \in R^{m\times r} lora_ARr×h,lora_BRm×r的参数量远小于 w ∈ R m × h w \in R^{m\times h} wRm×h。在采用LoRA的方式微调LLM,梯度更新的参数为每一层的线性层相应的lora参数,例如表达式中的 l o r a _ A lora\_{A} lora_A l o r a _ B lora\_{B} lora_B,原模型的参数不进行梯度更新,这样做的目的是训练参数量减少,节约显存,加快训练速度。在显卡资源不足的情况下,可以选择LoRA的方式进行微调LLM。而且 l o r a _ A T l o r a _ B T ∈ R h × m lora\_{A}^{T}lora\_{B}^{T}\in R^{h\times m} lora_ATlora_BTRh×m w T w^{T} wT的大小是一致的,方便将 l o r a _ A T l o r a _ B T lora\_{A}^{T}lora\_{B}^{T} lora_ATlora_BT参数合并到原始模型参数 w T w^{T} wT上,并未对原始模型增加新的参数,从而采用原始模型的推理方式进行推理。通常模型某一层hidden_state的值传递到下一层中参与层内计算,需要将某一层输入的hidden_state值与lora的运算结果求和后输入到下一层中。

LongLoRA

LongLoRA采用的是局部注意力机制Shift Short Attention( S 2 − A t t n {S}^{2}-Attn S2Attn)和LoRA方式对微调长上文本 LLM。Shift Short Attention是对句子tokens分组进行组间的attention,而不是整个句子间tokens的attention,这样的方式相当于稀疏的,节约显存和计算时间。默认是将tokens 分成4组,注意的是句子长度要整除分组长度,也可以根据实际情况设置不同的分组数量。Shift Short Attention计算原理见下图。
Shift Short Attention
S 2 − A t t n {S}^{2}-Attn S2Attn计算的时候获取分组tokens方式有两种,其中一种方式是,从句子开始位置按照分组长度依次截取,第二种方式是从分组长度一半的位置按照分组长度依次截取。两种选取方式tokens有一部分重合,加强句子语义间的相关性学习。详细过程参见https://github.com/dvlab-research/LongLoRA/blob/main/llama_attn_replace_sft.py

应用

我们将LongLoRA和千问结合微调长文本32k。下面是千问结合flash-attention的shift attention 计算。

# attention计算将输入长度分成的份数
group_size_ratio = 1/4
def qwen_flash_self_attention_shift_atten(self, q, k, v, attention_mask=None):bsz, q_len = q.shape[0], q.shape[1]num_heads = q.shape[2]head_dim = q.shape[3]q = q.transpose(1, 2) # (bs, heade_num, seq_len, head_dim)k = k.transpose(1, 2)v = v.transpose(1, 2)qkv = torch.stack([q, k, v], dim=2)  # [bsz, nh, 3, q_len, hd]qkv = qkv.transpose(1, 3)  # [bsz, q_len, 3, nh, hd]# We have disabled _prepare_decoder_attention_mask in LlamaModel# the attention_mask should be the same as the key_padding_mask# key_padding_mask = attention_mask.repeat(2, 1)attention_mask = torch.ones((bsz * 2, q_len), device=qkv.device)key_padding_mask = attention_masknheads = qkv.shape[-2]# 分组长度group_size = int(q_len * group_size_ratio)if q_len % group_size > 0:raise ValueError("q_len %d should be divisible by group size %d."%(q_len, group_size))# 将头分成两个部分qkv = qkv.reshape(bsz, q_len, 3, 2, num_heads // 2, head_dim).permute(0, 3, 1, 2, 4, 5).reshape(bsz * 2,q_len, 3,num_heads // 2,head_dim)x = rearrange(qkv, "b s three h d -> b s (three h d)")x_unpad, indices, cu_q_lens, max_s = unpad_input(x, key_padding_mask)cu_q_len_tmp = torch.arange(0, max_s, group_size, device=key_padding_mask.device, dtype=cu_q_lens.dtype)# 分组位置index平移 group_size // 2cu_q_len_tmp2 = cu_q_len_tmp + group_size // 2# 分组位置超多最大长度的修改为类型的最小值cu_q_len_tmp2[cu_q_len_tmp2 >= max_s] = torch.iinfo(cu_q_len_tmp2.dtype).mincu_q_len_tmp = torch.stack([cu_q_len_tmp, cu_q_len_tmp2]).repeat(bsz, 1) + cu_q_lens[:-1].unsqueeze(-1)cu_q_lens = torch.cat([cu_q_len_tmp, cu_q_lens[1:].unsqueeze(-1)], dim=-1).view(-1)cu_q_lens = cu_q_lens[cu_q_lens >= 0]x_unpad = rearrange(x_unpad, "nnz (three h d) -> nnz three h d", three=3, h=nheads // 2)output_unpad = flash_attn_varlen_qkvpacked_func(x_unpad, cu_q_lens, group_size, 0.0, softmax_scale=None, causal=True)output = rearrange(pad_input(rearrange(output_unpad, "nnz h d -> nnz (h d)"), indices, bsz * 2, q_len),"b s (h d) -> b s h d",h=nheads // 2,)output = output.reshape(bsz, 2, q_len, nheads // 2, head_dim).transpose(1, 2).reshape(bsz, q_len, nheads,head_dim)return output

以上是在LongLoRA的介绍,如有表述不当,请指证。

参考文献

[1]: LongLoRA: Efficient Fine-tuning of Long-Context Large Language Models 论文地址 代码地址

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/660541.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

代码随想录算法训练营DAY44|C++动态规划Part6|完全背包理论基础、518.零钱兑换II、377. 组合总和 Ⅳ

文章目录 完全背包理论基础完全背包问题的定义与01背包的核心区别为什么完全背包的循环顺序可以互换?CPP代码 518.零钱兑换II思路CPP代码 377. 组合总和 Ⅳ思路CPP代码扩展题 完全背包理论基础 卡码网第52题 文章链接:完全背包理论基础 视频链接&#xf…

Flutter笔记:Widgets Easier组件库(2)阴影盒子

Flutter笔记 Widgets Easier组件库(2):阴影盒子 - 文章信息 - Author: 李俊才 (jcLee95) Visit me at CSDN: https://jclee95.blog.csdn.netMy WebSite:http://thispage.tech/Email: 291148484163.com. Shenzhen ChinaAddress o…

SpringBoot之自定义注解参数校验

SpringBoot之自定义注解参数校验 为什么要自定义注解 我这里先引入一个例子,就比如我现在要写文章,文章也许写完正要发布,也可以是还没写完正要存草稿,前端往后端发送数据,如果前端的state不是草稿或者已发布状态&…

vue3、element-plus递归实现动态菜单

vue3、element-plus递归实现动态菜单 使用场景:动态菜单为什么使用递归递归在动态菜单中的实现 使用场景:动态菜单 动态菜单是指菜单项的数量和层次结构可能是动态的,通常来自后端或用户输入。这些菜单的特征包括: 多层嵌套&…

笔记-PPT绘图导出高清无失真图片

问题描述:PPT绘图已经用了高清图(jpg、tif格式),但论文图片还是不清晰,打印出来还是有点糊 以下是PPT导出高清不失真图片(emf格式)的具体描述。 目录 一、绘图工具二、操作步骤 一、绘图工具 …

SSH远程登录实操实验!

ssh远程登录协议:默认端口号22 以下实验7-2是服务端,7-1是客户端 服务器的相关信息: 服务名称:sshd 服务端主程序:/usr/sbin/sshd 服务端配置文件:/etc/ssh/sshd_config 客户端相关信息: …

SQL如何利用Bitmap思想优化array_contains()函数

目录 0 问题描述 1 位图思想 2 案例实战 3 小结 0 问题描述 在工作中,我们往往使用array_contains()函数来进行存在性问题分析,如判断某个数是否在某个数组中,但是当表数据量过多,存在大量array_contains()函数时,…

未来已来:深入探索LLAMA3驱动的人工智能革命

大家好!相信大家对于AI(人工智能)的发展已经有了一定的了解,但你是否意识到,到了2024年,AI已经变得如此强大和普及,带来了我们从未想象过的便利和创新呢?让我们一起来看看AI在这个时…

Open CASCADE学习|BRepFill_SectionPlacement

BRepFill_SectionPlacement 是一个与计算机辅助设计(CAD)相关的术语,通常用于指代一个几何对象或操作,它是Open CASCADE Technology(OCCT)中的一个类。Open CASCADE Technology是一个开源的CAD内核&#xf…

HOOPS Exchange导入数据时如何使用CATIA缓存选项?

1、什么是CATIA缓存选项和CGR文件? CATIA V5默认的工作方式是加载几何图形。加载大型程序集时,这可能会导致性能下降,因为所需的内存很重要。 在这种情况下,我们可能需要使用缓存选项。这将生成仅包含曲面细分数据而不包含几何图…

图片懒加载:提升网页性能的秘诀

🤍 前端开发工程师、技术日更博主、已过CET6 🍨 阿珊和她的猫_CSDN博客专家、23年度博客之星前端领域TOP1 🕠 牛客高级专题作者、打造专栏《前端面试必备》 、《2024面试高频手撕题》 🍚 蓝桥云课签约作者、上架课程《Vue.js 和 E…

Centos7+Hadoop3.3.4+KDC1.15+Ranger2.4.0集成

一、集群规划 本次测试采用3台虚拟机,操作系统版本为centos7.6。 kerberos采用默认YUM源安装,版本为:1.15.1-55 Ranger版本为2.4.0 系统用户为ranger:ranger IP地址主机名KDCRanger192.168.121.101node101.cc.localKDC masterRanger Admin…