一步一步微调小模型

news/2024/11/13 16:45:48/文章来源:https://www.cnblogs.com/zrq96/p/18345298

本文记录一下,使用自顶向下的编程法一步步编写微调小语言模型的代码。

微调一个语言模型,本质上是把一个已经预训练过的语言模型在一个新的数据集上继续训练。那么一次微调模型的任务,可以分为下面三个大个步骤(不包含evaluation):

  • 加载已经预训练好的模型和新的数据集
  • 预处理模型和数据集
  • 开始循环训练
# ======== imports ========
import torch # ======== config  =========num_epochs = ...# ========= Load model, tokenizer, dataset  =========model = ...
tokenizer = ...
dataset = ...# ========= Preprocessing  =========train_dataloader = ...# ========= Finetuning/Training  =========def compute_loss(X, y): ...optimizer = ...
scheduler = ...for epoch in range(num_epochs):for batch in train_dataloader:# Training code hereout = model(batch['X'])loss = compute_loss(out, batch['y'])loss.backward()optimizer.step()optimizer.zero_grad()scheduler.step()torch.save(model.state_dict(), save_path)

本文的目标是微调一个小规模的千问模型Qwen-0.5B,使之具有更强的数学思考能力,使用微软的数据集microsoft/orca-math-word-problems-200k。因此,准备步骤就是导入相应的模块,加载相应的模型和数据集:

# ======== imports ========
import torch
from torch.utils.data import DataLoader
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset, Dataset# ======== config  =========
model_path = 'Qwen/Qwen2-0.5B'
data_path = 'microsoft/orca-math-word-problems-200k'
save_path = './Qwen2-0.5B-math-1'# ========= Load model, tokenizer, dataset  =========
model = AutoModelForCausalLM.from_pretrained(model_path)
tokenizer = AutoTokenizer.from_pretrained(model_path)
dataset = load_dataset(data_path)# ========= Preprocessing  =========train_dataloader = ...# ========= Finetuning  =========def compute_loss(X, y): ...optimizer = ...
scheduler = ...for epoch in range(num_epochs):for batch in train_dataloader:# Training code hereout = model(batch['X'])loss = compute_loss(out, batch['y'])loss.backward()optimizer.step()optimizer.zero_grad()scheduler.step()torch.save(model.state_dict(), save_path)

优化器和调度器也按照惯例,使用Adam(或者SGD)和Cosine

# ========= Finetuning  =========
from transformers import get_cosine_schedule_with_warmup def compute_loss(X, y): ...optimizer = torch.optim.Adam(model.parameters(), lr=2e-4, betas=(0.9,0.99), eps=1e-5)
scheduler = get_cosine_schedule_with_warmup(optimizer, num_training_steps=100, num_warmup_steps=10)for epoch in range(num_epochs):for batch in train_dataloader:# Training code hereout = model(batch['X'])loss = compute_loss(out, batch['y'])loss.backward()optimizer.step()optimizer.zero_grad()scheduler.step()torch.save(model.state_dict(), save_path)

现在我们聚焦于怎么预处理数据,也就是获得train_dataloader

# ======== config  =========
num_epochs = 5  # 量力而行
batch_size = 8  # 量力而行# ========= Preprocessing  =========
train_dataloader = DataLoader(dataset, batch_size=batch_size)

Qwen模型和GPT类似,都是根据一串token输入,预测下一个token。所以在训练中的X就应该是一串token,y则是一个token。但实际上,Qwen以及Huggingface的其他AutoModelForCausalLM的输出不止是一个token,而是一串token。准确来说AutoModelForCausalLM输出的logits是一个的型为(batch_size, sequence_length, config.vocab_size)三维张量。因为模型会把输入x中的每一个x[0:i]子串都当作输入来预测一下,所以相应的y也应当调整为每一个x[0:i]子串的后一个token

那么,train_dataloader里面的每个batch的X和y都是token,且基本上y可以看作x左移了一位,而y[-1]则是x这个输入序列的真实预测值

