(四)详解RLHF

一直都特别好奇大模型的强化学习微调是怎么做的,网上虽然相关文章不少,但找到的文章都是浅尝辄止说到用PPO训练,再细致深入的就没有讲了。。。只能自己看一看代码,以前搞过一点用PPO做游戏,感觉和语言模型PPO的用法不太一样。在游戏场景,每个step给环境一个action之后,agent拿到的state都是会变化的,通常也会设计奖励函数使得每个step都会有reward;但是在用强化学习微调语言模型这里,prompt是state,只输入一次,然后输出一串action(回答的单词),得到一个reward,模型并没有在每个action之后得到新的state(感谢评论区大佬的点拨,对于answer的第二个词,可以把prompt+answer的一个词当作新的state,而不只是把prompt当作state,状态转移蕴含在transformer内部)

本篇文章并不会介绍太多PPO的原理,相关文章已经很多了,比如李宏毅介绍PPO的课程。大模型里边的PPO涉及到了critic model的概念,在李宏毅教程里只提了一下并没有细讲,如果想了解可以看一下这个文章,相当于利用一个critic model预测从t时刻到最后一个时刻的累加奖励值(强化学习里边的第t个时刻对标answer句子里边的第t个单词),而不是通过实际累加得到从t时刻到最后一个时刻的累加奖励值,这样可以降低奖励的方差。下文也结合代码介绍critic model输出的具体含义。同时RLHF是什么也会再详细介绍,相关文章已经很多了。

本篇文章涉及的代码均来自微软的deepspeed对RLHF的实现,可配合huggingface官方的博客一起食用。本文只对算法的一些有特点的关键点进行阐述,并不对整体实现进行介绍。先上一张经典的论文图。本文重点结合代码讲解奖励模型训练和强化学习训练部分。

奖励(reward)模型训练

首先要声明的是,在强化学习阶段,用到的reward model和critic model都使用同一个模型初始化,因此在训练reward模型的过程中,也是在训练critic model。其次对符号进行说明,大模型中间隐藏层的参数维度为(B,L,D),B为batch size大小,L为句子长度,D为embedding维度。在接下来的代码讲解中,我也会标明代码中各个变量的维度,以更好的理解其意义。

在进行RLHF时,需要一个奖励模型来评估语言大模型(actor model)回答的是好是坏,这个奖励模型通常比被评估的语言大模型小一些(deepspeed的示例中,语言大模型66B,奖励模型只有350M)。奖励模型的输入是prompt+answer的形式,让模型学会对prompt+answer进行打分。奖励模型最后一层隐藏层的输出维度为(B,L,D),通过一个D✖️1的全连接层将维度变为(B, L),在L这个维度上,第i个位置的数据表示:从第i个位置到最后一个位置输出所能获得的奖励分值的累加和(和DQN里边的Q值一个意义),这种形式的输出满足了critic model的输出要求。对应代码如下:

#huggingface模型返回值是个list,第0位是模型最后输出的hideen state
hidden_states = transformer_outputs[0]
# v_head为Dx1的全连接网络对最后一维压缩
rewards = self.v_head(hidden_states).squeeze(-1)

对于一个奖励模型来说,目标是给一个句子进行打分,按理说每个句子对应一个分值就行了,但是目前对于长度为L的句子,奖励模型输出了L个值。我们用L维度上的最后一个位置的值当作为本句话的奖励得分。奖励模型训练优化采用pair wiss loss,即同时输入模型关于同一个问题的两个回答,让模型学会这两个句子哪个分高哪个分低。之所以如此训练是因为,在给奖励模型进行数据标注的过程中,给同一个问题的不同回答量化的打具体分值比较难,但是对他们进行排序相对简单,代码如下:

# 同一个batch里边的句子需要等长,短句后边会被padding
# [divergence_ind:end_ind]索引了padding前一个位置的输出分值
# chosen_reward是同一个句子pair里分数高的句子,r_truncated_reward是句子pair里分数低的句子
c_truncated_reward = chosen_reward[divergence_ind:end_ind]
r_truncated_reward = rejected_reward[divergence_ind:end_ind]

