导出谷歌gemma模型为ONNX

参考代码如下(从GitHub - luchangli03/export_llama_to_onnx: export llama to onnx修改而来,后面会合入进去)

模型权重链接参考:

https://huggingface.co/google/gemma-2b-it

可以对modeling_gemma.py进行一些修改(transformers升级为最新版本内置该模型代码),从而提升导出的onnx性能:

1,GemmaForCausalLM中原始的logits计算为:

        hidden_states = outputs[0]logits = self.lm_head(hidden_states)

修改为:

        hidden_states = outputs[0]hidden_states = hidden_states[:,-1:,:]logits = self.lm_head(hidden_states)

这样使得降低prefill阶段lm_head的计算量。

2,模型使用了GemmaSdpaAttention,导出的onnx模型从一个很大的张量中索引向量仅仅用作attention mask:

causal_mask = attention_mask
if attention_mask is not None and cache_position is not None:causal_mask = causal_mask[:, :, cache_position, : key_states.shape[-2]]

这里即增加了存储又增加了计算。实际上可以直接把扩展后的attention mask作为onnx输入传入进来,从而完全消除这个存储和计算。

不知为何很多模型(例如千问等)都输入一个[1, seq_len]的向量,然后内部扩展为一个[1,1, seq_len, sumN]的mask,这些操作都可以直接替换为模型直接采用[1,1, seq_len, sumN]的mask输入。

这里对modeling_gemma.py修改方法为:

