基于lora技术对Gemma(2B)大模型的微调实践

一、概述

本文主要基于Lora技术,在Google colab上用A100对Gemma 2B大模型进行了指令微调,第一次指令微调是采用databricks-dolly-15k 作为数据集,取得了不错的微调效果,能准确用英文回答问题,但databricks-dolly-15k 毕竟是英文数据集,微调后的模型对中文的理解并不好。为了使模型对中文有更好的理解,笔者采用COIG-CQIA数据集对模型进行了指令微调,并展示了微调前后的效果对比。

《两个数据集说明》

databricks-dolly-15k 是一个开源数据集,其中包含数千名 Databricks 员工在 InstructGPT 论文中概述的几个行为类别中生成的指令跟踪记录,包括头脑风暴、分类、封闭式 QA、生成、信息提取、开放式 QA 和摘要。

图:databricks-dolly-15k Dataset card

图:databricks-dolly-15k 具体内容

COIG-CQIA全称为Chinese Open Instruction Generalist - Quality is All You Need, 是一个开源的高质量指令微调数据集,旨在为中文NLP社区提供高质量且符合人类交互行为的指令微调数据。COIG-CQIA以中文互联网获取到的问答及文章作为原始数据,经过深度清洗、重构及人工审核构建而成。

图:COIG-CQIA-full Dataset card

图:COIG-CQIA-full 具体内容

二、前置条件

获得模型访问权,选择Colab运行时,配置训练环境。

先在Kaggle上注册,然后获得Gemma 2B 的访问权;

然后在Google colab 配置环境,主要是GPU的选择,免费的是T4,建议采用付费的A100(为了节省时间,微调训练耗时T4需要30分钟左右,A100只需要2分钟左右)

最后 在Kaggle 上的account上生成令牌文件(主要是usename 和 API Key),并将令牌文件配置到colab环境。

三、微调步骤

因为直接使用databricks-dolly-15k进行微调时,可以基于原有代码进行快速验证,为了采用COIG-CQIA-full数据集进行微调,最直接的想法就是把代码中databricks-dolly-15k部分替换为COIG-CQIA-full,然后对代码进行稍微修改,但这一步花费了笔者非常多的时间,因为无论如何修改,总会报错。最总采用了以下办法:

1、先通过代码上传COIG-CQIA-full代码,

2、然后将其转换为和databricks-dolly-15k同格式的内容,并将新内容命名为databricks-dolly-15k-fb,

3、然后下载到本地检测下内容是否正确。在正确的前提下,取databricks-dolly-15k-fb内容的前1600行,保存到另外一个文件databricks-dolly-15k-fb1,

4、最后使用databricks-dolly-15k-fb1进行微调。

通过此方法,可以基本不修改原有微调语料处理代码的前提下,完成微调训练。

图:COIG-CQIA-full 格式转化后内容

本次选择1600行,主要是为了减少上传的时间,当然也可以更少,openai建议的50行

Example count recommendations计数建议示例

To fine-tune a model, you are required to provide at least 10 examples. We typically see clear improvements from fine-tuning on 50 to 100 training examples with gpt-3.5-turbo but the right number varies greatly based on the exact use case.
要微调模型,您需要提供至少 10 个示例。我们通常会看到对 50 到 100 个训练示例进行微调的明显改进, gpt-3.5-turbo 但正确的数量会根据确切的用例而有很大差异。

We recommend starting with 50 well-crafted demonstrations and seeing if the model shows signs of improvement after fine-tuning. In some cases that may be sufficient, but even if the model is not yet production quality, clear improvements are a good sign that providing more data will continue to improve the model. No improvement suggests that you may need to rethink how to set up the task for the model or restructure the data before scaling beyond a limited example set.
我们建议从 50 个精心制作的演示开始,看看模型在微调后是否显示出改进的迹象。在某些情况下,这可能就足够了,但即使模型尚未达到生产质量,明显的改进也是一个好兆头,表明提供更多数据将继续改进模型。没有改进表明,在扩展到有限的示例集之前,您可能需要重新考虑如何为模型设置任务或重组数据。

