从理论到实践---实现LLM微调的7个步骤

原文地址:7-steps-to-mastering-large-language-model-fine-tuning

From theory to practice, learn how to enhance your NLP projects with these 7 simple steps.

2024 年 3 月 27 日

在过去的一年半里,自然语言处理(NLP)领域发生了显著的变化,这主要归功于像OpenAI的GPT系列这样的大型语言模型(LLMs)的崛起。

这些强大的模型彻底改变了我们处理自然语言任务的方式,在翻译、情感分析和自动文本生成方面提供了前所未有的能力。它们理解和生成类似人类文本的能力,打开了曾经被认为无法实现的可能性。

然而,尽管这些模型功能强大,但训练它们的过程充满了挑战,如需要投入大量的时间和资金。

这让我们不得不重视LLM微调的关键作用。

通过精炼这些预训练模型,使其更好地适应特定的应用或领域,我们可以显著增强它们在特定任务上的表现。这一步骤不仅提升了它们的质量,还扩展了它们在众多领域的实用性。

本指南旨在将这一过程分解为7个简单步骤,以便任何LLM都能为特定任务进行微调。

了解预训练的大型语言模型

LLM 是一类专门的 ML 算法,旨在根据前面单词提供的上下文来预测序列中的下一个单词。这些模型建立在 Transformers 架构之上,这是机器学习技术的一项突破,并首先在 Google 的All you need is focus文章中进行了解释。

像GPT(生成式预训练Transformer)这样的模型是预训练语言模型的例子,它们已经接触过大量的文本数据。这种广泛的训练使它们能够捕获语言使用的潜在规则,包括单词如何组合成连贯的句子。

这些模型的一个关键优势在于它们不仅能够理解自然语言,还能够基于给定的输入产生类似于人类写作的文本。

那么,这些模型最大的优势是什么呢?

这些模型已经通过API向大众开放。

微调是什么,为什么它很重要?

微调是一个过程,即选择一个预训练模型,并使用特定领域的数据集对其进行进一步训练以改进其性能。

大多数LLM模型在自然语言技能和通用知识表现方面都非常出色,但在特定的任务导向型问题上却表现不佳。微调过程提供了一种方法,可以在不必从头开始构建模型的情况下,提高模型在特定问题上的性能,同时降低计算成本。

简而言之,微调可以调整模型以更好地执行特定任务,使其在现实应用中更加有效和灵活。对于改进特定任务或领域的现有模型而言,这一步骤至关重要。

微调LLM的逐步指南

让我们通过一个实例,仅在7个步骤中微调一个真实的模型来阐明这个概念。

第一步:明确具体目标

假设我们想要推断任何文本的情感,并决定尝试使用GPT-2来完成这项任务。

我确信,我们很快就会发现它在执行这项任务时表现相当糟糕。然后,一个自然的问题是:

我们能做些什么来改进它的性能吗?

当然,答案是肯定的!我们可以做到!

通过利用微调,我们将使用Hugging Face Hub上预训练的GPT-2模型和一个包含推文及其对应情感的数据集进行训练,以提高性能。

因此,我们的最终目标是拥有一个擅长从文本中推断情感的模型。

第二步:选择预训练模型和数据集

第二步是选择一个作为基准模型的模型。在我们的案例中,我们已经选择了模型:GPT-2。因此,我们将对它进行一些简单的微调。

微信截图_20240401111158

始终牢记要选择适合你的任务的模型。

第三步:加载要使用的数据

现在我们有了模型和主要任务,我们需要一些数据来工作。

但不用担心,Hugging Face已经安排好了!

这就是他们的数据集库发挥作用的地方。

在本例中,我们将利用Hugging Face数据集库导入一个包含标有相应情感(积极、中性或消极)的推文的数据集。

from datasets import load_datasetdataset = load_dataset("mteb/tweet_sentiment_extraction")
df = pd.DataFrame(dataset['train'])

数据如下:

微信截图_20240401111215

第四步:分词器

