Llama2-Chinese项目:8-TRL资料整理

  TRL(Transformer Reinforcement Learning)是一个使用强化学习来训练Transformer语言模型和Stable Diffusion模型的Python类库工具集,听上去很抽象,但如果说主要是做SFT(Supervised Fine-tuning)、RM(Reward Modeling)、RLHF(Reinforcement Learning from Human Feedback)和PPO(Proximal Policy Optimization)等的话,肯定就很熟悉了。最重要的是TRL构建于transformers库之上,两者均由Hugging Face公司开发。

一.TRL类库
1.TRL类库介绍
简单理解就是可以通过TRL库做RLHF训练,如下所示:


(1)SFTTrainer:是一个轻量级、友好的transformers Trainer包装器,可轻松在自定义数据集上微调语言模型或适配器。
(2)RewardTrainer:是一个轻量级的transformers Trainer包装器,可轻松为人类偏好(奖励建模)微调语言模型。
(3)PPOTrainer:一个PPO训练器,用于语言模型,只需要(query, response, reward)三元组来优化语言模型。
(4)AutoModelForCausalLMWithValueHead & AutoModelForSeq2SeqLMWithValueHead:一个带有额外标量输出的transformer模型,每个token都可以用作强化学习中的值函数。
(5)Examples:使用BERT情感分类器训练GPT2生成积极的电影评论,仅使用适配器的完整RLHF,训练GPT-j以减少毒性,Stack-Llama例子等。
2.PPO工作原理
  通过PPO对语言模型进行微调大致包括三个步骤:
(1)Rollout:语言模型根据query生成response或continuation,query可以是一个句子的开头。
(2)Evaluation:使用函数、模型、人类反馈或它们的某些组合对查询和响应进行评估。重要的是,此过程应为每个query/response对生成一个标量值。
(3)Optimization:这是最复杂的部分。在优化步骤中,query/response对用于计算序列中token的对数概率。这是使用经过训练的模型和Reference model完成的,Reference model通常是微调前的预训练模型。两个输出之间的KL散度用作额外的奖励信号,以确保生成的response不会偏离Reference model太远。然后使用PPO训练Active model。

二.TRL安装和使用方式
1.TRL安装

# 直接安装包
pip install trl# 从源码安装
git clone https://github.com/huggingface/trl.git
cd trl/
pip install .

2.SFTTrainer使用方式
  SFTTrainer是围绕transformer Trainer的轻量级封装,可以轻松微调自定义数据集上的语言模型或适配器。如下所示:

# 导入Python包
from datasets import load_dataset
from trl import SFTTrainer# 加载imdb数据集
dataset = load_dataset("imdb", split="train")# 得到trainer
trainer = SFTTrainer("facebook/opt-350m",train_dataset=dataset,dataset_text_field="text",max_seq_length=512,
)# 开始训练
trainer.train()

3.RewardTrainer使用方式
  RewardTrainer是围绕transformers Trainer的封装,可以轻松在自定义偏好数据集上微调奖励模型或适配器。如下所示:

# 导入Python包
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from trl import RewardTrainer# 加载模型和数据集,数据集需要为指定格式
model = AutoModelForSequenceClassification.from_pretrained("gpt2", num_labels=1)
tokenizer = AutoTokenizer.from_pretrained("gpt2")
...
# 得到trainer
trainer = RewardTrainer(model=model,tokenizer=tokenizer,train_dataset=dataset,
)# 开始训练
trainer.train()

4.PPOTrainer使用方式
  query通过语言模型输出一个response,然后对其进行评估。评估可以人类反馈,也可以是另一个模型的输出。如下所示:

# 导入Python包
import torch
from transformers import AutoTokenizer
from trl import PPOTrainer, PPOConfig, AutoModelForCausalLMWithValueHead, create_reference_model
from trl.core import respond_to_batch# 首先加载模型,然后创建参考模型
model = AutoModelForCausalLMWithValueHead.from_pretrained('gpt2')
model_ref = create_reference_model(model)
tokenizer = AutoTokenizer.from_pretrained('gpt2')# 初始化ppo配置对象
ppo_config = PPOConfig(batch_size=1,
)# 编码一个query
query_txt = "This morning I went to the "
query_tensor = tokenizer.encode(query_txt, return_tensors="pt")# 得到模型response
response_tensor  = respond_to_batch(model, query_tensor)# 创建一个ppo trainer
ppo_trainer = PPOTrainer(ppo_config, model, model_ref, tokenizer)# 为response定义一个reward(人类反馈或模型输出奖励) 
reward = [torch.tensor(1.0)]# 使用ppo训练一步模型
train_stats = ppo_trainer.step([query_tensor[0]], [response_tensor[0]], reward)

