浅谈LLAMA2核心函数generate源码

在学习LLAMA2的generate源码之前,先介绍Temperature超参数及sample_top_p的原理。

Temperature

Temperature 是一个超参数,可用于控制生成语言模型中生成文本的随机性和创造性。用于调整模型的softmax输出层中预测词的概率。

softmax函数:
p ( x i ) = e x i ∑ j = 1 V e x j p\left(x_i\right)=\frac{e^{x_i}}{\sum_{j=1}^V e^{x_j}} p(xi)=j=1Vexjexi

Temperature 参数(T)添加到softmax函数:
p ( x i ) = e x i T ∑ j = 1 V e x j T p\left(x_i\right)=\frac{e^{\frac{x_i}{T}}}{\sum_{j=1}^V e^{\frac{x_j}{T}}} p(xi)=j=1VeTxjeTxi
Temperature参数通常设置为 0.1 到 1.0 之间(T=1时形变为标准的Softmax函数),下图分别显示了 x i / T x_i/T xi/T在5:0.5和5:0.1时的图像(紫线为softmax,黑线为添加T参数的softmax),可以看到:

  • 当T值更大时,函数图像会变的更加的平缓,预测词的概率被拉平,这意味着所有词被选择的可能性更大。 这会产生更有创意和多样化的文本,因为模型更有可能生成不寻常或意想不到的词。

  • 当T值更小时,函数图像会变的更加的陡峭,预测词的概率会变尖锐,这意味着选择最有可能的词的概率更高。 这会产生更保守和可预测的文本,因为模型不太可能生成意想不到或不寻常的词。

在这里插入图片描述

x i / T x_i/T xi/T=5:0.5

在这里插入图片描述

x i / T 5 = 0.1 x_i/T5=0.1 xi/T5=0.1

小结:Temperature 参数是文本生成模型中用于控制生成文本的随机性和创造性的一个重要的超参数。

sample_top_p

在这里插入图片描述

平缓和陡峭的概率分布图-文献【2】

采样意味着根据当前条件概率分布随机选择输出词 ,使用采样方法时文本生成本身不再是确定性的。对单词序列进行采样时的大问题: 模型通常会产生不连贯的乱码。在LLAMA2中,缓解这一问题的方式是通过top_p(也称:nucleus sampling)

def sample_top_p(probs, p):probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)probs_sum = torch.cumsum(probs_sort, dim=-1)mask = probs_sum - probs_sort > pprobs_sort[mask] = 0.0# 归一化probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))# multinomial为多项式抽样函数next_token = torch.multinomial(probs_sort, num_samples=1)next_token = torch.gather(probs_idx, -1, next_token)return next_token_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]

sample_top_p函数的作用:每个时间步,按照字出现的概率由高到底排序,当概率之和大于top-p的时候,就不取后面的样本了。然后对取到的这些字的概率重新归一化后,进行采样。这样做的好处是,既保证了质量,又增加了适当的随机性。

核心函数generate()

这一块直接在代码中进行注释:

def generate(self,prompt_tokens: List[List[int]],  # 输入的提示max_gen_len: int,  # 最大生成长度temperature: float = 0.6,  # 影响生成文本的随机性top_p: float = 0.9,  # 用于决定采样过程中保留的 token 集合的概率阈值logprobs: bool = False,  # 是否返回每个 token 的对数概率echo: bool = False,  # 是否返回输入的提示
) -> Tuple[List[List[int]], Optional[List[List[float]]]]:# ---------------------------初始化长度为 total_len tokens张量,并填充 pad_id----------------------------------params = self.model.paramsbsz = len(prompt_tokens)assert bsz <= params.max_batch_size, (bsz, params.max_batch_size)min_prompt_len = min(len(t) for t in prompt_tokens)max_prompt_len = max(len(t) for t in prompt_tokens)assert max_prompt_len <= params.max_seq_lentotal_len = min(params.max_seq_len, max_gen_len + max_prompt_len)pad_id = self.tokenizer.pad_idtokens = torch.full((bsz, total_len), pad_id, dtype=torch.long, device="cuda")# 将prompt_tokens中的token复制到tokens张量中。for k, t in enumerate(prompt_tokens):tokens[k, : len(t)] = torch.tensor(t, dtype=torch.long, device="cuda")if logprobs:# 创建一个与tokens相同形状的token_logprobs张量,并用0填充token_logprobs = torch.zeros_like(tokens, dtype=torch.float)prev_pos = 0eos_reached = torch.tensor([False] * bsz, device="cuda")input_text_mask = tokens != pad_id# -------------------------------------------------------------for cur_pos in range(min_prompt_len, total_len):# 调用模型的forward方法获取logitslogits = self.model.forward(tokens[:, prev_pos:cur_pos], prev_pos)if logprobs:# 计算token level的logprobstoken_logprobs[:, prev_pos + 1: cur_pos + 1] = -F.cross_entropy(input=logits.transpose(1, 2),target=tokens[:, prev_pos + 1: cur_pos + 1],reduction="none",ignore_index=pad_id,)# 根据温度参数和top_p参数对logits进行softmax和采样,得到下一个tokenif temperature > 0:# sample_top_p函数对probs进行采样probs = torch.softmax(logits[:, -1] / temperature, dim=-1)next_token = sample_top_p(probs, top_p)else:# 将logits中概率最大的token作为下一个token。next_token = torch.argmax(logits[:, -1], dim=-1)next_token = next_token.reshape(-1)# only replace token if prompt has already been generatednext_token = torch.where(input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token)# tokens张量更新tokens[:, cur_pos] = next_tokeneos_reached |= (~input_text_mask[:, cur_pos]) & (next_token == self.tokenizer.eos_id)prev_pos = cur_pos# 检查是否已经生成了所有的eos token,如果是则停止生成if all(eos_reached):breakif logprobs:# token_logprobs列表化token_logprobs = token_logprobs.tolist()out_tokens, out_logprobs = [], []for i, toks in enumerate(tokens.tolist()):# cut to max gen len# 对于 tokens 张量中的每一行(即每一个生成的序列),如果 echo 参数为假,则去掉提示部分start = 0 if echo else len(prompt_tokens[i])toks = toks[start: len(prompt_tokens[i]) + max_gen_len]probs = Noneif logprobs:probs = token_logprobs[i][start: len(prompt_tokens[i]) + max_gen_len]# cut to eos tok if any# 存在结束标记,则去掉结束标记之后的部分if self.tokenizer.eos_id in toks:eos_idx = toks.index(self.tokenizer.eos_id)toks = toks[:eos_idx]probs = probs[:eos_idx] if logprobs else Noneout_tokens.append(toks)out_logprobs.append(probs)# 返回生成的tokens和对数概率(如果logprobs参数为真)return (out_tokens, out_logprobs if logprobs else None)

总结

本文介绍了Temperature以及sample_top_p的原理,并且阅读了LLAMA2的核心生成函数的源码。关于更多细节实现,请关注llama源码。

参考文献

【1】https://github.com/facebookresearch/llama/blob/main/llama/generation.py

【2】The Curious Case of Neural Text Degeneration

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/64000.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

日常BUG—— maven编译报错

&#x1f61c;作 者&#xff1a;是江迪呀✒️本文关键词&#xff1a;日常BUG、BUG、问题分析☀️每日 一言 &#xff1a;存在错误说明你在进步&#xff01; 一、问题描述 一个maven项目在由于在代码中书写了如下代码&#xff1a; public static ConcurrentMap<…

shell和反弹shell

文章目录 是什么&#xff1f;bash是什么&#xff1f;反弹shell 是什么&#xff1f; Shell 是一个用 C 语言编写的程序&#xff0c;它是用户使用 Linux 的桥梁。Shell 既是一种命令语言&#xff0c;又是一种程序设计语言。 Shell 是指一种应用程序&#xff0c;这个应用程序提供了…