pair wise loss代码如下,如果给pair里边好的句子打分高(c_truncated_reward),坏的句子(r_truncated_reward)打分低,loss就会小:

loss += -torch.log(torch.sigmoid(c_truncated_reward - r_truncated_reward)).mean()

在训练强化学习的过程中,会用到reward model(critic model,再次提醒,critic model和reward model是同一个模型的两个副本)的推理过程,通过调用forward_value实现,具体代码如下,返回的值中有两种值,values表示每个位置i,从第i个位置到最后一个位置的奖励累加值,供强化学习过程中critic model使用;“chosen_end_scores”指的是对每个prompt+answer的打分,供reward model使用。

def forward_value(...):...if return_value_only:#(B,L)return valueselse:...return {"values": values,# (B,)"chosen_end_scores": torch.stack(chosen_end_scores),}

强化学习微调

强化学习微调阶段,会用到4个模型,actor model, ref_model,reward model和critic model(好费显存啊!!!)。其中actor model和ref_model是RLHF第一个阶段有监督微调模型的两个副本,reward model和critic model是本文第一部分训练出来的模型的两个副本。整体流程见这篇文档,整体流程图如下所示(没画出critic model):

首先说明actor model的训练模式和推理模式的区别( 后边会用到)。训练模式是用teacher force的方式(不明白的同学知乎搜一下),将整句话输入到模型中,并通过mask机制在保证不泄漏未来的单词情况下预测下一个单词。推理模式是真正的自回归,预测出下一个单词之后,当作下一步输入再预测下下个单词,原理如下图所示:

首先用actor model在推理模式下根据prompt生成一个answer(prompt对应强化学习里边的state,answer对应一些列的action),代码如下:

# 保证不触发反向传播
with torch.no_grad():seq = self.actor_model.module.generate(prompts,max_length=max_min_length,min_length=max_min_length)

然后利用reward model和ciric model对输出的prompt+answer进行打分(PPO训练时使用的奖励值并不单单是reward model的输出还要考虑kl散度,后文介绍):

# 奖励模型返回的是个字典,key为chosen_end_scores位置存储数据维度为(B,),表示对于prompt+answer的打分
reward_score = self.reward_model.forward_value(seq, attention_mask,prompt_length=self.prompt_length)['chosen_end_scores'].detach()
#critic model返回的数据维度为(B,L),L维度上第i个位置代表从i位置到最后的累积奖励
#舍去最后一个位置是因为句子“终止符”无意义 
values = self.critic_model.forward_value(seq, attention_mask, return_value_only=True).detach()[:, :-1]

actor model是我们想通过强化学习微调的大模型,但是强化学习过程很容易把模型训练“坏”,因此需要另外一个不会参数更新的 ref_model来当作标的,别让actor mode跑偏太远。我们在训练模式下,将prompt+answer分别输入到actor mode和ref model,用KL散度来衡量 ref model和actor mode输出的差别。同时将KL散度(衡量数据分布差距大小)纳入损失函数(KL散度本质是纳入到奖励值里边的,奖励值被纳入到了损失函数),进而来约束 ref_model和actor mode的输出分布别差距太大。具体代码如下:

# 得到两个模型的输出
output = self.actor_model(seq, attention_mask=attention_mask)
output_ref = self.ref_model(seq, attention_mask=attention_mask)
logits = output.logits
logits_ref = output_ref.logits
...
return {
...
# 分别得到两个模型在真实单词上的预测概率
'logprobs': gather_log_probs(logits[:, :-1, :], seq[:, 1:]),
'ref_logprobs': gather_log_probs(logits_ref[:, :-1, :], seq[:,1:]),
...
}
...
# 计算kl散度,log_probs里边存的数字经过log变化了,因此减法就对应除法
kl_divergence_estimate = -self.kl_ctl * (log_probs - ref_log_probs)

