用Transformers实现简单的大模型文本生成

根据输入的prompt,生成一段指定长度的文字。Llama跑起来太慢了,这里用GPT-2作为列子。

from transformers import GPT2LMHeadModel, GPT2Tokenizer
import torchtokenizer = GPT2Tokenizer.from_pretrained("gpt2")
model = GPT2LMHeadModel.from_pretrained("gpt2", pad_token_id=tokenizer.eos_token_id)prompt_text = "This is a nice story that makes me"
max_gen_len = 9
input_ids = tokenizer.encode(prompt_text, return_tensors="pt")
prompt_len = input_ids.shape[-1]
print(f'length of prompt: {prompt_len}, length of generation: {max_gen_len}')print('>>> Way 1: Use `model.generate()` to generate tokens with KV cache')
generated_ids = model.generate(input_ids, max_length=prompt_len+max_gen_len, pad_token_id=tokenizer.eos_token_id)
print('generated_ids:', generated_ids)
generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
print('generated_text:', generated_text)print('>>> Way 2: Use `for loop` to generate tokens with KV cache')
past_key_values = None
print('Prefill Stage..')
outputs = model(input_ids=input_ids, past_key_values=past_key_values, use_cache=True)
past_key_values = outputs.past_key_values
pred_token_idx = outputs.logits[:, -1, :].argmax(dim=-1).unsqueeze(1)
generated_ids = [pred_token_idx.item()]
print('Decoding/Generating Stage..')
for _ in range(max_gen_len - 1):outputs = model(input_ids=pred_token_idx, past_key_values=past_key_values, use_cache=True)past_key_values = outputs.past_key_values  # if use_cache=False, past_key_values=Nonepred_token_idx = outputs.logits[:, -1, :].argmax(dim=-1).unsqueeze(1)generated_ids.append(pred_token_idx.item())
print('generated_ids:', generated_ids)
generated_text = tokenizer.decode(torch.Tensor(generated_ids), skip_special_tokens=True)
print('generated_text:', prompt_text + generated_text)

这里提供了两种方法实现文本生成:

  • model.generate():给模型输入prompt,一次性得到所有输出的token,最方便的写法
  • for loop:这是StreamingLLM中给的代码例子,也揭示了自回归生成的原理。首先是prefill阶段,输入prompt,得到KV cache和生成的第一个token;然后是decoding/generating阶段,开始自回归生成token,每次生成的模型输入是当前新token和KV cache,每生成一个token都会自动更新KV cache

最终,可以看到两种方法生成的文本是一模一样的:

在这里插入图片描述

进一步探究自回归过程中维度的变化:

在这里插入图片描述

这个就是标准的自回归生成任务了,不管是GPT还是Llama,都是如此(至少PyTorch版本都是这样的,Flax版本的KV cache有点奇怪,用的lax.dynamic_update_slice(cached_key.value, key, indices),KV cache的维度并没有随着token的生成而增加…不太明白)。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/697487.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

cookie、session、token、表单、json、jsonp、websocket、ajax都是什么

前后端数据交互的几种方式 1.cookie Cookie是服务器保存在客户端的一小段数据,(使用Cookie的前提是客户端浏览器允许使用Cookie并对此做出相应的设置。) cookie是一种存储在用户计算机上的小型数据文件,常用于在web应用程序中跟…

Prompt Engineering ,Fine-tuning , RAG ?

Prompt Engineering ,Fine-tuning , RAG 总结:1 prompt engineering2 RAG (Retrieval Augmented Generation)**RAG特点****RAG优势****RAG劣势** 3 微调(Fine-tuning)**微调特点****微调优势****微调劣势** 4 三者共性和区别5 RAG和微调的适应…

泽众一站式性能测试平台P-One,产品菜单和功能优化升级预告,用户操作更便捷

泽众一站式性能测试平台P-One为了进一步提升用户的使用体验,本轮将对P-One的产品菜单和功能进行优化升级。 首先给大家简单介绍一下P-One产品: P-One 是泽众软件自主研发的一套一站式性能测试平台软件产品。实现了集管理、设计、压测、监控以及分析于一体…

每日OJ题_贪心算法四⑧_力扣767. 重构字符串

目录 力扣767. 重构字符串 解析代码 力扣767. 重构字符串 767. 重构字符串 难度 中等 给定一个字符串 s ,检查是否能重新排布其中的字母,使得两相邻的字符不同。 返回 s 的任意可能的重新排列。若不可行,返回空字符串 "" 。 …

Docker-compose部署TRX节点

1、编写Dockerfile rootubuntu:~# mkdir /data/docker-compose/trx -p rootubuntu:~# cd /data/docker-compose/trx/ rootubuntu:/data/docker-compose/trx# ls rootubuntu:/data/docker-compose/trx# vim Dockerfile rootubuntu:/data/docker-compose/trx# cat Dockerfile FR…

数据结构(Java实现):顺序表

目录 1. 线性表2.顺序表2.1自己实现一个List接口2.2 IList接口的实现2.3 测试代码 1. 线性表 线性表(linear list)是n个具有相同特性的数据元素的有限序列。 线性表是一种在实际中广泛使用的数据结构,常见的线性表:顺序表、链表、…

《云原生安全攻防》-- 构建云原生攻防场景

在本节课程中,我们将学习云原生攻防场景的构建。为了研究云原生安全攻击案例,我们需要搭建一个云原生攻击测试环境,以便进行攻防研究和攻击手法的复现。 在这个课程中,我们将学习以下内容: 构建云原生攻防场景&#xf…

求两个数的差值

描述 根据输入信号a,b的大小关系,求解两个数的差值:输入信号a,b为8bit位宽的无符号数。如果a>b,则输出a-b,如果a≤b,则输出b-a。 接口信号图如下: 使用Verilog HDL实现以上功能并编写testbench验证。 …

高校普法|基于SSM+vue的高校普法系统的设计与实现(源码+数据库+文档)

高校普法系统 目录 基于SSM+vue的高校普法系统的设计与实现 一、前言 二、系统设计 三、系统功能设计 1系统功能模块 2管理员功能模块 3律师功能模块 4学生功能模块 四、数据库设计 五、核心代码 六、论文参考 七、最新计算机毕设选题推荐 八、源码获…

国标GB28181协议EasyGBS视频监控云平台端口正常却不能播放,是什么原因?

国标视频云服务EasyGBS支持设备/平台通过国标GB28181协议注册接入,并能实现视频的实时监控直播、录像、检索与回看、语音对讲、云存储、告警、平台级联等功能。平台部署简单、可拓展性强,支持将接入的视频流进行全终端、全平台分发,分发的视频…

号外!IP SSL证书申请只需十分钟!

IP SSL证书是一种专为IP地址设计的SSL证书,它使得基于IP地址的网站或服务能够实现HTTPS加密,确保数据在传输过程中的安全性和完整性。以下是关于IP SSL证书的一些技术性要点和申请流程概述: 一、IP SSL证书技术要点 1、适用场景&#xff1a…

ASP.NET Web Api 如何使用 Swagger 管理 API

前言 Swagger 是一个开源的框架,支持 OpenAPI 规范,可以根据 API 规范自动生成美观的、易于浏览的 API 文档页面,包括请求参数、响应示例等信息,并且,Swagger UI 提供了一个交互式的界面,可以帮助我们快速…