用通俗易懂的方式讲解:对 embedding 模型进行微调,我的大模型召回效果提升了太多了

QA对话目前是大语言模型的一大应用场景,在QA对话中,由于大语言模型信息的滞后性以及不包含业务知识的特点,我们经常需要外挂知识库来协助大模型解决一些问题。

在外挂知识库的过程中,embedding模型的召回效果直接影响到大模型的回答效果,因此,在许多场景下,我们都需要微调我们的embedding模型来提高我们的召回效果。

码字不易,喜欢记得收藏、点赞、关注,如果你希望技术交流&答疑,见文末

下面,我就基于llama-index对BAAI/bge-base-zh-v1.5模型进行微调,关于该模型的介绍,可以参考https://huggingface.co/BAAI/bge-base-zh-v1.5。

平台介绍

对embedding模型进行微调的过程中需要使用GPU加速训练,我这里就使用了Google colab提供的免费T4GPU进行微调测试。

如果大家没办法使用这个,可以使用国内一些公司的GPU云平台,租便宜的GPU就行,微调这个模型所耗费的GPU资源不多。

以下所有训练代码皆是在Jupter-notebook上编写并执行的。

依赖安装

安装一些依赖库,有些依赖需要制定版本,否则存在不兼容的问题。

!pip install langchain==0.0.300 llmx==0.0.15a0 openai==0.28.1 llama_index==0.8.23.post1 pypdf sentence-transformers

训练样本准备

我们当前的使用场景是QA问答场景,因此训练数据的格式最好也是问答的格式。我这里由于没有现成的问答样本(人工整理比较耗时),因此我就摘取了《明朝那些事儿》这个小说里面的部分章节,然后让GPT-3.5针对文章内容进行提问,从而形成问答对。代码如下

import json
import openai
import osfrom llama_index import SimpleDirectoryReader
from llama_index.node_parser import SimpleNodeParser
from llama_index.schema import MetadataMode
from llama_index import (VectorStoreIndex,SimpleDirectoryReader,ServiceContext,Response
)def load_corpus(docs, for_training=False, verbose=False):parser = SimpleNodeParser.from_defaults()if for_training:nodes = parser.get_nodes_from_documents(docs[:5], show_progress=verbose)else:nodes = parser.get_nodes_from_documents(docs[6:], show_progress=verbose)if verbose:print(f'Parsed {len(nodes)} nodes')return nodesSEC_FILE = ['embedding_test.txt'] print(f"Loading files {SEC_FILE}")reader = SimpleDirectoryReader(input_files=SEC_FILE)
docs = reader.load_data()
print(f'Loaded {len(docs)} docs')docs_nodes = load_corpus(docs, for_training=True, verbose=True)len(docs_nodes)train_nodes = docs_nodes[:75]  
print(f'Loaded {len(train_nodes)} train docs')
val_nodes = docs_nodes[76:] 
print(f'Loaded {len(val_nodes)} val docs')

构造训练集和测试集

使用GPT3.5基于小说内容生成对应的问题,最后生成train_dataset.json作为训练集,val_dataset.json作为验证集。

from llama_index.finetuning import (generate_qa_embedding_pairs,EmbeddingQAFinetuneDataset,
)
from llama_index.llms import OpenAIos.environ["OPENAI_API_KEY"] = "sk-************"
openai.api_key = os.environ["OPENAI_API_KEY"]
openai.api_base = "https://************"prompt="""下方是上下文信息。---------------------
{context_str}
---------------------根据提供的上下文信息和没有先验知识的原则,仅基于以下查询生成问题。你是一名教师/教授。你的任务是为即将到来的测验/考试设置{num_questions_per_chunk}个问题。这些问题应在文档中多样化,且仅限于所提供的上下文信息。
"""train_dataset = generate_qa_embedding_pairs(train_nodes, qa_generate_prompt_tmpl=prompt)
val_dataset = generate_qa_embedding_pairs(val_nodes, qa_generate_prompt_tmpl=prompt)train_dataset.save_json("train_dataset.json")
val_dataset.save_json("val_dataset.json")

微调Embedding模型

