如何导出rot90算子至onnx
- 1 背景描述
- 2 等价替换
- 2.1 rot90替换(NCHW)
- 2.2 rot180替换(NCHW)
- 2.3 rot270替换(NCHW)
- 3 rot导出ONNX
1 背景描述
在部署模型时,如果某些模型中或者前后处理中含有rot90
算子,但又希望一起和模型导出onnx时,可能会遇到如下错误(当前使用环境pytorch2.0.1
,opset_version
为17):
import torch
import torch.nn as nnclass RotModel(nn.Module):def forward(self, x: torch.Tensor):x = torch.rot90(x, k=1, dims=(2, 3))return xdef main():print("pytorch version:", torch.__version__)model = RotModel()with torch.inference_mode():x = torch.randn(size=(1, 3, 224, 224))torch.onnx.export(model,args=(x,),f="rot90_counterclockwise.onnx",opset_version=17)if __name__ == '__main__':main()
torch.onnx.errors.UnsupportedOperatorError: Exporting the operator ‘aten::rot90’ to ONNX opset version 17 is not supported. Please feel free to request support or submit a pull request on PyTorch GitHub: https://github.com/pytorch/pytorch/issues.
简单的说就是不支持导出该算子,包括在onnx支持的算子文档中也找不到rot90
算子,onnx官方github链接:
https://github.com/onnx/onnx
2 等价替换
导不出咋办,那就想想旋转矩阵的原理,以及如何使用现有支持的算子替换。
2.1 rot90替换(NCHW)
废话不多说,rot90度(以逆时针为例)可以使用翻转和转置实现。具体代码如下,使用torch自带的rot90与自己实现的对比,通过torch.equal()
来对比两个Tensor是否一致,结果一致,不信自己试试。
import torchdef self_rot90_counterclockwise(x: torch.Tensor):x = x.flip(dims=[3]).permute([0, 1, 3, 2])return xdef main():print("pytorch version:", torch.__version__)with torch.inference_mode():x = torch.randn(size=(1, 3, 224, 224))y0 = torch.rot90(x, k=1, dims=[2, 3])y1 = self_rot90_counterclockwise(x)print(torch.equal(y0, y1))if __name__ == '__main__':main()
2.2 rot180替换(NCHW)
rot180度(以逆时针为例)可以使用翻转实现。具体代码如下:
import torchdef self_rot180_counterclockwise(x: torch.Tensor):x = x.flip(dims=[2, 3])return xdef main():print("pytorch version:", torch.__version__)with torch.inference_mode():x = torch.randn(size=(1, 3, 224, 224))y0 = torch.rot90(x, k=2, dims=[2, 3])y1 = self_rot180_counterclockwise(x)print(torch.equal(y0, y1))if __name__ == '__main__':main()
2.3 rot270替换(NCHW)
rot270度(以逆时针为例)可以使用翻转和转置实现。具体代码如下:
import torchdef self_rot270_counterclockwise(x: torch.Tensor):x = x.flip(dims=[2]).permute([0, 1, 3, 2])return xdef main():print("pytorch version:", torch.__version__)with torch.inference_mode():x = torch.randn(size=(1, 3, 224, 224))y0 = torch.rot90(x, k=3, dims=[2, 3])y1 = self_rot270_counterclockwise(x)print(torch.equal(y0, y1))if __name__ == '__main__':main()
3 rot导出ONNX
这里以rot90度(以逆时针为例)结合刚刚的等价实现来导出ONNX:
import torch
import torch.nn as nnclass RotModel(nn.Module):def forward(self, x: torch.Tensor):# x = torch.rot90(x, k=1, dims=(2, 3))x = x.flip(dims=[3]).permute([0, 1, 3, 2])return xdef main():print("pytorch version:", torch.__version__)model = RotModel()with torch.inference_mode():x = torch.randn(size=(1, 3, 224, 224))torch.onnx.export(model,args=(x,),f="rot90_counterclockwise.onnx",opset_version=17)if __name__ == '__main__':main()
使用netron
打开生成的rot90_counterclockwise.onnx
文件,如下所示: