模型部署 - onnx 的导出和分析 -(1) - PyTorch 导出 ONNX - 学习记录

onnx 的导出和分析

  • 一、PyTorch 导出 ONNX 的方法
    • 1.1、一个简单的例子 -- 将线性模型转成 onnx
    • 1.2、导出多个输出头的模型
    • 1.3、导出含有动态维度的模型
  • 二、pytorch 导出 onnx 不成功的时候如何解决
    • 2.1、修改 opset 的版本
    • 2.2、替换 pytorch 中的算子组合
    • 2.3、在 pytorch 登记( 注册 ) onnx 中某些算子
      • 2.3.1、注册方法一
      • 2.3.2、注册方法二
    • 2.4、直接修改 onnx,创建 plugin

一、PyTorch 导出 ONNX 的方法

1.1、一个简单的例子 – 将线性模型转成 onnx

首先我们用 pytorch 定义一个线性模型,nn.Linear : 线性层执行的操作是 y = x * W^T + b,其中 x 是输入,W 是权重,b 是偏置。(实际上就是一个矩阵乘法)

class Model(torch.nn.Module):def __init__(self, in_features, out_features, weights, bias=False):super().__init__()self.linear = nn.Linear(in_features, out_features, bias)with torch.no_grad():self.linear.weight.copy_(weights)def forward(self, x):x = self.linear(x)return x

然后我们再定义一个函数,用于导出 onnx

def export_onnx():input   = torch.zeros(1, 1, 1, 4)weights = torch.tensor([[1, 2, 3, 4],[2, 3, 4, 5],[3, 4, 5, 6]],dtype=torch.float32)model   = Model(4, 3, weights)model.eval() #添加eval防止权重继续更新torch.onnx.export(model         = model, args          = (input,),f             = "model.onnx",input_names   = ["input0"],output_names  = ["output0"],opset_version = 12)print("Finished onnx export")

可以看到,这里面的关键在函数 torch.onnx.export(),这是 pytorch 导出 onnx 的基本方式,这个函数的参数有很多,但只要一些基本的参数即可导出模型,下面是一些基本参数的定义:

  • model (torch.nn.Module): 需要导出的PyTorch模型
  • args (tuple or Tensor): 一个元组,其中包含传递给模型的输入张量
  • f (str): 要保存导出模型的文件路径。
  • input_names (list of str): 输入节点的名字的列表
  • output_names (list of str): 输出节点的名字的列表
  • opset_version (int): 用于导出模型的 ONNX 操作集版本

最后我们完整的运行一下代码:

import torch
import torch.nn as nn
import torch.onnxclass Model(torch.nn.Module):def __init__(self, in_features, out_features, weights, bias=False):super().__init__()self.linear = nn.Linear(in_features, out_features, bias)with torch.no_grad():self.linear.weight.copy_(weights)def forward(self, x):x = self.linear(x)return xdef export_onnx():input   = torch.zeros(1, 1, 1, 4)weights = torch.tensor([[1, 2, 3, 4],[2, 3, 4, 5],[3, 4, 5, 6]],dtype=torch.float32)model   = Model(4, 3, weights)model.eval() #添加eval防止权重继续更新torch.onnx.export(model         = model, args          = (input,),f             = "model.onnx",input_names   = ["input0"],output_names  = ["output0"],opset_version = 12)print("Finished onnx export")if __name__ == "__main__":export_onnx()

导出模型后,我们用 netron 查看模型,在终端输入

netron model.onnx

在这里插入图片描述

1.2、导出多个输出头的模型

第一步:定义一个多输出的模型:

class Model(torch.nn.Module):def __init__(self, in_features, out_features, weights1, weights2, bias=False):super().__init__()self.linear1 = nn.Linear(in_features, out_features, bias)self.linear2 = nn.Linear(in_features, out_features, bias)with torch.no_grad():self.linear1.weight.copy_(weights1)self.linear2.weight.copy_(weights2)def forward(self, x):x1 = self.linear1(x)x2 = self.linear2(x)return x1, x2

第二步:编写导出 onnx 的函数

def export_onnx():input    = torch.zeros(1, 1, 1, 4)weights1 = torch.tensor([[1, 2, 3, 4],[2, 3, 4, 5],[3, 4, 5, 6]],dtype=torch.float32)weights2 = torch.tensor([[2, 3, 4, 5],[3, 4, 5, 6],[4, 5, 6, 7]],dtype=torch.float32)model   = Model(4, 3, weights1, weights2)model.eval() #添加eval防止权重继续更新torch.onnx.export(model         = model, args          = (input,),f             = "model.onnx",input_names   = ["input0"],output_names  = ["output0", "output1"],opset_version = 12)print("Finished onnx export")

