文章目录
- 0、基本信息
- 1、研究动机
- 2、创新点
- 3、准备
- 3.1、知识图谱
- 3.2、多项选择问答
- 3.3、提示词工程(prompt engineering)
- 4、具体实现
- 4.1、提示LLMs用于问答
- 4.2、子图检索
- 4.3、Graph Neural Prompting
- 4.3.1、GNN Encoder
- 4.3.2、Cross-modality Pooling
- 4.3.3、Domain Projector
- 4.3.4、Self-supervised Link Prediction
- 4.4、模型整体框架图
0、基本信息
- 会议:2024-AAAI
- 作者:Yijun Tian, Huan Song, Zichen Wang
- 文章链接:Graph Neural Prompting with Large Language Models
- 代码链接:Graph Neural Prompting with Large Language Models
1、研究动机
尽管大语言模型在各种各样的自然语言任务上取得了令人瞩目的成就,但是,当你在使用ChatGPT或者其他LLMs时有没有发现这样的一个问题,就是LLM的回答很宽泛,每个问题的回答不够精确,不够具体,这说明,LLMs仅仅是理解了语言的结构形式,,但没有理解语义信息。
另一方面,知识图谱(KGs)中包含大量的语义信息,作为一种系统化的知识表达方式,但是目前的方法都是利用KGs联合训练来增强语言模型或是定制化的模型结构,这样的方法导致了大量的参数,需要额外的计算资源。如果采用直接的方法,将KGs用于检索增强生成,将KG三元组直接输入到LLMs,但是这样会引入噪声。
那么,到底能不能从知识图谱中学习到拥有的知识并整合到与训练的LLM中呢?为此,本文提出了图神经提示(Graph Neural Prompt),解决上述问题,帮助预训练的LLMs从知识图谱中学习有用的知识。
主要解决的问题:
- LLMs不能准确得到相应的知识,容易产生毫无根据的回答;
- 训练阶段需要大量的计算资源;
- 解决如何处理知识图谱为LLMs提供有益的信息,同时避免噪声干扰;
2、创新点
- 一个即插即用的方法,帮助预训练LLMs从知识图谱中学习有用知识,据作者而言,这是首次研究。
- 提出Graph Neural Prompting(GNP)方法,包含四个模块,GNN Encoder,Cross-modality Pooling,Domain Projector和Self-supervised Link Prediction
这篇文章其实也是A+B的过程,常见的图提示是基于文本的,本篇提示,则是使用GNNs对KGs进行嵌入作为提示,来微调语言模型。属于知识图谱增强语言模型的范畴。
3、准备
3.1、知识图谱
知识图谱就是把不同种类的信息连接在一起而形成的一个关系网络,知识图谱由结点和边组成,每个结点表示现实世界中存在的“实体”,每条边表示实体与实体之间的“关系”。比如百度知识图谱,社交网络。
一个知识图谱定义为 G = ( E , R , T ) \mathcal{G}=(\mathcal{E},\mathcal{R},\mathcal{T}) G=(E,R,T), E \mathcal{E} E表示实体集合, R \mathcal{R} R表示关系集合, T \mathcal{T} T表示事实三元组 { ( e h , r , e t ) } ∈ E × R × T \{(e_h,r,e_t)\}\in \mathcal{E} \times \mathcal{R}\times \mathcal{T} {(eh,r,et)}∈E×R×T, e h e_h eh定义为头实体, r r r定义为关系, e t e_t et定义为尾实体。
3.2、多项选择问答
对于一个问题 Q Q Q,选项集合定义为 A = { a k } k = 1 K A=\{a_k\}^K_{k=1} A={ak}k=1K, K K K是回答选项的总个数, a k a_k ak定义为第 k k k个回答选项。可选的上下文 C C C取决于开卷闭卷。任务是设计一个机器学习模型 F θ \mathcal{F}_{\theta} Fθ( θ \theta θ是参数)选择最好的选项去回答问题。真实标签 y ∈ A y\in A y∈A是问题 Q Q Q的正确答案。本文,希望使用知识图谱 G \mathcal{G} G来提供丰富的知识,并协助模型回答问题。
3.3、提示词工程(prompt engineering)
提示工程是一个较新的学科,应用于开发和优化提示词(Prompt),帮助用户有效地将语言模型用于各种应用场景和研究领域。掌握了提示工程相关技能将有助于用户更好地了解大型语言模型的能力和局限性。研究人员可利用提示工程来提高大语言模型处理复杂任务场景的能力,如问答和算术推理能力。开发人员可通过提示工程设计和研发出强大的技术,实现和大语言模型或其他生态工具的高效接轨。
简单理解,Prompt指的是用户给大型语言模型发出的指令,它可以是一个问题、一段文字描述,甚至可以是带有一堆参数的文字描述。LLM会基于 prompt 所提供的信息,生成对应的文本,亦或者图片。
来自:https://zhuanlan.zhihu.com/p/631967998
4、具体实现
4.1、提示LLMs用于问答
一个常见的简单方法为,给定一个问题 Q Q Q,可选择的文本 C C C,答案选项 A A A,首先将 C , Q , A C,Q,A C,Q,A拼接并标记为输入文本序列X,然后设计一系列的提示文本tokens, P P P,并将其放在输入文本序列X的前面,之后,作为LLM模型的输入医生称预测 y ′ = f ( [ P , X ] ) y' = f([P,X]) y′=f([P,X])。LLM模型可以使用teacher forcing和交叉熵损失来训练以适应下游任务:
L l l m = − l o g p ( y ∣ X , θ ) \mathcal{L}_{llm}=-log\;p(y|X,\theta) Lllm=−logp(y∣X,θ)
其中, p p p是模型参数化的概率分布。
提示P,要么是来自文本输入形式的硬提示,要么是可学习嵌入向量的软提示。
本文使用的方法为软提示,将知识图谱中的结构和真实信息编码到软提示 P P P中,软提示 P P P嵌入到的可训练的向量序列X中。可学习的 P P P提供丰富的结构信息和知识以及为每个数据实例提供任务指令。
teacher-forcing 在训练网络过程中,每次不使用上一个state的输出作为下一个state的输入,而是直接使用训练数据的标准答案(ground truth)的对应上一项作为下一个state的输入。
prompt模板的制作分为手工创建模板和自动化生成模板,而自动化生成模板又分为离散提示(又叫做硬提示)和连续提示(又叫做软提示)。离散prompt中,prompt是一个实际的文本字符串;连续prompt中,prompt直接在底层语言模型的嵌入空间中进行描述.
4.2、子图检索
首先,为什么需要子图检索呢?我们要明确,知识图谱包含数百万个结点以及更多的关系,然而并不是每个结点对于我们的任务有帮助,因此我们需要检索出与我们任务相关的结点以及他的子图(因为子图包含丰富的语义信息)。所以我们需要在检索子图,这子图中包含于 X X X中的标记相关额实体。
对于每个答案选项 a k a_k ak以及它对应的上下文 C C C和问题 Q Q Q,首先通过KGs中实体之间的链接获得一组匹配的实体集合 E m a t c h \mathcal{E}_{match} Ematch,将 X X X中标记的实体与知识图谱 G \mathcal{G} G中的实体匹配。然后,基于集合 E m a t c h \mathcal{E}_{match} Ematch检索子图,包括他们的两跳邻居以及他们之间的关系。检索到的子图包含必要的内容知识帮助模型回答问题 Q Q Q。
4.3、Graph Neural Prompting
在第2节介绍了Graph Neural Prompting(GNP)主要包含了四个部分:
- GNN encoder:将知识图谱(KGs)作嵌入;
- Cross-modality pooling module:确定合适的结点嵌入;
- Domain projector:建立起图与文本之间的桥梁;
- Self-supervised link prediction objective:使模型能够识别结构信息;
下面对上述四个模块分别介绍。
4.3.1、GNN Encoder
为什么要用GNN Encoder呢?
尽管检索的子图 G ′ \mathcal{G}' G′包含了关于问题和答案选择的丰富上下文信息,但是一些实体和关系对于最终的答案并不相关。如果将子图 G ′ \mathcal{G}' G′中的每个三元组直接输入,这样不可避免地引入了噪声,对LLM的预测产生影响。
为此,使用GNN去编码最相关的知识并进一步整合(聚合)实体中复杂的关系。首先,使用与训练的实体嵌入来初始化结点嵌入(这里我也不是很懂)。然后,使用GAT作为对检索子图 G ′ \mathcal{G}' G′的编码器,编码过程如下:
H 1 = f G N N ( G ′ ) H_1 = f_{GNN}(\mathcal{G}') H1=fGNN(G′)
其中, H 1 ∈ R d g H_1\in\mathbb{R}_{d_g} H1∈Rdg表示子图 G ′ \mathcal{G}' G′中每个结点通过GNN学习到的结点嵌入向量, d g d_g dg表示GNN编码器输出的维度。
4.3.2、Cross-modality Pooling
为什么要设计Cross-modality Pooling这个模块呢?
如果不使用,之前的设计有什么问题?
文中讲的是,为了识别与问题最相关的结点,并将结点嵌入合并为一个整体的图集表示以便后续使用。
NOTE:如何实现?
1、识别结点的重要性
引入一个自注意层,利用内部图的特征和节点间的隐式交互来动态识别节点的重要性。
H 2 = Self-Attn ( H 1 ) H_2=\text{Self-Attn}(H_1) H2=Self-Attn(H1)
其中, H 2 H_2 H2是经过自注意力计算后的结点嵌入。
然后利用文本提示去计算图中结点的重要性。利用LLM中的字典(???)来获得输入文本中每个标记的嵌入文本 T ∈ R d t \mathcal{T}\in \mathbb{R}^{d_t} T∈Rdt, d t d_t dt表示LLM字典的维度。具体来说,首先对嵌入文本 T \mathcal{T} T进行变换,并获得变换后的文本嵌入 T ′ \mathcal{T}' T′,确保 T ′ \mathcal{T}' T′的维数与节点嵌入 H 2 H_2 H2的维数 d g d_g dg匹配。然后计算cross-modality attention, H 2 H_2 H2作为query, T ′ \mathcal{T}' T′作为key和value,计算过程如下:
T ′ = F F N 1 ( σ ( F F N 2 ( T ) ) ) , H 3 = s o f t m a x [ H 2 ⋅ ( T ′ ) T / d g ] ⋅ T ′ \begin{aligned}\mathcal{T}'&=\mathrm{FFN}_1(\sigma(\mathrm{FFN}_2(\mathcal{T}))),\\H_3&=\mathrm{softmax}[H_2\cdot(\mathcal{T}')^T/\sqrt{d_g}]\cdot\mathcal{T}'\end{aligned} T′H3=FFN1(σ(FFN2(T))),=softmax[H2⋅(T′)T/dg]⋅T′
其中, σ \sigma σ为GELU激活函数, FFN 1 \text{FFN}_1 FFN1和 FFN 2 \text{FFN}_2 FFN2是前馈神经网络, H 3 H_3 H3为最终的嵌入。接下来,通过pooling操作生成图级的嵌入:
H 4 = POOL ( H 3 ) H_4 = \text{POOL}(H_3) H4=POOL(H3)
其中, H 4 H_4 H4表示考虑 G ′ \mathcal{G}' G′中节点重要性的图级嵌入。
4.3.3、Domain Projector
为什么要设计Domain Project呢?
目的是为了建立起图级前嵌入和文本域之间的映射关系,以便LLM理解,弥补了图和文本之间固有的差异,允许更无缝的集成。此外,projector将图级嵌入映射到和LLM的相同维度 d t d_t dt,这确保了与LLM的固有结构对接时的兼容性。projector设计如下:
Z = FFN 3 ( σ ( FFN 4 ( H 4 ) ) ) Z = \text{FFN}_3(\sigma(\text{FFN}_4(H_4))) Z=FFN3(σ(FFN4(H4)))
Z Z Z定义为Graph Neural Prompt(GNP)最终的输出, FFN 3 \text{FFN}_3 FFN3和 FFN 4 \text{FFN}_4 FFN4是前馈神经网络
4.3.4、Self-supervised Link Prediction
为什么要设计Self-supervised Link Prediction?
尽管交叉熵目标使模型能够学习和适应下游任务的目标数据集,但是又设计了一个链接预测任务,以进一步完善其对实体之间关系的理解,并以自监督的方式捕获图知识。具体来说,掩盖子图 G ′ \mathcal{G}' G′中的一些边,并用模型去预测他们。使模型学习使用部分图的内容和结构来推理丢失的链接。
掩盖边的集合记为 E m a s k ⊆ E \mathcal{E}_{mask} \subseteq \mathcal{E} Emask⊆E,对于给定的三元组中的头实体,尾实体 { h 3 , t 3 } ∈ H \{h_3,t_3\} \in H {h3,t3}∈H,采用一种广泛使用的知识图嵌入方法DistMult,将KG中的实体嵌入和关系映射为向量, h , r , t h,r,t h,r,t。然后定义评分函数 ϕ ( e h , e t ) = < h , r , t > \phi(e_h,e_t)=<h,r,t> ϕ(eh,et)=<h,r,t>来生成每个三元组的评分, < ⋅ , ⋅ , ⋅ > <·,·,·> <⋅,⋅,⋅>表示三元线性点积。 r r r表示为KGs中的关系。 ϕ \phi ϕ越高,表明 ( e h , r , e t ) (e_h,r,e_t) (eh,r,et)成为一个正确的正例三元组的机会越大,而不是一个不正确的负例三元组。该模型将 E m a s k \mathcal{E}_{mask} Emask中的掩盖边预测为正,将其他随机边预测为负。(作者定义)
邻接预测的损失函数为:
L l p = ∑ ( e h , r , e t ) ∈ E m a s k ( S p o s + S n e g ) , \mathcal{L}_{lp}=\sum_{(e_{h},r,e_{t})\in\mathcal{E}_{mask}}(S_{pos}+S_{neg}), Llp=(eh,r,et)∈Emask∑(Spos+Sneg),
其中, S p o s = − log σ s ( ϕ ( e h , e t ) + γ ) S_{pos}\quad=\quad-\log\sigma_{s}(\phi(e_{h},e_{t})+\gamma) Spos=−logσs(ϕ(eh,et)+γ)表示正确正例三元组的得分, γ \gamma γ为margin, σ s \sigma_{s} σs为sigmoid激活函数, { ( e h ′ , r ~ , e l ′ ) } \{(e_h^{\prime},\tilde{r},e_l^{\prime})\} {(eh′,r~,el′)}是对应于正例三元组 ( e h , r , e t ) (e_h,r,e_t) (eh,r,et)的负例三元组。 S n e g = 1 n ∑ ( e h ′ , r , e t ′ ) log σ s ( ϕ ( e h ′ , e t ′ ) ^ + γ ^ ) S_{neg}=\frac{1}{n}\sum_{(e_{h}^{\prime},r,e_{t}^{\prime})}\log\sigma_{s}(\hat{\phi(e_{h}^{\prime},e_{t}^{\prime})}+\hat{\gamma}) Sneg=n1∑(eh′,r,et′)logσs(ϕ(eh′,et′)^+γ^)为不正确负例三元组的得分。最终的目标函数 L \mathcal{L} L定义为 L l l m \mathcal{L}_{llm} Lllm与 L l p \mathcal{L}_{lp} Llp的加权和:
L = L l l m + λ L l p , \mathcal{L}=\mathcal{L}_{llm}+\lambda\mathcal{L}_{lp}, L=Lllm+λLlp,
其中, λ \lambda λ为平衡参数。
4.4、模型整体框架图
5、实验结果