【llm 微调code-llama 训练自己的数据集 一个小案例】

这也是一个通用的方案,使用peft微调LLM。

准备自己的数据集

根据情况改就行了,jsonl格式,三个字段:context, answer, question

import pandas as pd
import random
import jsondata = pd.read_csv('dataset.csv')
train_data = data[['prompt','Code']]
train_data = train_data.values.tolist()random.shuffle(train_data)train_num = int(0.8 * len(train_data))with open('train_data.jsonl', 'w') as f:for d in train_data[:train_num]:d = {'context':'','question':d[0],'answer':d[1]}f.write(json.dumps(d)+'\n')
with open('val_data.jsonl', 'w') as f:for d in train_data[train_num:]:d = {'context':'','question':d[0],'answer':d[1]}f.write(json.dumps(d)+'\n')

初始化

from datetime import datetime
import os
import sysimport torchfrom peft import (LoraConfig,get_peft_model,get_peft_model_state_dict,prepare_model_for_int8_training,
)
from transformers import (AutoTokenizer, AutoModelForCausalLM, LlamaForCausalLM,TrainingArguments, Trainer, DataCollatorForSeq2Seq)# 加载自己的数据集
from datasets import load_datasettrain_dataset = load_dataset('json', data_files='train_data.jsonl', split='train')
eval_dataset = load_dataset('json', data_files='val_data.jsonl', split='train')# 读取模型
base_model = 'CodeLlama-7b-Instruct-hf'model = AutoModelForCausalLM.from_pretrained(base_model,load_in_8bit=True,torch_dtype=torch.float16,device_map="auto",low_cpu_mem_usage=True
)tokenizer = AutoTokenizer.from_pretrained(base_model)

微调前的效果

tokenizer.pad_token = tokenizer.eos_token
prompt = """You are programming coder.Now answer the question:{}"""
prompts = [prompt.format(train_dataset[i]['question']) for i in [1,20,32,45,67]]model_input = tokenizer(prompts, return_tensors="pt", padding=True).to("cuda")model.eval()
with torch.no_grad():outputs = model.generate(**model_input, max_new_tokens=300)outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)print(outputs)

进行微调

