PyTorch开放神经网络交换(Open Neural Network Exchange)ONNX通用格式模型的熟悉

我们在深度学习中可以发现有很多不同格式的模型文件,比如不同的框架就有各自的文件格式:.model、.h5、.pb、.pkl、.pt、.pth等等,各自有标准就带来互通的不便,所以微软、Meta和亚马逊在内的合作伙伴社区一起搞一个ONNX(Open Neural Network Exchange)文件格式的通用标准,这样就可以使得模型在不同框架之间方便的进行互操作了。
在上节的YOLOv8的目标对象的分类,分割,跟踪和姿态估计的多任务检测实践(Netron模型可视化) 我们在最后有接触到这个格式文件,而且给出了一个 https://netron.app/站点,可以将这个文件的计算图和相关属性都能可视化的给呈现出来。YOLO的模型有自带的方法:

from ultralytics import YOLO
model = YOLO('yolov8n-cls.pt')
model.export(format="onnx")

或者命令行的方式
yolo export model=yolov8n-cls.pt format=onnx

1、PyTorch演示onnx

这里我们来熟悉下在PyTorchonnx是怎么样的。看一个简单示例:

import torchclass myModel(torch.nn.Module):def __init__(self):super(myModel, self).__init__()def forward(self, x):return x.reshape(1,3,64,64)model = myModel()
x = torch.randn(64,64,3)
torch.onnx.export(model, x, 'newmodel.onnx', input_names=['images'], output_names=['newimages'])

修改输入进来的形状,可以看到通过torch.onnx.export就可以生成onnx格式的模型文件。我们将它上传到上面的站点,将会生成如下的流程图:

 我们打印看下export方法有哪些参数:

export(model: 'Union[torch.nn.Module, torch.jit.ScriptModule, torch.jit.ScriptFunction]', args: 'Union[Tuple[Any, ...], torch.Tensor]', f: 'Union[str, io.BytesIO]', export_params: 'bool' = True, verbose: 'bool' = False, training: '_C_onnx.TrainingMode' = <TrainingMode.EVAL: 0>, input_names: 'Optional[Sequence[str]]' = None, output_names: 'Optional[Sequence[str]]' = None, operator_export_type: '_C_onnx.OperatorExportTypes' = <OperatorExportTypes.ONNX: 0>, opset_version: 'Optional[int]' = None, do_constant_folding: 'bool' = True, dynamic_axes: 'Optional[Union[Mapping[str, Mapping[int, str]], Mapping[str, Sequence[int]]]]' = None, keep_initializers_as_inputs: 'Optional[bool]' = None, custom_opsets: 'Optional[Mapping[str, int]]' = None, export_modules_as_functions: 'Union[bool, Collection[Type[torch.nn.Module]]]' = False) -> 'None'

2、加载onnx模型

2.1、安装onnx库

为了能够正确的加载onnx格式文件,需要先安装onnx库,依然推荐加豆瓣镜像安装

pip install onnx -i http://pypi.douban.com/simple/ --trusted-host pypi.douban.com

安装好了之后,我们看下是否成功安装:

import onnx
dir(onnx)
'''
['Any', 'AttributeProto', 'EXPERIMENTAL', 'FunctionProto', 'GraphProto', 'IO', 'IR_VERSION', 'IR_VERSION_2017_10_10', 'IR_VERSION_2017_10_30', 'IR_VERSION_2017_11_3', 'IR_VERSION_2019_1_22', 'IR_VERSION_2019_3_18', 'IR_VERSION_2019_9_19', 'IR_VERSION_2020_5_8', 'IR_VERSION_2021_7_30', 'MapProto', 'ModelProto', 'NodeProto', 'ONNX_ML', 'OperatorProto', 'OperatorSetIdProto', 'OperatorSetProto', 'OperatorStatus', 'Optional', 'OptionalProto', 'STABLE', 'SequenceProto', 'SparseTensorProto', 'StringStringEntryProto', 'TensorAnnotation', 'TensorProto', 'TensorShapeProto', 'TrainingInfoProto', 'TypeProto', 'TypeVar', 'Union', 'ValueInfoProto', 'Version', '_Proto', '__builtins__', '__cached__', '__doc__', '__file__', '__loader__', '__name__', '__package__', '__path__', '__spec__', '__version__', '_deserialize', '_get_file_path', '_load_bytes', '_save_bytes', '_serialize', 'checker', 'compose', 'convert_model_to_external_data', 'defs', 'external_data_helper', 'gen_proto', 'google', 'helper', 'hub', 'load', 'load_external_data_for_model', 'load_from_string', 'load_model', 'load_model_from_string', 'load_tensor', 'load_tensor_from_string', 'mapping', 'numpy_helper', 'onnx_cpp2py_export', 'onnx_data_pb', 'onnx_data_pb2', 'onnx_ml_pb2', 'onnx_operators_ml_pb2', 'onnx_operators_pb', 'onnx_pb', 'os', 'parser', 'printer', 'save', 'save_model', 'save_tensor', 'shape_inference', 'typing', 'utils', 'version', 'version_converter', 'write_external_data_tensors']
'''