可以看到,和例 1.1 不一样的地方是 torch.onnx.export 的 output_names
例1.1:output_names = [“output0”]
例1.2:output_names = [“output0”, “output1”]

运行一下完整代码:

import torch
import torch.nn as nn
import torch.onnxclass Model(torch.nn.Module):def __init__(self, in_features, out_features, weights1, weights2, bias=False):super().__init__()self.linear1 = nn.Linear(in_features, out_features, bias)self.linear2 = nn.Linear(in_features, out_features, bias)with torch.no_grad():self.linear1.weight.copy_(weights1)self.linear2.weight.copy_(weights2)def forward(self, x):x1 = self.linear1(x)x2 = self.linear2(x)return x1, x2def export_onnx():input    = torch.zeros(1, 1, 1, 4)weights1 = torch.tensor([[1, 2, 3, 4],[2, 3, 4, 5],[3, 4, 5, 6]],dtype=torch.float32)weights2 = torch.tensor([[2, 3, 4, 5],[3, 4, 5, 6],[4, 5, 6, 7]],dtype=torch.float32)model   = Model(4, 3, weights1, weights2)model.eval() #添加eval防止权重继续更新torch.onnx.export(model         = model, args          = (input,),f             = "model.onnx",input_names   = ["input0"],output_names  = ["output0", "output1"],opset_version = 12)print("Finished onnx export")if __name__ == "__main__":export_onnx()

用 netron 查看模型,结果如下,模型多出了一个输出结果
在这里插入图片描述

1.3、导出含有动态维度的模型

完整运行代码如下:

import torch
import torch.nn as nn
import torch.onnxclass Model(torch.nn.Module):def __init__(self, in_features, out_features, weights, bias=False):super().__init__()self.linear = nn.Linear(in_features, out_features, bias)with torch.no_grad():self.linear.weight.copy_(weights)def forward(self, x):x = self.linear(x)return xdef export_onnx():input   = torch.zeros(1, 1, 1, 4)weights = torch.tensor([[1, 2, 3, 4],[2, 3, 4, 5],[3, 4, 5, 6]],dtype=torch.float32)model   = Model(4, 3, weights)model.eval() #添加eval防止权重继续更新torch.onnx.export(model         = model, args          = (input,),f             = "model.onnx",input_names   = ["input0"],output_names  = ["output0"],dynamic_axes  = {'input0':  {0: 'batch'},'output0': {0: 'batch'}},opset_version = 12)print("Finished onnx export")if __name__ == "__main__":export_onnx()

可以看到,比例 1.1 多了一行 torch.onnx.export 的 dynamic_axes 。我们可以用 dynamic_axes 来指定动态维度,其中 'input0': {0: 'batch'} 中的 0 表示在第 0 维度上的元素是动态的,这里取名为 ‘batch’

用 netron 查看模型:
在这里插入图片描述
可以看到相对于例1.1,他的维度 0 变成了动态的,并且名为 ‘batch’

二、pytorch 导出 onnx 不成功的时候如何解决

上面是 onnx 可以直接被导出的情况,是因为对应的 pytorch 和 onnx 版本都有相应支持的算子在里面。但是有些时候,我们不能顺利的导出 onnx,下面记录一下常见的解决思路 。

2.1、修改 opset 的版本

这是首先应该考虑的思路,因为有可能只是版本过低然后有些算子还不支持,所以考虑提高 opset 的版本

比如下面的这个报错,提示当前 onnx 的 opset 版本不支持这个算子,那我们可以去官方手册搜索一下是否在高的版本支持了这个算子
在这里插入图片描述

官方手册地址:https://github.com/onnx/onnx/blob/main/docs/Operators.md

在这里插入图片描述
又比如说 Acosh 这个算子,在 since version 9 才开始支持,那我们用 7 的时候就是不合适的,升级 opset 版本即可

2.2、替换 pytorch 中的算子组合

有些时候 pytorch 中的一些算子操作在 onnx 中并没有,那我们可以把这些算子替换成 onnx 支持的算子

2.3、在 pytorch 登记( 注册 ) onnx 中某些算子

有些算子在 onnx 中是有的,但是在 pytorch 中没被登记,则需要注册一下
比如下面这个案例,我们想要导出 asinh 这个算子的模型