此外数据格式的检测也非常关键,openai官网有专门的格式检查代码。

Check data formatting检查数据格式

Once you have compiled a dataset and before you create a fine-tuning job, it is important to check the data formatting. To do this, we created a simple Python script which you can use to find potential errors, review token counts, and estimate the cost of a fine-tuning job.
编译数据集后,在创建微调作业之前,检查数据格式非常重要。为此,我们创建了一个简单的 Python 脚本,您可以使用它来查找潜在错误、查看令牌计数以及估计微调作业的成本。

四、微调前后效果展示

基于databricks-dolly-15k微调的效果不做展示,主要展示下基于COIG-CQIA-full微调后的效果展示。

4.1微调前的表现

图:微调前问答表现

图:微调后问答表现

再看几个微调后的表现

图:微调后问答表现2

4.2微调前后训练参数对比

采用Lora微调,训练参数量由25亿降低到130万左右;

图:训练参数对比

4.3用A100微调训练的关键片段

训练耗时104s,80ms每步,这里如果采用T4,需要训练接近30分钟。

图:采用A100进行微调

五、关键源码

一、格式转换代码

import json
from google.colab import files  # 如果你是在Google Colab环境中运行,需要导入该模块进行文件上传# 提示上传文件
print("请上传 COIG-CQIA-full.jsonl 文件")
uploaded = files.upload()# 获取上传文件名
uploaded_filename = list(uploaded.keys())[0]# 读取上传的COIG-CQIA-full.jsonl文件
with open(uploaded_filename, "r", encoding="utf-8") as f:coig_data = f.readlines()# 转换格式
converted_data = []
for line in coig_data:coig_entry = json.loads(line.strip())converted_entry = {"instruction": coig_entry["instruction"],"context": coig_entry["input"],"response": coig_entry["output"],"category": "open_qa" if coig_entry["task_type"]["minor"] == ["问答"] else "classification"}converted_data.append(converted_entry)# 将转换后的数据写入databricks-dolly-15k-fb.jsonl文件
with open("databricks-dolly-15k-fb.jsonl", "w", encoding="utf-8") as f:for entry in converted_data:f.write(json.dumps(entry, ensure_ascii=False) + "\n")print("转换完成,结果已保存到 databricks-dolly-15k-fb.jsonl 文件中")

二、数据处理代码

import json
data = []
with open("databricks-dolly-15k-fb1.jsonl") as file:for line in file:features = json.loads(line)# Filter out examples with context, to keep it simple.if features["context"]:continue# Format the entire example as a single string.template = "Instruction:\n{instruction}\n\nResponse:\n{response}"data.append(template.format(**features))# Only use 1000 training examples, to keep it fast.
data = data[:1000]

三、Lora微调代码

# Enable LoRA for the model and set the LoRA rank to 4.
gemma_lm.backbone.enable_lora(rank=4)
gemma_lm.summary()
# Limit the input sequence length to 512 (to control memory usage).
gemma_lm.preprocessor.sequence_length = 512
# Use AdamW (a common optimizer for transformer models).
optimizer = keras.optimizers.AdamW(learning_rate=5e-5,weight_decay=0.01,
)
# Exclude layernorm and bias terms from decay.
optimizer.exclude_from_weight_decay(var_names=["bias", "scale"])gemma_lm.compile(loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),optimizer=optimizer,weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],
)
gemma_lm.fit(data, epochs=1, batch_size=1)

参考文档:

1、https://ai.google.dev/gemma/docs/lora_tuning

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/593944.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

无库存,无货源,怎么做视频号小店?

大家好,我是电商糖果 有朋友跟糖果说,这视频号小店非常火,自己想趁着这次的风口开店赚钱。 自己无库存,也无货源,能开店卖货不? 糖果告诉他,能卖货,而且还可以让他不到一个月时间…

6:算法基础--6.1:线性结构 ,6.2:查找算法