PPO训练时候的奖励值综合考虑KL散度和reward模型的输出,只考虑answer部分的KL散度,将reward model的输出加到KL散度L维度的最后一个位置上,得到最终的奖励值,代码如下:

rewards = kl_divergence_estimate
# 只考虑answer部分的奖励,不考虑prompt
start = prompts.shape[1] - 1
# 不考虑padding部分
ends = start + action_mask[:, start:].sum(1)
reward_clip = torch.clamp(reward_score, -self.clip_reward_value,self.clip_reward_value)
batch_size = log_probs.shape[0]
# 在L维度上,每个位置都有KL散度,但是只在最后一个位置加上奖励值
for j in range(batch_size):rewards[j, start:ends[j]][-1] += reward_clip[j]

接下来的内容就是PPO的训练过程的比较核心的内容了,目标是计算PPO更新公示里边的advantage,具体公式如下,V就是critic model的输出。如果原理不懂建议先到这个链接看看。我直接在代码中给注释了。

图片出处:https://huggingface.co/blog/deep-rl-a2c
def get_advantages_and_returns(self, values, rewards, start):# values(B,L) critic model输出# rewards(B,L)reward 包含kl散度# start answer开始的位置# Adopted from https://github.com/CarperAI/trlx/blob/main/trlx/models/modeling_ppo.py#L134lastgaelam = 0advantages_reversed = []length = rewards.size()[-1]# 计算每个时刻(序列位置)的critic model预测误差for t in reversed(range(start, length)):nextvalues = values[:, t + 1] if t < length - 1 else 0.0# critic model预测的是t到到最后一个时刻的奖励和,所以变化量delta可以用如下公式表示delta = (rewards[:, t] + self.gamma * nextvalues) - values[:, t]# self.gamma=1,self.lam=0.95是衰减因子,表示之前计算的delta对现在影响越来越小lastgaelam = delta + self.gamma * self.lam * lastgaelamadvantages_reversed.append(lastgaelam)advantages = torch.stack(advantages_reversed[::-1], dim=1)# 后续用来更新critic model用returns = advantages + values[:, start:]return advantages.detach(), returns

以上过程,我们已经拿到了PPO训练所需要的advantage以及actor model的输出,我先现在可以对actor model进行训练啦。具体代码如下。logprobs和old_logprobs这两个参数分别是“老actor(n个epoch才会更新一次)”和新actor(每个batch都会更新它)”在正确单词上出处的概率,这块时PPO import sampling相关的知识,就不在这重复介绍了,不明白的同学补习一下哈。借用一下李宏毅老师的PPO公式:

def actor_loss_fn(self, logprobs, old_logprobs, advantages, mask):## policy gradient loss#logprobs, old_logprobs都是经过log变化的单词概率,这里带着log做减法就相当于在做概率除法log_ratio = (logprobs - old_logprobs) * mask# 指数操作去掉logratio = torch.exp(log_ratio)pg_loss1 = -advantages * ratiopg_loss2 = -advantages * torch.clamp(ratio, 1.0 - self.cliprange,1.0 + self.cliprange)pg_loss = torch.sum(torch.max(pg_loss1, pg_loss2) * mask) / mask.sum()return pg_loss

同样的,我们也要对critic model进行训练,更新,loss就是mse loss。

def critic_loss_fn(self, values, old_values, returns, mask):## value loss# 用“老critic model”的输出约束“新critic model”不要步子太大,裁剪一下values_clipped = torch.clamp(values,old_values - self.cliprange_value,old_values + self.cliprange_value,)vf_loss1 = (values - returns)**2vf_loss2 = (values_clipped - returns)**2vf_loss = 0.5 * torch.sum(torch.max(vf_loss1, vf_loss2) * mask) / mask.sum()return vf_loss

至此,我们的RLHF训练流程就结束了。第二部分开头我们说过,共涉及actor model, ref_model,reward model和critic model这四个模型,其实更新参数的模型只有actor model和critic model。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/732461.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

模拟集成电路设计系列博客——8.1.1 锁相环基本介绍