tokenizer.add_eos_token = True
tokenizer.pad_token_id = 0
tokenizer.padding_side = "left"def tokenize(prompt):result = tokenizer(prompt,truncation=True,max_length=512,padding=False,return_tensors=None,)# "self-supervised learning" means the labels are also the inputs:result["labels"] = result["input_ids"].copy()return resultdef generate_and_tokenize_prompt(data_point):full_prompt =f"""You are a powerful programming model. Your job is to answer questions about a database. You are given a question.You must output the code that answers the question.### Input:
{data_point["question"]}### Response:
{data_point["answer"]}
"""return tokenize(full_prompt)tokenized_train_dataset = train_dataset.map(generate_and_tokenize_prompt)
tokenized_val_dataset = eval_dataset.map(generate_and_tokenize_prompt)model.train() # put model back into training mode
model = prepare_model_for_int8_training(model)config = LoraConfig(r=16,lora_alpha=16,target_modules=["q_proj","k_proj","v_proj","o_proj",
],lora_dropout=0.05,bias="none",task_type="CAUSAL_LM",
)
model = get_peft_model(model, config)# keeps Trainer from trying its own DataParallelism when more than 1 gpu is available
if torch.cuda.device_count() > 1:model.is_parallelizable = Truemodel.model_parallel = Truebatch_size = 128
per_device_train_batch_size = 32
gradient_accumulation_steps = batch_size // per_device_train_batch_size
output_dir = "code-llama-ft"training_args = TrainingArguments(per_device_train_batch_size=per_device_train_batch_size,gradient_accumulation_steps=gradient_accumulation_steps,warmup_steps=100,max_steps=400,learning_rate=3e-4,fp16=True,logging_steps=10,optim="adamw_torch",evaluation_strategy="steps", # if val_set_size > 0 else "no",save_strategy="steps",eval_steps=20,save_steps=20,output_dir=output_dir,load_best_model_at_end=False,group_by_length=True, # group sequences of roughly the same length together to speed up trainingreport_to="none", # if use_wandb else "none", wandbrun_name=f"codellama-{datetime.now().strftime('%Y-%m-%d-%H-%M')}", # if use_wandb else None,)trainer = Trainer(model=model,train_dataset=tokenized_train_dataset,eval_dataset=tokenized_val_dataset,args=training_args,data_collator=DataCollatorForSeq2Seq(tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True),
)

开始训练

model.config.use_cache = Falseold_state_dict = model.state_dict
model.state_dict = (lambda self, *_, **__: get_peft_model_state_dict(self, old_state_dict())).__get__(model, type(model)
)
if torch.__version__ >= "2" and sys.platform != "win32":print("compiling the model")model = torch.compile(model)
trainer.train()

进行测试

import torch
from peft import PeftModel
from transformers import AutoModelForCausalLM, BitsAndBytesConfig, AutoTokenizerbase_model = 'CodeLlama-7b-Instruct-hf'
model = AutoModelForCausalLM.from_pretrained(base_model,load_in_8bit=True,torch_dtype=torch.float16,device_map="auto",
)
tokenizer = AutoTokenizer.from_pretrained(base_model)output_dir = "code-llama-ft"
model = PeftModel.from_pretrained(model, output_dir)eval_prompt = """You are a powerful programming model. Your job is to answer questions about a database. You are given a question.You must output the code that answers the question.### Input:
Write a function in Java that takes an array and returns the sum of the numbers in the array, or 0 if the array is empty. Except the number 13 is very unlucky, so it does not count any 13, or any number that immediately follows a 13.### Response:
"""model_input = tokenizer(eval_prompt, return_tensors="pt").to("cuda")model.eval()
with torch.no_grad():outputs = model.generate(**model_input, max_new_tokens=100)[0]
print(tokenizer.decode(outputs, skip_special_tokens=True))

主要参考icon-default.png?t=N7T8https://zhuanlan.zhihu.com/p/660933421

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/419312.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【开源】基于JAVA的CRM客户管理系统

目录 一、摘要1.1 项目介绍1.2 项目录屏 二、功能模块三、系统设计3.1 用例设计3.2 E-R 图设计3.3 数据库设计3.3.1 客户表3.3.2 商品表3.3.3 客户跟踪表3.3.4 客户消费表3.3.5 系统角色表 四、系统展示五、核心代码5.1 查询客户5.2 新增客户跟踪记录5.3 新增客户消费订单5.4 查…

复现PointNet++(语义分割网络):Windows + PyTorch + S3DIS语义分割 + 代码

一、平台 Windows 10 GPU RTX 3090 CUDA 11.1 cudnn 8.9.6 Python 3.9 Torch 1.9.1 cu111 所用的原始代码:https://github.com/yanx27/Pointnet_Pointnet2_pytorch 二、数据 Stanford3dDataset_v1.2_Aligned_Version 三、代码 分享给有需要的人&#xf…

操作系统-操作系统体系结构(内核 外核 模块化 宏内核 微内核 分层结构)

文章目录 大内核与微内核总览操作系统的内核大内核与微内核的性能差异小结 分层结构与模块化与外核总览分层结构模块化宏内核,微内核外核 大内核与微内核 总览 操作系统的内核 操作系统的核心功能在内核中 对于与硬件关联程度的程序 由于进程管理,存…

L1-067 洛希极限(Java)

科幻电影《流浪地球》中一个重要的情节是地球距离木星太近时,大气开始被木星吸走,而随着不断接近地木“刚体洛希极限”,地球面临被彻底撕碎的危险。但实际上,这个计算是错误的。 洛希极限(Roche limit)是一…

OpenMV入门

1. 什么是OpenMV OpenMV 是一个开源,低成本,功能强大的 机器视觉模块。 OpenMV上的机器视觉算法包括 寻找色块、人脸检测、眼球跟踪、边缘检测、标志跟踪 等。 以STM32F427CPU为核心,集成了OV7725摄像头芯片,在小巧的硬件…

小程序学习-19

Vant Weapp - 轻量、可靠的小程序 UI 组件库 ​​​​​ Vant Weapp - 轻量、可靠的小程序 UI 组件库 安装出现问题:rollbackFailedOptional: verb npm-session 53699a8e64f465b9 解决办法:http://t.csdnimg.cn/rGUbe Vant Weapp - 轻量、可靠的小程序…

海外媒体发稿:满足要求的二十个爆款文案的中文标题-华媒舍

爆款文案是指在营销和推广方面非常受欢迎和成功的文案。它们能够吸引读者的眼球,引发浏览者的兴趣,最终促使他们采取行动。本文将介绍二十个满足要求的爆款文案的中文标题,并对每个标题进行拆解和描述。 1. "XX 绝对不能错过的十大技巧…

RHEL - 更新升级软件或系统

《OpenShift / RHEL / DevSecOps 汇总目录》 文章目录 小版本软件更新yum update 和 yum upgrade 的区别升级软件和升级系统检查软件包是否可升级指定升级软件使用的发行版本方法1方法2方法3方法4 查看软件升级类型更新升级指定的 RHSA/RHBA/RHEA更新升级指定的 CVE更新升级指定…

第36集《佛法修学概要》

请大家打开讲义第九十六面,我们讲到禅定的修学方便。 在我们发了菩提心,安住菩萨种性以后,我们开始操作六度的法门。六度的法门,它有两个不同的差别的内容,一种是成就我们的善业力,另外一种,是…

【Linux】第三十一站:管道的一些应用

文章目录 一、我们之前的|(竖划线)管道二、自定义shell三、使用管道实现一个简易的进程池1.详解2.代码3.一个小bug4.最终代码 一、我们之前的|(竖划线)管道 cat test.txt | head -10 | tail -5如上代码所示,是我们之前所用的管道 我们拿下面这个举个例子 当我们用…

深度探讨 Golang 中并发发送 HTTP 请求的最佳技术

目录 推荐 使用 Goroutines 的基本方法 Goroutine 入门 处理多个请求 并发 HTTP 请求的方法 基本 Goroutine WaitGroup Channels Worker Pools 使用通道限制 Goroutine 使用信号量限制 Goroutines 那么,最好的方法是什么? 评估你的需求 错误…

快速上手的AI工具-文心3.5vs文心4.0

前言 大家好晚上好,现在AI技术的发展,它已经渗透到我们生活的各个层面。对于普通人来说,理解并有效利用AI技术不仅能增强个人竞争力,还能在日常生活中带来便利。无论是提高工作效率,还是优化日常任务,AI工…