2.2、onnx模型信息

正确安装之后,就可以加载模型了。

import onnxmyModel = onnx.load("newmodel.onnx")
#打印整个模型信息
#print(myModel)
output = myModel.graph.output
#打印输出层的信息
print(output)

[name: "newimages"
type {
  tensor_type {
    elem_type: 1
    shape {
      dim {
        dim_value: 1
      }
      dim {
        dim_value: 3
      }
      dim {
        dim_value: 64
      }
      dim {
        dim_value: 64
      }
    }
  }
}
]

2.3、加载外部数据

需要注意的是,外部数据如果跟模型文件不是在同一个目录里面的话,需要使用load_external_data_for_model来加载外部数据

import onnx
from onnx.external_data_helper import load_external_data_for_modelmymodel = onnx.load('xx/model.onnx', load_external_data=False)
load_external_data_for_model(mymodel, 'datasets/')

这样就可以通过mymodel这个模型来加载指定目录的数据了。

3、保存新模型

我们可以试着在模型的基础上做一些修改,然后保存为一个新的模型文件。

import onnx
from onnx import helpermodel = onnx.load('newmodel.onnx')
prob_info = helper.make_tensor_value_info('xx',onnx.TensorProto.FLOAT, [1,3,128,128])
model.graph.output.insert(0, prob_info)
onnx.save(model, 'newmodel2.onnx')

将一个新的节点插入到输出层,我们来看下生成的图:

其中helper里面有很多用法,比如增加节点:

node = helper.make_node("Range",inputs=["start", "limit", "delta"],outputs=["output"])
start = np.float32(1)
limit = np.float32(15)
delta = np.float32(2)
output = np.arange(start, limit, delta, dtype=np.float32)
print(output)#[ 1.  3.  5.  7.  9. 11. 13.]

这里的Range是其中一个操作类型(算子),用法:Operators 更多的操作类型,有兴趣的可以进去查阅

4、常见对象

在onnix常出现的几个对象,AttributeProto,TensorProto,GraphProto,NodeProto一起来了解下:

4.1、AttributeProto 

属性

import onnx
from onnx import helper
from onnx import AttributeProto, TensorProto, GraphProtoarg = helper.make_attribute("this_is_an_int", 1701)
print(arg)
'''
name: "this_is_an_int"
type: INT
i: 1701
'''

当然类型还可以是浮点数、字符串、数组等

4.2、NodeProto

节点

node_proto = helper.make_node("Relu", ["X"], ["Y"])
print(node_proto)
'''
input: "X"
output: "Y"
op_type: "Relu"
'''

其中op_type的算子挺多的,比如上面那个Range的用法。

4.3、AttributeProto和NodeProto结合

看下两者结合使用的一个例子,卷积核为3,步幅为1,填充为1的一个卷积。 

node_proto = helper.make_node("Conv", ["X", "W", "B"], ["Y"],kernel=3, stride=1, pad=1)
# 属性按顺序打印
node_proto.attribute.sort(key=lambda attr: attr.name)
print(node_proto)
'''
input: "X"
input: "W"
input: "B"
output: "Y"
op_type: "Conv"
attribute {name: "kernel"type: INTi: 3
}
attribute {name: "pad"type: INTi: 1
}
attribute {name: "stride"type: INTi: 1
}
'''

