几种预训练模型微调方法和peft包的使用介绍

文章目录

  • 微调方法
    • Lora(在旁边添加训练参数)
    • Adapter(在前面添加训练参数)
    • Prefix-tuning(在中间添加训练参数)
    • Prompt tuning
  • PEFT
    • PEFT 使用
      • PeftConfig
      • PeftModel
      • 保存和加载模型

微调方法

现流行的微调方法有:Lora,prompt,p-tunning v1,p-tunning v2,prefix,adapter等等,下面抱着学习的心态进行宏观层面的介绍
如有错误,欢迎指出

Lora(在旁边添加训练参数)

LoRA(Low-Rank Adaptation)是一种技术,通过低秩分解将权重更新表示为两个较小的矩阵(称为更新矩阵),从而加速大型模型的微调,并减少内存消耗。

为了使微调更加高效,LoRA的方法是通过低秩分解,使用两个较小的矩阵(称为更新矩阵)来表示权重更新。这些新矩阵可以通过训练适应新数据,同时保持整体变化的数量较少。原始的权重矩阵保持冻结,不再接收任何进一步的调整。为了产生最终结果,同时使用原始和适应后的权重进行合并。
在这里插入图片描述
假设权重的更新在适应过程中也具有较低的“内在秩”。对于一个预训练的权重矩阵W0 ∈ Rd×k,我们通过低秩分解W0 + ∆W = W0 + BA来表示其更新,其中B ∈ Rd×r,A ∈ Rr×k,且秩r ≤ min(d,k)。在训练过程中,W0被冻结,不接收梯度更新,而A和B包含可训练参数。需要注意的是,W0和∆W = BA都与相同的输入进行乘法运算,它们各自的输出向量在坐标上求和。前向传播公式如下:h = W0x + ∆Wx = W0x + BAx

在上图中我们对A使用随机高斯随机分布初始化,对B使用零初始化,因此在训练开始时∆W = BA为零。然后,通过αr对∆Wx进行缩放,其中α是r中的一个常数。在使用Adam优化时,适当地缩放初始化,调整α的过程与调整学习率大致相同。因此,只需将α设置为我们尝试的第一个r,并且不对其进行调整。这种缩放有助于在改变r时减少重新调整超参数的需求。

Adapter(在前面添加训练参数)

2019 年,Houlsby N 等人将 Adapter 引入 NLP 领域,作为全模型微调的一种替代方案。Adapter 主体架构下图所示。
在这里插入图片描述
在预训练模型每一层(或某些层)中添加 Adapter 模块(如上图左侧结构所示),微调时冻结预训练模型主体,由 Adapter 模块学习特定下游任务的知识。每个 Adapter 模块由两个前馈子层组成,第一个前馈子层将 Transformer 块的输出作为输入,将原始输入维度 d 投影到 m,通过控制 m 的大小来限制 Adapter 模块的参数量,通常情况下 m<<d。在输出阶段,通过第二个前馈子层还原输入维度,将 m 重新投影到 d,作为 Adapter 模块的输出(如上图右侧结构)。通过添加 Adapter 模块来产生一个易于扩展的下游模型,每当出现新的下游任务,通过添加 Adapter 模块来避免全模型微调与灾难性遗忘的问题。Adapter 方法不需要微调预训练模型的全部参数,通过引入少量针对特定任务的参数,来存储有关该任务的知识,降低对模型微调的算力要求。

Prefix-tuning(在中间添加训练参数)

前缀微调(prefix-tunning),用于生成任务的轻量微调。前缀微调将一个连续的特定于任务的向量序列添加到输入,称之为前缀,如下图中的红色块所示。与提示(prompt)不同的是,前缀完全由自由参数组成,与真正的 token 不对应。相比于传统的微调,前缀微调只优化了前缀。因此,我们只需要存储一个大型 Transformer 和已知任务特定前缀的副本,对每个额外任务产生非常小的开销。

Prompt tuning

提示通过包括描述任务的文本提示或甚至演示任务示例的文本提示来为特定的下游任务准备一个冻结的预训练模型。具体的,给每个任务定义 Prompt,拼接到数据上作为输入,同时 freeze 预训练模型进行训练,在没有加额外层的情况下,可以看到随着模型体积增大效果越来越好,最终追上了精调的效果。

提示方法可以分为两类:

  • 硬提示(Hard Prompts):手工制作的具有离散输入标记的文本提示;缺点是需要花费很多精力来创建一个好的提示。
  • 软提示(Soft Prompts):可与输入嵌入连接并进行优化以适应数据集的可学习张量;缺点是它们不太易读,因为您不是将这些“虚拟标记”与实际单词的嵌入进行匹配。

