pytorch 转 onnx 模型需要函数 torch.onnx.export。
def export(model: Union[torch.nn.Module, torch.jit.ScriptModule, torch.jit.ScriptFunction],args: Union[Tuple[Any, ...], torch.Tensor],f: Union[str, io.BytesIO],export_params: bool = True,verbose: bool = False,training: _C_onnx.TrainingMode = _C_onnx.TrainingMode.EVAL,input_names: Optional[Sequence[str]] = None,output_names: Optional[Sequence[str]] = None,operator_export_type: _C_onnx.OperatorExportTypes = _C_onnx.OperatorExportTypes.ONNX,opset_version: Optional[int] = None,do_constant_folding: bool = True,dynamic_axes: Optional[Union[Mapping[str, Mapping[int, str]], Mapping[str, Sequence[int]]]] = None,keep_initializers_as_inputs: Optional[bool] = None,custom_opsets: Optional[Mapping[str, int]] = None,export_modules_as_functions: Union[bool, Collection[Type[torch.nn.Module]]] = False,
) -> None:
常用参数说明
model——需要导出的pytorch模型
args——模型的输入参数,满足输入层的shape正确即可。
f——输出的onnx模型的位置。例如‘yolov5.onnx’。
export_params——输出模型是否可训练。default=True,表示导出trained model,否则untrained。
verbose——是否打印模型转换信息。default=False。
input_names——输入节点名称。default=None。
output_names——输出节点名称。default=None。
opset_version——算子指令集合
do_constant_folding——是否使用常量折叠,默认即可。default=True。
dynamic_axes——模型的输入输出有时是可变的,如Rnn,或者输出图像的batch可变,可通过该参数设置。如输入层的shape为(b,3,h,w),batch,height,width是可变的,但是chancel是固定三通道
参数说明
ONNX算子文档
ONNX 算子的定义情况,都可以在官方的算子文档中查看
这份文档中最重要的开头的这个算子变更表格。表格的第一列是算子名,第二列是该算子发生变动的算子集版本号,也就是我们之前在torch.onnx.export中提到的opset_version表示的算子集版本号。通过查看算子第一次发生变动的版本号,我们可以知道某个算子是从哪个版本开始支持的;通过查看某算子小于等于opset_version的第一个改动记录,我们可以知道当前算子集版本中该算子的定义规则。
练习
import torch
import torch.nn as nn
import torch.onnxclass Model(torch.nn.Module):def __init__(self, in_features, out_features, weights, bias=False):super().__init__()self.linear = nn.Linear(in_features, out_features, bias)with torch.no_grad():self.linear.weight.copy_(weights)def forward(self, x):x = self.linear(x)return xdef infer():in_features = torch.tensor([1, 2, 3, 4], dtype=torch.float32)weights = torch.tensor([[1, 2, 3, 4],[2, 3, 4, 5],[3, 4, 5, 6]],dtype=torch.float32)model = Model(4, 3, weights)x = model(in_features)print("result is: ", x)def export_onnx():input = torch.zeros(1, 1, 1, 4)weights = torch.tensor([[1, 2, 3, 4],[2, 3, 4, 5],[3, 4, 5, 6]],dtype=torch.float32)model = Model(4, 3, weights)model.eval() #添加eval防止权重继续更新# pytorch导出onnx的方式,参数有很多,也可以支持动态size# 我们先做一些最基本的导出,从netron学习一下导出的onnx都有那些东西torch.onnx.export(model = model, args = (input,),f = "../models/example.onnx",input_names = ["input0"],output_names = ["output0"],opset_version = 12)print("Finished onnx export")if __name__ == "__main__":infer()export_onnx()
然后使用netron打开onnx文件,如果没有安装netron,在终端使用pip install netron。
参考链接
模型部署入门教程(三):PyTorch 转 ONNX 详解