入门微调预训练Transformer模型

大家好,HuggingFace 为众多开源的自然语言处理(NLP)模型提供了强大的支持平台,让这些模型能够通过训练和微调来更好地服务于各种特定的应用场景。在大型语言模型(LLM)迅猛发展的今天,HuggingFace 提供的核心工具,特别是 Trainer 类,极大地优化了 NLP 模型的训练过程,开发者得以更加高效地实现模型定制和优化。

HuggingFace 的 Trainer 类是为 Transformer 模型量身打造的,不仅优化了模型的交互体验,还与 Datasets 和 Evaluate 等库实现了紧密集成,支持更高级的分布式训练,并能无缝对接 Amazon SageMaker 等基础设施服务。通过这种方式,可以更加便捷地进行模型训练和部署。

本文将通过一个实例,展示如何利用 HuggingFace 的 Trainer 类在本地环境中对 BERT 模型进行微调,以处理文本分类任务。并且重点介绍如何使用 HuggingFace 模型中心的预训练模型,而不是深入机器学习的理论基础。

 1.设置

示例将在 SageMaker Studio(https://aws.amazon.com/cn/sagemaker/studio/) 环境下进行操作,利用 ml.g4dn.12xlarge 实例搭载的 conda_python3 内核来完成任务。需要提醒的是,可以选择使用更小型的实例,但这可能会影响训练速度,具体取决于可用的 CPU/工作进程的数量。

使用 HuggingFace 数据集库下载数据集。

import datasets
from datasets import load_dataset

这里指定了训练数据集和评估数据集,会在训练循环中进行使用。

train_dataset = load_dataset("imdb", split="train")
test_dataset = load_dataset("imdb", split="test")
test_subset = test_dataset.select(range(100)) # 取数据的一个子集进行评估

对于任何文本数据,必须指定一个分词器,将数据预处理成模型可以理解的格式。在这种情况下,这里指定了我们使用的 BERT 模型的 HuggingFace 模型中心 ID。

from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")# 分词文本数据
def tokenize_function(examples):return tokenizer(examples["text"], padding="max_length", truncation=True)

然后使用内置的 map 函数处理我们的训练和评估数据集。

tokenized_train = train_dataset.map(tokenize_function, batched=True)
tokenized_test = test_subset.map(tokenize_function, batched=True)

图片

预处理后的数据

2.微调 BERT

数据准备就绪后,利用先前选定的模型ID来加载BERT模型。需要注意的是,针对文本分类任务,还定义了标签的总数。在此案例中设定了两个标签,分别用0和1来表示,0代表负面,1代表正面。

from transformers import AutoModelForSequenceClassification
model = AutoModelForSequenceClassification.from_pretrained("bert-base-cased", num_labels=2)

接下来在训练循环中,需要定义一个TrainingArguments对象。在这个对象中,可以设置训练过程中的各种参数,比如训练周期的数量、分布式训练的策略等。

from transformers import TrainingArguments, Trainer
training_args = TrainingArguments(output_dir="test_trainer", evaluation_strategy="epoch", num_train_epochs=1)

对于评估,使用 Evaluate 库内置的评估函数。

import numpy as np
import evaluate
metric = evaluate.load("accuracy")# 评估函数
def compute_metrics(eval_pred):logits, labels = eval_predpredictions = np.argmax(logits, axis=-1)return metric.compute(predictions=predictions, references=labels)

然后将 TrainingArguments、分词数据集和评估指标函数传递给 Trainer 对象。可以使用 train 方法启动训练运行,这将需要大约 10-15 分钟的时间,具体取决于现有硬件。

trainer = Trainer(model=model,args=training_args,train_dataset=tokenized_train,eval_dataset=tokenized_test, # 使用测试作为评估compute_metrics=compute_metrics,tokenizer=tokenizer
)
trainer.train()

图片

训练完成

对于推理,可以直接使用微调后的 trainer 对象,并在用于评估的分词测试数据集上进行预测:

trainer.predict(tokenized_test)

图片

输出

在更为实际的应用场景中,可以使用 trainer 对象将模型工件保存到本地目录中。

trainer.save_model("./custom_model")

图片

模型工件

然后可以加载这些模型工件,指定训练的模型类型,并在单个数据点上进行推理。

loaded_model = AutoModelForSequenceClassification.from_pretrained(pretrained_model_name_or_path="custom_model/")# 样本推理
encoding = tokenizer("I am super delighted", return_tensors="pt")
res = loaded_model(**encoding)
predicted_label_classes = res.logits.argmax(-1)
predicted_label_classes

图片

正面分类

在现实应用场景中,可以将训练好的模型工件部署到像 Amazon SageMaker 这样的服务堆栈上。

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/600239.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

qt 打印日志

在 Qt Creator 中,将 QDebug、QInfo、QWarning、QCritical 和 QFatal 打印的日志输出到指定文件,需要设置 Qt 的消息处理机制。这通常涉及到安装一个自定义的消息处理器,该处理器将日志消息重定向到文件。以下是一个基本的步骤指南&#xff1…

第九届蓝桥杯大赛个人赛省赛(软件类)真题C 语言 A 组-航班时间

#include<iostream> using namespace std;int getTime(){int h1, h2, m1, m2, s1, s2, d 0;//d一定初始化为0&#xff0c;以正确处理不跨天的情况 scanf("%d:%d:%d %d:%d:%d (%d)", &h1, &m1, &s1, &h2, &m2, &s2, &d);return d …

【JAVASE】带你了解面向对象三大特性之一(多态)

✅作者简介&#xff1a;大家好&#xff0c;我是橘橙黄又青&#xff0c;一个想要与大家共同进步的男人&#x1f609;&#x1f609; &#x1f34e;个人主页&#xff1a;再无B&#xff5e;U&#xff5e;G-CSDN博客 1.多态 1.1 多态的概念 多态的概念&#xff1a;通俗来说&#…

MySQL的基本查询

&#x1f4df;作者主页&#xff1a;慢热的陕西人 &#x1f334;专栏链接&#xff1a;MySQL &#x1f4e3;欢迎各位大佬&#x1f44d;点赞&#x1f525;关注&#x1f693;收藏&#xff0c;&#x1f349;留言 本博客主要内容介绍了mysql的基本查询部分的知识&#xff0c;包括Crea…

WebAPI(一)之DOM操作元素属性和定时器

webAPI之DOM操作元素属性和定时器 介绍概念DOM 树DOM 节点document 获取DOM对象操作元素内容操作元素属性常用属性修改控制样式属性操作表单元素属性自定义属性 间歇函数今日单词 了解 DOM 的结构并掌握其基本的操作&#xff0c;体验 DOM 的在开发中的作用 知道 ECMAScript 与 …

大厂高频面试题复习JAVA学习笔记-JUC多线程及高并发(下)

目录 7 阻塞队列知道吗? 概念​编辑 synchronized和lock的区别 虚假唤醒情况 ​编辑​编辑 8 线程池用过吗?ThreadPoolExecutor谈谈你的理解? Callable接口 线程池 Executors工具类 线程池底层原理 线程池七大参数 七大参数 底层原理 9 线程池用过吗?生产上你如…

Cyber Weekly #1

赛博新闻 1、弱智吧竟成最佳中文AI训练数据&#xff1f;&#xff01;中科院等&#xff1a;8项测试第一&#xff0c;远超知乎豆瓣小红书 使用弱智吧数据训练的大模型&#xff0c;跑分超过百科、知乎、豆瓣、小红书等平台&#xff0c;甚至是研究团队精心挑选的数据集。弱智吧数…

基于单片机数码管20V电压表仿真设计

**单片机设计介绍&#xff0c;基于单片机数码管20V电压表仿真设计 文章目录 一 概要二、功能设计设计思路 三、 软件设计原理图 五、 程序六、 文章目录 一 概要 基于单片机数码管20V电压表仿真设计的主要目的是通过单片机和数码管显示电路实现一个能够测量0到20V直流电压的电…

漏洞挖掘 | 两个src案例分享

案例一 - 存储型XSS 文前废话:某天正在刷着**社区的帖子,突然间评论区的一条评论引起了我的注意,类似于下面这样 其中字体是蓝色的&#xff0c;这种评论在html标签中代码格式是<a>这是文字</a>这样的链接个格式。 同时评论区XSS漏洞的高发区,想着可能会有操作点 …

嵌入式工程师为什么要时刻保持危机感?

昨天有个特训营的铁子&#xff0c;说白天上班&#xff0c;晚上学习&#xff0c;有点干不动了&#xff0c;让我给他打打鸡血。 我突然词穷了&#xff0c;灵感枯竭了&#xff0c;就让徐工来发挥。 徐工的鸡汤是&#xff1a; 为什么大多数都很平庸&#xff0c;是因为想要进步&…

头盔检测 | 基于Caffe-SSD目标检测算法实现的建筑工地头盔检测

项目应用场景 面向建筑工地头盔检测场景&#xff0c;使用深度学习 Caffe SSD 目标检测算法&#xff0c;基于 C 实现。 项目效果 项目细节 > 具体参见项目 README.md (1) 安装 Caffe SSD(2) 执行训练 sh examples/Hardhat/SSD300/train_SSD300.sh (3) 部署算法 项目获取 h…

蓝桥杯-油漆面积

代码及其解析:(AC80%&#xff09; 思路:是把平面划成单位边长为1&#xff08;面积也是1&#xff09;的方格。每读入一个矩形&#xff0c;就把它覆盖的方格标注为已覆盖&#xff1b;对所有矩形都这样处理&#xff0c;最后统计被覆盖的方格数量即可。编码极其简单&#xff0c;但…