使用HF Trainer微调小模型

news/2024/11/15 15:56:11/文章来源:https://www.cnblogs.com/zrq96/p/18366846

本文记录HugginngFace的Trainer各种常见用法。

SFTTrainer的一个最简单例子

HuggingFace的各种Trainer能大幅简化我们预训练和微调的工作量。能简化到什么程度?就拿我们个人用户最常会遇到的用监督学习微调语言模型任务为例,只需要定义一个SFTrainer,给定我们想要训练的模型和数据集,就可以直接运行微调任务。

'''
The simplest way to supervised-finetune a small LM by SFTTrainerEnvironment:transformers==4.43.3datasets==2.20.0trl==0.9.6
'''
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import Dataset
from trl import SFTTrainermodel_path = 'Qwen/Qwen2-0.5B'
model = AutoModelForCausalLM.from_pretrained(model_path)
corpus = [{'prompt': 'calculate 24 x 99', 'completion': '24 x 99 = 2376'},{'prompt': 'Which number is greater, 70 or 68?', 'completion': '70 is greater than 68.'},{'prompt': 'How many vertices in a tetrahedron?', 'completion': 'A tetrahedron has 4 vertices.'},
]
dataset = Dataset.from_list(corpus)
trainer = SFTTrainer(model, train_dataset=dataset)  # 给定我们想要训练的模型和数据集
trainer.train()  # 就可以直接运行微调任务

使用Trainer不可或缺的参数只有两个:

  • model
  • train_dataset

是的,其他一切参数都是锦上添花,不可或缺的只有这两个。我们能够如此省心省力地去做微调,当然是因为SFTrainer帮我们做了很多事情,具体做了什么可以对比一下上面代码以和前文【一步一步微调小模型】中的完整代码。

要注意,我们的代码能如此简单很大程度上也是因为SFTTrainer帮我们预处理了数据集。它支持的数据集格式有两种,一种就是在上面代码里显示的那样,由promptcompletion组成的键值对:

{"prompt": "How are you", "completion": "I am fine, thank you."}
{"prompt": "What is the capital of France?", "completion": "It's Paris."}
{"prompt": "有志者事竟成", "completion": "Where there's a will, there's a way"}

这种格式就可以用于一问一答式的任务,包括只问答、翻译、摘要。另一种是更加灵活的多轮对话式的数据格式,每一个训练样本都以'messages'开头,然后式一列usersystem的对话记录:

{"messages": [{"role": "system", "content": "You are helpful"}, {"role": "user", "content": "What's the capital of France?"}, {"role": "assistant", "content": "It's Paris."}{"role": "user", "content": "and how about Japan?"}, {"role": "assistant", "content": "It's Tokyo."}]}
{"messages": [{"role": "system", "content": "You are helpful"}, {"role": "user", "content": "Who wrote 'Romeo and Juliet'?"}, {"role": "assistant", "content": "..."}]}

这两种数据格式已经足以应付几乎所有微调任务的需求。比如说,我要把使用数学内容的预料训练/微调Qwen-0.5B使它能回答和数学相关的问题,使用到的数据集是微软的orca-math-word-problems-200k,这个数据集由questionanswer组成

>>> from datasets import load_dataset
>>> dataset = load_dataset('microsoft/orca-math-word-problems-200k', split='train')
>>> print(dataset)
Dataset({features: ['question', 'answer'],num_rows: 200035
})

具体来说里面的数据长这个样子:

虽然数据集的键值对不是promptcompletion,但我们只需要修改一下名字就可以,就像下面的代码一样:

from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset
from trl import SFTTrainermodel_path = 'Qwen/Qwen2-0.5B'
data_path = 'microsoft/orca-math-word-problems-200k'
model = AutoModelForCausalLM.from_pretrained(model_path)
dataset = load_dataset(data_path, split='train')dataset = dataset.rename_column('question', 'prompt')   # rename dataset features to "prompt" and "completion"
dataset = dataset.rename_column('answer', 'completion') # to fit in the SFTTrainertrainer = SFTTrainer(model, train_dataset=dataset)
trainer.train()

SFTTrainer的一些常见用法

如果我们微调语言模型的任务使用到的数据非常奇特,无法用这两种数据格式来表示(虽然我觉得不太可能),那我们还能怎么办?这时候我们就只能把训练样本转化成语言模型一定会支持的格式,也就是字符串。具体来说就是编写一个函数,这个函数把训练样本作为输入,输出转化后的字符串,然后再定义训练器的时候把这个函数也传给训练器。以下代码种的formatting_prompts_func就把orca-math-word-problems-200k数据集种的每个训练样本都转化成了字符串

