使用向量检索和rerank 在RAG数据集上实验评估hit_rate和mrr

文章目录

    • 背景
    • 简介
    • 代码实现
      • 自定义检索器
      • 向量检索实验
      • 向量检索和rerank 实验
    • 代码开源

背景

在前面部分 大模型生成RAG评估数据集并计算hit_rate 和 mrr 介绍了使用大模型生成RAG评估数据集与评估;

在 上文 使用到了BM25 关键词检索器。接下来,想利用向量检索器测试一下在RAG评估数据集上的 hit_rate 和 mrr;

简介

使用 向量检索 和 rerank 在给定RAG评估数据集上的实验计算 hit_rate 和 mrr;

对比了使用 rerank 和 不使用 rerank的实验结果;

步骤:

  1. 基于RAG评估数据集,构建nodes节点;
  2. 构建 CustomRetriever 自定义的检索器,在检索器中实现 向量检索和 rerank;
  3. 实验评估;

代码实现

from typing import Listfrom llama_index.core import SimpleDirectoryReader, VectorStoreIndex
from llama_index.core.base.base_retriever import BaseRetriever
from llama_index.core.evaluation import RetrieverEvaluator
from llama_index.core.indices.postprocessor import SentenceTransformerRerank
from llama_index.core.indices.vector_store import VectorIndexRetriever
from llama_index.core.node_parser import SentenceWindowNodeParser
from llama_index.core.settings import Settings
from llama_index.legacy.embeddings import HuggingFaceEmbedding
# from llama_index.legacy.schema import NodeWithScore, QueryBundle
from llama_index.core.schema import NodeWithScore, QueryBundle, QueryType, Node
from llama_index.core.evaluation import EmbeddingQAFinetuneDataset

利用数据集中的数据,构建nodes
pg_eval_dataset.json的下载地址: https://www.modelscope.cn/datasets/jieshenai/paul_graham_essay_rag/files

qa_dataset = EmbeddingQAFinetuneDataset.from_json("pg_eval_dataset.json")nodes = []
for key, value in qa_dataset.corpus.items():nodes.append(Node(id_=key, text=value))

m3e 向量编码模型
若想使用其他的编码模型,直接进行修改即可,modelscope和huggingface的编码模型都行;

from modelscope import snapshot_download
model_dir = snapshot_download('AI-ModelScope/m3e-base')
Settings.embed_model = HuggingFaceEmbedding(model_dir)
Settings.llm = None

由于huggingface被墙了,笔者使用的是 modelscope平台,model_dir 为编码模型在本地的绝对路径

自定义检索器

tok_k: 表示召回的节点数量,可自定义设置;

top_k = 10

定义向量检索器,还实现了rerank;

class CustomRetriever(BaseRetriever):"""Custom retriever that performs both Vector search and Knowledge Graph search"""def __init__(self, vector_retriever: VectorIndexRetriever, reranker=None) -> None:"""Init params."""super().__init__()self._vector_retriever = vector_retrieverself.reranker = rerankerdef _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:"""Retrieve nodes given query."""# print(query_bundle, isinstance(QueryBundle))retrieved_nodes = self._vector_retriever.retrieve(query_bundle)if self.reranker != 'None':retrieved_nodes = self.reranker.postprocess_nodes(retrieved_nodes, query_bundle)else:retrieved_nodes = retrieved_nodes[:top_k]return retrieved_nodesasync def _aretrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:"""Asynchronously retrieve nodes given query.Implemented by the user."""return self._retrieve(query_bundle)async def aretrieve(self, str_or_query_bundle: QueryType) -> List[NodeWithScore]:if isinstance(str_or_query_bundle, str):str_or_query_bundle = QueryBundle(str_or_query_bundle)return await self._aretrieve(str_or_query_bundle)

eval_results包含每个query的 hit_rate 和 mrr,display_results 计算平均;

import pandas as pd
def display_results(eval_results):"""计算平均 hit_rate 和 mrr"""metric_dicts = []for eval_result in eval_results:metric_dict = eval_result.metric_vals_dictmetric_dicts.append(metric_dict)full_df = pd.DataFrame(metric_dicts)hit_rate = full_df["hit_rate"].mean()mrr = full_df["mrr"].mean()metric_df = pd.DataFrame({"hit_rate": [hit_rate], "mrr": [mrr]})return metric_df

向量检索实验

index = VectorStoreIndex(nodes)
vector_retriever = VectorIndexRetriever(index=index, similarity_top_k=top_k)
retriever_evaluator = RetrieverEvaluator.from_metric_names(["mrr", "hit_rate"], retriever=vector_retriever
)
eval_results = await retriever_evaluator.aevaluate_dataset(qa_dataset)
display_results(eval_results)

在这里插入图片描述

向量检索和rerank 实验

bge_reranker_base = SentenceTransformerRerank(model=snapshot_download("Xorbits/bge-reranker-base"),top_n=top_k)retriever = CustomRetriever(vector_retriever=vector_retriever,reranker=bge_reranker_base)retriever_evaluator = RetrieverEvaluator.from_metric_names(["mrr", "hit_rate"], retriever=retriever
)
eval_results = await retriever_evaluator.aevaluate_dataset(qa_dataset)
display_results(eval_results)

在这里插入图片描述
若想使用其他的rerank模型,更换Xorbits/bge-reranker-base

若使用modelscope平台的rerank模型,直接修改模型名即可;
若使用huggingface 平台的rerank模型,自行修改代码;

