pytorch之诗词生成3--utils

先上代码:


import numpy as np
import settingsdef generate_random_poetry(tokenizer, model, s=''):"""随机生成一首诗:param tokenizer: 分词器:param model: 用于生成古诗的模型:param s: 用于生成古诗的起始字符串,默认为空串:return: 一个字符串,表示一首古诗"""# 将初始字符串转成tokentoken_ids = tokenizer.encode(s)# 去掉结束标记[SEP]token_ids = token_ids[:-1]while len(token_ids) < settings.MAX_LEN:# 进行预测,只保留第一个样例(我们输入的样例数只有1)的、最后一个token的预测的、不包含[PAD][UNK][CLS]的概率分布output = model(np.array([token_ids, ], dtype=np.int32))_probas = output.numpy()[0, -1, 3:]del output# print(_probas)# 按照出现概率,对所有token倒序排列p_args = _probas.argsort()[::-1][:100]# 排列后的概率顺序p = _probas[p_args]# 先对概率归一p = p / sum(p)# 再按照预测出的概率,随机选择一个词作为预测结果target_index = np.random.choice(len(p), p=p)target = p_args[target_index] + 3# 保存token_ids.append(target)if target == 3:breakreturn tokenizer.decode(token_ids)def generate_acrostic(tokenizer, model, head):"""随机生成一首藏头诗:param tokenizer: 分词器:param model: 用于生成古诗的模型:param head: 藏头诗的头:return: 一个字符串,表示一首古诗"""# 使用空串初始化token_ids,加入[CLS]token_ids = tokenizer.encode('')token_ids = token_ids[:-1]# 标点符号,这里简单的只把逗号和句号作为标点punctuations = [',', '。']punctuation_ids = {tokenizer.token_to_id(token) for token in punctuations}# 缓存生成的诗的listpoetry = []# 对于藏头诗中的每一个字,都生成一个短句for ch in head:# 先记录下这个字poetry.append(ch)# 将藏头诗的字符转成token idtoken_id = tokenizer.token_to_id(ch)# 加入到列表中去token_ids.append(token_id)# 开始生成一个短句while True:# 进行预测,只保留第一个样例(我们输入的样例数只有1)的、最后一个token的预测的、不包含[PAD][UNK][CLS]的概率分布output = model(np.array([token_ids, ], dtype=np.int32))_probas = output.numpy()[0, -1, 3:]del output# 按照出现概率,对所有token倒序排列p_args = _probas.argsort()[::-1][:100]# 排列后的概率顺序p = _probas[p_args]# 先对概率归一p = p / sum(p)# 再按照预测出的概率,随机选择一个词作为预测结果target_index = np.random.choice(len(p), p=p)target = p_args[target_index] + 3# 保存token_ids.append(target)# 只有不是特殊字符时,才保存到poetry里面去if target > 3:poetry.append(tokenizer.id_to_token(target))if target in punctuation_ids:breakreturn ''.join(poetry)

我们首先看第一个函数generate_random_poetry,用于随机生成古诗。

def generate_random_poetry(tokenizer, model, s=''):"""随机生成一首诗:param tokenizer: 分词器:param model: 用于生成古诗的模型:param s: 用于生成古诗的起始字符串,默认为空串:return: 一个字符串,表示一首古诗"""

 tokenizer是一个分词器,用于将文本转化为模型可以理解的token序列。

# 将初始字符串转成tokentoken_ids = tokenizer.encode(s)# 去掉结束标记[SEP]token_ids = token_ids[:-1]

是将我们传入的字符串转化为id序列。并使用切片操作切除token序列的最后一个元素。这个元素是用来标记结束标志[SEP]的。

我们接着看:

 while len(token_ids) < settings.MAX_LEN:# 进行预测,只保留第一个样例(我们输入的样例数只有1)的、最后一个token的预测的、不包含[PAD][UNK][CLS]的概率分布output = model(np.array([token_ids, ], dtype=np.int32))_probas = output.numpy()[0, -1, 3:]del output