PEFT

PEFT(Parameter-Efficient Fine-Tuning,参数高效微调),是一个用于在不微调所有模型参数的情况下,高效地将预训练语言模型(PLM)适应到各种下游应用的库。

PEFT方法仅微调少量(额外的)模型参数,显著降低了计算和存储成本,因为对大规模PLM进行完整微调的代价过高。最近的最先进的PEFT技术实现了与完整微调相当的性能。

代码:
https://github.com/huggingface/peft
文档:
https://huggingface.co/docs/peft/index

PEFT 使用

接下来将展示 PEFT 的主要特点,并帮助在消费设备上通常无法访问的情况下训练大型预训练模型。您将了解如何使用LoRA来训练1.2B参数的bigscience/mt0-large模型,以生成分类标签并进行推理。

PeftConfig

每个 PEFT 方法由一个PeftConfig类来定义,该类存储了用于构建PeftModel的所有重要参数。

由于您将使用LoRA,您需要加载并创建一个LoraConfig类。在LoraConfig中,指定以下参数:

task_type,在本例中为序列到序列语言建模
inference_mode,是否将模型用于推理
r,低秩矩阵的维度
lora_alpha,低秩矩阵的缩放因子
lora_dropout,LoRA层的dropout概率
from peft import LoraConfig, TaskType
peft_config = LoraConfig(task_type=TaskType.SEQ_2_SEQ_LM, inference_mode=False, r=8, lora_alpha=32, lora_dropout=0.1)

有关您可以调整的其他参数的更多详细信息,请参阅LoraConfig参考。

PeftModel

使用 get_peft_model() 函数可以创建PeftModel。它需要一个基础模型 - 您可以从 Transformers 库加载 - 以及包含配置特定 PEFT 方法的PeftConfig。

首先加载您要微调的基础模型。

from transformers import AutoModelForSeq2SeqLMmodel_name_or_path = "bigscience/mt0-large"
tokenizer_name_or_path = "bigscience/mt0-large"
model = AutoModelForSeq2SeqLM.from_pretrained(model_name_or_path)

使用get_peft_model函数将基础模型和peft_config包装起来,以创建PeftModel。要了解您模型中可训练参数的数量,可以使用print_trainable_parameters方法。在这种情况下,您只训练了模型参数的0.19%!

from peft import get_peft_modelmodel = get_peft_model(model, peft_config)
model.print_trainable_parameters()
输出示例: trainable params: 2359296 || all params: 1231940608 || trainable%: 0.19151053100118282

至此,我们已经完成了!现在您可以使用Transformers的Trainer、 Accelerate,或任何自定义的PyTorch训练循环来训练模型。

保存和加载模型

在模型训练完成后,您可以使用save_pretrained函数将模型保存到目录中。您还可以使用push_to_hub函数将模型保存到Hub(请确保首先登录您的Hugging Face帐户)。

model.save_pretrained("output_dir")

如果要推送到Hub

from huggingface_hub import notebook_loginnotebook_login()
model.push_to_hub("my_awesome_peft_model")

这只保存了已经训练的增量PEFT权重,这意味着存储、传输和加载都非常高效。例如,这个在RAFT数据集的twitter_complaints子集上使用LoRA训练的bigscience/T0_3B模型只包含两个文件:adapter_config.json和adapter_model.bin,后者仅有19MB!

使用from_pretrained函数轻松加载模型进行推理:

from transformers import AutoModelForSeq2SeqLM
from peft import PeftModel, PeftConfigpeft_model_id = "smangrul/twitter_complaints_bigscience_T0_3B_LORA_SEQ_2_SEQ_LM"
config = PeftConfig.from_pretrained(peft_model_id)
model = AutoModelForSeq2SeqLM.from_pretrained(config.base_model_name_or_path)
model = PeftModel.from_pretrained(model, peft_model_id)

参考文章:
http://lihuaxi.xjx100.cn/news/1428773.html?action=onClick

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/131531.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【车载开发系列】嵌入式开发之中断向量表

【车载开发系列】嵌入式开发之中断向量表 嵌入式中断向量表 【车载开发系列】嵌入式开发之中断向量表一. 中断向量表的概念1&#xff09;中断向量2&#xff09;中断向量表3&#xff09;中断向量表的存放 二. 中断向量表的特点三. 中断向量表的作用四. 上电后第一条指令五. 芯片…

Filter(过滤器)Intercerptor(拦截器)

