Pytorch转onnx

pytorch 转 onnx 模型需要函数 torch.onnx.export。

def export(model: Union[torch.nn.Module, torch.jit.ScriptModule, torch.jit.ScriptFunction],args: Union[Tuple[Any, ...], torch.Tensor],f: Union[str, io.BytesIO],export_params: bool = True,verbose: bool = False,training: _C_onnx.TrainingMode = _C_onnx.TrainingMode.EVAL,input_names: Optional[Sequence[str]] = None,output_names: Optional[Sequence[str]] = None,operator_export_type: _C_onnx.OperatorExportTypes = _C_onnx.OperatorExportTypes.ONNX,opset_version: Optional[int] = None,do_constant_folding: bool = True,dynamic_axes: Optional[Union[Mapping[str, Mapping[int, str]], Mapping[str, Sequence[int]]]] = None,keep_initializers_as_inputs: Optional[bool] = None,custom_opsets: Optional[Mapping[str, int]] = None,export_modules_as_functions: Union[bool, Collection[Type[torch.nn.Module]]] = False,
) -> None:

常用参数说明

model——需要导出的pytorch模型
args——模型的输入参数,满足输入层的shape正确即可。
f——输出的onnx模型的位置。例如‘yolov5.onnx’。
export_params——输出模型是否可训练。default=True,表示导出trained model,否则untrained。
verbose——是否打印模型转换信息。default=False。
input_names——输入节点名称。default=None。
output_names——输出节点名称。default=None。
opset_version——算子指令集合
do_constant_folding——是否使用常量折叠,默认即可。default=True。
dynamic_axes——模型的输入输出有时是可变的,如Rnn,或者输出图像的batch可变,可通过该参数设置。如输入层的shape为(b,3,h,w),batch,height,width是可变的,但是chancel是固定三通道

参数说明
ONNX算子文档
ONNX 算子的定义情况,都可以在官方的算子文档中查看
这份文档中最重要的开头的这个算子变更表格。表格的第一列是算子名,第二列是该算子发生变动的算子集版本号,也就是我们之前在torch.onnx.export中提到的opset_version表示的算子集版本号。通过查看算子第一次发生变动的版本号,我们可以知道某个算子是从哪个版本开始支持的;通过查看某算子小于等于opset_version的第一个改动记录,我们可以知道当前算子集版本中该算子的定义规则。
在这里插入图片描述
练习

import torch
import torch.nn as nn
import torch.onnxclass Model(torch.nn.Module):def __init__(self, in_features, out_features, weights, bias=False):super().__init__()self.linear = nn.Linear(in_features, out_features, bias)with torch.no_grad():self.linear.weight.copy_(weights)def forward(self, x):x = self.linear(x)return xdef infer():in_features = torch.tensor([1, 2, 3, 4], dtype=torch.float32)weights = torch.tensor([[1, 2, 3, 4],[2, 3, 4, 5],[3, 4, 5, 6]],dtype=torch.float32)model = Model(4, 3, weights)x = model(in_features)print("result is: ", x)def export_onnx():input   = torch.zeros(1, 1, 1, 4)weights = torch.tensor([[1, 2, 3, 4],[2, 3, 4, 5],[3, 4, 5, 6]],dtype=torch.float32)model   = Model(4, 3, weights)model.eval() #添加eval防止权重继续更新# pytorch导出onnx的方式,参数有很多,也可以支持动态size# 我们先做一些最基本的导出,从netron学习一下导出的onnx都有那些东西torch.onnx.export(model         = model, args          = (input,),f             = "../models/example.onnx",input_names   = ["input0"],output_names  = ["output0"],opset_version = 12)print("Finished onnx export")if __name__ == "__main__":infer()export_onnx()

然后使用netron打开onnx文件,如果没有安装netron,在终端使用pip install netron。
在这里插入图片描述

参考链接
模型部署入门教程(三):PyTorch 转 ONNX 详解

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/596132.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【QT+QGIS跨平台编译】056:【pdal_json_schema+Qt跨平台编译】(一套代码、一套框架,跨平台编译)

点击查看专栏目录 文章目录 一、pdal_json_schema介绍二、pdal下载三、文件分析四、pro文件五、编译实践一、pdal_json_schema介绍 pdal_json_schema 是与 PDAL(Point Data Abstraction Library)相关的 JSON 模式文件。PDAL 是一个用于处理和分析点云数据的开源库。JSON 模式…

