LangChain之关于RetrievalQA input_variables 的定义与使用

最近在使用LangChain来做一个LLMs和KBs结合的小Demo玩玩,也就是RAG(Retrieval Augmented Generation)。
这部分的内容其实在LangChain的官网已经给出了流程图。在这里插入图片描述
我这里就直接偷懒了,准备对Webui的项目进行复刻练习,那么接下来就是照着葫芦画瓢就行。
那么我卡在了Retrieve这一步。先放有疑惑地方的代码:

if web_content:prompt_template = f"""基于以下已知信息,简洁和专业的来回答用户的问题。如果无法从中得到答案,请说 "根据已知信息无法回答该问题" 或 "没有提供足够的相关信息",不允许在答案中添加编造成分,答案请使用中文。已知网络检索内容:{web_content}""" + """已知内容:{context}问题:{question}"""else:prompt_template = """基于以下已知信息,请简洁并专业地回答用户的问题。如果无法从中得到答案,请说 "根据已知信息无法回答该问题" 或 "没有提供足够的相关信息"。不允许在答案中添加编造成分。另外,答案请使用中文。已知内容:{context}问题:{question}"""prompt = PromptTemplate(template=prompt_template,input_variables=["context", "question"])......knowledge_chain = RetrievalQA.from_llm(llm=self.llm,retriever=vector_store.as_retriever(search_kwargs={"k": self.top_k}),prompt=prompt)knowledge_chain.combine_documents_chain.document_prompt = PromptTemplate(input_variables=["page_content"], template="{page_content}")knowledge_chain.return_source_documents = Trueresult = knowledge_chain({"query": query})return result

我对prompt_templateknowledge_chain.combine_documents_chain.document_prompt result = knowledge_chain({"query": query})这三个地方的input_key不明白为啥一定要这样设置。虽然我也看了LangChain的API文档。但是我并未得到详细的答案,那么只能一行行看源码是到底怎么设置的了。

注意:由于LangChain是一层层封装的,那么result = knowledge_chain({"query": query})可以认为是最外层,那么我们先看最外层。

result = knowledge_chain({“query”: query})

其实这部分是直接与用户的输入问题做对接的,我们只需要定位到RetrievalQA这个类就可以了,下面是RetrievalQA这个类的实现:

class RetrievalQA(BaseRetrievalQA):"""Chain for question-answering against an index.Example:.. code-block:: pythonfrom langchain.llms import OpenAIfrom langchain.chains import RetrievalQAfrom langchain.vectorstores import FAISSfrom langchain.schema.vectorstore import VectorStoreRetrieverretriever = VectorStoreRetriever(vectorstore=FAISS(...))retrievalQA = RetrievalQA.from_llm(llm=OpenAI(), retriever=retriever)"""retriever: BaseRetriever = Field(exclude=True)def _get_docs(self,question: str,*,run_manager: CallbackManagerForChainRun,) -> List[Document]:"""Get docs."""return self.retriever.get_relevant_documents(question, callbacks=run_manager.get_child())async def _aget_docs(self,question: str,*,run_manager: AsyncCallbackManagerForChainRun,) -> List[Document]:"""Get docs."""return await self.retriever.aget_relevant_documents(question, callbacks=run_manager.get_child())@propertydef _chain_type(self) -> str:"""Return the chain type."""return "retrieval_qa"

可以看到其继承了BaseRetrievalQA这个父类,同时对_get_docs这个抽象方法进行了实现。

这里要扩展的说一下,_get_docs这个方法就是利用向量相似性,在vector Base中选择与embedding之后的query最近似的Document结果。然后作为RetrievalQA的上下文。具体只需要看BaseRetrievalQA这个方法的_call和就可以了。
接下来我们只需要看BaseRetrievalQA这个类的属性就可以了。