8.1.1 锁相环基本介绍 几乎所有的数字,射频电路以及大部分的模拟电路。不幸的是,集成电路振荡器本身并不适合用于高性能电路中的频率/时间参考源。一个主要的问题是它们的震荡频率并不能精确知道。更进一步的,集成电路振荡器的时钟抖动(可以被认为是频率上的随机波动)对于…

(三)使用 PPO 算法进行 RLHF 的 N 步实现细节

使用 PPO 算法进行 RLHF 的 N 步实现细节 当下,RLHF/ChatGPT 已经变成了一个非常流行的话题。我们正在致力于更多有关 RLHF 的研究,这篇博客尝试复现 OpenAI 在 2019 年开源的原始 RLHF 代码库,其仓库位置位于 openai/lm-human-preferences。尽管它具有 “tensorflow-1.x” …

JMeter安装目录简单说明

一 前言 环境: window 10 JMeter5.3 JMeter安装目录的文件通常容易被忽略,注意力全放在JMeter本身的各个功能的使用上。 但在前面的学习中我们发现了熟悉安装目录的必要性。 如jmeter.properties这个文件,之前的文章中就经常查看或者修改,还有一些日志文件也在安装目录中 二…

G61【模板】线性基 P3812 线性基

视频链接: G23 线性方程组 高斯消元法 - 董晓 - 博客园 (cnblogs.com) P3812 【模板】线性基 - 洛谷 | 计算机科学教育新生态 (luogu.com.cn)// 线性基 O(63*n) #include <iostream> #include <cstring> #include <algorithm> using namespace std;typede…

寿司

寿司 题目描述解析 合法的结果只有两种情况:\(B\) 都在两边、\(R\) 都在两边,至于是最左边还是最右边或者都有,无所谓,因为是环。 而每个 \(B\) 移到最左边的代价就是它左边 \(R\) 的个数,移到最右边就是它右边 \(R\) 的个数。 按环形 dp 的套路,我们可以把串复制二倍,然…

(一)ChatGPT 背后的“功臣”——RLHF 技术详解