这里的微调都是使用的默认参数,在实际微调过程中,可根据实际情况进行调整。

from llama_index.finetuning import SentenceTransformersFinetuneEngine
train_dataset = EmbeddingQAFinetuneDataset.from_json("train_dataset.json")
val_dataset = EmbeddingQAFinetuneDataset.from_json("val_dataset.json")
finetune_engine = SentenceTransformersFinetuneEngine(train_dataset,model_id="BAAI/bge-base-zh-v1.5",model_output_path="test_model",val_dataset=val_dataset,
)
finetune_engine.finetune() 
embed_model = finetune_engine.get_finetuned_model()
embed_model

评估微调后的模型

在评估阶段,我们对比了微调前、后的BAAI/bge-base-zh-v1.5模型以及OPENAI的ada002的Embedding模型的召回效果,代码如下:

from llama_index.embeddings import OpenAIEmbedding
from llama_index import ServiceContext, VectorStoreIndex
from llama_index.schema import TextNode
from tqdm.notebook import tqdm
import pandas as pd
def evaluate(dataset,embed_model,top_k=5,verbose=False,
):corpus = dataset.corpusqueries = dataset.queriesrelevant_docs = dataset.relevant_docsservice_context = ServiceContext.from_defaults(embed_model=embed_model)nodes = [TextNode(id_=id_, text=text) for id_, text in corpus.items()]index = VectorStoreIndex(nodes, service_context=service_context, show_progress=True)retriever = index.as_retriever(similarity_top_k=top_k)eval_results = []for query_id, query in tqdm(queries.items()):retrieved_nodes = retriever.retrieve(query)retrieved_ids = [node.node.node_id for node in retrieved_nodes]expected_id = relevant_docs[query_id][0]is_hit = expected_id in retrieved_ids  eval_result = {"is_hit": is_hit,"retrieved": retrieved_ids,"expected": expected_id,"query": query_id,}eval_results.append(eval_result)return eval_results

注意,在执行下面的代码前,需要先在当前项目的目录下创建results文件夹,否则会导致程序执行失败。

from sentence_transformers.evaluation import InformationRetrievalEvaluator
from sentence_transformers import SentenceTransformerdef evaluate_st(dataset,model_id,name,
):corpus = dataset.corpusqueries = dataset.queriesrelevant_docs = dataset.relevant_docsevaluator = InformationRetrievalEvaluator(queries, corpus, relevant_docs, name=name)model = SentenceTransformer(model_id)return evaluator(model, output_path="results/")

OPENAI-ada002

ada = OpenAIEmbedding()
ada_val_results = evaluate(val_dataset, ada)
df_ada = pd.DataFrame(ada_val_results)
hit_rate_ada = df_ada['is_hit'].mean()
hit_rate_ada

ada002模型的最终评测结果为0.9285714285714286

原始BAAI/bge-base-zh-v1.5

bge = "local:BAAI/bge-base-zh-v1.5"
bge_val_results = evaluate(val_dataset, bge)
df_bge = pd.DataFrame(bge_val_results)
hit_rate_bge = df_bge['is_hit'].mean()
hit_rate_bge

原始的bge-base-zh-v1.5模型的评测结果为0.7663744588744589

微调后的BAAI/bge-base-zh-v1.5

finetuned = "local:test_model"
val_results_finetuned = evaluate(val_dataset, finetuned)
df_finetuned = pd.DataFrame(val_results_finetuned)
hit_rate_finetuned = df_finetuned['is_hit'].mean()
hit_rate_finetuned

微调后模型的最终评测结果为0.975。即微调后,我们的embedding模型在当前数据集的召回效果由0.766上升到0.975注意,得分并不是越高越好,需考虑是否过拟合,可以在其他数据集上再评测下。

以上,即是一次简单的微调过程。感谢技术的发展和开源大佬们的贡献,使得人工智能的应用门槛越来越低。

技术交流

技术要学会分享、交流,不建议闭门造车。一个人可以走的很快、一堆人可以走的更远。

