参考代码如下(从GitHub - luchangli03/export_llama_to_onnx: export llama to onnx修改而来,后面会合入进去)
模型权重链接参考:
https://huggingface.co/google/gemma-2b-it
可以对modeling_gemma.py进行一些修改(transformers升级为最新版本内置该模型代码),从而提升导出的onnx性能:
1,GemmaForCausalLM中原始的logits计算为:
hidden_states = outputs[0]logits = self.lm_head(hidden_states)
修改为:
hidden_states = outputs[0]hidden_states = hidden_states[:,-1:,:]logits = self.lm_head(hidden_states)
这样使得降低prefill阶段lm_head的计算量。
2,模型使用了GemmaSdpaAttention,导出的onnx模型从一个很大的张量中索引向量仅仅用作attention mask:
causal_mask = attention_mask
if attention_mask is not None and cache_position is not None:causal_mask = causal_mask[:, :, cache_position, : key_states.shape[-2]]
这里即增加了存储又增加了计算。实际上可以直接把扩展后的attention mask作为onnx输入传入进来,从而完全消除这个存储和计算。
不知为何很多模型(例如千问等)都输入一个[1, seq_len]的向量,然后内部扩展为一个[1,1, seq_len, sumN]的mask,这些操作都可以直接替换为模型直接采用[1,1, seq_len, sumN]的mask输入。
这里对modeling_gemma.py修改方法为:
class GemmaModel(GemmaPreTrainedModel):def forward(# causal_mask = self._update_causal_mask(attention_mask, inputs_embeds)causal_mask = attention_maskclass GemmaSdpaAttention(GemmaAttention):def forward(# if attention_mask is not None and cache_position is not None:# causal_mask = causal_mask[:, :, cache_position, : key_states.shape[-2]]
模型导出代码(进行了上述修改,如果不想修改的话,修改下这里面的atten mask的shape,dtype即可):
import os
import argparse
import torch
from torch import nn
from transformers import AutoModelForCausalLM, AutoTokenizerclass LLMForCausalLMWrapper(nn.Module):def __init__(self, model, config, args):super().__init__()self.model = modelself.config = configself.args = argsdef forward(self,input_ids,attention_mask,position_ids,past_key_values,output_attentions=False,output_hidden_states=False,use_cache=True,):"""Note: you can modify modeling_gemma.py to make the converted model more efficient:hidden_states = outputs[0]hidden_states = hidden_states[:,-1:,:]logits = self.lm_head(hidden_states)"""outputs = self.model(input_ids=input_ids,attention_mask=attention_mask,position_ids=position_ids,past_key_values=past_key_values,inputs_embeds=None,use_cache=True,)logits = outputs.logitskv_caches_out = []for past_kv in outputs.past_key_values:kv_caches_out.extend(past_kv)topk_outputs = []if self.args.add_topk_warper > 0:logging.warning("add topk to glm model")if self.args.topk < 0:raise ValueError("topk {} is invalid")topk_outputs = torch.topk(logits, k=self.args.topk, dim=-1)return logits, *kv_caches_out, *topk_outputsdef export_llm_to_single_onnx(model, config, dtype, args, model_name):llama_model_wrapper = LLMForCausalLMWrapper(model, config, args)onnx_file_name = os.path.join(args.out_dir, f"{model_name}.onnx")layer_num = len(model.model.layers)hidden_size = config.hidden_sizehead_num = config.num_attention_headshead_dim = config.head_dimbatch = 1N = 1sumN = 32lastSum = sumN - Ninput_ids_shape = [batch, N]input_ids = torch.ones(input_ids_shape, dtype=torch.int64).to(args.device)# Note: orig atten_mask shape is [1, sumN]attention_mask = torch.randn([batch, 1, N, sumN], dtype=dtype).to(args.device)position_ids = torch.ones([batch, N], dtype=torch.int64).to(args.device)in_names = ["input_ids", "attention_mask", "position_ids"]dynamic_axes = {'input_ids': {1: 'N', },'attention_mask': {2: 'N', 3: 'sumN'},"position_ids": {1: 'N', },}if args.dyn_batch:dynamic_axes['input_ids'][0] = "batch"dynamic_axes['attention_mask'][0] = "batch"dynamic_axes['position_ids'][0] = "batch"kv_caches_in = []out_names = ["lm_logits"]kv_cache_in_shape = [1, 1, lastSum, head_dim]kv_cache_dyn_axes = {2: "sumN-N"}if args.dyn_batch:kv_cache_dyn_axes[0] = "batch"past_key_values = []for i in range(layer_num):past_key_in = torch.randn(kv_cache_in_shape, dtype=dtype).to(args.device)past_value_in = torch.randn(kv_cache_in_shape, dtype=dtype).to(args.device)kv_caches_in.extend([past_key_in, past_value_in])in_names.extend([f"past_key_in{i}", f"past_value_in{i}"])out_names.extend([f"past_key{i}", f"past_value{i}"])dynamic_axes[f"past_key_in{i}"] = kv_cache_dyn_axesdynamic_axes[f"past_value_in{i}"] = kv_cache_dyn_axespast_key_values.append((past_key_in, past_value_in))input_datas = (input_ids, attention_mask, position_ids, past_key_values)torch.onnx.export(llama_model_wrapper,input_datas,onnx_file_name,opset_version=args.opset,do_constant_folding=True,input_names=in_names,output_names=out_names,dynamic_axes=dynamic_axes,)def export_llama(args):device = args.devicedtype_map = {"float32": torch.float32,"float16": torch.float16,"bfloat16": torch.bfloat16,}dtype = dtype_map[args.dtype]print(f"begin load model from {args.model_path}")model = AutoModelForCausalLM.from_pretrained(args.model_path, device_map=device, torch_dtype=dtype, trust_remote_code=True).eval()# model.model.layers = model.model.layers[:1] # only export one layer for debugprint(f"finish load model from {args.model_path}")config = model.configprint("config:", config)print(f"begin export llm")export_llm_to_single_onnx(model, config, dtype, args, "llm_onnx")if __name__ == "__main__":parser = argparse.ArgumentParser(description='export llm',)parser.add_argument('-m', '--model_path', required=True, type=str)parser.add_argument('-o', '--out_dir', required=False, type=str, default="")parser.add_argument('--opset', required=False, type=int, default=15)parser.add_argument('-d', '--device', required=False, type=str, choices=["cpu", "cuda"], default="cuda")parser.add_argument('-p', '--dtype', required=False, type=str,choices=["float32", "float16", "bfloat16"], default="float16")parser.add_argument('--add_topk_warper', required=False, type=int, default=0)parser.add_argument('--topk', required=False, type=int, default=4)parser.add_argument('--dyn_batch', action='store_true')args = parser.parse_args()export_llama(args)
导出的onnx文件onnxsim:
GitHub - luchangli03/onnxsim_large_model: simplify >2GB large onnx model
导出的onnx模型推理示例(依赖文件在GitHub - luchangli03/export_llama_to_onnx: export llama to onnx)
import numpy as np
from onnx_rt_utils import OnnxRuntimeModel, get_random_data
from sample_utils import sample_topk
from transformers import AutoTokenizerdef prepare_kv_cache_round0(glm_model_inputs, layer_num, lastSum):"""only used at the first timein round 0, actually the lastSum is 0, thus past_key_in, past_value_in are empty tensor"""for i in range(layer_num):past_key_in = get_random_data([1, 1, lastSum, 256], "float16")past_value_in = get_random_data([1, 1, lastSum, 256], "float16")past_key_in_name = f"past_key_in{i}"past_value_in_name = f"past_value_in{i}"glm_model_inputs[past_key_in_name] = past_key_inglm_model_inputs[past_value_in_name] = past_value_inreturn glm_model_inputsdef prepare_kv_cache_from_outputs(glm_model_inputs, decoder_outputs, layer_num):offset = 1for i in range(layer_num):past_key_in_name = f"past_key_in{i}"past_value_in_name = f"past_value_in{i}"glm_model_inputs[past_key_in_name] = decoder_outputs[offset + i * 2]glm_model_inputs[past_value_in_name] = decoder_outputs[offset + i * 2 + 1]return glm_model_inputsdef get_atten_mask(N, sumN, padded_len):attention_mask = np.zeros(shape=[N * padded_len], dtype="float16")pad_num = padded_len - sumNif (N == sumN):for i in range(N):mask_num = N - 1 - i + pad_numstart = padded_len - mask_numfor j in range(start, padded_len):attention_mask[i * padded_len + j] = -65504else:if (N != 1):raise ValueError("N is not 1")lastSum = sumN - Nfor i in range(pad_num):attention_mask[lastSum + i] = -65504attention_mask = attention_mask.reshape([N, padded_len])return attention_mask# all decoder layer num
layer_num = 18
eos_token_id = 2pt_model_path = r"E:\test_models\llama\gemma-2b-it"
onnx_model_path = "llm_onnx.onnx"prompt = "Write me a poem about Machine Learning."
tokenizer = AutoTokenizer.from_pretrained(pt_model_path, trust_remote_code=True)
input_ids = tokenizer(prompt)['input_ids']print(input_ids)input_ids = np.array(input_ids).reshape([1, -1]).astype("int64")N = input_ids.shape[1]
sumN = N
lastSum = sumN - N
print("N:", N, sumN, lastSum)position_ids = np.arange(sumN).reshape([1, -1]).astype("int64")input_ids = input_ids.astype("int64")
position_ids = position_ids.astype("int64")glm_model = OnnxRuntimeModel(onnx_model_path)max_seq = 32glm_model_inputs = {}gen_tokens = []for i in range(max_seq):print("input_ids:", input_ids)print("position_ids:", position_ids)attention_mask = get_atten_mask(N, sumN, padded_len=sumN).astype("float16")print("attention_mask:", attention_mask)attention_mask = attention_mask.reshape([1, 1, N, sumN])glm_model_inputs["input_ids"] = input_idsglm_model_inputs["attention_mask"] = attention_maskglm_model_inputs["position_ids"] = position_idsif i == 0:glm_model_inputs = prepare_kv_cache_round0(glm_model_inputs, layer_num, lastSum)glm_model_outputs = glm_model(**glm_model_inputs)lm_logits = glm_model_outputs[0]print("lm_logits:", lm_logits)next_token = sample_topk(lm_logits, topk=1)gen_tokens.append(next_token)print("next_token:", next_token)if next_token == eos_token_id:breakinput_ids = np.array([next_token]).astype("int64").reshape([-1, 1])position_ids = np.array([sumN]).astype("int64").reshape([-1, 1])N = 1sumN += 1prepare_kv_cache_from_outputs(glm_model_inputs, glm_model_outputs, layer_num)gen_text = tokenizer.decode(gen_tokens)
print("Q:", prompt)
print("A:", gen_text)