class BaseRetrievalQA(Chain):"""Base class for question-answering chains."""combine_documents_chain: BaseCombineDocumentsChain"""Chain to use to combine the documents."""input_key: str = "query"  #: :meta private:output_key: str = "result"  #: :meta private:return_source_documents: bool = False"""Return the source documents or not."""……def _call(self,inputs: Dict[str, Any],run_manager: Optional[CallbackManagerForChainRun] = None,) -> Dict[str, Any]:"""Run get_relevant_text and llm on input query.If chain has 'return_source_documents' as 'True', returnsthe retrieved documents as well under the key 'source_documents'.Example:.. code-block:: pythonres = indexqa({'query': 'This is my query'})answer, docs = res['result'], res['source_documents']"""_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()question = inputs[self.input_key]accepts_run_manager = ("run_manager" in inspect.signature(self._get_docs).parameters)if accepts_run_manager:docs = self._get_docs(question, run_manager=_run_manager)else:docs = self._get_docs(question)  # type: ignore[call-arg]answer = self.combine_documents_chain.run(input_documents=docs, question=question, callbacks=_run_manager.get_child())if self.return_source_documents:return {self.output_key: answer, "source_documents": docs}else:return {self.output_key: answer}

可以看到其有input_key这个属性,默认值是"query"。到这里我们就可以看到result = knowledge_chain({"query": query})是调用的BaseRetrievalQA_call,这里的question = inputs[self.input_key]就是其体现。

knowledge_chain.combine_documents_chain.document_prompt

这个地方一开始我很奇怪,为什么会重新定义呢?
我们可以先定位到,combine_documents_chain这个参数的位置,其是StuffDocumentsChain的方法。

@classmethod
def from_llm(cls,llm: BaseLanguageModel,prompt: Optional[PromptTemplate] = None,callbacks: Callbacks = None,**kwargs: Any,
) -> BaseRetrievalQA:"""Initialize from LLM."""_prompt = prompt or PROMPT_SELECTOR.get_prompt(llm)llm_chain = LLMChain(llm=llm, prompt=_prompt, callbacks=callbacks)document_prompt = PromptTemplate(input_variables=["page_content"], template="Context:\n{page_content}")combine_documents_chain = StuffDocumentsChain(llm_chain=llm_chain,document_variable_name="context",document_prompt=document_prompt,callbacks=callbacks,)return cls(combine_documents_chain=combine_documents_chain,callbacks=callbacks,**kwargs,)

可以看到原始的document_prompt中PromptTemplate的template是“Context:\n{page_content}”。因为这个项目是针对中文的,所以需要将英文的Context去掉。

扩展

  1. 这里PromptTemplate(input_variables=[“page_content”], template=“Context:\n{page_content}”)的input_variablestemplate为什么要这样定义呢?其实是根据Document这个数据对象来定义使用的,我们可以看到其数据格式为:Document(page_content=‘……’, metadata={‘source’: ‘……’, ‘row’: ……})
    那么input_variables的输入就是Document的page_content。
  2. StuffDocumentsChain中有一个参数是document_variable_name。那么这个类是这样定义的This chain takes a list of documents and first combines them into a single string. It does this by formatting each document into a string with the document_prompt and then joining them together with document_separator. It then adds that new string to the inputs with the variable name set by document_variable_name. Those inputs are then passed to the llm_chain. 这个document_variable_name简单来说就是在document_prompt中的占位符,用于在Chain中的使用。
    因此我们上文prompt_template变量中的“已知内容: {context}”,用的就是context这个变量。因此在prompt_template中换成其他的占位符都不能正常使用这个Chain。

prompt_template

在上面的拓展中其实已经对prompt_template做了部分的讲解,那么这个字符串还剩下“问题:{question}”这个地方没有说通
还是回归源码:

return cls(combine_documents_chain=combine_documents_chain,callbacks=callbacks,**kwargs,)

我们可以在from_llm函数中看到其返回值是到了_call,那么剩下的我们来看这个函数:


......
uestion = inputs[self.input_key]
accepts_run_manager = ("run_manager" in inspect.signature(self._get_docs).parameters
)
if accepts_run_manager:docs = self._get_docs(question, run_manager=_run_manager)
else:docs = self._get_docs(question)  # type: ignore[call-arg]
answer = self.combine_documents_chain.run(input_documents=docs, question=question, callbacks=_run_manager.get_child()
)
......