from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset
from trl import SFTTrainermodel_path = 'Qwen/Qwen2-0.5B'
data_path = 'microsoft/orca-math-word-problems-200k'
model = AutoModelForCausalLM.from_pretrained(model_path)
dataset = load_dataset(data_path, split='train')def to_prompts_fn(batch) -> list[str]:'''take a batch of training samples, return a list of strings'''output_texts = []for i in range(len(batch['question'])):text = f"### Question: {batch['question'][i]}\n ### Answer: {batch['answer'][i]}"output_texts.append(text)return output_textstrainer = SFTTrainer(model, train_dataset=dataset, formatting_func=to_prompts_fn)
trainer.train()

虽说使用SFTTrainer省心省力,但有时我们也希望更加深入地掌控训练/微调过程,比如调整学习率,调整batch的大小,每隔几步就打印一下训练过程中的一些指标、再测试数据上看看模型的效果、保存一下模型,诸如此类的。如果想要更多的控制,只需要给训练器传入更多参数,这些参数都可以统一写在SFTConfig里面,再传给训练器。下面的代码示例展示了常用的一些配置参数,包括如何调整batch大小、设置频繁清空GPU缓存等来避免CUDAOutofMemory,还给了一个测试数据集来监控模型在测试集上的效果。

'''
Common usage of SFTrainer and SFTConfig to finetune a small LM
'''
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset
from trl import SFTTrainer, SFTConfigmodel_path = 'Qwen/Qwen2-0.5B'
data_path = 'microsoft/orca-math-word-problems-200k'
save_path = '/home/zrq96/checkpoints/qwen-0.5B-math-sft42'  # where checkpoints to outputmodel = AutoModelForCausalLM.from_pretrained(model_path)
tokenizer = AutoTokenizer.from_pretrained(model_path)
dataset = load_dataset(data_path, split='train')
dataset = dataset.rename_column('question', 'prompt')
dataset = dataset.rename_column('answer', 'completion')
splited = dataset.train_test_split(test_size=0.01)sft_config = SFTConfig(output_dir=save_path,# max length of the total sequencemax_seq_length=min(tokenizer.model_max_length, 2048), per_device_train_batch_size=4, # by default 8learning_rate=1e-4, # by default 5e-5weight_decay=0.1, # by default 0.0num_train_epochs=2, # by default 3logging_steps=50, # by default 500save_steps=100, # by default 500torch_empty_cache_steps=10,  # empty GPU cache every 10 stepseval_strategy='steps', # by default 'no'eval_steps=100, 
)trainer = SFTTrainer(model,args=sft_config,train_dataset=splited['train'],eval_dataset=splited['test'],
)
trainer.train()

使用Trainer

使用SFTTrainer能这么省力,还是因为它帮我们做了很多事情,这其中最主要的事情就是帮我们处理好数据集。可以说,在大模型领域的编程里,甚至在当前的人工智能领域里数据处理占了一半以上的工作量。SFTTrainerTrainer的一个子类,而Trainer是HuggingFace所有训练器的父类,我们可以使用Trainer来做SFTTrainer能做的一切事情。我们下面演示一下怎样使用更加通用的Trainer来做微调,顺便展示一下SFTTrainer帮我们做了哪些数据处理工作。理解了这个过程,我们以后甚至可以定制自己的训练器。根据Trainer的文档,它的用法跟SFTTRainer类似(倒反天罡了属于是...),也是传入待训练/微调的模型和数据集,以及一些可能的训练参数:

from transformers import AutoModelForCausalLM, Trainer, TrainingArgumentsmodel = AutoModelForCausalLM.from_pretrained('Qwen/Qwen2-0.5B')
dataset = ...
training_args = TrainingArguments(output_dir='./save_path')
trainer = Trainer(model, train_dataset=dataset, args=training_args)
trainer.train()

SFTTrainer最大的不同是对数据格式的支持。SFTTrainer可以接受一些直观的数据格式,但Trainer的数据集要严格按照model.forward函数所能接受的输入来设计,也就说Trainer会把数据集里的数据样本直接塞给模型,那么我们的数据集里的样本就要是能直接传给模型的。因此,想要对数据进行处理,我们就得研究一下我们的待训练模型究竟能接收怎样的输入数据。在我们的例子中,我们的模型是Qwen2ForCausalLM,其forward函数的签名是

