什么?穷哥们没钱RLHF?跟我一起DPO吧,丐版一样用

本次DPO训练采用TRL的方式来进行训练

Huggingface TRL是一个基于peft的库,它可以让RL步骤变得更灵活、简单,你可以使用这个算法finetune一个模型去生成积极的评论、减少毒性等等。

本次进行DPO的模型是一个500M的GPT-2,目的是训练快,少占资源,快速看到结果。

下载Tokenizer:

from transformers import AutoTokenizer

AutoTokenizer.from_pretrained('gpt2').save_pretrained('tokenizer/gpt2')

  下载Datasets:

from datasets import load_dataset

load_dataset('b-mc2/sql-create-context').save_to_disk(

'dataset/b-mc2/sql-create-context')

下载Model:

from transformers import AutoModelForCausalLM

AutoModelForCausalLM.from_pretrained('gpt2').save_pretrained('model/gpt2')

图片

图 下载Tokenizer,model,数据

首先我们看一下原始数据集,原始数据集的构成分为3部分,一个是question,代表想提出的问题,一个是answer代表回答,第三部分是context代表参考的表结构。

图片

图 原始数据集

图片

图 数据集样例

实际数据样例,我们进一步规范了三种数据类型:

·第一个prompt,包含了context表结构和问题。

·第二个chose,表示希望训练之后的模型按着什么范式来回答问题。

·第三个reject,表示不希望用什么方式来回答,这里就留空了,代表隐式确认,如果有条件也可以整理不喜欢的回答范式。

这个训练的目的就是不管回答什么问题,都要用SQL语句的形式来回答,强调一种受欢迎回答的范式,这也是RLHF/DPO训练的主要目的。

下面开始训练部分,首先load tokenizer。

图片

图8-9 load tokenizer

按照需求来整理数据格式。

图片

图 整理数据格式

读取模型。

from transformers import AutoTokenizer

import random

import torch

tokenizer = AutoTokenizer.from_pretrained('/data2/DPO/tokenizer/gpt2')

tokenizer.pad_token_id = 0

tokenizer

from transformers import AutoModelForCausalLM

model_dpo = AutoModelForCausalLM.from_pretrained('/data2/DPO/model/gpt2').to('cuda')

model_dpo_ref = AutoModelForCausalLM.from_pretrained('/data2/DPO/model/gpt2').to('cuda')

先做个测试看看模型目前是怎么回答的。

图片

图 训练前的回答方式

如上图所示,很显然这个回答方式不是我们要求的方式,我们需要它把问题都按着SQL语句来进行回答。

最后一步就是正式训练了。

图片

图片

图片

如上图所示,随着训练的开展,模型回复对话的方式,基本就越来越向着正规SQL的方向演进。

这就是DPO训练所达成的目的。

图片

也没有多废资源,我是点auto-map技能点了,正常也就一张A100够了。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/591847.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

数字乡村创新实践探索:科技赋能农业现代化与乡村治理体系现代化同步推进

随着信息技术的飞速发展,数字乡村作为乡村振兴的重要战略方向,正日益成为推动农业现代化和乡村治理体系现代化的关键力量。科技赋能下的数字乡村,不仅提高了农业生产的效率和品质,也为乡村治理带来了新的机遇和挑战。本文旨在探讨…

APP渗透总结

APP渗透测试和Web渗透测试本质上没有区别。目前APP应用主要分为Android和IOS,但是由于苹果的IOS操作系统不开源,所以一般对IOS系统进行渗透和反编译会比较困难,所以一般对APP系统进行渗透测试都是对Android进行测试。 目录 安装安卓模拟器抓…

Hadoop Yarn

首先先从Yarn开始讲起,Yarn是Hadoop架构的资源管理器,可以管理mapreduce程序的资源分配和任务调度。 Yarn主要有ResourceManager、NodeManage、ApplicationMaster,Container ResourceMange负责管理全局的资源 NodeManage(NM&a…

【LangChain学习之旅】—(13) 代理(中):AgentExecutor究竟是怎样驱动模型和工具完成任务的?

【LangChain学习之旅】—(13) 代理(中):AgentExecutor究竟是怎样驱动模型和工具完成任务的? Agent 的关键组件深挖 AgentExecutor 的运行机制开始 Debug第一轮思考:模型决定搜索第二轮思考:模型决定计算上节了解了 ReAct 框架的原理,LangChain 中的“代理”和“链”的…

在线生成占位图片工具:简便快捷的设计利器

title: 在线生成占位图片工具:简便快捷的设计利器 date: 2024/4/4 17:36:41 updated: 2024/4/4 17:36:41 tags: 占位图片网页设计开发工具图片生成页面布局效率提升预览调整 在网页开发或设计过程中,经常会遇到需要临时使用占位图片的情况。占位图片是指…

pycharm和Spyder多行注释快捷键

1.选取注释内容 2.pycharm:使用Ctrl/ 3.Spyder:使用Ctrl1 效果图

【opencv】教程代码 —video(2) optical_flow (稀疏光流、稠密光流)

1. optical_flow.cpp 稀疏光流 #include <iostream> // 引入输入输出流库 #include <opencv2/core.hpp> // 引入OpenCV的核心功能模块 #include <opencv2/highgui.hpp> // 引入OpenCV的高级GUI模块&#xff0c;提供显示图像的功能 #include <opencv2/imgp…

基于深度学习的条形码二维码检测系统(网页版+YOLOv8/v7/v6/v5代码+训练数据集)

摘要&#xff1a;本文深入研究了基于YOLOv8/v7/v6/v5的条形码二维码检测系统。核心采用YOLOv8并整合了YOLOv7、YOLOv6、YOLOv5算法&#xff0c;进行性能指标对比&#xff1b;详述了国内外研究现状、数据集处理、算法原理、模型构建与训练代码&#xff0c;及基于Streamlit的交互…

数据湖概述:大数据演进阶段-数据湖

文章目录 一. 大数据发展过程1. 离线大数据平台2. Lambda架构&#xff1a;速度层批层3. Kappa架构&#xff1a;流批一体4. 大数据架构痛点总结 二. 数据湖助力于解决数据仓库痛点问题1. 数据湖特点2. 开源数据湖的架构 三. 数据湖和数据仓库理念的对比1. 数据湖和数据仓库对比2…

自动驾驶中各种坐标系辨析

坐标系辨析 0. 地球椭圆体1. 大地坐标系2. eci地心惯性坐标系3. 地心地固坐标系(ECEF坐标系&#xff0c;E系)4. 站心坐标系(ENU坐标系)5. UTM坐标系6. LTM坐标系7. IMU坐标系8. 代码部分8.1 LLA(大地坐标系坐标、经纬度海拔)坐标转LTM系(ENU系)下的三维笛卡尔坐标8.2 LLA坐标转…

Hadoop: word count,并将reduce结果写入ES

一、依赖&#xff0c;其中ES版本为7.6.2 <?xml version"1.0" encoding"UTF-8"?> <project xmlns"http://maven.apache.org/POM/4.0.0"xmlns:xsi"http://www.w3.org/2001/XMLSchema-instance"xsi:schemaLocation"http…

STL中各类容器详细介绍

STL介绍 STL&#xff08;Standard Template Library&#xff09;&#xff0c;即标准模板库&#xff0c;是一个具有工业强度的&#xff0c;高效的C程序库。它被容纳于C标准程序库&#xff08;C Standard Library&#xff09;中&#xff0c;是ANSI/ISO C标准中最新的也是极具革命…