大模型LLM的微调技术:LoRA

0 引言

LoRA(Low-Rank Adaptation)出自2021年的论文“LoRA: Low-Rank Adaptation of Large Language Models”  LoRA技术冻结预训练模型的权重,并在每个Transformer块中注入可训练层(称为秩分解矩阵),即在模型的Linear层的旁边增加一个“旁支”A和B。其中,A将数据从d维降到r维,这个r是LoRA的秩,是一个重要的超参数;B将数据从r维升到d维,B部分的参数初始为0。模型训练结束后,需要将A+B部分的参数与原大模型的参数合并在一起使用。

论文地址:https://arxiv.org/pdf/2106.09685.pdf

github链接:https://github.com/microsoft/LoRA

 

1 大模型微调的困境

随着模型规模的不断扩大,模型会"涌现"出各种能力。特别是对大语言模型(LLM)来说,随着规模的扩大其在zero-shot、常识推理等能力上会有大幅度的提高。相比于规模较小的模型,大模型的微调成本和部署成本都非常高。例如,GPT-3 175B模型微调需要1.2TB的显存。此外,若针对不同下游任务微调多个模型,那么就需要为每个下游任务保存一份模型权重,成本非常高。在某些场景下,甚至可能需要针对不同的用户微调不同的模型,那么模型微调和部署的成本将不可接受。因此,如何降低大模型微调和部署成本,将是大模型商用的重要一环。

2 LoRA之前的方法

在LoRA方法提出之前,也有很多方法尝试解决大模型微调困境的方法。其中有两个主要的方向:(1) 添加adapter层;(2) 由于某种形式的输入层激活,但是这两种方法都有局限性。

2.1 Adapter层会引入推理时延

简单来说,adapter就是固定原有的参数,并添加一些额外参数用于微调。上图中会在原始的transformer block中添加2个adapter,一个在多头注意力后面,另一个这是FFN后面。

显然,adapter会在模型中添加额外的层,这些层会导致大模型在推理时需要更多的GPU通信,而且也会约束模型并行。这些问题都将导致模型推理变慢。

2.2 prefix-tuning难以优化

论文地址:https://aclanthology.org/2021.acl-long.353.pdf

prefix-tuning方法是受语言模型in-context learning能力的启发,只要有合适的上下文则语言模型可以很好的解决自然语言任务。但是,针对特定的任务找到离散token的前缀需要花费很长时间,prefix-tuning提出使用连续的virtual token embedding来替换离散token。