Filter过滤器 顾名思义&#xff0c;Filter可以对请求进行过滤&#xff0c;当浏览器发送请求时&#xff0c;首先先会被Filter进行拦截&#xff0c;Filter可以决定此次拦截是否放行&#xff0c;如果选择放行&#xff0c;放行之后还会返回Filter执行剩下的代码。 使用方法&…

Arduino程序设计(十四)舵机控制实验(SG90)

舵机控制实验 前言一、SG90舵机1、SG90舵机简介2、硬件电路连线3、Servo库常用函数 二、舵机实验1、舵机0~180来回转动2、串口控制舵机转动固定角度 总结 前言 本文介绍SG90舵机控制原理及实验&#xff0c;主要内容有&#xff1a;1、介绍SG90舵机&#xff1b;2、舵机0~180来回…

198、RabbitMQ 的核心概念 及 工作机制概述; Exchange 类型 及 该类型对应的路由规则

JMS 也是一种消息机制 AMQP ( Advanced Message Queuing Protocol ) 高级消息队列协议 ★ RabbitMQ的核心概念 Connection&#xff1a; 代表客户端&#xff08;包括消息生产者和消费者&#xff09;与RabbitMQ之间的连接。 Channel&#xff1a; 连接内部的Channel。 Exch…

bash上下键选择选项demo脚本

效果如下&#xff1a; 废话不多说&#xff0c;上代码&#xff1a; #!/bin/bashoptions("111" "222" "333" "444") # 选项列表 options_index0 # 默认选中第一个选项 options_len${#options[]}echo "请用上下方向键进行选择&am…

【例题】逆波兰表达式求值(图解+代码)

【例题】逆波兰表达式求值(图解代码) 这里写目录标题 【例题】逆波兰表达式求值(图解代码)逆波兰表达式解释优点转换计算代码 题目描述 : 逆波兰表示法是一种将运算符&#xff08;operator&#xff09;写在操作数&#xff08;operand&#xff09;后面的描述程序&#xff08;算式…

Ubuntu16.04apt更新失败

先设置网络设置 换成nat、桥接&#xff0c;如果发现都不行&#xff0c;那么就继续下面操作 1.如果出现一开始就e&#xff0c;检查源&#xff0c;先换源 2.换完源成功之后&#xff0c;ping网络&#xff0c;如果ping不通就是网络问题 如果ping baidu.com ping不通但是ping 112…

汽车烟雾测漏仪(EP120)

【汽车烟雾测漏仪&#xff08;EP120&#xff09;】 此烟雾测漏仪专为车辆管道&#xff08;油道、气道、冷却管道&#xff09; 的泄露检测而设计。适用于所有轻型 汽车、摩托车、轻卡、游艇等。 【特点】 具有空气模式和烟雾模式。空气模式&#xff0c;无需烟雾&#xff0c;检测…

C进阶-自定义类型:结构体、枚举、联合

本章重点&#xff1a; 结构体&#xff1a; 结构体类型的声明 结构的自引用 结构体变量的定义和初始化 结构体内存对齐 结构体传参 结构体实现位段&#xff08;位段的填充&可移植性&#xff09; 1 结构体的声明 1.1 结构的基础知识 结构是一些值的集合&#xff0c;这些值称…

VGG卷积神经网络实现Cifar10图片分类-Pytorch实战

前言 当涉足深度学习&#xff0c;选择合适的框架是至关重要的一步。PyTorch作为三大主流框架之一&#xff0c;以其简单易用的特点&#xff0c;成为初学者们的首选。相比其他框架&#xff0c;PyTorch更像是一门易学的编程语言&#xff0c;让我们专注于实现项目的功能&#xff0…

TCP/IP(十一)TCP的连接管理(八)socket网络编程

一 socket网络编程 socket 基本操作函数 bind、listen、connect、accept、recv、send、select、close 说明: 本文需要C语言、syscall系统调用、OS 操作系统支持,如果不了解可以暂时跳过备注&#xff1a; 知道对应库函数的更底层机制思考&#xff1a; socket函数与FIN、ACK等…

【Vue面试题十六】、Vue.observable你有了解过吗?说说看

文章底部有个人公众号&#xff1a;热爱技术的小郑。主要分享开发知识、学习资料、毕业设计指导等。有兴趣的可以关注一下。为何分享&#xff1f; 踩过的坑没必要让别人在再踩&#xff0c;自己复盘也能加深记忆。利己利人、所谓双赢。 面试官&#xff1a;Vue.observable你有了解…