def forward(self,input_ids: torch.LongTensor = None,attention_mask: Optional[torch.Tensor] = None,position_ids: Optional[torch.LongTensor] = None,past_key_values: Optional[List[torch.FloatTensor]] = None,inputs_embeds: Optional[torch.FloatTensor] = None,labels: Optional[torch.LongTensor] = None,use_cache: Optional[bool] = None,output_attentions: Optional[bool] = None,output_hidden_states: Optional[bool] = None,return_dict: Optional[bool] = None,cache_position: Optional[torch.LongTensor] = None,) -> Union[Tuple, CausalLMOutputWithPast]: ...

其中,input_ids: torch.LongTensor是必须有的。因为我们要做训练/微调,所以labels: Optional[torch.LongTensor]也是必须的而非optional了。所以我们的数据集应当是含input_idslabels的样本,而且这两个特征的数据类型都是torch.LongTensor

>>> from datasets import load_dataset
>>> dataset = load_dataset('microsoft/orca-math-word-problems-200k', split='train')
>>> def preprocess_dataset(x: dict) -> dict:
...     # to be implemented
...     ...
>>> 
>>> new_ds = dataset.map(preprocess_dataset)
>>> print(new_ds)
Dataset({features: ['input_ids', 'labels'],num_rows: 200035
})

我们会使用一个preprocess_dataset函数把原始数据集种的每一个样本转化成model能给接受的样本。要做的也只是把question和answer的文本内容拼接在一起,然后tokenize一下,有需要的话就pad或者truncate,这就搞定了input_ids。而CausalLM的forward要做的事情都是预测下一个词,所以labels就是input_ids左移一个token的位置而已。然而,因为左移这一步model自己会做,所以我们的labels就只是复制一份input_ids而已:

左移labels这一步model自己会做

因此,我们的预处理函数就仅仅是把文本变成input_ids,然后复制一份作为labels:

def preprocess_data(x: dict) -> dict:'''take a training sample and return a preprocessed sample with the keys that the model expects, in our case:- input_ids: the tokenized input- labels: the right-shifted tokenized inputand optionally:- attention_mask: a mask indicating which tokens should be attended to- position_ids: the position of each token in the input- ...'''text = f"### Question: {x['question']}\n ### Answer: {x['answer']}"tokenized = tokenizer(text, return_tensors='pt', truncation=True, padding='max_length', max_length=1024)  # 量力而行,可以2014甚至4096,或者使用 data collatorreturn {'input_ids': tokenized['input_ids'][0],'labels': tokenized['input_ids'][0].clone(),}

将所有代码整合起来,下面就是使用Trainer对模型做微调的完整代码:

'''
Supervised-FineTuning a small LM by the vanilla Trainer
'''
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset
from transformers import TrainingArguments, Trainermodel_path = 'Qwen/Qwen2-0.5B'
data_path = 'microsoft/orca-math-word-problems-200k'
save_path = '/home/ricky/checkpoints/qwen-0.5B-math-sft42'model = AutoModelForCausalLM.from_pretrained(model_path)
tokenizer = AutoTokenizer.from_pretrained(model_path)
dataset = load_dataset(data_path, split='train')def preprocess_data(x: dict) -> dict:'''take a training sample and return a preprocessed sample with the keys that the model expects, in our case:- input_ids: the tokenized input- labels: the right-shifted tokenized inputand optionally:- attention_mask: a mask indicating which tokens should be attended to- position_ids: the position of each token in the input- ...'''text = f"### Question: {x['question']}\n ### Answer: {x['answer']}"tokenized = tokenizer(text, return_tensors='pt', truncation=True, padding='max_length', max_length=1024)return {'input_ids': tokenized['input_ids'][0],'labels': tokenized['input_ids'][0].clone(),}new_ds = dataset.map(function=preprocess_data,  # map all samples with this functionnum_proc=4                 # use 4 processes to speed up
)
splited = new_ds.train_test_split(test_size=0.01)training_args = TrainingArguments(output_dir=save_path,per_device_train_batch_size=2,per_device_eval_batch_size=2,torch_empty_cache_steps=2,num_train_epochs=2, # by default 3logging_steps=50, # by default 500save_steps=100,eval_strategy='steps', # by default 'no'eval_steps=100, 
)trainer = Trainer(model=model,args=training_args,train_dataset=splited['train'],eval_dataset=splited['test'],
)trainer.train()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/786991.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

题解:P10358 [PA2024] Obrazy

