大模型生成RAG评估数据集并计算hit_rate 和 mrr

文章目录

    • 背景
    • 简介
    • 代码实现
    • 公开
    • 参考资料

背景

最近在做RAG评估的实验,需要一个RAG问答对的评估数据集。在网上没有找到好用的,于是便打算自己构建一个数据集。

简介

本文使用大模型自动生成RAG 问答数据集。使用BM25关键词作为检索器,然后在问答数据集上评估该检索器的效果。
输入是一篇文本,使用llamaindex加载该文本,使用prompt让大模型针对输入的文本生成提问。
步骤如下:

  1. llamaindex 加载数据;
  2. 利用 chatglm3-6B 构建CustomLLM;
  3. 使用prompt和chatglm,结合文本生成对应的问题,构建RAG问答数据集;
  4. 使用BM25Retriever,构建基于关键词的检索器;
  5. 评估BM25Retriever在数据集上的hite_ratemrr结果;

由于在构建问答对时,让大模型结合文本生成对应的问题。笔者在测试时,发现关键词检索比向量检索效果要好

代码实现

导入包

from typing import List, Anyfrom llama_index.core import SimpleDirectoryReaderfrom llama_index.core.node_parser import SentenceWindowNodeParser
from llama_index.legacy.llms import (CustomLLM, CompletionResponse, CompletionResponseGen, LLMMetadata)
from llama_index.legacy.schema import NodeWithScore, QueryBundle, Node
from llama_index.core.base.base_retriever import BaseRetriever
from llama_index.legacy.retrievers import BM25Retriever
from llama_index.core.evaluation import RetrieverEvaluator
from llama_index.core.evaluation import (generate_question_context_pairs,EmbeddingQAFinetuneDataset,
)

加载数据,使用llamaindex网站的paul_graham_essay.txt

# Load data
documents = SimpleDirectoryReader(input_files=["data/paul_graham_essay.txt"]
).load_data()# create the sentence window node parser w/ default settings
node_parser = SentenceWindowNodeParser.from_defaults(window_size=3,window_metadata_key="window",original_text_metadata_key="original_text",
)# Extract nodes from documents
nodes = node_parser.get_nodes_from_documents(documents)# by default, the node ids are set to random uuids. To ensure same id's per run, we manually set them.
for idx, node in enumerate(nodes):node.id_ = f"node_{idx}"

大模型加载
chatglm3-6B 使用half,显存占用12G

from modelscope import snapshot_download
from modelscope import AutoTokenizer, AutoModelmodel_name = "chatglm3-6b"
model_path = snapshot_download('ZhipuAI/chatglm3-6b')
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
model = AutoModel.from_pretrained(model_path, trust_remote_code=True).half().cuda()
model = model.eval()

本地自定义大模型

# set context window size
context_window = 2048
# set number of output tokens
num_output = 256class ChatGML(CustomLLM):@propertydef metadata(self) -> LLMMetadata:"""Get LLM metadata."""return LLMMetadata(context_window=context_window,num_output=num_output,model_name=model_name,)def complete(self, prompt: str, **kwargs: Any) -> CompletionResponse:prompt_length = len(prompt)# only return newly generated tokenstext, _ = model.chat(tokenizer, prompt, history=[])return CompletionResponse(text=text)def stream_complete(self, prompt: str, **kwargs: Any) -> CompletionResponseGen:raise NotImplementedError()llm_model = ChatGML()

生成RAG测试数据集