class GemmaModel(GemmaPreTrainedModel):def forward(# causal_mask = self._update_causal_mask(attention_mask, inputs_embeds)causal_mask = attention_maskclass GemmaSdpaAttention(GemmaAttention):def forward(# if attention_mask is not None and cache_position is not None:#     causal_mask = causal_mask[:, :, cache_position, : key_states.shape[-2]]

模型导出代码(进行了上述修改,如果不想修改的话,修改下这里面的atten mask的shape,dtype即可):

import os
import argparse
import torch
from torch import nn
from transformers import AutoModelForCausalLM, AutoTokenizerclass LLMForCausalLMWrapper(nn.Module):def __init__(self, model, config, args):super().__init__()self.model = modelself.config = configself.args = argsdef forward(self,input_ids,attention_mask,position_ids,past_key_values,output_attentions=False,output_hidden_states=False,use_cache=True,):"""Note: you can modify modeling_gemma.py to make the converted model more efficient:hidden_states = outputs[0]hidden_states = hidden_states[:,-1:,:]logits = self.lm_head(hidden_states)"""outputs = self.model(input_ids=input_ids,attention_mask=attention_mask,position_ids=position_ids,past_key_values=past_key_values,inputs_embeds=None,use_cache=True,)logits = outputs.logitskv_caches_out = []for past_kv in outputs.past_key_values:kv_caches_out.extend(past_kv)topk_outputs = []if self.args.add_topk_warper > 0:logging.warning("add topk to glm model")if self.args.topk < 0:raise ValueError("topk {} is invalid")topk_outputs = torch.topk(logits, k=self.args.topk, dim=-1)return logits, *kv_caches_out, *topk_outputsdef export_llm_to_single_onnx(model, config, dtype, args, model_name):llama_model_wrapper = LLMForCausalLMWrapper(model, config, args)onnx_file_name = os.path.join(args.out_dir, f"{model_name}.onnx")layer_num = len(model.model.layers)hidden_size = config.hidden_sizehead_num = config.num_attention_headshead_dim = config.head_dimbatch = 1N = 1sumN = 32lastSum = sumN - Ninput_ids_shape = [batch, N]input_ids = torch.ones(input_ids_shape, dtype=torch.int64).to(args.device)# Note: orig atten_mask shape is [1, sumN]attention_mask = torch.randn([batch, 1, N, sumN], dtype=dtype).to(args.device)position_ids = torch.ones([batch, N], dtype=torch.int64).to(args.device)in_names = ["input_ids", "attention_mask", "position_ids"]dynamic_axes = {'input_ids': {1: 'N', },'attention_mask': {2: 'N', 3: 'sumN'},"position_ids": {1: 'N', },}if args.dyn_batch:dynamic_axes['input_ids'][0] = "batch"dynamic_axes['attention_mask'][0] = "batch"dynamic_axes['position_ids'][0] = "batch"kv_caches_in = []out_names = ["lm_logits"]kv_cache_in_shape = [1, 1, lastSum, head_dim]kv_cache_dyn_axes = {2: "sumN-N"}if args.dyn_batch:kv_cache_dyn_axes[0] = "batch"past_key_values = []for i in range(layer_num):past_key_in = torch.randn(kv_cache_in_shape, dtype=dtype).to(args.device)past_value_in = torch.randn(kv_cache_in_shape, dtype=dtype).to(args.device)kv_caches_in.extend([past_key_in, past_value_in])in_names.extend([f"past_key_in{i}", f"past_value_in{i}"])out_names.extend([f"past_key{i}", f"past_value{i}"])dynamic_axes[f"past_key_in{i}"] = kv_cache_dyn_axesdynamic_axes[f"past_value_in{i}"] = kv_cache_dyn_axespast_key_values.append((past_key_in, past_value_in))input_datas = (input_ids, attention_mask, position_ids, past_key_values)torch.onnx.export(llama_model_wrapper,input_datas,onnx_file_name,opset_version=args.opset,do_constant_folding=True,input_names=in_names,output_names=out_names,dynamic_axes=dynamic_axes,)def export_llama(args):device = args.devicedtype_map = {"float32": torch.float32,"float16": torch.float16,"bfloat16": torch.bfloat16,}dtype = dtype_map[args.dtype]print(f"begin load model from {args.model_path}")model = AutoModelForCausalLM.from_pretrained(args.model_path, device_map=device, torch_dtype=dtype, trust_remote_code=True).eval()# model.model.layers = model.model.layers[:1]  # only export one layer for debugprint(f"finish load model from {args.model_path}")config = model.configprint("config:", config)print(f"begin export llm")export_llm_to_single_onnx(model, config, dtype, args, "llm_onnx")if __name__ == "__main__":parser = argparse.ArgumentParser(description='export llm',)parser.add_argument('-m', '--model_path', required=True, type=str)parser.add_argument('-o', '--out_dir', required=False, type=str, default="")parser.add_argument('--opset', required=False, type=int, default=15)parser.add_argument('-d', '--device', required=False, type=str, choices=["cpu", "cuda"], default="cuda")parser.add_argument('-p', '--dtype', required=False, type=str,choices=["float32", "float16", "bfloat16"], default="float16")parser.add_argument('--add_topk_warper', required=False, type=int, default=0)parser.add_argument('--topk', required=False, type=int, default=4)parser.add_argument('--dyn_batch', action='store_true')args = parser.parse_args()export_llama(args)

导出的onnx文件onnxsim:

GitHub - luchangli03/onnxsim_large_model: simplify >2GB large onnx model

导出的onnx模型推理示例(依赖文件在GitHub - luchangli03/export_llama_to_onnx: export llama to onnx)

import numpy as np
from onnx_rt_utils import OnnxRuntimeModel, get_random_data
from sample_utils import sample_topk
from transformers import AutoTokenizerdef prepare_kv_cache_round0(glm_model_inputs, layer_num, lastSum):"""only used at the first timein round 0, actually the lastSum is 0, thus past_key_in, past_value_in are empty tensor"""for i in range(layer_num):past_key_in = get_random_data([1, 1, lastSum, 256], "float16")past_value_in = get_random_data([1, 1, lastSum, 256], "float16")past_key_in_name = f"past_key_in{i}"past_value_in_name = f"past_value_in{i}"glm_model_inputs[past_key_in_name] = past_key_inglm_model_inputs[past_value_in_name] = past_value_inreturn glm_model_inputsdef prepare_kv_cache_from_outputs(glm_model_inputs, decoder_outputs, layer_num):offset = 1for i in range(layer_num):past_key_in_name = f"past_key_in{i}"past_value_in_name = f"past_value_in{i}"glm_model_inputs[past_key_in_name] = decoder_outputs[offset + i * 2]glm_model_inputs[past_value_in_name] = decoder_outputs[offset + i * 2 + 1]return glm_model_inputsdef get_atten_mask(N,  sumN,  padded_len):attention_mask = np.zeros(shape=[N * padded_len], dtype="float16")pad_num = padded_len - sumNif (N == sumN):for i in range(N):mask_num = N - 1 - i + pad_numstart = padded_len - mask_numfor j in range(start, padded_len):attention_mask[i * padded_len + j] = -65504else:if (N != 1):raise ValueError("N is not 1")lastSum = sumN - Nfor i in range(pad_num):attention_mask[lastSum + i] = -65504attention_mask = attention_mask.reshape([N, padded_len])return attention_mask# all decoder layer num
layer_num = 18
eos_token_id = 2pt_model_path = r"E:\test_models\llama\gemma-2b-it"
onnx_model_path = "llm_onnx.onnx"prompt = "Write me a poem about Machine Learning."
tokenizer = AutoTokenizer.from_pretrained(pt_model_path, trust_remote_code=True)
input_ids = tokenizer(prompt)['input_ids']print(input_ids)input_ids = np.array(input_ids).reshape([1, -1]).astype("int64")N = input_ids.shape[1]
sumN = N
lastSum = sumN - N
print("N:", N, sumN, lastSum)position_ids = np.arange(sumN).reshape([1, -1]).astype("int64")input_ids = input_ids.astype("int64")
position_ids = position_ids.astype("int64")glm_model = OnnxRuntimeModel(onnx_model_path)max_seq = 32glm_model_inputs = {}gen_tokens = []for i in range(max_seq):print("input_ids:", input_ids)print("position_ids:", position_ids)attention_mask = get_atten_mask(N, sumN, padded_len=sumN).astype("float16")print("attention_mask:", attention_mask)attention_mask = attention_mask.reshape([1, 1, N, sumN])glm_model_inputs["input_ids"] = input_idsglm_model_inputs["attention_mask"] = attention_maskglm_model_inputs["position_ids"] = position_idsif i == 0:glm_model_inputs = prepare_kv_cache_round0(glm_model_inputs, layer_num, lastSum)glm_model_outputs = glm_model(**glm_model_inputs)lm_logits = glm_model_outputs[0]print("lm_logits:", lm_logits)next_token = sample_topk(lm_logits, topk=1)gen_tokens.append(next_token)print("next_token:", next_token)if next_token == eos_token_id:breakinput_ids = np.array([next_token]).astype("int64").reshape([-1, 1])position_ids = np.array([sumN]).astype("int64").reshape([-1, 1])N = 1sumN += 1prepare_kv_cache_from_outputs(glm_model_inputs, glm_model_outputs, layer_num)gen_text = tokenizer.decode(gen_tokens)
print("Q:", prompt)
print("A:", gen_text)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/522763.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

YOLOv8创新改进:SPPF创新涨点篇 | SPPELAN:SPP创新结合ELAN ,效果优于SPP、SPPF| YOLOv9

💡💡💡本文独家改进:新颖SPPF创新涨点改进,SPP创新结合ELAN,来自于YOLOv9,助力YOLOv8,将SPPELAN代替原始的SPPF SPPELAN结构图如下: 💡💡💡在多个私有数据集和公开数据集VisDrone2019、PASCAL VOC实现涨点 收录 YOLOv8原创自研 https://blog.csdn.net/m0_…

Android使用WebView打开内嵌H5网页

Android打开外部网页链接请参考上一篇文章 https://public.blog.csdn.net/article/details/136384559 继上篇&#xff0c;新建assets文章夹&#xff0c;将H5的网页资源放到此文件夹下 把H5的资源文件都拷进来 这个时候&#xff0c;将添加打开本地网页的代码&#xff1a; //打…

JVM-垃圾收集底层算法实现

三色标记 背景描述 在并发标记的过程中&#xff0c;因为标记期间应用线程还在继续跑&#xff0c;对象间的引用可能发生变化&#xff0c;多标和漏标的情况就有可能发生。 如何解决上面的问题&#xff1f; 引入“三色标记” 意思就是&#xff0c;把Gcroots可达性分析遍历对象过程…

Matlab|基于目标级联法的微网群多主体分布式优化调度

目录 主要内容 1.1 上层微网群模型 1.2 下层微网模型 部分程序 实现效果 下载链接 主要内容 本文复现《基于目标级联法的微网群多主体分布式优化调度》文献的目标级联部分&#xff0c; 建立微网群系统的两级递阶优化调度模型: 上层是微网群能量调度中心优化调度…

AI 应用之路:质疑汤姆猫,成为汤姆猫,超越汤姆猫

过去一年&#xff0c;我对 AI 应用的看法经历了这样一个过程&#xff1a;质疑汤姆猫&#xff0c;理解汤姆猫&#xff0c;成为汤姆猫&#xff0c;超越汤姆猫。 什么是汤姆猫&#xff1f;汤姆猫是 2010 年移动互联网早期的一款应用&#xff0c;迅速走红&#xff0c;又淡出视野。…

React 事件机制原理

相关问题 React 合成事件与原生 DOM 事件的区别React 如何注册和触发事件React 事件如何解决浏览器兼容问题 回答关键点 React 的事件处理机制可以分为两个阶段&#xff1a;初始化渲染时在 root 节点上注册原生事件&#xff1b;原生事件触发时模拟捕获、目标和冒泡阶段派发合…

c++的STL(2)-- vector容器

目录 1. 默认构造 代码: 相关知识点: 2. 有参构造函数 以及 使用{}初始化对象 代码: 相关知识点: 3. vector容器在尾部添加和删除元素 代码: 使用push_back()和pop_back()进行尾部元素的添加和删除 相关知识点: 代码: 使用emplace_back在尾部添…

Unity插件之天气系统UniStorm

首先呢&#xff0c;它是一款强大的动态昼夜天气系统&#xff0c;能够以较快的帧速率创建AAA级动态生成的天气、照明和天空&#xff0c;并且具有300多个可定制的组件&#xff0c;允许用户创建任何可以想象的环境。 第一步&#xff1a;他需要两个物体Camera摄像机、Player播放器…

knife4j生产环境禁止打开页面

Knife4j是一个集Swagger2 和 OpenAPI3为一体的增强解决方案&#xff0c;官网地址&#xff1a;Knife4j 集Swagger2及OpenAPI3为一体的增强解决方案. | Knife4j 考虑到安全性问题&#xff0c;在实际服务部署到生产环境后就需要禁用到swagger页面的展示&#xff0c;这个时候只需…

报错:ModuleNotFoundError: No module named ‘tensorrt’

写在前面 我安装了tensorRT,在运行它自带的模型都没问题。 但是在代码中import tensorrt就报错&#xff1a; ModuleNotFoundError: No module named ‘tensorrt’。 网上搜了一大堆&#xff0c;发现是没有在自己的python环境下安装。 所以特意写这篇文章记录一下。 在进行下一…

继深圳后,重庆与鸿蒙展开原生应用开发合作

截至2023年底&#xff0c;开源鸿蒙开源社区已有250多家生态伙伴加入&#xff0c;开源鸿蒙项目捐赠人达35家&#xff0c;通过开源鸿蒙兼容性测评的伙伴达173个&#xff0c;累计落地230余款商用设备&#xff0c;涵盖金融、教育、智能家居、交通、数字政府、工业、医疗等各领域。 …

C++之创建与使用dll

目录 1、创建dll test.h test.cpp Source.def 2、使用dll testdll.cpp DLL&#xff0c;全称“Dynamic Link Library”&#xff0c;中文名为“动态链接库”&#xff0c;是一种在Windows操作系统中常见的库文件格式。它包含了可以由多个程序同时使用的代码和数据。与静态链接…