import torch
import torch.onnxclass Model(torch.nn.Module):def __init__(self):super().__init__()def forward(self, x):x = torch.asinh(x)return xdef export_norm_onnx():input   = torch.rand(1, 5)model   = Model()model.eval()file    = "asinh.onnx"torch.onnx.export(model         = model, args          = (input,),f             = file,input_names   = ["input0"],output_names  = ["output0"],opset_version = 9)print("Finished normal onnx export")if __name__ == "__main__":export_norm_onnx()

但是报错,提示 opset_version = 9 不支持这个算子
在这里插入图片描述

但是我们打开官方手册去搜索发现 asinh 在 version 9 又是支持的
在这里插入图片描述
这里的问题是 PyTorch 与 onnx 之间没有建立 asinh 的映射 (没有搭建桥梁),所以我们编写一个注册代码,来手动注册一下这个算子

2.3.1、注册方法一

完整代码如下:

import torch
import torch.onnx
import onnxruntime
from torch.onnx import register_custom_op_symbolicdef asinh_symbolic(g, input, *, out=None):return g.op("Asinh", input)
register_custom_op_symbolic('aten::asinh', asinh_symbolic, 12)class Model(torch.nn.Module):def __init__(self):super().__init__()def forward(self, x):x = torch.asinh(x)return xdef validate_onnx():input = torch.rand(1, 5)# PyTorch的推理model = Model()x     = model(input)print("result from Pytorch is :", x)# onnxruntime的推理sess  = onnxruntime.InferenceSession('asinh.onnx')x     = sess.run(None, {'input0': input.numpy()})print("result from onnx is:    ", x)def export_norm_onnx():input   = torch.rand(1, 5)model   = Model()model.eval()file    = "asinh.onnx"torch.onnx.export(model         = model, args          = (input,),f             = file,input_names   = ["input0"],output_names  = ["output0"],opset_version = 12)print("Finished normal onnx export")if __name__ == "__main__":export_norm_onnx()# 自定义完onnx以后必须要进行一下验证validate_onnx()

这段代码的关键在于 算子的注册:

1、定义 asinh_symbolic 函数

def asinh_symbolic(g, input, *, out=None):return g.op("Asinh", input)
  1. 函数必须是 asinh_symbolic 这个名字
  2. g: 就是 graph,计算图 (在计算图中添加onnx算子)
  3. input :symblic的参数需要与Pytorch的asinh接口函数的参数对齐
    (def asinh( input: Tensor, *, out: Optional[Tensor]=None) -> Tensor: … )
  4. 符号函数内部调用 g.op, 为 onnx 计算图添加 Asinh 算子
  5. g.op中的第一个参数是onnx中的算子名字: Asinh

2、使用 register_custom_op_symbolic 函数

register_custom_op_symbolic('aten::asinh', asinh_symbolic, 12)
  1. aten 是"a Tensor Library"的缩写,是一个实现张量运算的C++库
  2. asinh 是在名为 aten 的一个c++命名空间下进行实现的
  3. 将 asinh_symbolic 这个符号函数,与PyTorch的 asinh 算子绑定
  4. register_op 中的第一个参数是PyTorch中的算子名字: aten::asinh
  5. 最后一个参数表示从第几个 opset 开始支持(可自己设置)

3、自定义完 onnx 以后必须要进行一下验证,可使用 onnxruntime

2.3.2、注册方法二

import torch
import torch.onnx
import onnxruntime
import functools
from torch.onnx import register_custom_op_symbolic
from torch.onnx._internal import registration_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=9)@_onnx_symbolic('aten::asinh')
def asinh_symbolic(g, input, *, out=None):return g.op("Asinh", input)class Model(torch.nn.Module):def __init__(self):super().__init__()def forward(self, x):x = torch.asinh(x)return xdef validate_onnx():input = torch.rand(1, 5)# PyTorch的推理model = Model()x     = model(input)print("result from Pytorch is :", x)# onnxruntime的推理sess  = onnxruntime.InferenceSession('asinh2.onnx')x     = sess.run(None, {'input0': input.numpy()})print("result from onnx is:    ", x)def export_norm_onnx():input   = torch.rand(1, 5)model   = Model()model.eval()file    = "asinh2.onnx"torch.onnx.export(model         = model, args          = (input,),f             = file,input_names   = ["input0"],output_names  = ["output0"],opset_version = 12)print("Finished normal onnx export")if __name__ == "__main__":export_norm_onnx()# 自定义完onnx以后必须要进行一下验证validate_onnx()

与上面例子不同的是,这个注册方式跟底层文件的写法是一样的(文件在虚拟环境中的 torch/onnx/symbolic_opset*.py )

通过torch._internal 中的 registration 来注册这个算子,让这个算子可以与底层C++实现的 aten::asinh 绑定