具体来说,对于transformer中的每一层,都在句子表征前面插入可训练的virtual token embedding。对于自回归模型(GPT系列),在句子前添加连续前缀,即z=[PREFIX;x,y]。对于Encoder-Decoder模型(T5),则在Ecoder和Decoder前都添加连续前缀z=[PREFIX;x|PREFIX^{'};y]。添加前缀的过程如上图所示。虽然,prefix-tuning并没有添加太多的额外参数。但是,prefix-tuning难以优化,且会减少下游任务的序列长度。

3 LoRA解析

3.1 重参数与本征维度

  • 重参数

也就是结构重参数化,首先构造原始网络结构(一般使用带有预训练权重网络),将其权重参数等价转换为另一组参数(推理),从而将这一系列结构等价转换为另一系列结构。举个简单的例子,用卷积的方式去理解:

  • 本征维度

通常是指一个数据集的有效维度数量,即可以用最少的维度来表达数据集的大部分信息。确定一个数据集的本征维度一般使用主成成分分析(PCA),独立成分分析(ICA),多维缩放(MDS)等。此处补充一些问题:

预训练的好坏与本征维度的关系:预训练模型的表征能力越强(训练得越好),本征维度越小。

预训练模型参数与本征维度的关系模型越大本征维度越小,即越强的模型本征维度越低。

本征维度与泛化能力的关系:本征维度低的模型,训练出来的模型准确率是更高的。也就是说本征维度越低,泛化性能越好。

3.2 LoRA原理

首先给出微调的定义,即在尽量不改变推理速度的前提下,使用少量数据就能使预训练大模型达到原始推理精度的90%以上,来实现各种下游任务的应用。也就是说,在进行大量推理阶段时,网络结构是不允许被修改的,这就是重参数的重要性。

神经网络包含许多密集的层,这些层执行矩阵乘法。这些层中的权重矩阵通常具有全秩。当适应特定的任务时,预训练的语言模型往往具有较低的“instrisic dimension”,尽管随机投影到较小的子空间,但仍然可以有效地学习。
LoRA的实现原理在于,冻结预训练模型权重,并将可训练的秩分解矩阵注入到Transformer层的每个权重中,大大减少了下游任务的可训练参数数量。直白的来说,实际上是增加了右侧的“旁支”,也就是先用一个Linear层A,将数据从 d维降到r,再用第二个Linear层B,将数据从r变回d维。最后再将左右两部分的结果相加融合,得到输出的hidden_state。
如上图所示,左边是预训练模型的权重,输入输出维度都是d,在训练期间被冻结,不接受梯度更新。右边部分对A使用随机的高斯初始化,B在训练开始时为零,r是秩,会对△Wx做缩放 α/r。

LoRA的公式:

BA则是低秩分解,对应的图解:

  • Lora 性质 1:全面微调的推广

通过将 LoRA 秩 r 设置为预训练的权重矩阵的秩,大致恢复了完整微调的表达性。换句话说,当增加可训练参数的数量时,训练LoRA大致收敛于训练原始模型。

  • Lora 性质 2:没有额外的推断延迟

在生产中部署时,可以显式地计算和存储 W = W 0 + B A W = W_0 + BA W=W0​+BA,并像往常一样执行推理,也就是将 LoRA 权重和原始模型权重合并,不增加任何的推断耗时 W 0 W_0 W0​ 和 B A BA BA 都是 R d × k R^{d×k} Rd×k。当需要切换到另一个下游任务时,可以通过减去 BA 然后添加不同的 B ′ A ′ B'A' B′A′ 来恢复 W 0 W_0 W0​,这是一个内存开销很小的快速操作。

3.3 LoRA的优势

  • 一个预先训练好的模型可以被共享,并用于为不同的任务建立许多小的LoRA模块。可以冻结共享模型,并通过替换图中的矩阵A和B来有效地切换任务,大大降低了存储需求和任务切换的难度。

  • 在使用自适应优化器时,LoRA使训练更加有效,并将硬件进入门槛降低了3倍,因为我们不需要计算梯度或维护大多数参数的优化器状态。相反,我们只优化注入的、小得多的低秩矩阵。

  • 简单的线性设计允许在部署时将可训练矩阵与冻结权重合并,与完全微调的模型相比,在结构上没有引入推理延迟。

  • LoRA与许多先前的方法是正交的,可以与许多方法结合,如前缀调整。

 4 LoRA的使用

HuggingFace的PEFT(Parameter-Efficient Fine-Tuning)中提供了模型微调加速的方法,参数高效微调(PEFT)方法能够使预先训练好的语言模型(PLMs)有效地适应各种下游应用,而不需要对模型的所有参数进行微调。

对大规模的PLM进行微调往往成本过高,在这方面,PEFT方法只对少数(额外的)模型参数进行微调,基本思想在于仅微调少量 (额外) 模型参数,同时冻结预训练 LLM 的大部分参数,从而大大降低了计算和存储成本,这也克服了灾难性遗忘的问题,这是在 LLM 的全参数微调期间观察到的一种现象PEFT 方法也显示出在低数据状态下比微调更好,可以更好地泛化到域外场景。

## 1、引入组件并设置参数
from transformers import AutoModelForSeq2SeqLM
from peft import get_peft_config, get_peft_model, get_peft_model_state_dict, LoraConfig, TaskType
import torch
from datasets import load_dataset
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"
from transformers import AutoTokenizer
from torch.utils.data import DataLoader
from transformers import default_data_collator, get_linear_schedule_with_warmup
from tqdm import tqdm## 2、搭建模型peft_config = LoraConfig(task_type=TaskType.SEQ_2_SEQ_LM, inference_mode=False, r=8, lora_alpha=32, lora_dropout=0.1)model = AutoModelForSeq2SeqLM.from_pretrained(model_name_or_path)## 3、加载数据
dataset = load_dataset("financial_phrasebank", "sentences_allagree")
dataset = dataset["train"].train_test_split(test_size=0.1)
dataset["validation"] = dataset["test"]
del dataset["test"]classes = dataset["train"].features["label"].names
dataset = dataset.map(lambda x: {"text_label": [classes[label] for label in x["label"]]},batched=True,num_proc=1,
)## 4、训练数据预处理
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)def preprocess_function(examples):inputs = examples[text_column]targets = examples[label_column]model_inputs = tokenizer(inputs, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt")labels = tokenizer(targets, max_length=3, padding="max_length", truncation=True, return_tensors="pt")labels = labels["input_ids"]labels[labels == tokenizer.pad_token_id] = -100model_inputs["labels"] = labelsreturn model_inputsprocessed_datasets = dataset.map(preprocess_function,batched=True,num_proc=1,remove_columns=dataset["train"].column_names,load_from_cache_file=False,desc="Running tokenizer on dataset",
)train_dataset = processed_datasets["train"]
eval_dataset = processed_datasets["validation"]train_dataloader = DataLoader(train_dataset, shuffle=True, collate_fn=default_data_collator, batch_size=batch_size, pin_memory=True
)
eval_dataloader = DataLoader(eval_dataset, collate_fn=default_data_collator, batch_size=batch_size, pin_memory=True)## 5、设定优化器和正则项
optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
lr_scheduler = get_linear_schedule_with_warmup(optimizer=optimizer,num_warmup_steps=0,num_training_steps=(len(train_dataloader) * num_epochs),
)## 6、训练与评估model = model.to(device)for epoch in range(num_epochs):model.train()total_loss = 0for step, batch in enumerate(tqdm(train_dataloader)):batch = {k: v.to(device) for k, v in batch.items()}outputs = model(**batch)loss = outputs.losstotal_loss += loss.detach().float()loss.backward()optimizer.step()lr_scheduler.step()optimizer.zero_grad()model.eval()eval_loss = 0eval_preds = []for step, batch in enumerate(tqdm(eval_dataloader)):batch = {k: v.to(device) for k, v in batch.items()}with torch.no_grad():outputs = model(**batch)loss = outputs.losseval_loss += loss.detach().float()eval_preds.extend(tokenizer.batch_decode(torch.argmax(outputs.logits, -1).detach().cpu().numpy(), skip_special_tokens=True))eval_epoch_loss = eval_loss / len(eval_dataloader)eval_ppl = torch.exp(eval_epoch_loss)train_epoch_loss = total_loss / len(train_dataloader)train_ppl = torch.exp(train_epoch_loss)print(f"{epoch=}: {train_ppl=} {train_epoch_loss=} {eval_ppl=} {eval_epoch_loss=}")## 7、模型保存
peft_model_id = f"{model_name_or_path}_{peft_config.peft_type}_{peft_config.task_type}"
model.save_pretrained(peft_model_id)## 8、模型推理预测from peft import PeftModel, PeftConfig
peft_model_id = f"{model_name_or_path}_{peft_config.peft_type}_{peft_config.task_type}"
config = PeftConfig.from_pretrained(peft_model_id)
model = AutoModelForSeq2SeqLM.from_pretrained(config.base_model_name_or_path)model = PeftModel.from_pretrained(model, peft_model_id)
model.eval()inputs = tokenizer(dataset["validation"][text_column][i], return_tensors="pt")
print(dataset["validation"][text_column][i])
print(inputs)
with torch.no_grad():outputs = model.generate(input_ids=inputs["input_ids"], max_new_tokens=10)print(outputs)print(tokenizer.batch_decode(outputs.detach().cpu().numpy(), skip_special_tokens=True))

5 LoRA的应用场景

在Transformer架构里Lora经常应用在self-attention模块和MLP模块中。在 Stable Diffusion模型里,Lora被用在condition和图像表示建立关联的Cross-Attention层。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/310223.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Java EE 网络原理之HTTPS

文章目录 1. HTTPS 是什么?2. "加密" 是什么?3. HTTPS 的工作过程3.1 引入对称加密3.2 引入非对称加密3.3 中间人攻击3.4 引入证书 4. Tomecat4.1 tomcat 的作用 1. HTTPS 是什么? HTTPS也是⼀个应用层协议,是在 HTTP …

【计算机毕业设计】python+django数码电子论坛系统设计与实现

本系统主要包括管理员和用户两个角色组成;主要包括:首页、个人中心、用户管理、分类管理、数码板块管理、数码评价管理、数码论坛管理、畅聊板块管理、系统管理等功能的管理系统。 后端:pythondjango 前端:vue.jselementui 框架&a…

Java强软弱虚引用

面试: 1.强引用,软引用,弱引用,虚引用分别是什么? 2.软引用和弱引用适用的场景? 3.你知道弱引用的话,能谈谈WeakHashMap吗? 目录 一、Java引用 1、强引用(默认支持模式…

ES6的默认参数和rest参数

✨ 专栏介绍 在现代Web开发中,JavaScript已经成为了不可或缺的一部分。它不仅可以为网页增加交互性和动态性,还可以在后端开发中使用Node.js构建高效的服务器端应用程序。作为一种灵活且易学的脚本语言,JavaScript具有广泛的应用场景&#x…

编程羔手解决Maven引入多个版本的依赖包,导致包冲突了

最近升级了些依赖发现有个hutool的方法老报错,java.lang.NoSuchMethodError: cn.hutool.core.util.ObjectUtil.defaultIfNull(Ljava/lang/Object;Ljava/util/function/Supplier;) 在 Maven 项目中,当不同的依赖模块引入 Hutool 的不同版本时&#xff0c…

12.21自动售货机,单物品,多物品

自动售货机 if朴素方法 一种思路是用寄存器cnt记录已有的最小单位货币量,这里就是0.5 当d1时,cnt1;d2时,cnt2;d3时,cnt4; timescale 1ns/1ns module seller1(input wire clk ,input wire rst ,input wire d1 ,input wire d2 …

AI人工智能大模型讲师叶梓《基于人工智能的内容生成(AIGC)理论与实践》培训提纲

【课程简介】 本课程介绍了chatGPT相关模型的具体案例实践,通过实操更好的掌握chatGPT的概念与应用场景,可以作为chatGPT领域学习者的入门到进阶级课程。 【课程时长】 1天(6小时/天) 【课程对象】 理工科本科及以上&#xff0…

STM32F407-14.3.10-表73具有有断路功能的互补通道OCx和OCxN的输出控制位-00x01

如上表所示,MOE0,OSSI0,CCxE0,CCxNE1时,OCx与OCxN的输出状态取决于GPIO端口上下拉状态。 ---------------------------------------------------------------------------------------------------------------------…

ElasticSearch学习笔记(二)

通过前面的一阵胡乱操作,显然提升了我的学习兴趣,趁热打铁,接着往下学。还是先看看别人的教程吧。这里我看的是B站上【尚硅谷】的ElasticSearch教程,有兴趣的同学也可以去看看。 一、缘起–索引操作 看B站上的视频教程&#xff0…

LoongArch指令集-特权指令系统——摘抄自胡伟武体系结构和龙芯架构32位精简版参考手册

例外与中断 1 中断 1.1 中断类型 龙芯架构 32 位精简版下的中断采用线中断的形式。每个处理器核内部可记录 12 个线中断,分别是:1 个核间中断(IPI),1 个定时器中断(TI),8 个硬中断…

AI模型训练【偏差/方差】与【欠拟合/过拟合】

在我们拿到一个数据集,高高兴兴准备训练一个模型时,会遇到欠拟合或过拟合的问题,业内也喜欢用偏差和方差这两指标去定义它们,那这些词什么意思呢?有什么方法能避免/解决 欠拟合和过拟合呢? 这其实是非常非常…

Docker单点部署Seata(2.0.0) + Nacos(v2.3.0) + Mysql(5.7)

文章目录 一、部署Nacos二、部署Mysql三、Seata准备工作1. 记住nacos、mysql、宿主机的ip2. 建立数据库3. Nacos远程配置文件 四、部署Seata五、初步检验Seata部署情况六、微服务使用Seata1.引入依赖2. application.yml配置 七、遇到的坑1. Nacos显示Seata服务的ip为容器内网ip…