这里我们就可以看出为什么我们要对最后一个 token进行切片处理了,是为了控制预测的长度。

之后我们进行预测,通过model模型对当前的token_ids序列进行预测,model接受一个形状为(batch_size,sequence_length)的输入,这里batch_size为1,sequence_length为token_ids的长度。因此,我们将token_ids包装成一个形状为(1,sequence_length)的numpy数组,并将其传递给model进行预测。

输出的结果是一个3D数组,形状为(1,sequence_length,vocab_size)。其中vocab_size是词汇表,得到的是对下一个词的预测。

然后通过output.numpy()将输出转化为numpy数组,由于我们只关注当前序列的最后一个token的预测结果,所以使用[0,-1,3:]来获取这部分概率分布。具体的,[0,-1]表示获取第一个样例的最后一个token的预测,而[3:]表示去除前三个token,即[PAD],[UNK]和[CLS]。

最后,通过del output释放output变量所占用的内存空间,这样做是为了及时释放内存,避免内存占用过多。

解释一下四个特殊的token在自然语言处理任务中的不同作用。

  • [PAD](填充):在序列中用于填充长度不足的部分,使得序列长度统一,在很多深度学习模型中,要求输入的序列长度是相同的,因此当序列长度不足时,可以使用[PAD]进行填充,使得序列达到统一的长度。
  • [UNK](未知):用于表示模型未见过的或者无法识别的词汇,当模型在处理文本时遇到不在词汇表中的词汇,就会用[UNK]表示,这样可以避免模型对于未知词汇的处理错误。
  • [CLS](分类):在自然语言处理中的一些任务中,例如文本分类,文本语义相似度等,[CLS]被用作特殊的起始标记。模型会将[CLS]作为整个序列的表示,并用于后续任务的处理。例如:对于文本分类任务,模型可以根据[CLS]表示进行预测分类标签。
  • [SEP](分隔符):用于表示序列的分隔,常见于处理多个句子或多个文本之间的任务。他可以用来分隔句子,句子对,文本片段等,在一些任务中,如文本对分类,问答系统等,输入可能包含多个句子或文本片段,模型可以通过识别[SEP]来理解不同句子之间的关系。另外,[SEP]还可以用于标记序列的结束,在一些 情况下,序列的长度是固定的,使用[SEP]标记还可以表示序列的结束。

继续看我们的代码:

 # 按照出现概率,对所有token倒序排列p_args = _probas.argsort()[::-1][:100]# 排列后的概率顺序p = _probas[p_args]# 先对概率归一p = p / sum(p)# 再按照预测出的概率,随机选择一个词作为预测结果target_index = np.random.choice(len(p), p=p)target = p_args[target_index] + 3# 保存token_ids.append(target)

 首先,通过argsort()函数对得到的所有outputs进行从小到大的排序,然后经过[::-1]将排序结果进行反转,这样就变成了从大到小的排序,[:100]表示只选择排名前100的token。返回的是_probas中元素从小到大排序的索引数组。

得到索引数组之后,我们通过p=_probas[p_args],根据排名靠前的token的索引p_args,从预测的概率分布_probas中获取对应的概率值,这样得到的p就是排列后的概率顺序(概率值从大到小)。

之后对概率进行归一化,将概率值除以概率之和,使其概率和为1。

target_index = np.random.choice(len(p), p=p)根据参数p(给定的概率分布)在长度为len(p)的范围内进行随机选择,并返回选择的索引target_index。(显然,元素对应的p值越大,被选择的概率就越高)。

后面的target = p_args[target_index] + 3,其中+3是必不可少的,我刚刚因为没有+3就运行报错了,因为我们得到的target是对应于token_ids的索引,而token_ids的前三项是特殊字符,为了和实际中的模型相对应,我们需要将预测的token编号做一些偏离处理。

最后我们将我们的预测列表中加上我们预测的词的索引值,也就是target。

看着一段:

 if target == 3:break
return tokenizer.decode(token_ids)

