使用Pytorch导出自定义ONNX算子

在实际部署模型时有时可能会遇到想用的算子无法导出onnx,但实际部署的框架是支持该算子的。此时可以通过自定义onnx算子的方式导出onnx模型(注:自定义onnx算子导出onnx模型后是无法使用onnxruntime推理的)。下面给出个具体应用中的示例:需要导出pytorch的affine_grid算子,但在pytorch的2.0.1版本中又无法正常导出该算子,故可通过如下自定义算子代码导出。

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Function
from torch.onnx import OperatorExportTypesclass CustomAffineGrid(Function):@staticmethoddef forward(ctx, theta: torch.Tensor, size: torch.Tensor):grid = F.affine_grid(theta=theta, size=size.cpu().tolist())return grid@staticmethoddef symbolic(g: torch.Graph, theta: torch.Tensor, size: torch.Tensor):return g.op("AffineGrid", theta, size)class MyModel(nn.Module):def __init__(self) -> None:super().__init__()def forward(self, x: torch.Tensor, theta: torch.Tensor, size: torch.Tensor):grid = CustomAffineGrid.apply(theta, size)x = F.grid_sample(x, grid=grid, mode="bilinear", padding_mode="zeros")return xdef main():with torch.inference_mode():custum_model = MyModel()x = torch.randn(1, 3, 224, 224)theta = torch.randn(1, 2, 3)size = torch.as_tensor([1, 3, 512, 512])torch.onnx.export(model=custum_model,args=(x, theta, size),f="custom.onnx",input_names=["input0_x", "input1_theta", "input2_size"],output_names=["output"],dynamic_axes={"input0_x": {2: "h0", 3: "w0"},"output": {2: "h1", 3: "w1"}},opset_version=16,operator_export_type=OperatorExportTypes.ONNX_FALLTHROUGH)if __name__ == '__main__':main()

在上面代码中,通过继承torch.autograd.Function父类的方式实现导出自定义算子,继承该父类后需要用户自己实现forward以及symbolic两个静态方法,其中forward方法是在pytorch正常推理时调用的函数,而symbolic方法是在导出onnx时调用的函数。对于forward方法需要按照正常的pytorch语法来实现,其中第一个参数必须是ctx但对于当前导出onnx场景可以不用管它,后面的参数是实际自己传入的参数。对于symbolic方法的第一个必须是g,后面的参数任为实际自己传入的参数,然后通过g.op方法指定具体导出自定义算子的名称,以及输入的参数(注:上面示例中传入的都是Tensor所以可以直接传入,对与非Tensor的参数可见下面一个示例)。最后在使用时直接调用自己实现类的apply方法即可。使用netron打开自己导出的onnx文件,可以看到如下所示网络结构。
在这里插入图片描述

有时按照使用的推理框架导出自定义算子时还需要设置一些参数(非Tensor)那么可以参考如下示例,例如要导出int型的参数k那么可以通过传入k_i来指定,要导出float型的参数scale那么可以通过传入scale_f来指定,要导出string型的参数clockwise那么可以通过传入clockwise_s来指定:

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Function
from torch.onnx import OperatorExportTypesclass CustomRot90AndScale(Function):@staticmethoddef forward(ctx, x: torch.Tensor):x = torch.rot90(x, k=1, dims=(3, 2))  # clockwise 90x *= 1.2return x@staticmethoddef symbolic(g: torch.Graph, x: torch.Tensor):return g.op("Rot90AndScale", x, k_i=1, scale_f=1.2, clockwise_s="yes")class MyModel(nn.Module):def __init__(self) -> None:super().__init__()def forward(self, x: torch.Tensor):return CustomRot90AndScale.apply(x)def main():with torch.inference_mode():custum_model = MyModel()x = torch.randn(1, 3, 224, 224)torch.onnx.export(model=custum_model,args=(x,),f="custom_rot90.onnx",input_names=["input"],output_names=["output"],dynamic_axes={"input": {2: "h0", 3: "w0"},"output": {2: "w0", 3: "h0"}},opset_version=16,operator_export_type=OperatorExportTypes.ONNX_FALLTHROUGH)if __name__ == '__main__':main()

使用netron打开自己导出的onnx文件,可以看到如下所示信息。
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/521041.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

pytorch续写tensorboard

模型训练到一半有 bug 停了,可以 resume 继续炼,本篇给出 pytorch 在 resume 训练时续写 tensorboard 的简例,参考 [1-3],只要保证 writer 接收的 global step 是连着的就行。 Code import numpy as np from torch.utils.tensor…