转上一节: http://t.csdnimg.cn/ql5Cdhttp://t.csdnimg.cn/ql5Cd 课程内容提要: 6:知识点考点详解 6.1:线性结构 通常分析时间复杂度的方法是从算法中选取-种对于所研究的问题来说是基本运算的操作,以 该操作重…

用顺序表实现通讯录

前言 这次的通讯录是基于上一篇的动态顺序表的基础上实现的,如果对动态顺序表不熟悉,可以打开这个链接阅读http://t.csdnimg.cn/9zJ5g,这里我们会调用动态顺序表的函数。 如果想看静态顺序表实现通讯录,可以打开这个链接阅读http:…

Redis缓存穿透和缓存雪崩

一、缓存穿透 1 什么是缓存穿透 缓存穿透说简单点就是大量请求的 key 根本不存在于缓存中,导致请求直接到了数据库上,根本没有经过缓存这一层。举个例子:某个黑客故意制造我们缓存中不存在的 key 发起大量请求,导致大量请求落到数…

1 【机器学习】统计学习的概念

🎈个人主页:豌豆射手^ 🎉欢迎 👍点赞✍评论⭐收藏 🤗收录专栏:李航统计学习笔记 🤝希望本文对您有所裨益,如有不足之处,欢迎在评论区提出指正,让我们共同学习…

如何借助Idea创建多模块的SpringBoot项目

目录 1.1、前言1.2、开发环境1.3、项目多模块结构1.4、新建父工程1.5、创建子模块1.6、编辑父工程的pom.xml文件 1.1、前言 springmvc项目,一般会把项目分成多个包:controler、service、dao、utl等,但是随着项目的复杂性提高,想复用其他一个模…

蓝桥-时间显示

目录 题目链接 代码 题目链接 1.时间显示 - 蓝桥云课 (lanqiao.cn) 代码 #include <bits/stdc.h> using namespace std;int main() {long long x;cin>>x;int h,m,s;x x / 1000 % (3600*24); // 毫秒化秒&#xff0c;并且保留最后一天的时间h x / 3600; //求得…

深入理解计算机系统 家庭作业 2.84

这题没有这个要求所以可以用 ? > : < 这种运算 以下代码用的是位级运算.因为我误解了题意 呜呜呜 想看用判断的代码请自行百度 ((((ux<<9>>9)<<((ux<<1>>24)-127)) - ((uy<<9>>9)<<((uy<<1>>24)-127)))>…

如何利用FLUENT计算流体力学方法解决大气与环境领域流动问题

ANSYS FLUENT是目前全球领先的商用CFD 软件&#xff0c;市场占有率达70%左右&#xff0c;是工程师和研究者不可多得的有力工具。由于采用了多种求解方法和多重网格加速收敛技术&#xff0c;因而FLUENT能达到最佳的收敛速度和求解精度。灵活的非结构化网格和基于解的自适应网格技…

学习网安(21)

第20章存在疑问&#xff0c;待开学后和老师求证改动后发布 中间件之一——apache 先说一下http协议——超文本传输协议 全称为&#xff1a;Hyper Text Transfor Protocol 用途&#xff1a;让用户通过浏览器发送请求到服务器端&#xff0c;接收客户端返回的数据&#xff0c;…

SRS 实时视频服务器搭建及使用

一、SRS 介绍 SRS是一个开源的&#xff08;MIT协议&#xff09;简单高效的实时视频服务器&#xff0c;支持RTMP、WebRTC、HLS、HTTP-FLV、SRT、MPEG-DASH和GB28181等协议。 SRS媒体服务器和FFmpeg、OBS、VLC、 WebRTC等客户端配合使用&#xff0c;提供流的接收和分发的能力&am…

独角数卡对接支付卡跳转问题解决方法

问题描述 最近在用独角数卡搭建了一个测试版的商店程序&#xff0c;结果却在对接易支付的过程中出现了卡跳转的问题&#xff0c;支付能正常完成&#xff0c;订单发卡也正常&#xff0c;就是会卡在这个弹窗页面无法正常跳转至订单查看页面。 本来这种BUG无关痛痒&#xff0c;但…