ChatGPT 背后的“功臣”——RLHF 技术详解 OpenAI 推出的 ChatGPT 对话模型掀起了新的 AI 热潮,它面对多种多样的问题对答如流,似乎已经打破了机器和人的边界。这一工作的背后是大型语言模型 (Large Language Model,LLM) 生成领域的新训练范式:RLHF (Reinforcement Learnin…

Jetpack Compose(8)——嵌套滚动

目录前言一、Jetpack Compose 中处理嵌套滚动的思想二、Modifier.nestedScroll2.1 NestedScrollConnection2.2 NestedScrollDispatcher三、实操讲解3.1 父组件消费子组件给过来的事件——NestedScrollConnection3.2 子组件对事件进行分发——NestedScrollDispatcher3.2 按照分发…

Unity Address Asset System:Assembly-CSharp - 可用Assembly-CSharp.Player - 不可用

在使用Unity的Addressables插件进行游戏资源分包管理的时候,报了这个错误: 反编译查看发现是unity与.net版本不匹配导致的问题 解决方案: 在Unity中打开Edit->Project Settings->Player,更改.Net版本 微软官方文档: 在 Unity 中使用 .NET 4 和更高版本 | Microsoft …

CC2分析与利用

CC2分析与利用 环境配置 一、 CC2 需要使用commons-collections-4.0版本,因为3.1-3.2.1版本中TransformingComparator没有实现Serializable接口,不能被序列化,所以CC2不用3.x版本。还需要 javassist 依赖,利用链2需要。 pom.xml 添加:<dependency><groupId>or…

【计算机网络】TCP连接三次握手和四次挥手

三次握手建立连接 TCP(传输控制协议)的三次握手机制是一种用于在两个 TCP 主机之间建立一个可靠的连接的过程。这个机制确保了两端的通信是同步的,并且在数据传输开始前,双方都准备好了进行通信。①、第一次握手:SYN(最开始都是 CLOSE,之后服务器进入 LISTEN)发起连接:…

原型设计

原型设计的重要性 网页原型显示了网页的骨架结构,因此可以更好地了解用户将去哪里以及如何导航,通过视觉方式表达产品的要求。 网页原型还有助于交流想法和规划网页,提高团队沟通的效率和质量,进行高效协作。 设计团队与客户沟通变得容易,能够有效地减少了返工和误解,降低…

weblogic 漏洞复现

1.环境地址信息http://192.168.116.112:7001/console/ 2.使用漏洞检测工具,检测对应漏洞 选中对应漏洞检查,发现存在对应漏洞 3.漏洞利用 命令执行 内存马上传使用冰蝎连接 连接成功

详细解析ORB-SLAM3的源码

随着计算机视觉和机器人技术的发展,SLAM(同步定位与地图构建)技术在自动导航、机器人和无人机等领域中起着至关重要的作用。作为当前最先进的SLAM系统之一,ORB-SLAM3因其卓越的性能和开源特性,备受关注。本文将详细解析ORB-SLAM3的源码 ,帮助读者更好地理解其内部机制。 …

H3C之IRF典型配置举例(BFD MAD检测方式)

IRF典型配置举例(BFD MAD检测方式) 1、组网需求由于网络规模迅速扩大,当前中心设备(Device A)安全业务处理能力已经不能满足需求,现在需要另增一台设备Device B,将这两台设备组成一个IRF(如图所示),并配置BFD MAD进行分裂检测。2、组网图 IRF典型配置组网图(BFD MAD…

【攻防技术系列+反溯源】入侵痕迹清理

#溯源 #入侵痕迹清理 #攻防演练 在授权攻防演练中,攻击结束后,如何不留痕迹的清除日志和操作记录,以掩盖入侵踪迹,这其实是一个细致的技术活。 在蓝队的溯源中,攻击者的攻击路径都将记录在日志中,所遗留的工具也会被蓝队进行分析,在工具中可以查找特征,红队自研工具更容…

ProfibusDP主站转Modbus模块连接综合保护装置配置案例

常见的协议有Modbus协议,ModbusTCP协议,Profinet协议,Profibus协议,Profibus DP协议,EtherCAT协议,EtherNET协议等。本案例描述了如何使用ProfibusDP主站转Modbus模块(XD-MDPBM20)来连接综合保护装置(综保),实现数据交换和远程控制。通过配置ProfibusDP主站和Modbus…

HL集训日记(更新ing)

Day -inf 听说又要去海亮,感到恐慌,想起了被xxs碾压的日子,遂卷; Day -1 与学校说再见 Day 0 去机场,这次倒是没有人迟到力; 下大雨,冷,明明天气预报上气温是比DL热的,却这么冷!!! 到了HL,这回住进了24小时摆烂中心(确信,空调吹得好难受,,, gg并没有收手机,…

【计算机网络】TCP如何保证稳定性

连接管理 校验和 序列号/确认应答 流量控制 最大消息长度 超时重传 拥塞控制资料来源连接管理 TCP 使用三次握手和四次挥手保证可靠地建立连接和释放连接。 校验和 TCP 将保持它首部和数据的检验和。这是一个端到端的检验和,目的是检测数据在传输过程中的任何变化。如果接收端…

数据分析 | 数据清理的方法

数据清理的步骤# 一、读取数据 导入NumPy和Pandas数据库,用Pandas的read_csv函数读取原始数据集’e_commerce.csv’,使其转换成DataFrame格式,并赋值给变量df。 展示数据集的前5行和后5行。# 二、评估数据(整洁度、干净度) 创建一个新的变量cleaned_data = df(相当于复制…

iMovie视频剪辑入门

iMovie学习笔记自己不是摄影爱好者📹(也许以后是,说不准),想学视频剪辑的原因如下:大一的一些小组作业有拍视频的任务,有时需要我承担剪辑的工作。因为不熟练,只能用剪映瞎折腾,浪费不少时间。系统地学习可以让我更好地完成剪辑工作。 想了解iMovie本身。本文是我的i…