【节能减排/能耗分析/设备运维】AcrelEMS-Zone园区能源管理系统解决方案

市场规模 智慧园区市场稳步增长,市场规模由2019年的1191亿元增至2021年的1394亿元。智慧园区作为产业升级转型的重要载体,近年来受到国家政策大力支持,行业前景广阔,预计2022年智慧园区市场规模将达1543亿元,2024年市…

如何让JMeter也生成精美详细allure测试报告

(全文约2000字,阅读约需5分钟,首发于公众号:测试开发研习社,欢迎关注) 内容目录: 一、需求 二、思路 三、验证 四、实现 五、优化 六、彩蛋 篇幅较长,建议先收藏后阅读 一、需…

LCR 168. 丑数

解题思路&#xff1a; class Solution {public int nthUglyNumber(int n) {int a 0, b 0, c 0;int[] res new int[n];res[0] 1;for(int i 1; i < n; i) {int n2 res[a] * 2, n3 res[b] * 3, n5 res[c] * 5;res[i] Math.min(Math.min(n2, n3), n5);if (res[i] n2)…

345.反转字符串中的元音字母

题目&#xff1a;给你一个字符串 s &#xff0c;仅反转字符串中的所有元音字母&#xff0c;并返回结果字符串。 元音字母包括 a、e、i、o、u&#xff0c;且可能以大小写两种形式出现不止一次。 class Solution {//画图&#xff0c;好理解点public String reverseVowels(String…

kasan排查kernel内存越界示例(linux5.18.11)

参考资料&#xff1a; 1&#xff0c;内核源码目录中的Documentation\dev-tools\kasan.rst 2&#xff0c;KASAN - Kernel Address Sanitizer | Naveen Naidu (naveenaidu.dev) 一、kasan实现原理 KASAN&#xff08;Kernel Address SANitizer&#xff09;是一个动态内存非法访…

51-26 DriveMLM: 多模态大型语言模型与自动驾驶行为规划状态对齐

DriveMLM是来自上海AILab、港中文、商汤、斯坦福、南京大学和清华大学的工作。该模型使用各种传感器(如相机、激光雷达)、驾驶规则和用户指令作为输入&#xff0c;采用多模态LLM对AD系统的行为规划进行建模&#xff0c;做出驾驶决策并提供解释。该模型可以用于闭环自动驾驶&…

蓝桥杯——web(ECharts)

ECharts 初体验 <!DOCTYPE html> <html lang"en"><head><meta charset"UTF-8" /><meta name"viewport" content"widthdevice-width, initial-scale1.0" /><script src"echarts.js">&l…

Docker-完整项目的部署(保姆级教学)

目录 1 手动部署(白雪版) 1.1 创建网络 1.2 MySQL的部署 1.2.1 准备 1.2.2 部署 1.3 Java项目的部署 1.3.1 准备 1.3.1.1 将Java项目打成jar包 1.3.1.2 编写Dockerfile文件 1.3.2 部署 1.3.2.1 将jar包、Dockerfile文件放在linux同一个文件夹下 1.3.2.2 构建镜像 …

【Godot4自学手册】第二十一节掉落金币和收集

这一节我们主要学习敌人死亡后随机掉落金币&#xff0c;主人公可以进行拾取功能。 一、新建金币场景 新建场景&#xff0c;节点选择CharacterBody2D&#xff0c;命名为Coins&#xff0c;将场景保存到Scenes目录下。 1.新建节点 为根节点依次添加CollisionShape2D节点&#…

管理技巧 | 提升团队效能:如何与下属进行有效沟通

本文节选霍格沃兹测试开发学社沟通管理公开课- 某外企PMO Angelia老师的分享 在日常的管理工作中&#xff0c;沟通作为一项基础而关键的技能&#xff0c;往往决定了团队的协作效率和目标达成率。作为一个曾经从基层员工一路成长为管理者的Angelia老师&#xff0c;深知沟通的艺术…

极海APM32F407典型应用——可编程逻辑控制器方案

PLC&#xff08;可编程逻辑控制器&#xff09;作为可控制、执行和监控自动化机器设备的数字运算操作电子系统&#xff0c;广泛应用于楼宇设备控制、水处理、能源、工业自动化等众多领域&#xff0c;并已形成广大的市场规模&#xff0c;随着汽车电子“新四化”发展&#xff0c;将…