x, y = x[0:-1], x[1:len(x)]

现在的任务就是把数据集变换成我们想要的模样。考察数据集,发现每个样本有两个序列

>>>  print(dataset)
DatasetDict({train: Dataset({features: ['question', 'answer'],num_rows: 200035})
})

我们使用Qwentokenizer提供的方法,把这两个序列合并成一个x

from transformers import default_data_collatordef ds_generator():for item in ds['train']:messages = [{"role": "user", "content": item['question']},{"role": "system", "content": item['answer']}]text = tokenizer.apply_chat_template(messages,tokenize=False,add_generation_prompt=True)model_inputs = tokenizer([text], return_tensors="pt", padding='max_length', truncation=True)yield {'x': model_inputs['input_ids'][0][:-1],'y': model_inputs['input_ids'][0][1:]}
dataset = Dataset.map(ds_generator)
train_dataloader = DataLoader(new_ds, batch_size=batch_size,  collate_fn=default_data_collator)

我们这里只考虑语言模型,所以模型常规的输出应该是logits,则compute_loss则是计算模型输出与真实label之间的交叉熵。

def compute_loss(X, y):'''X: (batch_size, seq_len, vocab_size)y: (batch_size, seq_len)'''return torch.nn.functional.cross_entropy(X.view(-1, X.shape[-1]), y.view(-1))

