基于Sentence Transformer微调向量模型

Sentence Transformer库升级到了V3,其中对模型训练部分做了优化,使得模型训练和微调更加简单了,跟着官方教程走了一遍,顺利完成向量模型的微调,以下是对官方教程的精炼和总结。

一 所需组件

使用Sentence Transformer库进行向量模型的微调需要如下的组件:

  1. 数据数据: 用于训练和评估的数据。
  2. 损失函数 : 一个量化模型性能并指导优化过程的函数。
  3. 训练参数 (可选): 影响训练性能和跟踪/调试的参数。
  4. 评估器 (可选): 一个在训练前、中或后评估模型的工具。
  5. 训练器 : 将模型、数据集、损失函数和其他组件整合在一起进行训练。

二 数据集

大部分微调用到的数据都是本地的数据集,因此这里只提供本地数据的处理方法。如用其他在线数据可参考相对应的API。

1 数据类型

常见的数据类型为json、csv、parquet,可以使用load_dataset进行加载:

from datasets import load_datasetcsv_dataset = load_dataset("csv", data_files="my_file.csv")
json_dataset = load_dataset("json", data_files="my_file.json")
parquet_dataset = load_dataset("parquet", data_files="my_file.parquet")

2 数据格式

数据格式需要与损失函数相匹配。如果损失函数需要计算三元组,则数据集的格式为['anchor', 'positive', 'negative'],且顺序不能颠倒。如果损失函数计算的是句子对的相似度或者标签类别,则数据集中需要包含['label']或者['score'],其余列都会作为损失函数的输入。常见的数据格式和损失函数选择见表1。

三 损失函数

从链接整理了一些常见的数据格式和匹配的损失函数

Inputs Labels Appropriate Loss Functions
(sentence_A, sentence_B) pairs class SoftmaxLoss
(anchor, positive) pairs none MultipleNegativesRankingLoss
(anchor, positive/negative) pairs 1 if positive, 0 if negative ContrastiveLoss / OnlineContrastiveLoss
(sentence_A, sentence_B) pairs float similarity score CoSENTLoss / AnglELoss / CosineSimilarityLoss
(anchor, positive, negative) triplets none MultipleNegativesRankingLoss / TripletLoss

表1 常见的数据格式和损失函数

四 训练参数

配置训练参数主要是用于提升模型的训练效果,同时可以显示训练过程的进度或者其他参数信息,方便调试。

1 影响训练效果的参数

learning_rate lr_scheduler_type warmup_ratio num_train_epochs
max_steps per_device_train_batch_size per_device_evak_batch_size auto_find_batch_size
fp16 bf16 gradient_accumulation_steps gradient_checkpointing
eval_accmulation_steps optim batch_sampler multi_dataset_batch_sampler

2 观察训练过程的参数

eval_strategy eval_steps save_strategy save_steps
save_total_limit load_best_model_at_end report_to log_eval log_eval
logging_steps push_to_hub hub_model_id hub_strategy
hub_private_repo

五 评估器

评估器用于评估模型训练过程中的损失。同损失函数的选择一样,它也需要与数据格式相匹配,以下是评估器的选择依据。

Evaluator Required Data
BinaryClassificationEvaluator Pairs with class labels
EmbeddingSimilarityEvaluator Pairs with similarity scores
InformationRetrievalEvaluator Queries (qid => question), Corpus (cid => document), and relevant documents (qid => set[cid])
MSEEvaluator Source sentences to embed with a teacher model and target sentences to embed with the student model. Can be the same texts.
ParaphraseMiningEvaluator Mapping of IDs to sentences & pairs with IDs of duplicate sentences.
RerankingEvaluator List of {'query': '...', 'positive': [...], 'negative': [...]} dictionaries.
TranslationEvaluator Pairs of sentences in two separate languages.
TripletEvaluator (anchor, positive, negative) pairs.

六 训练器

训练器的作用是把先前的组件组合在一起使用。我们仅需要指定模型、训练数据、损失函数、训练参数(可选)、评估器(可选),就可以开始模型的训练。

from datasets import load_dataset
from sentence_transformers import (
SentenceTransformer,
SentenceTransformerTrainer,
SentenceTransformerTrainingArguments,
SentenceTransformerModelCardData,
)
from sentence_transformers.losses import MultipleNegativesRankingLoss
from sentence_transformers.training_args import BatchSamplers
from sentence_transformers.evaluation import TripletEvaluator# 1. Load a model to finetune with 2. (Optional) model card data
model = SentenceTransformer("microsoft/mpnet-base",model_card_data=SentenceTransformerModelCardData(language="en",license="apache-2.0",model_name="MPNet base trained on AllNLI triplets",)
)# 3. Load a dataset to finetune on
dataset = load_dataset("sentence-transformers/all-nli", "triplet")
train_dataset = dataset["train"].select(range(100_000))
eval_dataset = dataset["dev"]
test_dataset = dataset["test"]# 4. Define a loss function
loss = MultipleNegativesRankingLoss(model)# 5. (Optional) Specify training arguments
args = SentenceTransformerTrainingArguments(# Required parameter:output_dir="models/mpnet-base-all-nli-triplet",# Optional training parameters:num_train_epochs=1,per_device_train_batch_size=16,per_device_eval_batch_size=16,learning_rate=2e-5,warmup_ratio=0.1,fp16=True,  # Set to False if you get an error that your GPU can't run on FP16bf16=False,  # Set to True if you have a GPU that supports BF16batch_sampler=BatchSamplers.NO_DUPLICATES,  # MultipleNegativesRankingLoss benefits from no duplicate samples in a batch# Optional tracking/debugging parameters:eval_strategy="steps",eval_steps=100,save_strategy="steps",save_steps=100,save_total_limit=2,logging_steps=100,run_name="mpnet-base-all-nli-triplet",  # Will be used in W&B if `wandb` is installed
)# 6. (Optional) Create an evaluator & evaluate the base model
dev_evaluator = TripletEvaluator(anchors=eval_dataset["anchor"],positives=eval_dataset["positive"],negatives=eval_dataset["negative"],name="all-nli-dev",
)
dev_evaluator(model)# 7. Create a trainer & train
trainer = SentenceTransformerTrainer(model=model,args=args,train_dataset=train_dataset,eval_dataset=eval_dataset,loss=loss,evaluator=dev_evaluator,
)
trainer.train()# (Optional) Evaluate the trained model on the test set
test_evaluator = TripletEvaluator(anchors=test_dataset["anchor"],positives=test_dataset["positive"],negatives=test_dataset["negative"],name="all-nli-test",
)
test_evaluator(model)# 8. Save the trained model
model.save_pretrained("models/mpnet-base-all-nli-triplet/final")

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/739818.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【攻防技术系列+代理转发】工具--netcat