如果生成的target的值是3,表示生成的token是特殊的[CLS]标记,这意味着生成的序列已结束。
返回我们解码之后的结果也就是生成的诗词。等于3的情况是存在的,3意味着结束标志SPE,当句子的长度合适时,其预测的下一个词很可能是3,我们规定3也属于可预测范围。

看生成藏头诗的代码:

def generate_acrostic(tokenizer, model, head):"""随机生成一首藏头诗:param tokenizer: 分词器:param model: 用于生成古诗的模型:param head: 藏头诗的头:return: 一个字符串,表示一首古诗"""# 使用空串初始化token_ids,加入[CLS]token_ids = tokenizer.encode('')token_ids = token_ids[:-1]# 标点符号,这里简单的只把逗号和句号作为标点punctuations = [',', '。']punctuation_ids = {tokenizer.token_to_id(token) for token in punctuations}# 缓存生成的诗的listpoetry = []# 对于藏头诗中的每一个字,都生成一个短句for ch in head:# 先记录下这个字poetry.append(ch)# 将藏头诗的字符转成token idtoken_id = tokenizer.token_to_id(ch)# 加入到列表中去token_ids.append(token_id)# 开始生成一个短句while True:# 进行预测,只保留第一个样例(我们输入的样例数只有1)的、最后一个token的预测的、不包含[PAD][UNK][CLS]的概率分布output = model(np.array([token_ids, ], dtype=np.int32))_probas = output.numpy()[0, -1, 3:]del output# 按照出现概率,对所有token倒序排列p_args = _probas.argsort()[::-1][:100]# 排列后的概率顺序p = _probas[p_args]# 先对概率归一p = p / sum(p)# 再按照预测出的概率,随机选择一个词作为预测结果target_index = np.random.choice(len(p), p=p)target = p_args[target_index] + 3# 保存token_ids.append(target)# 只有不是特殊字符时,才保存到poetry里面去if target > 3:poetry.append(tokenizer.id_to_token(target))if target in punctuation_ids:breakreturn ''.join(poetry)

我们看一下tokenizer.id_to_token(3,2,1,0)

这段代码和上面的随机生成古诗类似,不同的是:

上面的是完整生成一首古诗,其中就算遇到“,”或者“。”不会中断,直到遇到结束符号[SEP]才会中止。

而我们生成的藏头诗是对每一个字进行生成,这里遇到“,”,“。”就进行下一个字的继续生成。然后将每个字的自动生成进行拼接。得到完整的诗句。(当然,遇到[SEP]也是进行终止的,保证了句式的协调)。

return ''.join(poetry)

作用是将我们的peotry中的内容(短句),连接成一个字符串。并将其作为函数值返回。

如果不加join函数的话,输出是一个列表:

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/539895.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Linux常用操作命令和服务器硬件基础知识

&#x1f31f; 前言 欢迎来到我的技术小宇宙&#xff01;&#x1f30c; 这里不仅是我记录技术点滴的后花园&#xff0c;也是我分享学习心得和项目经验的乐园。&#x1f4da; 无论你是技术小白还是资深大牛&#xff0c;这里总有一些内容能触动你的好奇心。&#x1f50d; &#x…

区别于传统家!三翼鸟定制智慧家电家居一体化场景

在这个科技创新、智能AI主导的时代&#xff0c;寻求更便捷智慧、舒心适宜、一体化的居家场景&#xff0c;成为一个时代的命题和竞赛&#xff0c;也是家居行业共同奔赴的使命。在纷繁复杂的竞争格局和方向答案中&#xff0c;一条清晰坚定的路径正在显露出来…… AWE前一天&…

搭建谷歌Gemini

前言 Gemini是Google AI于2023年发布的大型语言模型&#xff0c;拥有强大的文本生成、理解和转换能力。它基于Transformer模型架构&#xff0c;并使用了大量文本和代码数据进行训练。Gemini可以执行多种任务&#xff0c;包括&#xff1a; 生成文本&#xff1a;可以生成各种类…