(其实Qwen模型如果在喂输入x时同时把y喂进去,那么返回值就包含了loss:loss = model(x, y).loss


完整的微调代码如下

from types import SimpleNamespaceimport torch
from torch.utils.data import DataLoader
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import default_data_collator, get_cosine_schedule_with_warmup
from datasets import load_dataset, Datasetmodel_path = 'Qwen/Qwen2-0.5B'
data_path = 'microsoft/orca-math-word-problems-200k'
save_path = './Qwen2-0.5B-math-1'device = 'cuda' if torch.cuda.is_available() else 'cpu'config = SimpleNamespace(batch_size=8,epochs=5,eps=1e-5,lr=2e-4,model_max_length=2048
)model = AutoModelForCausalLM.from_pretrained(model_path)
tokenizer = AutoTokenizer.from_pretrained(model_path, padding_side="right", model_max_length=config.model_max_length)
dataset = load_dataset(data_path)def ds_generator():for item in ds['train']:messages = [{"role": "user", "content": item['question']},{"role": "system", "content": item['answer']}]text = tokenizer.apply_chat_template(messages,tokenize=False,add_generation_prompt=True)model_inputs = tokenizer([text], return_tensors="pt", padding='max_length', truncation=True)yield {'x': model_inputs['input_ids'][0][:-1],'y': model_inputs['input_ids'][0][1:]}dataset = Dataset.map(ds_generator)
train_dataloader = DataLoader(dataset, batch_size=config.batch_size,  collate_fn=default_data_collator)def compute_loss(X, y):'''X: (batch_size, seq_len, vocab_size)y: (batch_size, seq_len)'''return torch.nn.functional.cross_entropy(X.view(-1, X.shape[-1]), y.view(-1))optimizer = torch.optim.Adam(model.parameters(), lr=config.lr, betas=(0.9,0.99), eps=config.eps)
scheduler = get_cosine_schedule_with_warmup(optimizer, num_training_steps=100, num_warmup_steps=10)for epoch in range(num_epochs):for batch in train_dataloader:# Training code hereout = model(batch['x'])loss = compute_loss(out.logits, batch['y'])loss.backward()optimizer.step()optimizer.zero_grad()scheduler.step()torch.save(model.state_dict(), save_path)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/781383.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

最全MySQL面试20题和答案(三)

接第二期的MySQL面试二十题,这是之后的20题!视图 1. 为什么要使用视图?什么是视图?为了提高复杂 SQL 语句的复用性和表操作的安全性,MySQL 数据库管理系统提供了视图特性。所谓视图,本质上是一种虚拟表,在物理上是不存在的,其内容与真实的表相似,包含一系列带有名称的…

实战-行业攻防应急响应

实战-行业攻防应急响应简介: 服务器场景操作系统 Ubuntu 服务器账号密码:root/security123 分析流量包在/home/security/security.pcap 相关jar包在/home/security/ruoyi/ruoyi-admin.jar 应急主机: 192.168.0.211 网关: 192.168.0.1/24 其它傀儡机: 段内 本次环境来自某次行…

033.Vue3入门,多个插槽Slot的命名调用和#号简写

1、App.vue代码如下:<template><div><h3>插槽学习</h3><Slot001><!-- 插槽1--><template v-slot:s2><p>{{ msg1 }}</p></template><!-- 插槽2--><template #s1><p>{{ msg2 }}</p>…

个人Blog的第一篇博文

个人Blog的第一篇博文正式加入"博客园"大家庭了,希望以后可以一直坚持下去欸。

Day 15

今天留了个小遗憾......并没有学会已崩溃......啊啊啊啊啊啊啊啊啊啊啊啊啊啊啊啊啊啊啊啊啊啊啊啊啊啊啊啊啊啊啊啊啊啊啊啊啊!!!!!!!!!!!!!!!!!!!!!!!!!!!!!不驰于空想,不骛于虚声

032.Vue3入门,插槽Slot的作用域和默认内容

1、App.vue代码如下:<template><div><h3>插槽学习</h3><!-- 插槽1--><Slot001><p>{{ msg }}</p></Slot001><!-- 插槽2--><Slot001><!-- <p>{{ msg }}</p>--></Slot001>…

bugbountyhunter scope BARKER:第九滴血 存储型 Storage Cross-Site Scripting XSS 头像处SVG文件上传 报告

登录后来到My profile页面,页面里存在一个Edit Profile头像处可以上传SVG图片 检查 xss payload:https://github.com/swisskyrepo/PayloadsAllTheThings/tree/master/XSS Injection#xss-in-files使用SVG进行图片上传,发现SVG文件上传成功并返回图片地址poc:https://cfceb12f…

bugbountyhunter scope BARKER:第九滴血 存储型 Storage Cross-Site Scripting XSS SVG文件上传 报告

登录后来到My profile页面,页面里存在一个Edit Profile头像处可以上传SVG图片 检查 xss payload:https://github.com/swisskyrepo/PayloadsAllTheThings/tree/master/XSS Injection#xss-in-files使用SVG进行图片上传,发现SVG文件上传成功并返回图片地址poc:https://cfceb12f…

WPS Office 2023专业版 v12.8.2.17149v2 精简优化版

概述 WPS Office是由金山软件股份有限公司自主研发的一款办公软件套装,可以实现办公软件最常用的文字、表格、演示等多种功能。具有内存占用低、运行速度快、体积小巧、强大插件平台支持、免费提供海量在线存储空间及文档模板、支持阅读和输出PDF文件、全面兼容微软Office97-2…

030.Vue3入门,父页面给子页面传递attribute属性

1、App.vue代码如下:<template><Father/> </template><script setup> import Father from ./view/Father.vue </script><style> </style>2、Father.vue代码如下:<template><h3>父页面</h3><p>父页面属性&…

C基础篇 文件操作

1. EOF宏,C语言EOF宏详解2. gets和fgets函数及其区别,C语言gets和fgets函数详解3. puts和fputs函数及其区别,C语言puts和fputs函数详解4. feof和ferror函数,C语言feof和ferror函数详解5. setbuf与setvbuf函数,C语言setbuf与setvbuf函数详解6. fseek、ftell和rewind函数,C…

029.Vue3入门,父页面自定义Event传给子页面,子页面通过此Event回传数据给父页面

1、App.vue代码:<template><Father/> </template><script setup> import Father from ./view/Father.vue </script><style> </style>2、Father代码如下:<template><h3>父页面</h3><p>搜索内容为: {{ msg …