低成本微调长文本LLM
最近有一个需求微调长文本的大模型LLM。通常情况下,数据长度扩大后,需要的显存更大。在有限的设备资源上微调长文本的LLM显得很重要了。中文Llama2-7b支持的最大长度为4k,Qwen1.5-7b支持的最大长度为32k,Qwen1.5-14b支持的最大长度为32k,Qwen-14b支持的最大长度为8k,Baichuan2-13b支持的最大长度为4k。在原有的预训练模型上微调比支持的最大长度的长的文本面临两个问题:如何修改旋转位置编码支持更长的文本,一个是有限的资源上LLM微调长文本。修改旋转位置编码支持长文本微调,可以修改配置文件中的rope_theta值,或者采用动态的ntk。一般是将rope_theta的值设置大一些,使得超过预训练模型的最大长度的token都有相应的位置编码表征。可以采用LongLoRA(来自于LongLoRA: Efficient Fine-tuning of Long-Context Large Language Models)和LoRA的方式对长文本进行微调,可以有限的资源情况下进行LLM微调长文本,不过LongLoRA的方式更加省显存。
LoRA
LoRA是对线性层的参数增加两个低维度的矩阵进行线性运算,从而降低了模型微调的参数量和显存消耗。通常低维度的参数rank 可以自行设置,相比于LLM的hidden_size的值要小很多。下面介绍LoRA的数学表达式:
h i d d e n _ s t a t e = ( w T + α l o r a _ A T l o r a _ B T ) x = w T x + α l o r a _ A T l o r a _ B T x hidden\_state = (w^{T} + \alpha lora\_{A}^{T}lora\_{B}^{T}) x\\ =w^{T} x+ \alpha lora\_{A}^{T}lora\_{B}^{T} x hidden_state=(wT+αlora_ATlora_BT)x=wTx+αlora_ATlora_BTx
其中 w ∈ R m × h w \in R^{m\times h} w∈Rm×h, l o r a _ A ∈ R r × h lora\_{A} \in R^{r\times h} lora_A∈Rr×h, l o r a _ B ∈ R m × r lora\_{B} \in R^{m\times r} lora_B∈Rm×r, r r r通常可以自行设置的值比较小,而且远远小于m, l o r a _ A ∈ R r × h , l o r a _ B ∈ R m × r lora\_{A} \in R^{r\times h},lora\_{B} \in R^{m\times r} lora_A∈Rr×h,lora_B∈Rm×r的参数量远小于 w ∈ R m × h w \in R^{m\times h} w∈Rm×h。在采用LoRA的方式微调LLM,梯度更新的参数为每一层的线性层相应的lora参数,例如表达式中的 l o r a _ A lora\_{A} lora_A和 l o r a _ B lora\_{B} lora_B,原模型的参数不进行梯度更新,这样做的目的是训练参数量减少,节约显存,加快训练速度。在显卡资源不足的情况下,可以选择LoRA的方式进行微调LLM。而且 l o r a _ A T l o r a _ B T ∈ R h × m lora\_{A}^{T}lora\_{B}^{T}\in R^{h\times m} lora_ATlora_BT∈Rh×m与 w T w^{T} wT的大小是一致的,方便将 l o r a _ A T l o r a _ B T lora\_{A}^{T}lora\_{B}^{T} lora_ATlora_BT参数合并到原始模型参数 w T w^{T} wT上,并未对原始模型增加新的参数,从而采用原始模型的推理方式进行推理。通常模型某一层hidden_state的值传递到下一层中参与层内计算,需要将某一层输入的hidden_state值与lora的运算结果求和后输入到下一层中。
LongLoRA
LongLoRA采用的是局部注意力机制Shift Short Attention( S 2 − A t t n {S}^{2}-Attn S2−Attn)和LoRA方式对微调长上文本 LLM。Shift Short Attention是对句子tokens分组进行组间的attention,而不是整个句子间tokens的attention,这样的方式相当于稀疏的,节约显存和计算时间。默认是将tokens 分成4组,注意的是句子长度要整除分组长度,也可以根据实际情况设置不同的分组数量。Shift Short Attention计算原理见下图。
在 S 2 − A t t n {S}^{2}-Attn S2−Attn计算的时候获取分组tokens方式有两种,其中一种方式是,从句子开始位置按照分组长度依次截取,第二种方式是从分组长度一半的位置按照分组长度依次截取。两种选取方式tokens有一部分重合,加强句子语义间的相关性学习。详细过程参见https://github.com/dvlab-research/LongLoRA/blob/main/llama_attn_replace_sft.py
应用
我们将LongLoRA和千问结合微调长文本32k。下面是千问结合flash-attention的shift attention 计算。
# attention计算将输入长度分成的份数
group_size_ratio = 1/4
def qwen_flash_self_attention_shift_atten(self, q, k, v, attention_mask=None):bsz, q_len = q.shape[0], q.shape[1]num_heads = q.shape[2]head_dim = q.shape[3]q = q.transpose(1, 2) # (bs, heade_num, seq_len, head_dim)k = k.transpose(1, 2)v = v.transpose(1, 2)qkv = torch.stack([q, k, v], dim=2) # [bsz, nh, 3, q_len, hd]qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd]# We have disabled _prepare_decoder_attention_mask in LlamaModel# the attention_mask should be the same as the key_padding_mask# key_padding_mask = attention_mask.repeat(2, 1)attention_mask = torch.ones((bsz * 2, q_len), device=qkv.device)key_padding_mask = attention_masknheads = qkv.shape[-2]# 分组长度group_size = int(q_len * group_size_ratio)if q_len % group_size > 0:raise ValueError("q_len %d should be divisible by group size %d."%(q_len, group_size))# 将头分成两个部分qkv = qkv.reshape(bsz, q_len, 3, 2, num_heads // 2, head_dim).permute(0, 3, 1, 2, 4, 5).reshape(bsz * 2,q_len, 3,num_heads // 2,head_dim)x = rearrange(qkv, "b s three h d -> b s (three h d)")x_unpad, indices, cu_q_lens, max_s = unpad_input(x, key_padding_mask)cu_q_len_tmp = torch.arange(0, max_s, group_size, device=key_padding_mask.device, dtype=cu_q_lens.dtype)# 分组位置index平移 group_size // 2cu_q_len_tmp2 = cu_q_len_tmp + group_size // 2# 分组位置超多最大长度的修改为类型的最小值cu_q_len_tmp2[cu_q_len_tmp2 >= max_s] = torch.iinfo(cu_q_len_tmp2.dtype).mincu_q_len_tmp = torch.stack([cu_q_len_tmp, cu_q_len_tmp2]).repeat(bsz, 1) + cu_q_lens[:-1].unsqueeze(-1)cu_q_lens = torch.cat([cu_q_len_tmp, cu_q_lens[1:].unsqueeze(-1)], dim=-1).view(-1)cu_q_lens = cu_q_lens[cu_q_lens >= 0]x_unpad = rearrange(x_unpad, "nnz (three h d) -> nnz three h d", three=3, h=nheads // 2)output_unpad = flash_attn_varlen_qkvpacked_func(x_unpad, cu_q_lens, group_size, 0.0, softmax_scale=None, causal=True)output = rearrange(pad_input(rearrange(output_unpad, "nnz h d -> nnz (h d)"), indices, bsz * 2, q_len),"b s (h d) -> b s h d",h=nheads // 2,)output = output.reshape(bsz, 2, q_len, nheads // 2, head_dim).transpose(1, 2).reshape(bsz, q_len, nheads,head_dim)return output
以上是在LongLoRA的介绍,如有表述不当,请指证。
参考文献
[1]: LongLoRA: Efficient Fine-tuning of Long-Context Large Language Models 论文地址 代码地址