这里是在run这个函数中传入了一个字典值,这个字典值有三个参数。

注意:

  1. 这三个参数就是kwargs,也就是_validate_inputs的参数input;
  2. 此时已经是在Chain这个基本类了)
def run(self,*args: Any,callbacks: Callbacks = None,tags: Optional[List[str]] = None,metadata: Optional[Dict[str, Any]] = None,**kwargs: Any,) -> Any:"""Convenience method for executing chain.The main difference between this method and `Chain.__call__` is that thismethod expects inputs to be passed directly in as positional arguments orkeyword arguments, whereas `Chain.__call__` expects a single input dictionarywith all the inputs"""

接下来调用__call__:

def __call__(self,inputs: Union[Dict[str, Any], Any],return_only_outputs: bool = False,callbacks: Callbacks = None,*,tags: Optional[List[str]] = None,metadata: Optional[Dict[str, Any]] = None,run_name: Optional[str] = None,include_run_info: bool = False,) -> Dict[str, Any]:"""Execute the chain.Args:inputs: Dictionary of inputs, or single input if chain expectsonly one param. Should contain all inputs specified in`Chain.input_keys` except for inputs that will be set by the chain'smemory.return_only_outputs: Whether to return only outputs in theresponse. If True, only new keys generated by this chain will bereturned. If False, both input keys and new keys generated by thischain will be returned. Defaults to False.callbacks: Callbacks to use for this chain run. These will be called inaddition to callbacks passed to the chain during construction, but onlythese runtime callbacks will propagate to calls to other objects.tags: List of string tags to pass to all callbacks. These will be passed inaddition to tags passed to the chain during construction, but onlythese runtime tags will propagate to calls to other objects.metadata: Optional metadata associated with the chain. Defaults to Noneinclude_run_info: Whether to include run info in the response. Defaultsto False.Returns:A dict of named outputs. Should contain all outputs specified in`Chain.output_keys`."""inputs = self.prep_inputs(inputs)......

这里的prep_inputs会调用_validate_inputs函数

def _validate_inputs(self,inputs: Dict[str, Any]) -> None:"""Check that all inputs are present."""missing_keys = set(self.input_keys).difference(inputs)if missing_keys:raise ValueError(f"Missing some input keys: {missing_keys}")

这里的input_keys通过调试,看到的就是有多个输入,分别是"input_documents"和"question"
这里的"input_documents"是来自于BaseCombineDocumentsChain

class BaseCombineDocumentsChain(Chain, ABC):"""Base interface for chains combining documents.Subclasses of this chain deal with combining documents in a variety ofways. This base class exists to add some uniformity in the interface these typesof chains should expose. Namely, they expect an input key related to the documentsto use (default `input_documents`), and then also expose a method to calculatethe length of a prompt from documents (useful for outside callers to use todetermine whether it's safe to pass a list of documents into this chain or whetherthat will longer than the context length)."""input_key: str = "input_documents"  #: :meta private:output_key: str = "output_text"  #: :meta private:

那为什么有两个呢,“question”来自于哪里?
StuffDocumentsChain继承BaseCombineDocumentsChain,其input_key是这样定义的:

  @propertydef input_keys(self) -> List[str]:extra_keys = [k for k in self.llm_chain.input_keys if k != self.document_variable_name]return super().input_keys + extra_keys

原来是重写了input_keys函数,其是对llm_chain的input_keys进行遍历。

那么llm_chain的input_keys是用其prompt的input_variables。(这里的input_variables是PromptTemplate中的[“context”, “question”])

	@propertydef input_keys(self) -> List[str]:"""Will be whatever keys the prompt expects.:meta private:"""return self.prompt.input_variables

至此,我们StuffDocumentsChain的input_keys有两个变量了。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/163978.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

web3 dapp React项目引入 antd 对 balance 用户token信息组件进行样式改造

