pytorch导出rot90算子至onnx

如何导出rot90算子至onnx

    • 1 背景描述
    • 2 等价替换
      • 2.1 rot90替换(NCHW)
      • 2.2 rot180替换(NCHW)
      • 2.3 rot270替换(NCHW)
    • 3 rot导出ONNX

1 背景描述

在部署模型时,如果某些模型中或者前后处理中含有rot90算子,但又希望一起和模型导出onnx时,可能会遇到如下错误(当前使用环境pytorch2.0.1opset_version为17):

import torch
import torch.nn as nnclass RotModel(nn.Module):def forward(self, x: torch.Tensor):x = torch.rot90(x, k=1, dims=(2, 3))return xdef main():print("pytorch version:", torch.__version__)model = RotModel()with torch.inference_mode():x = torch.randn(size=(1, 3, 224, 224))torch.onnx.export(model,args=(x,),f="rot90_counterclockwise.onnx",opset_version=17)if __name__ == '__main__':main()

torch.onnx.errors.UnsupportedOperatorError: Exporting the operator ‘aten::rot90’ to ONNX opset version 17 is not supported. Please feel free to request support or submit a pull request on PyTorch GitHub: https://github.com/pytorch/pytorch/issues.

简单的说就是不支持导出该算子,包括在onnx支持的算子文档中也找不到rot90算子,onnx官方github链接:
https://github.com/onnx/onnx


2 等价替换

导不出咋办,那就想想旋转矩阵的原理,以及如何使用现有支持的算子替换。

2.1 rot90替换(NCHW)

废话不多说,rot90度(以逆时针为例)可以使用翻转和转置实现。具体代码如下,使用torch自带的rot90与自己实现的对比,通过torch.equal()来对比两个Tensor是否一致,结果一致,不信自己试试。

import torchdef self_rot90_counterclockwise(x: torch.Tensor):x = x.flip(dims=[3]).permute([0, 1, 3, 2])return xdef main():print("pytorch version:", torch.__version__)with torch.inference_mode():x = torch.randn(size=(1, 3, 224, 224))y0 = torch.rot90(x, k=1, dims=[2, 3])y1 = self_rot90_counterclockwise(x)print(torch.equal(y0, y1))if __name__ == '__main__':main()

2.2 rot180替换(NCHW)

rot180度(以逆时针为例)可以使用翻转实现。具体代码如下:

import torchdef self_rot180_counterclockwise(x: torch.Tensor):x = x.flip(dims=[2, 3])return xdef main():print("pytorch version:", torch.__version__)with torch.inference_mode():x = torch.randn(size=(1, 3, 224, 224))y0 = torch.rot90(x, k=2, dims=[2, 3])y1 = self_rot180_counterclockwise(x)print(torch.equal(y0, y1))if __name__ == '__main__':main()

2.3 rot270替换(NCHW)

rot270度(以逆时针为例)可以使用翻转和转置实现。具体代码如下:

import torchdef self_rot270_counterclockwise(x: torch.Tensor):x = x.flip(dims=[2]).permute([0, 1, 3, 2])return xdef main():print("pytorch version:", torch.__version__)with torch.inference_mode():x = torch.randn(size=(1, 3, 224, 224))y0 = torch.rot90(x, k=3, dims=[2, 3])y1 = self_rot270_counterclockwise(x)print(torch.equal(y0, y1))if __name__ == '__main__':main()

3 rot导出ONNX

这里以rot90度(以逆时针为例)结合刚刚的等价实现来导出ONNX:

import torch
import torch.nn as nnclass RotModel(nn.Module):def forward(self, x: torch.Tensor):# x = torch.rot90(x, k=1, dims=(2, 3))x = x.flip(dims=[3]).permute([0, 1, 3, 2])return xdef main():print("pytorch version:", torch.__version__)model = RotModel()with torch.inference_mode():x = torch.randn(size=(1, 3, 224, 224))torch.onnx.export(model,args=(x,),f="rot90_counterclockwise.onnx",opset_version=17)if __name__ == '__main__':main()

