使用LLM-Tuning实现百川和清华ChatGLM的Lora微调

LLM-Tuning项目源码:

GitHub - beyondguo/LLM-Tuning: Tuning LLMs with no tears💦, sharing LLM-tools with love❤️.Tuning LLMs with no tears💦, sharing LLM-tools with love❤️. - GitHub - beyondguo/LLM-Tuning: Tuning LLMs with no tears💦, sharing LLM-tools with love❤️.icon-default.png?t=N7T8https://github.com/beyondguo/LLM-Tuning

1、环境准备

训练主机配置:

  • NVIDIA A100-PCIE-40GB
  • Python 3.8.10
  • CUDA: 11.2

安装依赖库:

pip install xformers==0.0.20
pip install torch==2.0.1
pip install transformers==4.29.1
pip install datasets==2.12.0
pip install accelerate==0.19.0
pip install sentencepiece==0.1.99
pip install tensorboard==2.13.0
pip install peft==0.3.0

2、数据准备

原始文件的准备

指令微调数据一般有输入和输出两部分,输出则是希望模型的回答,统一使用json的格式在整理数据,可以自定义输出输出的字段名。

{"q": "请计算:39 * 0 = 什么?", "a": "这是简单的乘法运算,39乘以0得到的是0"}
{"q": "题目:51/186的答案是什么?", "a": "这是简单的除法运算,51除以186大概为0.274"}
{"q": "鹿妈妈买了24个苹果,她想平均分给她的3只小鹿吃,每只小鹿可以分到几个苹果?", "a": "鹿妈妈买了24个苹果,平均分给3只小鹿吃,那么每只小鹿可以分到的苹果数就是总苹果数除以小鹿的只数。\n24÷3=8\n每只小鹿可以分到8个苹果。所以,答案是每只小鹿可以分到8个苹果。"}
...

整理好数据后,保存为.json或者.jsonl文件,然后放入目录中的data/文件夹中。

对数据集进行分词

为了避免每次训练的时候都要重新对数据集分词,先分好词形成特征后保存成可直接用于训练的数据集。

例如:

  • 原始指令微调文件为:data/ 文件夹下的 simple_math_4op.json 文件
  • 输入字段为q,输出字段为a
  • 希望经过 tokenize 之后保存到 data/tokenized_data/ 下名为 simple_math_4op 的文件夹中设定
  • 文本最大程度为 2000

则我们可以直接使用下面这段命令(即tokenize.sh文件)进行处理:

CUDA_VISIBLE_DEVICES=0,1 python tokenize_dataset_rows.py \--model_checkpoint THUDM/chatglm-6b \--input_file simple_math_4op.json \--prompt_key q \--target_key a \--save_name simple_math_4op \--max_seq_length 2000 \--skip_overlength False

处理完毕之后,在 data/tokenized_data/ 下生成名为 simple_math_4op 的文件夹,这就是下一步中可以直接用于训练的数据。

对比不同的 LLM,需在 tokenize.sh 文件里切换 model_checkpoint 参数。

3、使用 LoRA 微调

得到 tokenize 之后的数据集,就可以直接运行 chatglm_lora_tuning.py 来训练 LoRA 模型了。

对于不同的 LLM,需切换不同的 python 文件来执行:

  • ChatGLM-6B 应使用 chatglm_lora_tuning.py
  • ChatGLM2-6B 应使用 chatglm2_lora_tuning.py
  • baichuan-7B 应使用 baichuan_lora_tuning.py
  • baichuan2-7B 应使用 baichuan2_lora_tuning.py
  • internlm-chat/base-7b 应使用 intermlm_lora_tuning.py
  • chinese-llama2/alpaca2-7b 应使用 chinese_llama2_alpaca2_lora_tuning.py

具体可设置的主要参数包括:

  • tokenized_dataset, 分词后的数据集,即在 data/tokenized_data/ 地址下的文件夹名称
  • lora_rank, 设置 LoRA 的秩,推荐为4或8,显存够的话使用8
  • per_device_train_batch_size, 每块 GPU 上的 batch size
  • gradient_accumulation_steps, 梯度累加,可以在不提升显存占用的情况下增大 batch size
  • max_steps, 训练步数
  • save_steps, 多少步保存一次
  • save_total_limit, 保存多少个checkpoint
  • logging_steps, 多少步打印一次训练情况(loss, lr, etc.)
  • output_dir, 模型文件保存地址

例如我们的数据集为 simple_math_4op,希望保存到 weights/simple_math_4op ,则执行下面命令(即train.sh文件):

CUDA_VISIBLE_DEVICES=0,1 python chatglm_lora_tuning.py \--tokenized_dataset simple_math_4op \--lora_rank 8 \--per_device_train_batch_size 10 \--gradient_accumulation_steps 1 \--max_steps 100000 \--save_steps 200 \--save_total_limit 2 \--learning_rate 1e-4 \--fp16 \--remove_unused_columns false \--logging_steps 50 \--output_dir weights/simple_math_4op

训练完之后,可以在 output_dir 中找到 LoRA 的相关模型权重,主要是adapter_model.bin和adapter_config.json两个文件。

如何查看 tensorboard:

  • 在 output_dir 中找到 runs 文件夹,复制其中日期最大的文件夹的地址,假设为 your_log_path
  • 执行 tensorboard --logdir your_log_path 命令,就会在 http://localhost:6006/ 上开启tensorboard
  • 如果是在服务器上开启,则还需要做端口映射到本地。
  • 如果要自己手动进行端口映射,具体方式是在使用 ssh 登录时,后面加上 -L 6006:127.0.0.1:6006 参数,将服务器端的6006端口映射到本地的6006端口。

4、在本地大模型上加载LoRA并推理