好 上文 web3 React dapp中编写balance组件从redux取出并展示用户资产 我们简单处理了用户资产的展示 那么 我们继续 先启动 ganache 环境 终端输入 ganache -d然后 打开我们的项目 将合约发布到区块链上 truffle migrate --reset然后 我们启动项目 确认一切正常 还原到上文…

MCU常见通信总线串讲(一)—— UART和USART

🙌秋名山码民的主页 😂oi退役选手,Java、大数据、单片机、IoT均有所涉猎,热爱技术,技术无罪 🎉欢迎关注🔎点赞👍收藏⭐️留言📝 获取源码,添加WX 目录 前言一…

网络编程套接字(2)——简单的TCP网络程序

文章目录 一.简单的TCP网络程序1.服务端创建套接字2.服务端绑定3.服务端监听4.服务端获取连接5.服务端处理请求6.客户端创建套接字7.客户端连接服务器8.客户端发起请求9.服务器测试10.单执行流服务器的弊端 二.多进程版的TCP网络程序1.捕捉SIGCHLD信号2.让孙子进程提供服务 三.…

el-table中的el-input标签修改值,但界面未更新,解决方法

el-table中的el-input标签修改值,界面未更新 在el-table中的el-input里面写的change事件根本不触发,都不打印,试了网络上各种方法都没用 然后换成input事件,input事件会触发,但界面也未更新。我在触发事件的时候&…

构建强大的Web应用之Django详解

引言: Django是一个功能强大且灵活的Python Web框架,它提供了一套完整的工具和功能,帮助开发者快速构建高效的Web应用。本篇文章将带您逐步了解Django的基本概念和使用方法,并通过实际的代码案例,帮助您从零开始构建自…

vue + axios + mock

参考来源:Vue mock.js模拟数据实现首页导航与左侧菜单功能_vue.js_AB教程网 记录步骤:在参考资料来源添加axios步骤 1、安装mock依赖 npm install mock -D //只在开发环境使用 下载完成后,项目文件package.json中的devDependencies就会加…

基于java+springboot+vue的幼儿园信息网站

项目介绍 随着科学技术的飞速发展,各行各业都在努力与现代先进技术接轨,通过科技手段提高自身的优势;对于幼儿园管理系统当然也不能排除在外,随着网络技术的不断成熟,带动了幼儿园管理系统,它彻底改变了过…

【排序算法】 快速排序(快排)!图解+实现详解!

🎥 屿小夏 : 个人主页 🔥个人专栏 : 算法—排序篇 🌄 莫道桑榆晚,为霞尚满天! 文章目录 📑前言🌤️快速排序的概念☁️快速排序的由来☁️快速排序的思想☁️快速排序的实…

高校为什么需要大数据挖掘平台?

目前数据挖掘已经成为各种应用领域的重要技术,大学数据挖掘课程的开放已经出现。数据挖掘课程整合了多门学科知识。该课程包括各种理论知识,也离不开相关的实用技术。整个教学过程是培养和提高学生全面创新和解决问题的能力。过去,教学过程理…

ZYNQ_project:led

本次实验完成:led流水间隔0.5s 闪烁间隔0.25s。 名词解释: analysis分析:对源文件进行全面的语法检查。 synthesis综合:综合的过程是由 FPGA 综合工具箱 HDL 原理图或其他形式源文件进行分析,进而推演出由 FPGA 芯…

【C语言】数据结构——无头单链表实例探究

💗个人主页💗 ⭐个人专栏——数据结构学习⭐ 💫点击关注🤩一起学习C语言💯💫 目录 导读:1. 单链表1.1 什么是单链表1.2 优缺点 2. 实现单链表基本功能2.1 定义结构体2.2 单链表打印2.3 销毁单链…

大数据之LibrA数据库系统告警处理(ALM-12030 无合法license存在)

告警解释 系统在安装集群后和每天零点检查当前系统中是否存在合法的license文件,如果没有则产生该告警。 导入合法license文件时,告警恢复。 说明: 如果当前集群使用节点数小于等于10节点(不包含管理节点)&#xf…