现在我们有了模型和要微调的数据集。因此,接下来的自然步骤是加载一个分词器。由于LLM使用令牌(而不是单词!),我们需要一个分词器将数据发送到我们的模型。

我们可以很容易地通过利用map方法来对整个数据集进行分词来做到这一点。

from transformers import GPT2Tokenizer# Loading the dataset to train our model
dataset = load_dataset("mteb/tweet_sentiment_extraction")
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_tokendef tokenize_function(examples):return tokenizer(examples["text"], padding="max_length", truncation=True)tokenized_datasets = dataset.map(tokenize_function, batched=True)

额外奖励:为了提高我们的处理性能,我们生成了两个较小的子集:

  • 训练集:用于微调我们的模型。
  • 测试集:用于评估模型。
small_train_dataset = tokenized_datasets["train"].shuffle(seed=42).select(range(1000))
small_eval_dataset = tokenized_datasets["test"].shuffle(seed=42).select(range(1000))
第五步:初始化基础模型

一旦我们有了要使用的数据集,我们就加载模型并指定预期标签的数量。从推文的情感数据集中,你可以知道有三种可能的标签:

  • 0 或 消极
  • 1 或 中立
  • 2 或 积极
from transformers import GPT2ForSequenceClassificationmodel = GPT2ForSequenceClassification.from_pretrained("gpt2", num_labels=3)
第六步:评估方法

Transformers库提供了一个名为“Trainer”的类,用于优化我们的模型的训练和评估。因此,在实际训练开始之前,我们需要定义一个函数来评估微调后的模型。

import evaluatemetric = evaluate.load("accuracy")def compute_metrics(eval_pred):logits, labels = eval_predpredictions = np.argmax(logits, axis=-1)return metric.compute(predictions=predictions, references=labels)
第七步:使用Trainer方法进行微调

最后一步是微调模型。为此,我们设置训练参数以及评估策略,并执行Trainer对象。

要执行Trainer对象,我们只需使用train()命令。