参考文献:
[1]https://github.com/huggingface/trl
[2]https://huggingface.co/docs/trl/v0.7.1/en/index

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/126044.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Python脚本实现xss攻击

实验环境:zd靶场、vscode 知识点 requests.session() 首先我们需要先利用python requests模块进行登录,然后利用开启session记录,保持之后的操作处于同一会话当中 requests.session()用于创建一个会话(session)的实例对象。使用requests库…

【JavaEE】多线程进阶(一)饿汉模式和懒汉模式

多线程进阶(一) 文章目录 多线程进阶(一)单例模式饿汉模式懒汉模式 本篇主要引入多线程进阶的单例模式,为后面的大冰山做铺垫 代码案例介绍 单例模式 非常经典的设计模式 啥是设计模式 设计模式好比象棋中的 “棋谱”…

Python无废话-办公自动化Excel格式美化

设置字体 在使用openpyxl 处理excel 设置格式,需要导入Font类,设置Font初始化参数,常见参数如下: 关键字参数 数据类型 描述 name 字符串 字体名称,如Calibri或Times New Roman size 整型 大小点数 bold …

做外贸独立站选Shopify还是WordPress?

现在确实会有很多新人想做独立站,毕竟跨境电商平台内卷严重,平台规则限制不断升级,脱离平台“绑架”布局独立站,才能获得更多流量、订单、塑造品牌价值。然而,在选择建立外贸独立站的过程中,选择适合的建站…

TinyWebServer整体流程

从main主函数开始: 一、定义MySQL数据库的账号、密码和用到的数据库名称。 二、调用Config获得服务器初始化属性 在这一步确定触发模式端口等信息。 三、创建服务器实例对象 设置根目录、开辟存放http连接对象的空间,开辟定时器空间。 四、利用Confi…

《视觉 SLAM 十四讲》V2 第 5 讲 相机与图像

文章目录 相机 内参 && 外参5.1.2 畸变模型单目相机的成像过程5.1.3 双目相机模型5.1.4 RGB-D 相机模型 实践5.3.1 OpenCV 基础操作 【Code】OpenCV版本查看 5.3.2 图像去畸变 【Code】5.4.1 双目视觉 视差图 点云 【Code】5.4.2 RGB-D 点云 拼合成 地图【Code】 习题题…

Clion中使用C/C++开发stm32程序

前言 从刚开始学习阶段,一直是用的keil5开发stm32程序,自从看到稚晖君推荐的CLion开发嵌入式程序后,这次尝试在CLion上开发stm32程序。 1、配置CLion用于STM32开发的环境 这里我就不详细写了,没必要重新写,网上教程很多…

Http常见问题

说说 HTTP 常用的状态码及其含义? HTTP 状态码首先应该知道个大概的分类: 1XX:信息性状态码2XX:成功状态码3XX:重定向状态码4XX:客户端错误状态码5XX:服务端错误状态码 301:永久性…

突破封锁|华为芯片10年进化史:从K3V1到麒麟9000S

华为海思麒麟芯片过去10年研发历程回顾如下: 2009年:华为推出第一款手机芯片K3V1,采用65nm工艺制程,基于ARM11架构,主频600MHz,支持WCDMA/GSM双模网络。这款芯片搭载在华为U8800手机上,标志着华…

伟大不能被计划

假期清理书单,把这个书读完了,结果发现出奇的好,可以说是值得亲身去读的书,中间的一些论述提供了人工智能专业方面的视角来论证这这个通识观点,可信度很不错; 这篇blog也不是对书的总结,更多的是…

CSS小计

1:设置图片随窗缩放 使用百分比 width: 100%;height: 100%; 使用vmin: 将可视区域分为100vmin width: 100vmin;height: 100vmin; 2:设置字体颜色与背景色融合 mix-blend-mode: difference 3: 设置宽度自适应 width:fit-content 4:外边距合并 当两个相领的两个容…

Spring Cloud OpenFeign 性能优化的4个方法

OpenFeign 是 Spring 官方推出的一种声明式服务调用和负载均衡组件。它的出现就是为了替代已经进入停更维护状态的 Netflix Feign,是目前微服务间请求的常用通讯组件。 1.超时设置 OpenFeign 底层依赖Ribbon 框架,并且使用了 Ribbon 的请求连接超时时间…