《甲午》观后感——GPT-3.5所写

《甲午》是一部令人深思的纪录片&#xff0c;通过生动的画面和真实的故事&#xff0c;向观众展示了中国历史上的一段重要时期。观看这部纪录片&#xff0c;我深受触动&#xff0c;对历史的认识也得到了深化。 首先&#xff0c;这部纪录片通过精心搜集的历史资料和珍贵的影像资料…

仅使用 CSS 创建打字机动画效果

创建打字机效果比您想象的要容易。虽然实现这种效果的最常见方法是使用 JavaScript&#xff0c;但我们也可以使用纯 CSS 来创建我们的打字机动画。 在本文中&#xff0c;我们将了解如何仅使用 CSS 创建打字机动画效果。它简单、漂亮、容易。我们还将看看使用 CSS 与 JavaScrip…

【LeetCode】102. 二叉树的层序遍历、107. 二叉树的层序遍历 II

作者&#xff1a;小卢 专栏&#xff1a;《Leetcode》 喜欢的话&#xff1a;世间因为少年的挺身而出&#xff0c;而更加瑰丽。 ——《人民日报》 102. 二叉树的层序遍历 102. 二叉树的层序遍历 给你二叉树的根节点 root &#xff0c;返回其节…

IOC容器

DI&#xff08;依赖注入&#xff09;&#xff1a;DI&#xff08;Dependency Injection&#xff09;是一种实现松耦合和可测试性的软件设计模式。它的核心思想是将依赖关系的创建与管理交给外部容器&#xff0c;使得对象之间只依赖于接口而不直接依赖于具体实现类。通过依赖注入…

C语言三子棋小游戏--数组的应用

注&#xff1a;在最后面&#xff0c;完整源码会以两种形式展现。在讲解时&#xff0c;以三个源文件的形式。 前言&#xff1a;三子棋&#xff0c;顾名思义&#xff0c;就是三个子连在一起就可以胜出。在本节我们要介绍的三子棋模式是这样子的&#xff1a;在键盘输入坐标&#x…

剑指offer14-I.剪绳子

昨天写的那道题是数组中除了一个元素外其余元素的乘积&#xff0c;这道题自然就想到了把一个数分成两个的和&#xff0c;然后积就是这两个数的积&#xff0c;而这两个数中的每个数又可以分成两个数&#xff0c;所以可以用动态规划的方法&#xff0c;dp[i] dp[j]*dp[i-j]。但是…

预测知识 | 神经网络、机器学习、深度学习

预测知识 | 预测技术流程及模型评价 目录 预测知识 | 预测技术流程及模型评价神经网络机器学习深度学习参考资料 神经网络 神经网络&#xff08;neural network&#xff09;是机器学习的一个重要分支&#xff0c;也是深度学习的核心算法。神经网络的名字和结构&#xff0c;源自…

Java版工程行业管理系统源码-专业的工程管理软件-em提供一站式服务 em

​ Java版工程项目管理系统 Spring CloudSpring BootMybatisVueElementUI前后端分离 功能清单如下&#xff1a; 首页 工作台&#xff1a;待办工作、消息通知、预警信息&#xff0c;点击可进入相应的列表 项目进度图表&#xff1a;选择&#xff08;总体或单个&#xff09;项目…

24届近3年青岛理工大学自动化考研院校分析

今天给大家带来的是青岛理工大学控制考研分析 满满干货&#xff5e;还不快快点赞收藏 一、青岛理工大学 学校简介 青岛理工大学是一所以工为主&#xff0c;土木建筑、机械制造、环境能源学科特色鲜明&#xff0c;理工经管文法艺等学科协调发展的多科性大学。是国家首批地方…

NIDS网络威胁检测系统-Golang

使用技术&#xff1a; Golang Gin框架 前端三件套 演示画面&#xff1a; 可以部署在linux和window上 目前已在Kali2021和Window10上进行测试成功