DHCP-PXE

Dynamic Host Configuration Protocol 动态主机配置协议 1.Selinux 调试为Permission 防火墙配置 搭建DHCP的主机必须有一个静态地址,提前配置好 安装DHCP软件 服务名为dhcpd DHCP地址分配四次会话, DISCOVERY发现 OFFER 提供 REQUEST 回应 A…

5G网络架构及技术(二):OFDM一

ToDo: 等把这些讲义看完后得单开一个文章整理思维导图   该部分由于内容比较重要,OFDM是5G物理层的基础,但学习时直接跳到5G OFDM去看它的那些参数设置感觉没什么意义,还得从发展的角度进行学习,先从最先用到OFDM的WiFi协议开始…

WCH恒沁单片机-CH32V307学习记录2----FreeRTOS移植

RISC-V 单片机 FreeRTOS 移植 前面用了 5 篇博客详细介绍了 FreeRTOS 在 ARM Cortex-M3 MCU 上是如何运行的。 FreeRTOS从代码层面进行原理分析系列 现在我直接用之前的 RISC-V MCU 开发板子(CH32V307VCT6)再次对 FreeRTOS 进行移植,其实也…

【C语言自定义类型之----结构体,联合体和枚举】

一.结构体 1.结构体类型的声明 srruct tag {nemer-list;//成员列表 }varible-list;//变量列表结构体在声明的时候,可以不完全声明。 例如:描述一个学生 struct stu {char name[20];//名字int age;//年龄char sex[20];//性别 };//分号不能省略2.结构体…

C语言实现快速排序算法

1. 什么是快速排序算法 快速排序的核心思想是通过分治法(Divide and Conquer)来实现排序。 算法的基本步骤是: 1. 选择一个基准值(通常是数组中的某个元素),将数组分成两部分,使得左边的部分所有元素都小于…

Open CASCADE学习|在给定的TopoDS_Shape中查找与特定顶点 V 对应的TopoDS_Edge编号

enum TopAbs_ShapeEnum{TopAbs_COMPOUND,TopAbs_COMPSOLID,TopAbs_SOLID,TopAbs_SHELL,TopAbs_FACE,TopAbs_WIRE,TopAbs_EDGE,TopAbs_VERTEX,TopAbs_SHAPE}; 这段代码定义了一个名为 TopAbs_ShapeEnum 的枚举类型,它包含了表示不同几何形状类型的常量。这些常量通常…

刷题之Leetcode283题(超级详细)

283.移动零 283. 移动零https://leetcode.cn/problems/move-zeroes/ 给定一个数组 nums,编写一个函数将所有 0 移动到数组的末尾,同时保持非零元素的相对顺序。 请注意 ,必须在不复制数组的情况下原地对数组进行操作。 示例 1: 输入: nu…

【了解下Oracle】

🌈个人主页:程序员不想敲代码啊 🏆CSDN优质创作者,CSDN实力新星,CSDN博客专家 👍点赞⭐评论⭐收藏 🤝希望本文对您有所裨益,如有不足之处,欢迎在评论区提出指正,让我们共…

golang 选择排序

学习笔记~ // Author sunwenbo // 2024/4/6 21:49 package mainimport "fmt"/* 选择排序基本介绍选择式排序也属于内部排序法,是从预排序的数据中按指定的规则选出某一元素,经过和其他元素重整,再依原则交换位置后达到…

快速了解FastAPI与Uvicorn是什么?

概念 什么是Uvicorn Python Uvicorn 是一个快速的 ASGI(Asynchronous Server Gateway Interface)服务器,用于构建异步 Web 服务。它基于 asyncio 库,支持高性能的异步请求处理,适用于各种类型的 Web 应用程序。 Uvi…

RUST语言值所有权之内存复制与移动

1.RUST中每个值都有一个所有者,每次只能有一个所有者 String::from函数会为字符串hello分配一块内存 内存示例如下: 在内存分配前调用s1正常输出 在分配s1给s2后调用报错 因为s1分配给s2后,s1的指向自动失效 s1被move到s2 s1自动释放 字符串克隆使用