# Prompt to generate questions
qa_generate_prompt_tmpl = """\
Context information is below.---------------------
{context_str}
---------------------Given the context information and not prior knowledge.
generate only questions based on the below query.You are a Professor. Your task is to setup \
{num_questions_per_chunk} questions for an upcoming \
quiz/examination. The questions should be diverse in nature \
across the document. The questions should not contain options, not start with Q1/ Q2. \
Restrict the questions to the context information provided.\
"""
# The questions should be solely based on the provided context information, and please pose them in Chinese.\
qa_dataset = generate_question_context_pairs(nodes,llm=llm_model,num_questions_per_chunk=2,qa_generate_prompt_tmpl=qa_generate_prompt_tmpl
)
qa_dataset.save_json("pg_eval_dataset.json")
# qa_dataset = EmbeddingQAFinetuneDataset.from_json("pg_eval_dataset.json")
import pandas as pddef display_results(eval_results):"""计算hit_rate和mrr的平均值"""metric_dicts = []for eval_result in eval_results:metric_dict = eval_result.metric_vals_dictmetric_dicts.append(metric_dict)full_df = pd.DataFrame(metric_dicts)hit_rate = full_df["hit_rate"].mean()mrr = full_df["mrr"].mean()metric_df = pd.DataFrame({"hit_rate": [hit_rate], "mrr": [mrr]})return metric_df
class JieRetriever(BM25Retriever, BaseRetriever):def _get_scored_nodes(self, query: str):tokenized_query = self._tokenizer(query)doc_scores = self.bm25.get_scores(tokenized_query)nodes = []for i, node in enumerate(self._nodes):node_new = Node.from_dict(node.to_dict())node_with_score = NodeWithScore(node=node_new, score=doc_scores[i])nodes.append(node_with_score)return nodesdef _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:if query_bundle.custom_embedding_strs or query_bundle.embedding:logger.warning("BM25Retriever does not support embeddings, skipping...")scored_nodes = self._get_scored_nodes(query_bundle.query_str)# Sort and get top_k nodes, score range => 0..1, closer to 1 means more relevantnodes = sorted(scored_nodes, key=lambda x: x.score or 0.0, reverse=True)return nodes[:self._similarity_top_k]
retriever = JieRetriever.from_defaults(
# retriever = BM25Retrieve r.from_defaults(nodes=nodes,similarity_top_k=10)

现在llamaindex在使用BM25Retrieve会报错,故笔者创建了JieRetriever,具体请点击查看链接

from llama_index.core.base.base_retriever import BaseRetrieverretriever_evaluator = RetrieverEvaluator.from_metric_names(["mrr", "hit_rate"], retriever=retriever)
eval_results = await retriever_evaluator.aevaluate_dataset(qa_dataset)
for idx, item in enumerate(eval_results):if idx == 15:breakd = item.metric_vals_dictmrr, hit_rate = d['mrr'], d['hit_rate']if mrr != 1 or hit_rate != 1:print(mrr, hit_rate, item.expected_ids, item.retrieved_ids)

下图展示了hit_rate 和 mrr 的计算:
在这里插入图片描述

结合下述结果,分析一下 hit_rate 和 mrr:

0.5 1.0 ['node_2'] ['node_71', 'node_2', 'node_0', 'node_199', 'node_126', 'node_419', 'node_446', 'node_218', 'node_1', 'node_70']
  • ['node_2'] 是 label
  • ['node_71', 'node_2', 'node_0', 'node_199', 'node_126', 'node_419', 'node_446', 'node_218', 'node_1', 'node_70'] 是检索器召回的候选列表;
  • mrr : 0.5;'node_2' 在候选列表的第二个位置,故mrr为 二分之一。在第几位就是几分之一;
  • hit_rate:代表label是否在候选集中,在就是1,不在就是0;
def display_results(eval_results):"""计算平均 hit_rate 和 mrr"""metric_dicts = []for eval_result in eval_results:metric_dict = eval_result.metric_vals_dictmetric_dicts.append(metric_dict)full_df = pd.DataFrame(metric_dicts)hit_rate = full_df["hit_rate"].mean()mrr = full_df["mrr"].mean()metric_df = pd.DataFrame({"hit_rate": [hit_rate], "mrr": [mrr]})return metric_df
display_results(eval_results)

在这里插入图片描述

公开

生成的评估数据集和相应示例代码,已上传到modelscope平台;

https://www.modelscope.cn/datasets/jieshenai/paul_graham_essay_rag/files

在这里插入图片描述

参考资料

  • https://www.llamaindex.ai/blog/boosting-rag-picking-the-best-embedding-reranker-models-42d079022e83

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/593826.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

WPS 不登录无法使用基本功能的解决办法

使用wps时,常常有个比较让人烦恼的事,在不登录的情况下,新建或者打开文档时,wps不让你使用其基本的功能,如设置字体等,相关界面变成灰色,这时Wps提示用户登录注册或登录,但我又不想登…

UTONMOS:AI+Web3+元宇宙数字化“三位一体”将触发经济新爆点

人工智能、元宇宙、Web3,被称为数字化的“三位一体”,如何看待这三大技术所扮演的角色? 3月24日,2024全球开发者先锋大会“数字化的三位一体——人工智能、元宇宙、Web3.0”论坛在上海漕河泾开发区举行,首次提出&…

【性能测试】接口测试各知识第2篇:学习目标,1. 理解接口的概念【附代码文档】

接口测试完整教程(附代码资料)主要内容讲述:接口测试,学习目标学习目标,2. 接口测试课程大纲,3. 接口学完样品,4. 学完课程,学到什么,5. 参考:,1. 理解接口的概念。学习目标,RESTFUL1. 理解接口的概念,2.什么是接口测试…

Mybatis plue(二) 扩展功能、插件功能

扩展功能 P12 扩展功能-代码生成器 方法一:mybatisplus官方文档中的代码生成配置 方法二:插件mybatsx 方法三:插件mybatisplus P13 DB静态工具 iservice中的方法是非静态的,db方法是静态的。 静态方法无法读取到类的泛型的…

2024年DeFi的四大主导趋势:Restaking、Layer3、AI和DePin

DeFi(去中心化金融)行业在2024年将继续呈现快速增长的势头,驱动这一增长的主要因素将是四大主导趋势:Restaking、Layer3、AI和DePin。这些趋势将推动DeFi生态系统的发展,为用户提供更多的机会和创新。 趋势1&#xff…

【Linux】第二个小程序--简易shell

请看上面的shell,其本质就是一个字符串,我们知道bash本质上就是一个进程,只不过命令行就是一个输出的字符串, 我们输入的命令“ls -a -l”实际上是我们在输入行输入的字符串,所以,如果我们想要做一个简易的…

Redis从入门到精通(五)Redis实战(二)商户查询缓存

↑↑↑请在文章头部下载测试项目原代码↑↑↑ 文章目录 前言4.2 商户查询缓存4.2.1 缓存介绍4.2.2 查询商户信息的传统做法4.2.2.1 接口文档4.2.2.2 代码实现4.2.2.3 功能测试 4.2.3 查询商户信息添加Redis缓存4.2.3.1 逻辑分析4.2.3.2 代码实现4.2.3.3 功能测试 4.2.3 数据一致…

MySQL基础【语句执行顺序】

一个SQL语句它的执行顺序对于我们思考题意有着很重要的关系 题意就是:找出哪些只逛超市不买单的人(买单0元也算哦,可能是使用的是代金券吧) 看到此题关键找出两个数据 参观过的人 和 买单的人 他们的差就是白嫖的人(支…

H.264 压缩与编解码原理

H.264 压缩与编解码原理 H.264 压缩与编解码原理H.264 简介视频编码的总体思路H.264 压缩技术帧内预测压缩什么是空间冗余?具体预测方法 帧间预测压缩什么是时间冗余?具体预测方法:运动估计 概念:Group of Pictures(GO…

LC 96.不同的二叉搜索树

96.不同的二叉搜索树 给你一个整数 n ,求恰由 n 个节点组成且节点值从 1 到 n 互不相同的 二叉搜索树 有多少种?返回满足题意的二叉搜索树的种数。 示例 1: 输入: n 3 输出: 5 示例 2: 输入:…

【JavaSE】接口 详解(上)

前言 本篇会讲到Java中接口内容,概念和注意点可能比较多,需要耐心多看几遍,我尽可能的使用经典的例子帮助大家理解~ 欢迎关注个人主页:逸狼 创造不易,可以点点赞吗~ 如有错误,欢迎指出~ 目录 前言 接口 语法…

SQL Server详细安装使用教程

1.安装环境 现阶段基本不用SQL Server数据库了,看到有这样的分析话题,就把多年前的存货发一下,大家也可以讨论看看,思路上希望还有价值。 SQL Server 2008 R2有32位版本和64位版本,32位版本可以安装在Windows XP及以上…