前言
大型语言模型,尤其是像ChatGPT这样的模型,尽管在自然语言处理领域展现了强大的能力,但也伴随着隐私泄露的潜在风险。在模型的训练过程中,可能会接触到大量的用户数据,其中包括敏感的个人信息,进而带来隐私泄露的可能性。此外,模型在推理时有时会无意中回忆起训练数据中的敏感信息,这一点也引发了广泛的关注。
隐私泄露的风险主要来源于两个方面:一是数据在传输过程中的安全性,二是模型本身的记忆风险。在数据传输过程中,如果没有采取充分的安全措施,攻击者可能会截获数据,进而窃取敏感信息,给用户和组织带来安全隐患。此外,在模型的训练和推理阶段,如果使用了个人身份信息或企业数据等敏感数据,这些数据可能会被模型运营方窥探或收集,存在被滥用的风险。
过去已经发生了多起与此相关的事件,导致许多大公司禁止员工使用ChatGPT。此前的研究表明,当让大模型反复生成某些特定词汇时,它可能会在随后的输出中暴露出训练数据中的敏感内容。
学术研究表明,对模型进行训练数据提取攻击是切实可行的。攻击者可以通过与预训练模型互动,从而恢复出训练数据集中包含的个别示例。例如,GPT-2曾被发现能够记住训练数据中的一些个人信息,如姓名、电子邮件地址、电话号码、传真号码和实际地址。这不仅带来了严重的隐私风险,还对语言模型的泛化能力提出了质疑。
本文要探讨的就是可以高效从大模型中提取出用于训练的隐私数据的技巧与方法,主要来自《Bag of Tricks for Training Data Extraction from Language Models》,这篇论文发在了人工智能顶级会议ICML 2023上。
背景知识
尽管大模型在各种下游语言任务中展现了令人瞩目的性能,但其内在的记忆效应使得训练数据可能被提取出来。这些训练数据可能包含敏感信息,如姓名、电子邮件地址、电话号码和物理地址,从而引发隐私泄露问题,阻碍了大模型在更广泛应用中的推进。
之前谷歌举办了一个比赛,链接如下
https://github.com/google-research/lm-extraction-benchmark/tree/master
这是一个针对性数据提取的挑战赛,目的是测试参赛者是否能从给定的前缀中准确预测后缀,从而构成整个序列,使其包含在训练数据集中。这与无针对性的攻击不同,无针对性的攻击是搜索训练数据集中出现的任意数据。
针对性提取被认为更有价值和具有挑战性,因为它可以帮助恢复与特定主题相关的关键信息,而不是任意的数据。此外,评估针对性提取也更容易,只需检查给定前缀的正确后缀是否被预测,而无针对性攻击需要检查整个庞大的训练数据集。
这个比赛使用1.3B参数的GPT-Neo模型,以1-eidetic记忆为目标,即模型能够记住训练数据中出现1次的字符串。这比无针对性和更高eidetic记忆的设置更具有挑战性。
比赛的基准测试集包含从The Pile数据集中选取的20,000个示例,这个数据集已被用于训练许多最新的大型语言模型,包括GPT-Neo。每个示例被分为长度为50的前缀和后缀,攻击的任务是在给定前缀的情况下预测正确的后缀。这些示例被设计成相对容易提取的,即存在一个前缀长度使得模型可以准确生成后缀。
训练数据提取
从预训练的语言模型中提取训练数据,即所谓的"语言模型数据提取",是一种恢复用于训练模型的示例的方法。这是一个相对较新的任务,但背后的许多技术和分析方法,如成员资格推断和利用网络记忆进行攻击,早就已经被引入。
Carlini等人是最早定义模型知识提取和κ-eidetic记忆概念的人,并提出了有希望的数据提取训练策略。关于记忆的理论属性以及在敏感领域应用模型提取(如临床笔记分析)等,已经成为这个领域后续研究的焦点。
最近的研究也有一些重要发现:
-
Kandpal等人证明,在语言模型中,数据提取的效果经常归因于常用网络抓取训练集中的重复。
-
Jagielski等人使用非确定性为忘记记忆示例提供了一种解释。
-
Carlini等人分析了影响训练数据记忆的三个主要因素。
-
Feldman指出,为了达到接近最优的性能,在自然数据分布下需要记忆标签。
-
Lehman等人指出,预训练的BERT在训练临床笔记时存在敏感数据泄露的风险,特别是当数据表现出高水平的重复或"笔记膨胀"时。
总的来说,这个新兴领域正在深入探讨如何从语言模型中提取训练数据,以及这种提取带来的安全和隐私风险。最新的研究成果为进一步理解和应对这些挑战提供了重要的洞见。
成员推理攻击
成员资格推断攻击(MIA)是一种与训练数据提取密切相关的对抗性任务,目标是在只能对模型进行黑盒访问的情况下,确定给定记录是否在模型的训练数据集中。MIA已被证明在各种机器学习任务中都是有效的,包括分类和生成模型。
MIA使用的方法主要分为两类:
-
基于分类器的方法:这涉及训练一个二元分类器来识别成员和非成员之间的复杂模式关系,影子训练是一种常用的技术。
-
基于度量的方法:这通过首先计算模型预测向量上的度量(如欧几里得距离或余弦相似度)来进行成员资格推断。
这两类方法都有各自的优缺点,研究人员正在不断探索新的MIA攻击方法,以更有效地从机器学习模型中推断训练数据。这突出了训练数据隐私保护在模型部署和应用中的重要性。对MIA技术的深入理解,有助于设计更加安全和隐私保护的机器学习模型训练和部署策略,这对于广泛应用尤其是在敏感领域的应用至关重要。
其他基于记忆的攻击
大型预训练模型由于容易记住训练数据中的信息,因此面临着各种潜在的安全和隐私风险。除了训练数据提取攻击和成员资格推断攻击之外,还有其他基于模型记忆的攻击针对这类模型。
其中,模型提取攻击关注于复制给定的黑盒模型的功能性能。在这类攻击中,对手试图构建一个具有与原始黑盒模型相似预测性能的第二个模型,从而可以在不获取原始模型的情况下复制其功能。针对模型提取攻击的保护措施,集中在如何限制模型的功能复制。
另一类攻击是属性推断攻击,其目标是从模型中提取特定的个人属性信息,如地点、职业和兴趣等。这些属性信息可能是模型生产者无意中共享的训练数据属性,例如生成数据的环境或属于特定类别的数据比例。
与训练数据提取攻击不同,属性/属性推断攻击不需要事先知道要提取的具体属性。而训练数据提取攻击需要生成与训练数据完全一致的信息,这更加困难和危险。
总之,这些基于模型记忆的各类攻击,都突显了大型预训练模型在隐私保护方面的重大挑战。如何有效应对这些攻击,成为当前机器学习安全研究的一个重要焦点。
【----帮助网安学习,以下所有学习资料免费领!加vx:dctintin,备注 “博客园” 获取!】
① 网安学习成长路径思维导图
② 60+网安经典常用工具包
③ 100+SRC漏洞分析报告
④ 150+网安攻防实战技术电子书
⑤ 最权威CISSP 认证考试指南+题库
⑥ 超1800页CTF实战技巧手册
⑦ 最新网安大厂面试题合集(含答案)
⑧ APP客户端安全检测指南(安卓+IOS)
威胁模型
数据集是从 Pile 训练数据集中抽取的 20,000 个样本子集。每个样本由一个 50-token 的前缀和一个 50-token 的后缀组成。
攻击者的目标是给定前缀时,尽可能准确地预测后缀。
这个数据集中,所有 100-token 长的句子在训练集中只出现一次。
采用了 HuggingFace Transformers 上实现的 GPT-Neo 1.3B 模型作为语言模型。这是一个基于 GPT-3 架构复制品,针对 Pile 数据集进行过训练的模型。
GPT-Neo 是一个自回归语言模型 fθ,通过链式规则生成一系列token。
这个场景中,攻击者希望利用语言模型对训练数据的记忆,来尽可能准确地预测给定前缀的后缀。由于数据集中每个句子在训练集中只出现一次,这就给攻击者提供了一个机会,试图从模型中提取这些罕见句子的信息。
在句子层面,给定一个前缀p,我们表示在前缀p上有条件生成某个后缀s的概率为fθ(s|p)。
我们专注于针对性提取 κ-eidetic 记忆数据的威胁模型,我们选择 κ=1。根据 Carlini定义的模型知识提取,我们假设语言模型通过最可能的标准生成后缀 s。然后我们可以将针对性提取的正式定义写为:
给定一个包含在训练数据中的前缀 p 和一个预训练的语言模型 fθ。针对性提取是通过下式来生成后缀
至于 κ-eidetic 记忆数据,我们遵循 Carlini的定义,即句子 [p, s] 在训练数据中出现不超过 κ 个示例。在实践中,生成句子的长度通常使用截断和连接技术固定在训练数据集上。如果生成的句子短于指定长度,使用填充 token 将其增加到所需长度。
流程
第一阶段 - 后缀生成:
-
利用自回归语言模型 fθ 计算词汇表中每个 token 的生成概率分布。
-
从这个概率分布中采样生成下一个 token,采用 top-k 策略限制采样范围,将 k 设为10。
-
不断重复这个采样过程,根据前缀生成一组可能的后缀。
第二阶段 - 后缀排名:
-
使用成员资格推断攻击,根据每个生成后缀的困惑度进行排序。
-
只保留那些概率较高(困惑度较低)的后缀。
这样的两阶段流程,首先利用语言模型生成可能的后缀候选,然后通过成员资格推断攻击对这些候选进行评估和筛选,从而尽可能还原出训练数据中罕见的完整句子。
这个训练数据提取攻击的关键在于,利用语言模型对训练数据的"记忆"来生成接近训练样本的内容,再结合成员资格推断技术进一步挖掘出高概率的真实训练样本。
其中 N 是生成句子中的 token 数量。
改进策略
为了改进后缀生成,我们可以来看看真实和生成token的logits分布。如下图所示,这两种分布之间存在显著差异。
为了解决这个问题,我们可以采用一系列技术进行改进
采样策略
在自然语言处理的条件生成任务中,最常见的目标是最大化解码,即给定前缀,找到具有最高概率的后缀序列。这种"最大似然"策略同样适用于训练数据提取攻击场景,因为模型会试图最大化生成的内容与真实训练数据的相似性。
然而,从模型中直接找到理论上的全局最优解(argmax序列)是一个不切实际的目标。原因在于,语言模型通常是auto-regressive的,每个token的生成都依赖于前面生成的内容,因此搜索全局最优解的计算复杂度会随序列长度呈指数级上升,实际上是不可行的。
因此,常见的做法是采用束搜索(Beam Search)作为一种近似解决方案。束搜索会在每一步保留若干个得分最高的部分解,而不是简单地选择概率最高的单一路径。这种方式可以有效降低计算复杂度,但同时也存在一些问题:
-
束搜索可能会缺乏生成输出的多样性,因为它总是倾向于选择得分最高的少数几个路径。
-
尽管增大束宽度可以提高性能,但当束宽超过一定程度时,性能增益会迅速下降,同时也会带来更高的内存开销。
为了克服束搜索的局限性,我们可以采用随机采样的方法,引入更多的多样性。常见的采样策略包括:
-
Top-k 采样:只从概率最高的k个token中进行采样,k是一个超参数。这种方法可以控制生成输出的多样性,但过大的k可能会降低输出的质量和准确性。
-
Nucleus 采样(Nucleus Sampling):从概率总和达到设定阈值的token集合中进行采样,可以自适应地调整采样空间的大小。
-
典型采样(Typical Sampling):从完整的概率分布中采样,偏向采样接近平均概率的token,可以在保持输出质量的同时引入更多的多样性。
总的来说,条件生成任务中的解码策略需要在生成质量、多样性和计算复杂度之间进行权衡。束搜索作为一种近似解决方案,能够有效控制计算成本,但缺乏生成多样性。而随机采样方法则可以引入更多的多样性,但需要在采样策略上进行细致的调整。这些技术在训练数据提取攻击中都有重要的应用价值。
Nucleus采样的核心思想是从总概率达到一定阈值η的token集合中进行采样,而不是简单地从概率最高的k个token中采样。
在故事生成任务中,研究表明较低的η值(如0.6左右)更有利于生成更为多样化和创造性的内容。这说明在生成任务中,保留一定程度的低概率token是有益的,可以引入更多的多样性。但在训练数据提取攻击这样的任务中,较大的η值(约0.6)效果更好,相比基线提升了31%的提取精度。这表明对于数据提取这类任务,我们需要更加关注生成内容与训练数据的相似性,而不是过度强调多样性。
如下图示进一步说明了这一点,即η值过大或过小都会导致性能下降。存在一个最优的η值区间,需要根据具体任务进行调整。
Typical-ϕ是一种用于自然语言生成任务的采样策略。它的核心思想是选择与预期输出内容相似的token,从而保证在典型解码中能够考虑到原始分布的概率质量。这种策略可以提高生成句子的一致性,同时减少一些容易出现的退化重复等问题。Typical-ϕ 策略在数学上等价于一个带有熵率约束的子集优化问题。这种策略在一定程度上可以控制生成文本的多样性和流畅性,平衡了文本质量和创造性。
Typical-ϕ 策略在不同任务中表现可能会有所不同。例如,在抽象摘要和故事生成任务中,Typical-ϕ 策略展现出一定的非单调趋势,即随着ϕ值的变化,生成文本的质量并非线性提升。这说明Typical-ϕ需要根据具体任务进行合适的参数调整,以达到最佳的生成效果。
概率分布调整
温度控制(Temperature)
-
这是一种直接调整概率分布的策略,通过引入温度参数T来重新归一化语言模型的输出概率分布。较高的温度T > 1会降低模型预测的确信度,但可以增加生成文本的多样性。研究发现,在生成过程中逐渐降低温度是有益的,可以在多样性和生成效率之间达到平衡。但过高的温度也可能导致生成的文本偏离真实分布,降低效率。因此需要合理调节温度参数。
重复惩罚(Repetition Penalty)
-
这是一种基于条件语言模型的策略,通过修改每个token的生成概率来抑制重复token的出现。具体做法是,重复token的logit在进入softmax层之前被除以一个值r。当r > 1时会惩罚重复,r < 1则会鼓励重复。研究发现,重复惩罚对训练数据提取任务通常有负面影响,因为它可能会抑制一些有用的重复信息。因此在使用重复惩罚时,需要根据具体任务和数据特点来合理设置参数r,在抑制不必要重复和保留有意义重复之间寻求平衡。
总的来说,温度控制和重复惩罚是两种常见的直接调整概率分布的策略,可以在一定程度上提高自然语言生成的质量和多样性。但它们也存在一些局限性,需要根据实际应用场景进行合理的参数调整和组合使用,以达到最佳的生成效果。
为了有效的向量化,通常在训练语言模型时将多个句子打包成固定长度的序列。例如,句子"Yu的电话号码是12345"可能在训练集中被截断,或与另一个句子拼接成前缀,如"Yu的地址在XXX。Yu的电话号码是12345"。训练集中的这些前缀序列并不总是完整的句子。为了更好地模拟这种训练设置,我们可以调整上下文窗口大小和位置偏移。
动态上下文窗口
训练窗口的长度可能与提取窗口的长度不同。因此,提出调整上下文窗口的大小,即之前生成的token的数量,如下所示。
此外,鼓励不同上下文窗口大小的结果在确定下一个生成的token时进行协作:
其中 hW 表示集成方法,W 表示集成超参数,包括不同上下文窗口大小的数量 m 和每个窗口大小 w_i。我们在代码中使用 m = 4 和 w_i ∈ {n, n - 1, n - 2, n - 3}。
动态位置偏移
位置嵌入被添加到像 GPT-Neo 这样的模型中的 token 特征中。在训练过程中,这是按句子批次添加的,导致相同的句子在不同的训练批次和生成过程中具有不同偏移的位置嵌入。
为了改进对记忆后缀的提取,可以通过评估不同偏移位置并选择 "最佳" 的一个来恢复训练期间使用的位置。具体来说,对于给定的前缀 p,评估不同的偏移位置 C = c_i,其中 c_i 是一系列连续自然数的列表,c_i = {c_i1, ...},使得 |c_i| = |p|,并计算相应的困惑度值。然后选择具有最低困惑度值的位置作为生成后缀的位置。
通过评估不同的位置偏移来选择最佳的位置嵌入,来提高模型对记忆后缀的提取能力。这种方法可以很好地补充原有的位置嵌入方法,增强模型的性能。
其中 ψ(·) 表示位置编码层,φ(·) 表示特征映射函数,𝜙^ϕ^ 表示包含位置编码的特征映射函数,P 计算前缀的困惑度。
前瞻(Look-Ahead)
有时候在生成过程中只有一个或两个token被错误生成或者放置在不适当的位置。为了解决这个问题,可以使用一种技术,它涉及向前看ν步,并使用后续token的概率来通知当前token的生成。前瞻的目标是使用后验分布来帮助计算当前token的生成概率。后验被计算为:
设 Track(xstart, xend | xcond) 表示从 xstart 开始到 xend 结束,在 xcond 条件下的轨迹的概率乘积。那么我们可以写ν步后验为:
其中 Track 被计算为:
超参数优化
以上提到的技巧涉及到各种超参数,简单地使用最佳参数通常是次优的。
手动搜索最佳超参数,也称为 "babysitting",可能非常耗时。
所以其实可以使用多功能的架构自动调整方法,结合了高效的搜索和剪枝策略,根据先进的框架来确定优化的超参数。作为搜索算法,比如可以确定搜索目标为 MP(精确度),搜索的参数包括 top-k、nucleus-η、typical-ϕ、温度 T 和重复惩罚 r。
后缀排名改进
在生成多个后缀之后,会进行一个排名过程,使用困惑度 P 作为度量来消除那些不太可能的后缀。然而,下图的统计分析揭示了真实句子并不总是具有最低困惑度值
句子级标准
文本的熵,由 Zlib 压缩算法用位数来确定,是序列信息内容的量化指标。使用由 GPT-Neo 模型计算的给定句子的困惑度与相同句子的 Zlib 熵的比率作为成员推断的度量。此外还可以分析困惑度和 Zlib 熵的乘积的潜在效用,因为当模型对其预测有高度信心时,这两种度量都趋于减少。实验表明这两种度量在成员推断任务的整体性能上只产生了边际改进。
词级别标准
对高置信度的奖励。记忆数据的高置信度存在是被称为 "记忆效应"的现象的明确特征之一。我们对高置信度的 token 进行奖励。如果句子包含置信度高的 token,那么生成的 token 的可能性高于某个阈值,并且生成的 token 与其他 token 之间的差异也高于某个阈值,我们会将其排名提高。具体来说,对于生成后缀中的 token 𝑥𝑛x**n,如果其概率高于阈值 0.9,那么我们会从后缀 𝑠𝑖s**i 的分数中减去一个给定的数值 0.1(原始分数 𝑠𝑖s**i 是其困惑度)。
鼓励惊讶模式。根据最近的研究,人类文本生成经常表现出一种模式,即高困惑度的 token 被间歇性地包含,而不是一直选择低困惑度的 token。为了解决这个问题,通过只基于大多数 token 计算生成提示的困惑度来鼓励惊讶 token(高困惑度 token)的存在:
其中 µ 和 σ 分别表示一批中 𝑝(𝑥𝑛∣𝑥[0:𝑛−1])p(x**n∣x[0:n−1]) 的均值和标准差。使用这种方法,生成中包含的惊讶 token 不会在整体句子困惑度上产生负面影响,从而在成员推断期间增加了它们被选择的可能性。
实战
分析关键的函数
如下函数通过批处理方式高效地生成文本,并计算每个生成文本的损失,以评估模型在生成任务中的表现。这样可以帮助分析和改进生成文本的质量和模型的泛化能力。
该函数的主要目的是从给定的提示中生成文本,并计算生成文本的概率(或损失)。
输入参数
-
prompts
: 一个包含提示的numpy数组。 -
batch_size
: 每次处理的提示数量,默认值为32。
主要步骤
-
初始化:
-
初始化空列表用于存储生成的文本和相应的损失。
-
确定生成文本的总长度,这包括前缀和后缀的长度。
-
-
批次处理:
-
将提示按批次进行处理,批次大小由
batch_size
决定。 -
将每个批次的提示堆叠成一个批次,并转换为PyTorch张量。
-
-
生成文本:
-
使用模型生成文本。生成过程中:
-
将输入提示移至GPU。
-
设置生成文本的最大长度。
-
进行随机采样(
do_sample=True
),并只考虑概率最高的10个标记(top_k=10
)。 -
处理生成过程中可能出现的填充标记。
-
-
-
计算概率:
-
将生成的文本再次输入模型,计算每个标记的概率。
-
提取模型输出的logits,重新整形为二维张量。
-
使用交叉熵计算每个标记的损失。
-
将损失重新整形,并提取后缀部分的损失。
-
计算每个生成序列的平均损失,作为生成文本的概率。
-
-
存储结果:
-
将生成的文本和损失转换为numpy数组,并分别存储在列表中。
-
-
返回结果:
-
返回生成的文本和相应的损失,以numpy数组的形式返回。
-
如下函数组合在一起用于评估和比较语言模型的生成质量。write_array
函数保存生成结果,hamming
函数计算生成文本与真实文本之间的汉明距离,gt_position
函数计算真实答案的损失,compare_loss
函数比较生成文本与真实文本的损失,plot_hist
函数则用于可视化损失分布。通过这些步骤,可以全面评估模型在生成任务中的表现和准确性。
1. write_array
-
功能: 将numpy数组保存到文件中,文件名包含一个唯一标识符。
-
输入: 文件路径(包含格式化标记)、数组、唯一标识符(整数或字符串)。
-
实现: 使用给定的格式化标记生成文件名,然后将数组保存到该文件中。
2. hamming
-
功能: 计算生成序列与真实序列之间的汉明距离。
-
输入: 真实序列和生成的序列。
-
实现:
-
如果生成的序列是二维的,逐行计算每行的汉明距离。
-
否则,计算生成序列第一行与真实序列的汉明距离。
-
返回平均汉明距离和汉明距离的形状。
-
3. gt_position
-
功能: 计算真实答案序列的损失。
-
输入: 真实答案序列列表和批次大小(默认为50)。
-
实现:
-
将答案分批处理。
-
计算每个标记的logits。
-
使用交叉熵计算每个标记的损失。
-
提取后缀部分的损失,并计算平均损失。
-
返回每个序列的损失列表。
-
4. compare_loss
-
功能: 比较真实序列和生成序列的损失。
-
输入: 真实序列的损失和生成序列的损失。
-
实现:
-
将两组损失拼接在一起。
-
对每个序列的损失进行排序。
-
获取排序后的索引。
-
返回排序后的损失,排序索引和排名第一的索引。
-
5. plot_hist
-
功能: 绘制损失的直方图。
-
输入: 损失数组。
-
实现: 该函数目前为空,未实现绘图逻辑。
如下函数组合在一起用于处理和评估语言模型的生成任务。load_prompts
函数加载提示数据,is_memorization
函数评估生成模型是否记住了训练数据,error_100
函数计算在发生100次错误之前的匹配次数,precision_multiprompts
函数计算多提示生成序列的精确度,prepare_data
函数则准备实验所需的数据和目录结构。这些步骤帮助全面评估和改进模型的生成质量和泛化能力。
1. load_prompts
-
功能: 从指定目录加载numpy文件并转换为64位整数类型的numpy数组。
-
输入:
-
dir_
: 文件所在的目录路径。 -
file_name
: 文件名。
-
-
实现: 通过拼接目录路径和文件名构造完整文件路径,加载文件并转换数据类型。
2. is_memorization
-
功能: 计算生成的序列与真实序列完全匹配的比例,以确定模型是否记住了训练数据。
-
输入:
-
guesses
: 生成的序列。 -
answers
: 真实序列。
-
-
实现:
-
对比生成的序列和真实序列是否完全相同,统计完全匹配的次数。
-
计算匹配次数在所有生成序列中的比例。
-
3. error_100
-
功能: 计算在前100个错误之前的正确匹配次数。
-
输入:
-
guesses_order
: 按顺序排列的生成序列。 -
order
: 序列顺序索引。 -
answers
: 真实序列。
-
-
实现:
-
遍历生成序列,统计与真实序列匹配的次数,直到发生100次错误为止。
-
返回在发生100次错误之前的总遍历次数和超出100次错误的匹配数。
-
4. precision_multiprompts
-
功能: 计算多提示生成序列的精确度。
-
输入:
-
generations
: 多提示生成的序列。 -
answers
: 真实序列。 -
num_perprompt
: 每个提示生成的序列数量。
-
-
实现:
-
截取每个提示生成的前
num_perprompt
个序列。 -
检查每个提示生成的序列是否与真实序列匹配。
-
计算匹配的提示数量占总提示数量的比例。
-
5. prepare_data
-
功能: 准备数据和目录结构以进行实验。
-
输入:
-
val_set_num
: 验证集的数量。
-
-
实现:
-
构造实验目录和生成结果、损失结果的子目录。
-
加载提示数据,并提取验证集部分的提示数据。
-
返回构造的目录路径和提示数据。
-
### 如下函数组合在一起用于处理和评估语言模型的生成任务。
-
write_guesses_order
函数将生成的序列按顺序写入CSV文件,便于进一步分析。 -
edit_dist
函数计算生成序列和真实序列之间的编辑距离,这是评估生成质量的重要指标。 -
metric_print
函数计算并打印各种评估指标,包括精度、多提示精度、前100个错误之前的正确匹配数、汉明距离和编辑距离。这些指标帮助全面评估模型在生成任务中的表现和准确性。
1. write_guesses_order
-
功能: 将生成的序列按顺序写入CSV文件。
-
输入:
-
generations_per_prompt
: 每个提示生成的序列数。 -
order
: 序列的顺序索引。 -
guesses_order
: 生成的序列按顺序排列。
-
-
实现:
-
打开CSV文件进行写操作,文件名包含
generations_per_prompt
。 -
写入表头。
-
遍历序列索引和生成的序列,将每个序列按指定格式写入CSV文件。
-
2. edit_dist
-
功能: 计算生成序列和真实序列之间的编辑距离。
-
输入:
-
answers
: 真实序列。 -
generations_one
: 生成的单个序列。
-
-
实现:
-
初始化编辑距离总和为0。
-
遍历真实序列和生成序列,计算每对序列的编辑距离并累加。
-
返回平均编辑距离。
-
3. metric_print
-
功能: 计算并打印各种评估指标。
-
输入:
-
generations_one
: 单个生成序列。 -
all_generations
: 所有生成序列。 -
generations_per_prompt
: 每个提示生成的序列数。 -
generations_order
: 按顺序排列的生成序列。 -
order
: 序列的顺序索引。 -
val_set_num
: 验证集的数量。
-
-
实现:
-
加载真实答案数据。
-
打印生成序列和真实序列的形状。
-
计算生成序列的精度并打印。
-
计算多提示生成序列的精度并打印。
-
计算前100个错误之前的正确匹配数并打印。
-
计算生成序列和真实序列的汉明距离并打印。
-
计算生成序列和真实序列的编辑距离并打印。
-
返回各种评估指标。
-
我们首先来看基线的攻击效果
我们在前面提到Zlib 压缩算法,可以用来衡量文本的熵,即信息内容的量化指标。在这项研究中,Zlib 用于与语言模型计算的困惑度相结合,作为成员推断的一个度量标准。具体地,使用 GPT-Neo 模型对给定句子计算的困惑度与相同句子的 Zlib 熵的比值,来评估句子是否可能属于模型的训练数据集。但是 Zlib 方法的效果是有限的。尽管 Zlib 熵和困惑度都是衡量模型对句子预测信心的指标,且两者在模型高度自信时趋于减少,但它们在成员推断任务的整体性能上只产生了边际(即很小的)改进。这表明,尽管 Zlib 方法在理论上是一个有趣的尝试,但在实际应用中可能不是最有效的手段。所以我们可以来看看是否如此
首先来看看zlib在实现上的不同
generate_for_prompts
函数用于生成给定提示的输出序列,并计算每个生成序列的损失
输入参数
-
prompts
: 一个包含提示序列的numpy数组。 -
batch_size
: 每个批次处理的提示数量,默认值为32。
输出
-
生成的序列数组和对应的损失数组。
步骤
-
初始化:
-
generations
和losses
用于存储生成的序列和计算的损失。 -
generation_len
计算生成序列的长度,该长度为后缀和前缀的总和。
-
-
批次处理:
-
将提示序列按批次进行处理。
-
对每个批次,提取相应的提示序列,并将其转换为PyTorch张量。
-
-
生成序列:
-
在禁用梯度计算的上下文中,使用模型生成序列。
-
max_length
设置为生成序列的总长度。 -
do_sample=True
和top_k=10
控制生成策略。 -
pad_token_id=50256
设置填充标记ID,避免警告。
-
-
计算损失:
-
生成序列后,计算每个生成序列的概率。
-
将生成的序列作为输入和标签传递给模型。
-
提取logits并重新形状,以适应交叉熵损失计算。
-
计算每个标记的损失,只考虑后缀部分的损失。
-
-
压缩长度调整:
-
使用zlib库对每个生成的序列进行压缩,并获取压缩后的长度。
-
调整每个生成序列的损失,使其与压缩长度成正比。
-
-
结果存储:
-
将生成的序列和对应的损失添加到结果列表中。
-
最后,将结果转换为至少二维的numpy数组并返回。
-
该函数通过以下几个步骤生成序列并计算损失:
-
按批次加载提示序列。
-
使用预训练模型生成序列。
-
计算生成序列的损失。
-
通过压缩调整损失。
-
存储并返回生成的序列和损失。
这种方法既考虑了生成序列的质量(通过损失计算),又通过压缩长度的调整,间接考虑了序列的复杂性和压缩率。
执行后效果如下
之前还提到了动态上下文窗口(Dynamic Context Window)技术。
在语言模型生成文本时,如果生成了一个错误的token,可能会因为语言模型的自回归特性而导致后续的token也生成错误。通过使用动态上下文窗口,可以从不同长度的历史上下文中获取信息,这有助于减少这种错误传播。通过调整上下文窗口的大小,即考虑不同数量的之前生成的token,可以帮助模型更好地理解前缀的上下文,从而提高生成后缀的准确性。文中提到的实验结果显示,使用动态上下文窗口可以显著提高数据提取的准确性。动态上下文窗口允许模型在生成每个token时考虑不同长度的上下文,这增加了生成过程的灵活性,使模型能够根据当前的上下文信息选择最合适的token。
有两种实现动态上下文窗口的方法。第一种是加权平均策略(Weighted Average Strategy),第二种是基于投票机制的策略(Voting Strategy)。两种方法都旨在结合不同窗口大小生成的概率,以提高生成后缀的准确性。
我们首先来看代码上的不同
1. winlen_logits_output
-
功能: 计算输入序列的一部分(从
win_len
到input_len
的片段)的模型输出logits。 -
输入:
-
input_batch
: 输入序列的批次。 -
win_len
: 截断窗口的起始位置。 -
input_len
: 截断窗口的结束位置。 -
answer_batch
: 真实答案的批次。
-
-
实现:
-
禁用梯度计算以提高效率。
-
截取输入序列的指定部分并传递给模型,计算logits。
-
初始化一个空列表
val
,准备存储一些计算结果(但在此函数中并未实际使用)。 -
根据训练标志决定如何处理logits。
-
返回最后一层logits和空的
val
列表。
-
2. zlib_filter
-
功能: 预留的过滤函数,目前没有实现任何功能。
3. vote_for_the_one
-
功能: 通过投票机制选择最可能的输出序列。
-
输入:
-
last_logits
: 最后一层的logits。 -
k
: 用于投票的前k个logits。 -
answers
: 真实答案。 -
input_len
: 输入序列的长度。
-
-
实现:
-
初始化投票计数数组。
-
获取logits中每个序列的前k个最高值的索引。
-
为每个索引分配线性权重。
-
打印预测结果和原始结果的比较。
-
返回投票计数最高的索引作为最终预测。
-
4. logits_add
-
功能: 通过加权求和的方式整合logits,得到最终的预测。
-
输入:
-
last_logits
: 多个窗口的logits。 -
weight_win
: 每个窗口的权重。
-
-
实现:
-
使用权重加权求和各个窗口的logits。
-
返回加权求和后的logits中概率最高的索引作为最终预测。
-
这些函数用于处理和评估生成模型的输出:
-
winlen_logits_output
提取并计算输入序列部分片段的logits,帮助理解模型对不同输入片段的响应。 -
vote_for_the_one
使用投票机制从logits中选择最可能的输出,提高预测的准确性。 -
logits_add
通过加权求和不同窗口的logits,进一步优化预测结果。 -
zlib_filter
目前未实现,可能预留用于将来对数据进行某种过滤处理。
这用于生成给定提示的输出序列,并计算每个生成序列的损失的函数
输入参数
-
prompts
: 包含提示序列的numpy数组。 -
batch_size
: 每个批次处理的提示数量。 -
_SUFFIX_LEN
,_PREFIX_LEN
: 后缀和前缀的长度。 -
_DATASET_DIR.value
: 数据集的目录路径。 -
_val_set_num.value
: 用于加载的验证集数量。
输出
-
生成的序列数组 (
generations
) 和对应的损失数组 (losses
)。
主要步骤
-
初始化:
-
generations
和losses
初始化为空列表。 -
generation_len
计算生成序列的长度,为后缀和前缀长度之和。 -
answers
加载验证集的答案数据。
-
-
循环处理提示序列:
-
根据设定的批次大小,循环处理提示序列。
-
每次循环中,提取并准备输入的提示批次 (
prompt_batch
) 和对应的答案批次 (answers_batch
)。
-
-
生成序列:
-
使用带有截断窗口的方法生成序列,通过调用
gene_next_token
函数获取每次生成的下一个标记。 -
将生成的标记 (
generated_tokens
) 拼接在一起形成完整的生成序列。 -
将生成序列转换为PyTorch张量,并在禁用梯度计算的上下文中生成模型输出 (
generated_tokens
是最终的生成序列)。
-
-
计算损失:
-
计算每个生成序列的logits。
-
使用交叉熵损失函数计算损失。
-
将损失加入到
losses
列表中。
-
-
返回结果:
-
将
generations
和losses
转换为至少二维的numpy数组,并返回。
-
执行后效果如下
在上图可以看到指标有极大的提升(可以看precision,精确度是指正确生成的后缀占给定前缀总数的比例。这是通过比较生成的后缀和实际的训练数据后缀来计算的。精确度反映了模型生成正确后缀的能力。这个值越高说明效果越好;或者也可以看hamming dist,汉明距离是用来衡量两个等长字符串之间差异的指标,计算为两个字符串对应位置上不同符号的数量。在训练数据提取的上下文中,汉明距离用来定量评估生成后缀与真实后缀之间的相似度,提供了一个在token级别上对提取方法性能的评估。这个值越小,说明效果越好)
在来看看我们在上文提到的另一个改进策略:一种基于词级别的排名方法,称为 "Reward on high confidence"(简称 highconf 方法)。这种方法的核心思想是奖励那些在生成后缀中包含高置信度 token 的候选后缀。具体来说,如果一个生成的后缀中的某个 token 具有高于特定阈值(例如 0.9)的概率,那么这个后缀在排名时会被赋予更高的分数。这种策略的目的是利用语言模型对其预测的置信度来提高提取任务的性能。
对应的代码如下
这段代码的功能是生成给定提示的输出序列,并计算每个生成序列的损失。
输入参数
-
prompts
: 包含提示序列的numpy数组。 -
batch_size
: 每个批次处理的提示数量。默认为32。
输出
-
生成的序列数组 (
generations
) 和对应的损失数组 (losses
)。
主要步骤
-
初始化:
-
generations
和losses
初始化为空列表。 -
generation_len
计算生成序列的长度,为后缀和前缀长度之和。 -
将输入的
batch_size
设置为32,这个值在后续循环中使用。
-
-
循环处理提示序列:
-
根据设定的批次大小,循环处理提示序列。
-
每次循环中,提取并准备输入的提示批次 (
prompt_batch
),并将其转换为PyTorch张量 (input_ids
)。
-
-
生成序列:
-
使用带有截断的方法生成序列,通过调用
_MODEL.generate
函数获取生成的标记 (generated_tokens
)。 -
在生成的标记上禁用梯度计算,并通过计算模型输出 (
outputs.logits
) 获得每个标记的logits值。
-
-
损失计算:
-
计算每个标记的损失 (
loss_per_token
),使用交叉熵损失函数 (torch.nn.functional.cross_entropy
)。 -
对损失进行后处理:
-
使用标准差过滤异常值,如果损失超出3倍标准差范围,则设置为1。
-
根据前两个最高的logits分数之间的差异和是否大于0.5来调整损失值。
-
最后,计算每个生成序列的平均损失 (
likelihood
)。
-
-
-
结果整理:
-
将生成的序列 (
generated_tokens
) 和损失 (likelihood
) 添加到generations
和losses
列表中。
-
-
返回结果:
-
将
generations
和losses
转换为至少二维的numpy数组,并返回。
-
执行后如下所示
在上图中,也是用我们之前说的方法,看指标,precision,hamming dist等都相比基线方法有了较大提升。也就表明我们在本文中所说的这些策略都是有效的。
更多网安技能的在线实操练习,请点击这里>>