【需求】现在想要实现两个不同网段的私网之间相互通信,我们该如何做呢?🔴实验环境:【kali(攻击端)】:192.168.10.131 【centos7(跳板机)】:192.168.10.39;172.16.80.130 【win7】:172.16.80.131 工具:netcat【kali】: 开启监听【centos7】:【kali】: 获得对方的…

基于FPGA的A律压缩解压缩verilog实现,包含testbench

1.算法仿真效果 VIVADO2019.2仿真结果如下(完整代码运行后无水印):RTL图如下所示:2.算法涉及理论知识概要A律压缩是一种广泛应用于语音编码的非均匀量化技术,尤其在G.711标准中被欧洲和中国等国家采纳。该技术的核心目的是在有限的带宽下高效传输语音信号,同时保持较高的…

LFU算法实现

LFU (Least Frequently Used) 是一种用于缓存管理的算法。它通过跟踪每个缓存项被访问的频率来决定哪些项应该被移除。LFU算法倾向于保留那些使用频率较高的项,而移除那些使用频率较低的项。以下是LFU算法的详细介绍: 工作原理计数器:每个缓存项都有一个计数器,用于记录该项…

灰色预测GM(1,1)模型的理论原理

灰色预测是对时间有关的灰色过程进行预测。通过建立相应的微分方程模型,从而预测事物未来发展趋势的状况。 由于笔者的水平不足,本章只是概括性地介绍GM(1,1)模型的理论原理,便于对初学者的初步理解 目录一、灰色系统二、GM(1,1)灰色预测模型1.生成累加数据与紧临均值生成…

JMonkeyEngine——材质文件备注

默认J3M编辑器不支持编辑纹理参数的Mag/Min滤波选项,只能配置Flip和Wrap模式,但是可以单独编辑J3M源码,如下: 添加你需要的Mag/Min滤波选项,参考源码的解析,就是Mag/Min+拼接对应的Filter值。 虽然打开J3M编辑器会报错: 但实际进游戏时并不会报错,而且一切正常,如下:…

04-JS中的面向对象ES5

JS面向对象基础01 JS对象中key的类型02 创建对象的方法03 对象的常见操作 3.1 访问对象的属性 <!DOCTYPE html> <html lang="en"> <head><meta charset="UTF-8"><meta http-equiv="X-UA-Compatible" content="I…

程序员的AI工作流

AI 工具在日常工作中的应用逐渐成为程序员必备利器。本文介绍了作者常用的一些 AI 工具及使用方式,涵盖需求文档分析、技术文档编写、编程、PR/CR 和技术调研等工作内容,为提升工作效率提供了有力支持。作为一名程序员, 我现在已经深刻的体会到了AI带来的巨大的工作提升 本文…

An Attentive Inductive Bias for Sequential Recommendation beyond the Self-Attention

目录概符号说明BSARec (Beyond Self-Attention for Sequential Recommendation)代码Shin Y., Choi J., Wi H. and Park N. An attentive inductive bias for sequential recommendation beyond the self-attention. AAAI, 2024.概 本文在 attention block 中引入高低频滤波. 符…

[Leetcode]经典算法

检测环 快慢指针法是一种用于检测链表中是否存在环的有效方法,同时也可以找到环的起点。该方法的原理基于两个指针在链表上同时移动,其中一个移动得更快,而另一个移动得更慢。检测环的存在:使用两个指针,一个称为快指针(fast),一个称为慢指针(slow)。 在每一步中,快…

关于import multiprocessing引用出错

关于import multiprocessing引用出错 0. 原因 当前文件名与python包体中关键词出现同名,导致循环引用 1. 排查过程 问题代码 import timefrom multiprocessing import Process, Queue # 这里提示错误def producer(queue):queue.put("a")time.sleep(2)def consumer(q…

进程信号

进程信号的产生,本质,进程信号的操作,进程信号的底层实现,以及阻塞信号,屏蔽信号1. 信号的产生 1.1 信号概念在生活中有很多的信号在我们身边围绕,例如红绿灯,发令枪,上课铃等等 在接受到信号,我们可以做出三种动作 1.立马去做对应信号的事情 2.等一会再做,有自己的…