把上面的 output_dir 打包带走,假设文件夹为 weights/simple_math_4op, 其中(至少)包含 adapter_model.bin 和 adapter_config.json 两个文件,用下面的方式直接加载,并推理

from peft import PeftModel
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers.generation.utils import GenerationConfig
import torchdevice = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")# 加载原始 LLM
model_path = "baichuan-inc/Baichuan2-7B-Chat"
model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True).half().to(device)
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)model.generation_config = GenerationConfig.from_pretrained(model_path)
messages = [{"role": "user", "content": "什么是视觉生成艺术"}]
output = model.chat(tokenizer, messages)
print(output)# 给原始 LLM 安装上你的 LoRA tool
model = PeftModel.from_pretrained(model, "weights/chat_messagge_4op").half().to(device)
output = model.chat(tokenizer, messages)
print(output)

理论上可以通过多次执行 model = PeftModel.from_pretrained(model, "weights/simple_math_4op").half()  的方式,加载多个 LoRA 模型,从而混合不同Tool的能力,但实际测试的时候,由于暂时还不支持设置不同 LoRA weights的权重,往往效果不太好,存在覆盖或者遗忘的情况。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/171895.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Django框架FAQ

文章目录 问题1:Django数据库恢复问题2:null和blank的区别3.报错 django.db.utils.IntegrityError: (1062, “Duplicate entry ‘‘ for key ‘mobile‘“)4.报错 Refused to display ‘url‘ in a frame because it set ‘X-Frame-Options‘ to deny5.报错 RuntimeError: cryp…

K8S篇之etcd数据备份与恢复

一、etcd备份与恢复 基本了解: 1、k8s 使用etcd数据库实时存储集群中的数据,安全起见,一定要备份。 2、备份只需要在一个节点上备份就可以了,每个节点上的数据是同步的;但是数据恢复是需要在每个节点上进行。 3、etcd…

C++ VS2015安装教程,下载和安装(下载地址+图解+详细步骤)

说明:VS2015的三个版本分别为: Visual Studio Community(社区版):满足大部分程序员的需求(推荐) Visual Studio Professional(专业版) Visual Studio Enterprise(企业版) 1、下载地址(这里只提供Community版) htt…

Python爬虫过程中DNS解析错误解决策略

在Python爬虫开发中,经常会遇到DNS解析错误,这是一个常见且也令人头疼的问题。DNS解析错误可能会导致爬虫失败,但幸运的是,我们可以采取一些策略来处理这些错误,确保爬虫能够正常运行。本文将介绍什么是DNS解析错误&am…

Shell编程入门--概念、特性、bash配置文件

目录 一、Shell概念1.定义2.分类和使用场景2.1.分类和切换2.2.使用场景 3.特性3.1.文件描述符与输出重定向3.2.历史命令---history3.3.别名 --alias3.4.命令排序执行3.5.部分快捷键3.6.通配符置换 4.脚本规范5.脚本运行方式5.1.bash脚本执行5.2.bash脚本测试 二、bash配置文件1…

燃气管网监测系统|全面保障燃气安全

根据新华日报的报道,2023年上半年,我国共发生了294起燃气事故,造成了57人死亡和190人受伤,燃气事故的发生原因有很多,其中涉及到燃气泄漏、设备故障等因素。因此,加强燃气安全管理,提高城市的安…

好消息!2023年汉字小达人市级比赛在线模拟题大更新:4个组卷+11个专项,助力孩子更便捷、有效、有趣地备赛

自从《中文自修》杂志社昨天发通知,官宣了2023年第十届汉字小达人市级比赛的日期和安排后,各路学霸们闻风而动,在自己本就繁忙的日程中又加了一项:备赛汉字小达人市级比赛,11月30日,16点-18点。 根据这几年…

MySQL中修改注释+报错1067错误时的解决方法

修改某字段的注释内容的mysql语句 ALTER TABLE consumption_table MODIFY COLUMN risk_level tinyint(1) NOT NULL DEFAULT 0 COMMENT 0-低 1-中 2-高;修改某字段的注释内容的mysql语句时报错1067的解决方法 首先执行MySQL语句:SET sql_mode ‘ALLOW_INVALID_DAT…

拼图游戏,开源代码

拼图游戏 最近玩了玩孩子的拼图游戏,感觉还挺好玩的,心血来潮要不动手做一个吧,偷懒摸鱼的时候可以来一把。 以下就是拼图游戏的界面截图。 体验地址 代码开源地址 心得体会 虽说是一个小游戏,但是需要注意的地方还是挺多的 方…

机器视觉目标检测 - opencv 深度学习 计算机竞赛

文章目录 0 前言2 目标检测概念3 目标分类、定位、检测示例4 传统目标检测5 两类目标检测算法5.1 相关研究5.1.1 选择性搜索5.1.2 OverFeat 5.2 基于区域提名的方法5.2.1 R-CNN5.2.2 SPP-net5.2.3 Fast R-CNN 5.3 端到端的方法YOLOSSD 6 人体检测结果7 最后 0 前言 &#x1f5…

Python---综合案例:通讯录管理系统---涉及点:列表、字典、死循环

需求: 开个一个通讯录的管理系统,主要用于实现存储班级中同学的信息(姓名、年龄、电话) 涉及点:列表、字典、死循环 相关链接:Python--列表及其应用场景---增、删、改、查。-CSDN博客 Python---字典---…

高防CDN:构筑网络安全的钢铁长城

在当今数字化的世界里,网络安全问题日益突显,而高防CDN(高防御内容分发网络)正如一座坚不可摧的钢铁长城,成为互联网安全的不可或缺之物。本文将深入剖析高防CDN在网络安全环境中的关键作用,探讨其如何构筑…