from transformers import TrainingArguments, Trainertraining_args = TrainingArguments(output_dir="test_trainer",#evaluation_strategy="epoch",per_device_train_batch_size=1,  # Reduce batch size hereper_device_eval_batch_size=1,    # Optionally, reduce for evaluation as wellgradient_accumulation_steps=4)trainer = Trainer(model=model,args=training_args,train_dataset=small_train_dataset,eval_dataset=small_eval_dataset,compute_metrics=compute_metrics,)trainer.train()

一旦我们的模型被微调,我们使用测试集来评估其性能。Trainer对象已经包含了一个优化的evaluate()方法。

import evaluatetrainer.evaluate()

这是一个对任何LLM进行微调的基本过程。

此外,请记住,微调LLM的过程对计算资源的需求很大,因此你的本地计算机可能没有足够的算力来执行它。

结论

如今,对像GPT这样的大型预训练语言模型进行特定任务的微调,对于提升LLM在特定领域的性能至关重要。它允许我们利用它们的自然语言处理能力,同时提高它们的效率和定制潜力,使该过程变得易于访问且成本效益高。

遵循这七个简单步骤——从选择合适的模型和数据集到训练和评估微调后的模型——我们可以在特定领域实现更出色的模型性能。

对于那些想要查看完整代码的人,可以在我的大型语言模型 GitHub 存储库中找到它。 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/586425.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

深入解析大数据体系中的ETL工作原理及常见组件

** 引言 关联阅读博客文章:探讨在大数据体系中API的通信机制与工作原理 关联阅读博客文章:深入理解HDFS工作原理:大数据存储和容错性机制解析 ** 在当今数字化时代,大数据处理已经成为了企业成功的重要组成部分。而在大数据处…

分月饼 java题解

import java.util.Scanner;public class Main {public static void main(String[] args) {Scanner sc new Scanner(System.in); int m sc.nextInt(); // 读取员工数量mint n sc.nextInt(); // 读取月饼数量n// 调用distribute方法并打印返回的分配方法总数//先默认每人分一个…

简单使用bootstrap-datepicker日期插件

目录 下载datepicker 方式一: 方式二: 下载依赖 下载bootstarp.js 下载jquery 使用示例 日期选择 单独选择年 单独选择月 单独选择日 设置截止日期 设置默认日期 总结 下载datepicker 方式一: 下载地址 GitHub - uxsolution…

软件测试(Junit5 单元测试框架)(五)

1. Junit单元测试框架 Junit 是 Java 的一个单元测试框架, 使用Selenium写自动化测试用例, 使用Junit 管理写好的测试用例. 2. 注解&#xff1a; Test 表示当前的这个方法是一个测试用例. 示例: 添加依赖 <!-- https://mvnrepository.com/artifact/org.junit.jupiter/junit-…

【THM】SQL Injection(SQL注入)-初级渗透测试

简介 SQL(结构化查询语言)注入,通常称为 SQLi,是对 Web 应用程序数据库服务器的攻击,导致执行恶意查询。当 Web 应用程序使用未经正确验证的用户输入与数据库进行通信时,攻击者有可能窃取、删除或更改私人数据和客户数据,并攻击 Web 应用程序身份验证方法以获取私有数据…

python对接百度云车牌识别

注册百度智能云&#xff0c;选择产品服务。 https://console.bce.baidu.com/ 每天赠送200次&#xff0c;做开发测试足够了。 在应用列表复制 AppID , API Key ,Secret Key 备用。 SDK下载地址 https://ai.baidu.com/sdk#ocr 下载SDK文件&#xff0c;解压&#xff0c;…

python+scrapy电影推荐系统可视化分析系统

在本系统的开发过程中&#xff0c;研究学习了如何使用scrapy、Django这两大框架&#xff0c;体会到了python语言的“极简至优美”&#xff0c;我接触到了这几个框架的前沿知识&#xff0c;对自己可以站在巨人的肩膀上兴奋不已。我在系统开发过程中&#xff0c;经历了由抓取数据…

67、yolov8目标检测和旋转目标检测算法batchsize=1/6部署Atlas 200I DK A2开发板上

基本思想:需求部署yolov8目标检测和旋转目标检测算法部署atlas 200dk 开发板上 一、转换模型 链接: https://pan.baidu.com/s/1hJPX2QvybI4AGgeJKO6QgQ?pwd=q2s5 提取码: q2s5 from ultralytics import YOLO# Load a model model = YOLO("yolov8s.yaml") # buil…

Linux之ssh服务

目录 一、ssh简介 ssh组件 二、配置文件 三、相关的命令 ssh scp 四、密钥认证 一、ssh简介 远程登陆linux用的就是ssh服务 ssh服务特点就是数据会机密传输 ssh组件 组件&#xff1a;openssh 服务器&#xff1a;sshd 默认端口&#xff1a;22 二、配置文件 /etc/ssh/ss…

如何在CentOS安装StackEdit Markdown编辑器并实现无公网IP远程访问使用

最近&#xff0c;我发现了一个超级强大的人工智能学习网站。它以通俗易懂的方式呈现复杂的概念&#xff0c;而且内容风趣幽默。我觉得它对大家可能会有所帮助&#xff0c;所以我在此分享。点击这里跳转到网站。 文章目录 前言1. ubuntu安装VNC2. 设置vnc开机启动3. windows 安…

镭速如何解决UDP传输不通的问题

我们之前有谈到过企业如果遇到UDP传输不通的情况&#xff0c;常见的一些解决方式&#xff0c;同时也介绍了一站式企业文件传输方式-镭速相关优势&#xff0c;如果在实际应用中&#xff0c;若镭速UDP传输出现不通的情况&#xff0c;需要按照网络通信的一般性排查方法以及针对镭速…

男裤哪个品牌质量好?一次教你学会怎么选男生裤子

相信大家每次在选衣服和裤子的时候都希望能够买到好看质量又好的&#xff0c;但现在市面却太多不同的品牌&#xff0c;并且质量也参差不齐&#xff0c;十分容易选择到一些质量不好的裤子。那么今天就专门对现在市面热门的几款男装裤子进行测评&#xff0c;并根据具体结果进行推…