_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=9)
@_onnx_symbolic('aten::asinh')
def asinh_symbolic(g, input, *, out=None):return g.op("Asinh", input)

2.4、直接修改 onnx,创建 plugin

直接手动创建一个 onnx (这是一个思路,会在后续博客进行总结记录)

参考链接

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/505824.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

查看php版本

查看PHP版本有多种方法,下面列出几种常见的: 命令行方式: 在命令行中输入以下命令以显示PHP的完整版本信息: php -v 如果系统中存在多个PHP版本,并且你想检查特定版本,可以指定该版本的路径执行相同命令。 …

Linux系统管理:虚拟机 Kali Linux 安装

目录 一、理论 1.Kali Linux 二、实验 1.虚拟机Kali Linux安装准备阶段 2.安装Kali Linux 2. Kali Linux 更换国内源 3. Kali Linux 设置固定IP 4. Kali Linux 开启SSH远程连接 5. MobaXterm远程连接 Kali Linux 三、问题 1.apt 命令 取代哪些 apt-get命令 一、理论…

使用Spark探索数据

需求分析 使用Spark来探索数据是一种高效处理大规模数据的方法,需要对数据进行加载、清洗和转换,选择合适的Spark组件进行数据处理和分析。需求分析包括确定数据分析的目的和问题、选择合适的Spark应用程序和算法、优化数据处理流程和性能、可视化和解释…

Mysql学习之MVCC解决读写问题

多版本并发控制 什么是MVCC MVCC (Multiversion Concurrency Control)多版本并发控制。顾名思义,MVCC是通过数据行的多个版本管理来实现数据库的并发控制。这项技术使得在InnoDB的事务隔离级别下执行一致性读操作有了保证。换言之&#xff0…

C++的内联函数

目录 前言 内联函数 为什么声明和定义分离 为什么声明和定义分离后不出错 为什么内联函数不支持声明和定义分离 为什么内联函数支持声明和定义不分离 坚持声明和定义不分离的解决方法 static修饰函数 inline修饰函数 结论 声明和定义不分离的应用场景 前言 在C语言…

二次元风格地址发布页源码

二次元风格地址发布页源码,源码由HTMLCSSJS组成,记事本打开源码文件可以进行内容文字之类的修改,双击html文件可以本地运行效果,也可以上传到服务器里面,重定向这个界面 下载地址 https://www.qqmu.com/2347.html

C3_W2_Collaborative_RecSys_Assignment_吴恩达_中英_Pytorch

Practice lab: Collaborative Filtering Recommender Systems(实践实验室:协同过滤推荐系统) In this exercise, you will implement collaborative filtering to build a recommender system for movies. 在本次实验中,你将实现协同过滤来构建一个电影推荐系统。 …

在什么时候企业档案才会发生调整

档案在企业中通常会调整在以下几个时刻: 1. 入职时:员工入职时,企业会要求员工提供个人档案,包括身份证件、学历证明、工作经历等相关文件。 2. 离职时:员工离职时,企业会整理员工的离职档案,包…

Spring重点记录

文章目录 1.Spring的组成2.Spring优点3.IOC理论推导4.IOC本质5.IOC实现:xml或者注解或者自动装配(零配置)。6.hellospring6.1beans.xml的结构为:6.2.Spring容器6.3对象的创建和控制反转 7.IOC创建对象方式7.1以有参构造的方式创建…

SpringCloud-MQ消息队列

一、消息队列介绍 MQ (MessageQueue) ,中文是消息队列,字面来看就是存放消息的队列。也就是事件驱动架构中的Broker。消息队列是一种基于生产者-消费者模型的通信方式,通过在消息队列中存放和传递消息,实现了不同组件、服务或系统…

鸿蒙Harmony应用开发—ArkTS声明式开发(通用属性:尺寸设置)

用于设置组件的宽高、边距。 说明: 从API Version 7开始支持。后续版本如有新增内容,则采用上角标单独标记该内容的起始版本。 width width(value: Length) 设置组件自身的宽度,缺省时使用元素自身内容需要的宽度。若子组件的宽大于父组件的…

电子电气架构——汽车以太网诊断路由汇总

电子电气架构——汽车以太网诊断路由汇总 我是穿拖鞋的汉子,魔都中坚持长期主义的工程师。 老规矩,分享一段喜欢的文字,避免自己成为高知识低文化的工程师: 人们会在生活中不断攻击你。他们的主要武器是向你灌输对自己的怀疑:你的价值、你的能力、你的潜力。他们往往会将…