另一种更有可读性的打印:

print(helper.printable_node(node_proto))
%Y = Conv[kernel = 3, pad = 1, stride = 1](%X, %W, %B)

4.4、TensorProto和GraphProto

graph_proto = helper.make_graph([helper.make_node("FC", ["X", "W1", "B1"], ["H1"]),helper.make_node("Relu", ["H1"], ["R1"]),helper.make_node("FC", ["R1", "W2", "B2"], ["Y"]),],"MLP",[helper.make_tensor_value_info("X" , TensorProto.FLOAT, [1]),helper.make_tensor_value_info("W1", TensorProto.FLOAT, [1]),helper.make_tensor_value_info("B1", TensorProto.FLOAT, [1]),helper.make_tensor_value_info("W2", TensorProto.FLOAT, [1]),helper.make_tensor_value_info("B2", TensorProto.FLOAT, [1]),],[helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1]),]
)
print(helper.printable_graph(graph_proto))
'''
graph MLP (%X[FLOAT, 1]%W1[FLOAT, 1]%B1[FLOAT, 1]%W2[FLOAT, 1]%B2[FLOAT, 1]
) {%H1 = FC(%X, %W1, %B1)%R1 = Relu(%H1)%Y = FC(%R1, %W2, %B2)return %Y
'''

5、解析器

5.1、parse_graph

onnx.parser.parse_graph将文本表示,创建成ONNX图形

input = """agraph (float[N, 128] X, float[128, 10] W, float[10] B) => (float[N, 10] C){T = MatMul(X, W)S = Add(T, B)C = Softmax(S)}
"""
graph = onnx.parser.parse_graph(input)
print(helper.printable_graph(graph))
'''
graph agraph (%X[FLOAT, Nx128]%W[FLOAT, 128x10]%B[FLOAT, 10]
) {%T = MatMul(%X, %W)%S = Add(%T, %B)%C = Softmax(%S)return %C
}
'''

 5.2、parse_model

onnx.parser.parse_model将文本表示,创建成ONNX模型

input = """<ir_version: 7,opset_import: ["" : 10]>agraph (float[N, 128] X, float[128, 10] W, float[10] B) => (float[N, 10] C){T = MatMul(X, W)S = Add(T, B)C = Softmax(S)}
"""
model = onnx.parser.parse_model(input)
print(model)'''
ir_version: 7
opset_import {domain: ""version: 10
}
graph {node {input: "X"input: "W"output: "T"op_type: "MatMul"domain: ""}node {input: "T"input: "B"output: "S"op_type: "Add"domain: ""}node {input: "S"output: "C"op_type: "Softmax"domain: ""}name: "agraph"input {name: "X"type {tensor_type {elem_type: 1shape {dim {dim_param: "N"}dim {dim_value: 128}}}}}input {name: "W"type {tensor_type {elem_type: 1shape {dim {dim_value: 128}dim {dim_value: 10}}}}}input {name: "B"type {tensor_type {elem_type: 1shape {dim {dim_value: 10}}}}}output {name: "C"type {tensor_type {elem_type: 1shape {dim {dim_param: "N"}dim {dim_value: 10}}}}}
}
'''

引用来源
github:https://github.com/onnx/onnx

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/4561.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

jupyter notebook优化

一&#xff0e;这个是jupyter notebook主题设置的相关教程&#xff0c;如果经常看着高亮的屏幕&#xff0c;对于眼睛会是一种损伤&#xff01; https://blog.csdn.net/qq_41566627/article/details/104984796?utm_mediumdistribute.pc_relevant.none-task-blog-2%7Edefault%7…

乐鑫线上研讨会|探索 LCD 屏在物联网中的发展趋势

LCD 屏通过显示实时信息并提供交互式体验&#xff0c;现已成为各类设备的重要组成部分。在当下的 AIoT 时代&#xff0c;随着物联网技术的快速发展和应用场景的不断拓展&#xff0c;LCD 作为人机交互的主要输入输出设备&#xff0c;在智能家居、智能安防、工业控制、智慧城市等…