本文完整代码、相关资料、技术交流&答疑,均可加我们的交流群获取,群友已超过2000人,添加时最好的备注方式为:来源+兴趣方向,方便找到志同道合的朋友。

方式①、微信搜索公众号:机器学习社区,后台回复:加群
方式②、添加微信号:mlc2060,备注:来自CSDN + 技术交流

在这里插入图片描述

通俗易懂讲解大模型系列

  • 用通俗易懂的方式讲解:大模型 RAG 在 LangChain 中的应用实战

  • 用通俗易懂的方式讲解:一文讲清大模型 RAG 技术全流程

  • 用通俗易懂的方式讲解:如何提升大模型 Agent 的能力?

  • 用通俗易懂的方式讲解:使用 Mistral-7B 和 Langchain 搭建基于PDF文件的聊天机器人

  • 用通俗易懂的方式讲解:ChatGPT 开放的多模态的DALL-E 3功能,好玩到停不下来!

  • 用通俗易懂的方式讲解:结合检索和重排序模型,改善大模型 RAG 效果明显

  • 用通俗易懂的方式讲解:基于扩散模型(Diffusion),文生图 AnyText 的效果太棒了

  • 用通俗易懂的方式讲解:在 CPU 服务器上部署 ChatGLM3-6B 模型

  • 用通俗易懂的方式讲解:ChatGLM3-6B 功能原理解析

  • 用通俗易懂的方式讲解:使用 LangChain 和大模型生成海报文案

  • 用通俗易懂的方式讲解:一个强大的 LLM 微调工具 LLaMA Factory

  • 用通俗易懂的方式讲解:ChatGLM3-6B 部署指南

  • 用通俗易懂的方式讲解:LangChain Agent 原理解析

  • 用通俗易懂的方式讲解:HugggingFace 推理 API、推理端点和推理空间使用详解

  • 用通俗易懂的方式讲解:使用 LangChain 封装自定义的 LLM,太棒了

  • 用通俗易懂的方式讲解:使用 FastChat 部署 LLM 的体验太爽了

  • 用通俗易懂的方式讲解:基于 Langchain 和 ChatChat 部署本地知识库问答系统

  • 用通俗易懂的方式讲解:使用 Docker 部署大模型的训练环境

  • 用通俗易懂的方式讲解:在 Ubuntu 22 上安装 CUDA、Nvidia 显卡驱动、PyTorch等大模型基础环境

  • 用通俗易懂的方式讲解:Llama2 部署讲解及试用方式

  • 用通俗易懂的方式讲解:LangChain 知识库检索常见问题及解决方案

  • 用通俗易懂的方式讲解:基于 LangChain 和 ChatGLM2 打造自有知识库问答系统

  • 用通俗易懂的方式讲解:代码大模型盘点及优劣分析

  • 用通俗易懂的方式讲解:Prompt 提示词在开发中的使用

  • 用通俗易懂的方式讲解:万字长文带你入门大模型

参考资料

1.https://github.com/wenqiglantz/nvidia-sec-finetuning/tree/main/embedding-finetuning

2.https://colab.research.google.com/github/wenqiglantz/nvidia-sec-finetuning/blob/main/embedding-finetuning/finetune_embedding_nvidia_sec.ipynb

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/342083.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

四、C++运算符(5)逻辑运算符

作用&#xff1a;用于根据表达式的值返回真值或假值 #define _CRT_SECURE_NO_WARNINGS #include<iostream> #include<string> using namespace std; int main() {//逻辑运算符 非&#xff01;int a 10;int b 20;//在c中除了0都是真cout << !a << end…

小红书种草类型有哪些,小红书营销攻略

我们都知道小红书是个内容平台。用户来这可以看到各种类型的笔记&#xff0c;从笔记中获取自己想要了解的内容。这也就意味着平台上有着许多种不同的笔记类型。今天我们和大家分享下小红书种草类型有哪些&#xff0c;小红书营销攻略&#xff01; 1. 明星带货类 顾名思义&#x…

职称为什么要提前报名?⬇️ ⬇️

评职称需要一堆的材料&#xff0c;比如&#xff1a; 论文发表&#xff0c;从写作到选期刊到发表&#xff0c;需要3-12个月的时间继续教育&#xff0c;职称评审对于继续教育是有学时要求的&#xff0c;这一点申报职称的人想必都清楚&#xff0c;但具体学时要求,就要根据当时的要…

5. 属性自动填充

《阿里巴巴Java开发手册》&#xff0c;在第 5 章 MySQL 数据库可以看到这样一条规范&#xff1a; 对于一张数据表&#xff0c;它必须具备三个字段&#xff1a; id : 唯一IDgmt_create : 保存的是当前数据创建的时间gmt_modified : 保存的是更新时间 改造一下数据表&#xf…

python炒股自动化(0),申请券商API接口

上次发了量化交易接口的区别&#xff0c;发现很多人根本不知道券商提供的API交易接口&#xff0c;这里补充一篇&#xff0c;关于券商接口的介绍。 现在市面上可以给个人账户接入的股票交易接口&#xff0c;用的最多的也就是QMT和Ptrade&#xff0c;以前接入量化交易需要机构或…

java通过okhttp方式实现https请求的工具类(绕过证书验证)

目录 一、引入依赖包二、okhttp方式实现的https请求工具类2.1、跳过证书配置类2.2、okhttp方式的 https工具类 三、测试类 一、引入依赖包 引入相关依赖包 <!--okhttp依赖包--> <dependency><groupId>com.squareup.okhttp3</groupId><artifactId>…

Linux(Centos7)安装 jenkins(jdk11+jenkins2.375),并配置JDK,Maven,Git,GitLab

安装步骤 1. JDK11安装2. Maven安装3. git安装4. Jenkins2.375安装4.1 设置中文显示4.2 端口,用户权限修改4.3 插件下载4.4 全局工具配置4.4.1 Maven配置4.4.2 JDK配置4.4.3 Git配置 4.5 系统配置4.5.1 Gitee配置 4.6 构建测试 1. JDK11安装 #下载 yum -y install fontconfig …

SpringBoot 把PageHelper分页信息返回给前端

第1步&#xff1a;定义线程容器收纳HttpHeaders和HttpStatus import org.springframework.http.HttpHeaders; import org.springframework.http.HttpStatus;public class ResponseUtils {private static ThreadLocal<HttpHeaders> ThreadLocalHeaders new InheritableT…

SQL-DCL-如何用户管理,如何给用户权限?

&#x1f389;欢迎您来到我的MySQL基础复习专栏 ☆* o(≧▽≦)o *☆哈喽~我是小小恶斯法克&#x1f379; ✨博客主页&#xff1a;小小恶斯法克的博客 &#x1f388;该系列文章专栏&#xff1a;重拾MySQL &#x1f379;文章作者技术和水平很有限&#xff0c;如果文中出现错误&am…

计算机网络-2019期末考试解析

【前言】 从内容上看比较像计算机网络课程了&#xff0c;先做了。 一&#xff0e;填空选择题&#xff08;共 20 分&#xff0c;每空 1 分&#xff09; 1 、双绞线由两根相互绝缘的、绞合成均匀的螺纹状的导线组成&#xff0c;下列关于双绞线的叙述&#xff0c;不正确的是___ __…

多端多用户万能DIY商城系统源码:自营+多商户入驻商城系统 独立部署 带完整的安装代码包以及搭建教程

电子商务行业日新月异&#xff0c;许多企业希望能够通过线上商城拓展业务。但是&#xff0c;传统商城系统往往无法满足多样化、个性化的需求&#xff0c;而且开发周期长、成本高。罗峰就来给大家分享一款多端多用户万能DIY商城系统源码&#xff0c;搭建简单。 以下是部分代码示…

win系统搭建Minecraft世界服务器,MC开服教程,小白开服教程

Windows系统搭建我的世界世界服务器&#xff0c;Minecraft开服教程&#xff0c;小白开服教程&#xff0c;MC 1.19.4版本服务器搭建教程。 此教程使用 Mohist 1.19.4 服务端&#xff0c;此服务端支持Forge模组和Bukkit/Spigot/Paper插件&#xff0c;如果需要开其他服务端也可参…