题解:P10358 [PA2024] Obrazy 题目传送门 即当最小的画框都不可能覆盖整个矩形墙面时,输出 −1。 [PA2024] Obrazy 题目背景 PA 2024 3C 题目描述 题目译自 PA 2024 Runda 3 Obrazy,感谢 Macaronlin 提供翻译 给定尺寸为 $h\times w$ 的矩形墙面,以及 $n$ 种尺寸的正方形画…

CMake构建学习笔记4-libjpeg库的构建

介绍了通过CMake构建libjpeg库的关键步骤。libjpeg是一个广泛使用的开源库,用于处理JPEG(Joint Photographic Experts Group)图像格式的编码、解码、压缩和解压缩功能,是许多图像处理软件和库的基础。 libjpeg本身的构建没什么特别的,不过值得说道的是libjpeg存在一个高性…

第一个selenium测试

一、环境搭建 使用语言:python 1、python解释器:python.exe 版本 3.11.4 下载地址:[https://www.python.org/downloads/release/python-3114/]设置环境变量:复制python.exe安装路径--高级系统设置--环境变量--PATH中添加--粘贴python.exe安装路径--确定 目的是确保接下来系…

博客园OpenApi管理平台

简介 博客园(Cnblogs)提供了OpenAPI服务,允许开发者通过API来获取博客园中的数据。使用这个API,可以实现从博客园抓取文章、评论等信息的功能,这对于想要集成博客园内容到自己网站或应用的开发者来说是非常有用的。 网址 https://api.cnblogs.com/结束

【论文阅读】TBA Faster Large Language Model Training Using SSD Based Activation Offloading

摘要 GPU内存容量的增长速度跟不上大型语言模型(llm)的增长速度,阻碍了模型的训练过程。特别是,激活——在前向传播过程中产生的中间张量,并在后向传播中重用——主导着GPU内存的使用。为了应对这一挑战,我们建议TBA将激活有效地卸载到高容量NVMe ssd上。这种方法通过自适应…

隧道代理ip使用

简介 隧道代理(Tunnel Proxy)是一种特殊的代理服务,它的工作方式是在客户端与远程服务器之间建立一条“隧道”。这种技术常被用来绕过网络限制或提高网络安全性。 主要功能IP地址变换:隧道代理能够改变客户端的IP地址,使得客户端访问的目标服务器看到的是代理服务器的IP地…

使用C#爬取快手作者主页,并下载视频/图集

最近发现一些快手的作者,作品还不错,出于学习研究的目的,决定看一下怎么爬取数据。现在网上有一些爬虫工具,不过大部分都失效了,或者不开源。于是自己就写了一个小工具。先看一下成果:软件只需要填写作者uid以及网页版的请求Cookie,即可实现自动下载,下载目录在程序根目…

Python安装教程

第一步:先去官网上下载python安装包 系统64位下载地址:系统32位下载地址:第二步:点击安装包 要点击Add python.exe to PATH 这个是配置系统变量的然后点install Now就开始跑安装进度了打开终端输入代码 python 再输入print(123),如果输入123那么安装完成,可以整活了

用空间清理调理风水4什么是杂乱5杂物如何影响你6为什么保留杂物

4 什么是杂乱? 《牛津英语词典》将杂乱定义为 “乱七八糟的东西的集合”。是的,这是其中的一部分,但它只是从纯粹的物理层面来描述杂乱。 在我的定义中,凌乱分为四类:你不使用或不喜欢的东西杂乱无章的东西太多东西放在太小的空间里任何未完成的东西让我们逐一来看看这些东…

go免杀学习记录

本文主要介绍了go语言的加载器基本框架与绕过360与火绒的通用方法。题记最近剑来动漫上线,虽然观感不如我的预期,感觉节奏过快。但是也是一种进步了,愿各位道友都能找到自己的宁姚。"我喜欢的姑娘啊,她眉如远山,浩然天下所有好看的山,好看的水,加起来都不如她。她睫…

看图学 - Swift actor

本文首发于 Ficow Shens Blog,原文地址: 看图学 - Swift actor。想第一时间获取对于自己有帮助的新内容? 欢迎关注 Ficow 的公众号: 看图学 Swift actor如需获取PDF版本思维导图、示例代码,请查阅公众号内容: 《看图学 - Swift actor》Stay hungry,stay foolish.

VulNyx - Ceres 靶机

有80端口访问看看他这个挺奇葩的看了wp才知道 file.php的参数是file 他会自动给你加上php 也就是说file=secret.php读不到数据要file=secret才能读到数据伪协议读取文件<?php include($_GET["file"].".php"); ?><?php system("id…