上述对比了,在向量检索下,对比了添加rerank和不添加rerank的实验结果;
如上图所示,相比只有向量检索的实验,加了rerank mrr 反而还下降了,这是一个比较反常的实验结果;

这个并不能说明rerank没有用,笔者在其他的RAG数据集测试时,rerank确实能提升mrr;本例子这里的情况大家忽略即可。
在本实验这里仅仅是给读者展示如何使用rerank;这也说明了rerank模型,也并不都能提升所有的mrr;

代码开源

本项目的完整代码,已发布到modelscope平台上;
点击下述链接查看代码:
https://www.modelscope.cn/datasets/jieshenai/paul_graham_essay_rag/file/view/master/vector_rerank_eval.ipynb?status=1

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/596325.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

wireshark数据流分析学习日记day3-从 Pcap 导出对象

从 HTTP 流量导出文件 过滤http请求 发现get请求上传了两个文件 保存即可 也可以保存网页 点击保存 改文件名为html结尾以便于访问 请谨慎使用此方法。如果从 pcap 中提取恶意 HTML 代码并在 Web 浏览器中查看它,则 HTML 可能会调用恶意域,这就是为什么…

递归实现组合型枚举(acwing)

题目描述: 从 1∼n 这 n 个整数中随机选出 m 个,输出所有可能的选择方案。 输入格式: 两个整数 n,m ,在同一行用空格隔开。 输出格式: 按照从小到大的顺序输出所有方案,每行 1 个。 首先,同一行内的数…

机器学习 - multi-class 数据集训练 (含代码)

直接上代码 # Multi-class datasetimport numpy as np RANDOM_SEED 42 np.random.seed(RANDOM_SEED) N 100 # number of points per class D 2 # dimensionality K 3 # number of classes X np.zeros((N*K, D)) y np.zeros(N*K, dtypeuint8) for j in range(K):ix rang…

摆动序列(力扣376)

文章目录 题目前知题解一、思路二、解题方法三、Code 总结 题目 Problem: 376. 摆动序列 如果连续数字之间的差严格地在正数和负数之间交替,则数字序列称为 摆动序列 。第一个差(如果存在的话)可能是正数或负数。仅有一个元素或者含两个不等元…

C语言程序编译全流程,从源代码到二进制

源程序 对于一个最简单的程序&#xff1a; int main(){int a 1;int b 2;int c a b;return 0; }预处理 处理源代码中的宏指令&#xff0c;例如#include等 clang -E test.c处理结果&#xff1a; # 1 "test.c" # 1 "<built-in>" 1 # 1 "&…

本地Windows打包启动前端后台

本地Windows打包启动前端后台 1、安装jdk Windows JDK安装 2、Nginx 2.1、将 nginx-1.16.1文件夹复制到D:\home\jisapp目录下 2.2、域名证书配置&#xff1a; 将域名证书放到D:\home\jisapp\ssl\2023目录下->配置nginx.conf文件&#xff08;D:\home\jisapp\nginx-1.22.0…

智能感应门改造工程

今天记录一下物联网专业学的工程步骤及实施过程 智能感应门改造工程 1 规划设计1.1 项目设备清单1.2项目接线图 软件设计信号流 设备安装与调试工程函数 验收 1 规划设计 1.1 项目设备清单 1.2项目接线图 软件设计 信号流 设备安装与调试 工程函数 工程界面: using System; …

新手如何开始运营朋友圈?分享朋友圈前5条内容运营技巧!

最近加入的新伙伴比较多&#xff0c;不少伙伴反馈一个问题&#xff1a;作为新人&#xff0c;前期我们的朋友圈要如何发&#xff1f;要怎么开始发朋友&#xff1f;要怎么配图&#xff0c;怎么配文案&#xff1f; 为了解决新伙伴们的这个问题&#xff0c;今天同伙伴们分享&#…

7.二叉树的遍历方式及二叉树习题

4.二叉树链式结构的实现 二叉树是&#xff1a; 空树 非空&#xff1a;根节点&#xff0c;根节点的左子树、根节点的右子树组成的。 4.1二叉树的遍历 4.2.1 前序、中序以及后序遍历 前序遍历(Preorder Traversal 亦称先序遍历)——访问根结点的操作发生在遍历其左右子树之前…

华清远见STM32MP157开发板助力嵌入式大赛ST赛道MPU应用方向项目开发

第七届&#xff08;2024&#xff09;全国大学生嵌入式芯片与系统设计竞赛&#xff08;以下简称“大赛”&#xff09;已经拉开帷幕&#xff0c;大赛的报名热潮正席卷而来。嵌入式大赛截止今年已连续举办了七届&#xff0c;为教育部认可的全国普通高校大学生国家级A类赛事&#x…

数据结构和算法:分治

分治算法 分治&#xff08;divide and conquer&#xff09;&#xff0c;全称分而治之&#xff0c;是一种非常重要且常见的算法策略。分治通常基于递归实现&#xff0c;包括“分”和“治”两个步骤。 1.分&#xff08;划分阶段&#xff09;&#xff1a;递归地将原问题分解为两个…

Spring源码解析-容器基本实现

spring源码解析 整体架构 defaultListableBeanFactory xmlBeanDefinitionReader 创建XmlBeanFactory 对资源文件进行加载–Resource 利用LoadBeandefinitions(resource)方法加载配置中的bean loadBeandefinitions加载步骤 doLoadBeanDefinition xml配置模式 validationMode 获…