使用netron打开生成的rot90_counterclockwise.onnx文件,如下所示:

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/219196.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【二叉树】oj题

在处理oj题之前我们需要先处理一下之前遗留的问题 在二叉树中寻找为x的节点 BTNode* BinaryTreeFind(BTNode* root, int x) {if (root NULL)return NULL;if (root->data x)return root;BTNode* ret1 BinaryTreeFind(root->left, x);BTNode* ret2 BinaryTreeFind(ro…

【云原生】什么是 Kubernetes ?

什么是 Kubernetes ? Kubernetes 是一个开源容器编排平台,管理着一系列的 主机 或者 服务器,它们被称作是 节点(Node)。 每一个节点运行了若干个相互独立的 Pod。 Pod 是 Kubernetes 中可以部署的 最小执行单元&#x…

机器学习【03】在本地浏览器使用远程服务器的Jupyter Notebook【conda环境】

1.激活虚拟环境 conda activate 虚拟环境名字2.虚拟环境下安装jupyter notebook pip install jupyter3.配置 jupyter 文件 在 Jupyter Notebook 的配置目录中生成一个配置文件 jupyter_notebook_config.py jupyter notebook --generate-config3.设置密码 jupyter notebook …

性能压测工具:wrk

一般我们压测的时候,需要了解衡量系统性能的一些参数指标,比如。 1、性能指标简介 1.1 延迟 简单易懂。green:一般指响应时间 95线:P95。平均100%的请求中95%已经响应的时间 99线:P99。平均100%的请求中99%已经响应的时间 平…

JVM字节码文件的相关概述解读

Java全能学习面试指南:https://javaxiaobear.cn 1、字节码文件 从下面这个图就可以看出,字节码文件是可以跨平台使用的 想要让一个Java程序正确地运行在JVM中,Java源码就必须要被编译为符合JVM规范的字节码。 https://docs.oracle.com/java…

如何使用nginx部署静态资源

Nginx可以作为静态web服务器来部署静态资源,这个静态资源是指在服务端真实存在,并且能够直接展示的一些文件数据,比如常见的静态资源有html页面、css文件、js文件、图片、视频、音频等资源相对于Tomcat服务器来说,Nginx处理静态资…

Flutter桌面应用开发之毛玻璃效果

目录 效果实现方案依赖库支持平台实现步骤注意事项话题扩展 毛玻璃效果:毛玻璃效果是一种模糊化的视觉效果,常用于图像处理和界面设计中。它可以通过在图像或界面元素上应用高斯模糊来实现。使用毛玻璃效果可以增加图像或界面元素的柔和感,同…

Elasticsearch集群部署,配置head监控插件

Elasticsearch是一个开源搜索引擎,基于Lucene搜索库构建,被广泛应用于全文搜索、地理位置搜索、日志处理、商业分析等领域。它采用分布式架构,可以处理大规模数据集和支持高并发访问。Elasticsearch提供了一个简单而强大的API,可以…

超级详细的 Maven 教程(基础+高级)

1. Maven 是什么 Maven 是 Apache 软件基金会组织维护的一款专门为 Java 项目提供构建和依赖管理支持的工具。 一个 Maven 工程有约定的目录结构,约定的目录结构对于 Maven 实现自动化构建而言是必不可少的一环,就拿自动编译来说,Maven 必须…

ArkTs变量类型、数据类型

可以参考官网学习路径学习HarmonyOS第一课|应用开发视频教程学习|HarmonyOS应用开发官网 一、变量 1、ArkTS语言 ArkTS是华为自研的开发语言。它在TypeScript(简称TS)的基础上,匹配ArkUI框架,扩展了声明式UI、状态管理等相应的…

赢麻了!义乌一个村有5000个网红,有人年收租就300万!

#义乌一村电商年成交额超300亿# ,在中国,电商行业的发展可谓是日新月异,而位于浙江省义乌市的江北下朱村,正是这股潮流的一个典型代表。这个村子,处处弥漫着“直播”的气息,仿佛每个人都在为这个新兴行业助力。 江北下…