物联网技术助力智慧城市转型升级:智能、高效、可持续

目录 一、物联网技术概述及其在智慧城市中的应用 二、物联网技术助力智慧城市转型升级的路径 1、提升城市基础设施智能化水平 2、推动公共服务智能化升级 3、促进城市治理现代化 三、物联网技术助力智慧城市转型升级的成效与展望 1、成效显著 2、展望未来 四、物联网技…

Excel第26享:模糊查找之Hlookup函数与通配符的嵌套

1、需求描述 如下图所示&#xff0c;现第一行有三个参考值&#xff1a;人S、羊E、猪3&#xff0c;在第三行有5个字&#xff1a;马、牛、人、羊、猪&#xff0c;每个字如果出现在第一行的三个参考值中&#xff0c;就返回该单元格的数值。如&#xff0c;人&#xff0c;就返回“人…

画图实战-Python实现某产品全年销量数据多种样式可视化

画图实战-Python实现某产品全年销量数据多种样式可视化 学习心得Matplotlib说明什么是Matplotlib&#xff1f;Matplotlib特性Matplotlib安装 产品订单量-折线图某产品全年订单量数据数据提取和分析绘制折线图 产品订单&销售额-条形图某产品全年订单&销售额数据绘制条形…

圈子社交系统-多人语音-交友-陪玩-活动报名-商城-二手论坛-源码交付,支持二开!

圈子小程序适用于多种场景&#xff0c;涵盖了各个领域的社交需求。以下是一些常见的适用场景&#xff1a; 兴趣社区&#xff1a; 用户可以加入自己感兴趣的圈子&#xff0c;与志同道合的人一起讨论交流&#xff0c;分享经验和知识。 行业交流&#xff1a; 各个行业可以建立自…

【大模型系列】图片生成(DDPM/VAE/StableDiffusion/ControlNet/LoRA)

文章目录 1 DDPM(UC Berkeley, 2020)1.1 如何使用DDPM生成图片1.2 如何训练网络1.3 模型原理 2 VAE:Auto-Encoding Variational Bayes(2022&#xff0c;Kingma)2.1 如何利用VAE进行图像增广2.2 如何训练VAE网络2.3 VAE原理2.3.1 Auto-Encoder2.3.2 VAE编码器2.3.3 VAE解码器 3 …

【MySQL性能优化】- 一文了解MVCC机制

MySQL理解MVCC &#x1f604;生命不息&#xff0c;写作不止 &#x1f525; 继续踏上学习之路&#xff0c;学之分享笔记 &#x1f44a; 总有一天我也能像各位大佬一样 &#x1f3c6; 博客首页 怒放吧德德 To记录领地 &#x1f31d;分享学习心得&#xff0c;欢迎指正&#xff…

多媒体操作流程

&#xff01; 从左至右依次为&#xff1a;话筒、投影遥控器、ppt演讲笔、幕布升降遥控器、无线投屏连接器 主机箱 投影仪 二、操作流程 1、打开主机电源&#xff1a;最下面两台设备的开关打开 2、打开投影仪&#xff1a;用投影遥控器对准投影仪按开机键&#xff08;如无需用到…

linux信号的概念

目录 1.预备 2.信号如何产生 1.引入 2.原理 3.总结 3.接口 1.singal函数 2.kill函数 3.raise函数&#xff08;给自己发信号&#xff09; 4.abort函数&#xff08;给自己发送6号信号&#xff09; 4.异常 1.现象 2.原理 5.core和term区别 6.由软件条件产生信号 3.…

Linux系统搭建DataEase并结合内网穿透实现任意设备公网查看本地数据

&#x1f308;个人主页: Aileen_0v0 &#x1f525;热门专栏: 华为鸿蒙系统学习|计算机网络|数据结构与算法|MySQL| ​&#x1f4ab;个人格言:“没有罗马,那就自己创造罗马~” #mermaid-svg-HLq7cU4G5of6W4QU {font-family:"trebuchet ms",verdana,arial,sans-serif;f…