react封装jsoneditor

1、安装&#xff1a; api文档&#xff1a;jsoneditor npm install jsoneditor2、代码&#xff1a; JsonEditor/index.tsx: import { useMemoizedFn } from ahooks; import JSONEditor from jsoneditor; import { useEffect, useState } from react; import ./index.less;in…

redis 主从复制 哨兵 安装部署

学习开始前先了解一下 Redis是一个开源的内存数据结构存储系统&#xff0c;它支持多种数据结构&#xff0c;如字符串、哈希表、列表、集合、有序集合等。Redis最大的特点是数据存储在内存中&#xff0c;因此读写速度非常快&#xff0c;同时也支持数据持久化&#xff0c;可以将数…

【K8S系列】如何高效查看 k8s日志

序言 你只管努力&#xff0c;其他交给时间&#xff0c;时间会证明一切。 文章标记颜色说明&#xff1a; 黄色&#xff1a;重要标题红色&#xff1a;用来标记结论绿色&#xff1a;用来标记一级论点蓝色&#xff1a;用来标记二级论点 Kubernetes (k8s) 是一个容器编排平台&#x…

AndroidStudio xml布局文件输入没有提示

AndroidStudio xml布局文件输入没有提示&#xff0c;如下图&#xff1a; 原因是老的AndroidStudio与新的sdk版本不一致 方法1&#xff1a;修改compileSdkVersion低于33即可&#xff0c;不建议 方法2&#xff1a;升级AndroidStudio版本&#xff0c;建议 如下是我的AndroidStu…

Docker部署(1)——将jar包打成docker镜像并启动容器

在代码编写完成即将部署的时候&#xff0c;如果采用docker容器的方法&#xff0c;需要将jar包打成docker镜像并通过镜像将容器启动起来。具体的步骤如下。 一、首先下载java镜像 先使用docker search java命令进行搜索。 然而在拉取镜像的时候要注意不能直接去选择pull java ,…

改进的白鲸优化算法

改进的白鲸优化算法 一、算法灵感二、算法介绍2.1 初始化2.2 探索阶段2.3 开发阶段2.4 鲸落阶段 三、改进的白鲸优化算法3.1 集体行动策略3.2 小孔成像策略3.3 二次插值策略3.4 IBWO伪代码 一、算法灵感 白鲸优化算法(Beluga whale optimization, BWO)是2022年提出的一种元启发…

需求分析引言:架构漫谈(三)可用性专题

前文介绍了非功能性需求的各个指标和一些业界的标准。 非功能性需求里有一项可靠性&#xff0c;与之关联的一个指标叫可用性 本文对非功能性需求里的可用性、可靠性&#xff0c;进行一些详细的说明。 概念 我们在网上的云服务商处&#xff0c;经常看到产品介绍里会有这种字样…

系列五、NotePad++下载安装

一、下载 链接&#xff1a;https://pan.baidu.com/s/1U2f74vfBJIds7W2wJYnBxg?pwdyyds 提取码&#xff1a;yyds 二、安装 2.1、安装NotePad 解压NotePad-x64.zip至指定目录即可&#xff0c;例如 2.2、安装NppFTP 2.2.1、查看NotePad对应的位数&#xff08;32位or64位&a…

最强优化指令大全 | 【Linux技术专题】「系统性能调优实战」终极关注应用系统性能调优及原理剖析(下册)

Linux命令相关查看指标 CPU 指标 vmstat指令 vmstat -n m该命令用于每隔n秒采集系统的性能统计信息&#xff0c;共采集m次。 [rootsvr01]$ vmstat 1 3procs -----------memory---------- ---swap-- -----io---- --system-- -----cpu-----r b swpd free buff cache …

行业榜单揭晓:项目管理软件排行榜中的行业翘楚!

在当今复杂的商业环境中&#xff0c;项目管理软件已经成为企业管理不可或缺的工具之一。它们帮助企业组织和协调各个部门的工作&#xff0c;并确保项目能够按时、按预算、按质量要求完成。这就是为什么许多公司都在积极寻找最好的项目管理软件。但是&#xff0c;市场上有许多项…