PyTorch的ONNX结合MNIST手写数字数据集的应用(.pth和.onnx的转换与onnx运行时)

在PyTorch以前的模型都是.pth格式,后面Meta跟微软一起做了一个.onnx的通用格式。这里对这两种格式文件,分别做一个介绍,依然使用MNIST数据集来做示例

1、CUDA下的pth文件

那pth文件里面是什么结构呢?其实在以前的文章就有介绍过,属于字典类型,而且是有序字典类型,这样就可以按照一定的顺序进行处理。

1.1、了解pth结构

先来查看一下pth文件的内容:
MNIST预训练模型.pth文件

import torch
model=torch.load("lenet_mnist_model.pth",map_location=torch.device('cpu'))
print(type(model),len(model))
for k,v in model.items():print(k,v.size())
'''
<class 'collections.OrderedDict'> 8
conv1.weight torch.Size([10, 1, 5, 5])
conv1.bias torch.Size([10])
conv2.weight torch.Size([20, 10, 5, 5])
conv2.bias torch.Size([20])
fc1.weight torch.Size([50, 320])
fc1.bias torch.Size([50])
fc2.weight torch.Size([10, 50])
fc2.bias torch.Size([10])
'''

可以看到类型是OrderedDict,两个卷积层加上两个全连接层。每个层都带有权重和偏置,简单显示了它们的形状。

1.2、torch.device

这里有一个需要注意的地方就是,如果将

model=torch.load("lenet_mnist_model.pth",map_location=torch.device('cpu'))

修改为

model=torch.load("lenet_mnist_model.pth")

就会报如下错误:

RuntimeError: Attempting to deserialize object on a CUDA device but torch.cuda.is_available() is False. If you are running on a CPU-only machine, please use torch.load with map_location=torch.device('cpu') to map your storages to the CPU.

也就是说想在CUDA中做反序列化操作,而CUDA是不可用的,从这里可以看到这个模型的训练是在有CUDA的环境下进行的,所以我们这里指定到CPU设备上。

2、CPU下的pth文件

 我们来看一个在CPU的环境下的加载方法,mnist.pth文件下载地址:mnist.pth

import torch
model=torch.load("mnist.pth")
print(type(model['net']),len(model['net']))
for k,v in model['net'].items():print(k,v.size())'''
<class 'collections.OrderedDict'> 10
conv1.weight torch.Size([6, 1, 3, 3])
conv1.bias torch.Size([6])
conv2.weight torch.Size([16, 6, 3, 3])
conv2.bias torch.Size([16])
fc1.weight torch.Size([120, 400])
fc1.bias torch.Size([120])
fc2.weight torch.Size([84, 120])
fc2.bias torch.Size([84])
fc3.weight torch.Size([10, 84])
fc3.bias torch.Size([10])
'''

这里可以不指定map_location参数,默认是cpu设备,可以看到这个pth文件结构是两个卷积层加三个全连接层。

3、pth转onnx

我们根据上面的mnist.pth结构,自己来构造一个模型:

import torch
import torch.nn as nn
import torch.nn.functional as Fclass LeNet(nn.Module):def __init__(self):super(LeNet, self).__init__()self.conv1 = nn.Conv2d(in_channels=1,out_channels=6,kernel_size=3,stride=1,padding=0)self.conv2 = nn.Conv2d(in_channels=6,out_channels=16,kernel_size=3,stride=1,padding=0)self.fc1   = nn.Linear(400, 120)self.fc2   = nn.Linear(120, 84)self.fc3   = nn.Linear(84, 10)def forward(self, x):out = self.conv1(x) # torch.Size([1, 6, 26, 26])out = F.max_pool2d(F.relu(out), 2) # [1, 6, 13, 13]out = self.conv2(out) # [1, 16, 11, 11]out = F.max_pool2d(F.relu(out), 2)  # [1, 16, 5, 5]out = out.view(out.size(0), -1) # [1, 400]out = self.fc1(out) # [1, 120]out = self.fc2(F.relu(out)) # [1, 84]out = self.fc3(F.relu(out))  # [1, 10]return outnet = LeNet()
net = net.to('cpu')
checkpoint = torch.load('mnist.pth')
net.load_state_dict(checkpoint['net'])
batch_size = 1
input_shape = (1,28,28)
x = torch.randn(batch_size,*input_shape)
net.eval()
torch.onnx.export(net,x,"mnist.onnx")

构造一样的结构,加载mnist.pth,然后就可以通过export转换成onnx格式的文件了。我们上传到https://netron.app/ 站点,可视化整个模型图,然后点击每个节点,将在右边出现它们的属性值:

4、onnx运行时

onnxruntime主要是拿来推理,当然在ir7的版本也增加了训练等功能,我们来了解下这个东西 

4.1、安装模块

如果缺少onnxruntime模块,就会报错:

ModuleNotFoundError: No module named 'onnxruntime'

这里在JupyterLab中,所以在前面加一个叹号安装

!pip install onnxruntime -i http://pypi.douban.com/simple/ --trusted-host pypi.douban.com

import torch
import onnxruntime as ort
import numpy as npsession = ort.InferenceSession("mnist.onnx")
x = np.random.rand(1, 1, 28, 28).astype(np.float32)
outputs = session.run(None, {"input": x})
print(outputs[0])

4.2、名称一致 

这里容易出错:InvalidArgument: [ONNXRuntimeError] : 2 : INVALID_ARGUMENT : Invalid Feed Input Name:input

也就是说这个sess.run([output_name], {input_name: x})中的输入名称错误,所以名称要一样,这里的输入名称是input.1,修改成outputs = session.run(None, {"input.1": x})就可以了
怎么查看名称,可以通过上面站点可视化直接看到名称,也可以使用下面代码获取

input_name = session.get_inputs()
print(input_name[0].name)#input.1

同样的,如果输出名称也想指定,可以使用下面代码获取

out_name = session.get_outputs()[0].name

4.3、三通道转一通道

彩色三通道的图片转成灰色的单通道图片:

import cv2
import numpy as np
img = cv2.imread('1.png', cv2.IMREAD_GRAYSCALE)
cv2.imwrite('1.jpg',img)
print(img.shape)#(28, 28)

5、转成json格式

有时候的需求需要可读文件,一般json是很常见的,也可以进行转换:

import onnx
import json
from google.protobuf.json_format import MessageToJsononnx_model = onnx.load("mnist.onnx")
s = MessageToJson(onnx_model)
onnx_json = json.loads(s)output_json_path = 'mnist2.json'with open(output_json_path, 'w') as f:json.dump(onnx_json, f, indent=2)

这样就将onnx文件转成了json格式的文件了

引用来源
github:https://github.com/onnx/onnx
可视化模型:https://netron.app/
ONNX实践:http://www.icfgblog.com/index.php/software/227.html

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/6052.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

HTML 超链接标签、图片标签

超链接标签 超链接描述 HTML使用标签<a>来设置超文本链接 超链接可以是一个字&#xff0c;一个词&#xff0c;或者一组词&#xff0c;也可以是一幅图像&#xff0c;您可以点击这些内容来跳转到新的文档或者当前文档中的某个部分。 <a href"url">链接文…

图片加载失败捕获上报及处理

图片加载失败捕获上报及处理 前端页面中加载最多的静态资源之一就是图片了&#xff0c;当出现图片加载失败时&#xff0c;非常影响用户体验。这时候我们就需要对图片是否成功加载进行判断&#xff0c;并对图片加载失败进行处理。 图片加载监听 单个捕获 HTML中的img标签可以…

微信小程序浏览docx,pdf等文件在线预览使用wx.openDocument

wx.downloadFile({ url: fileUrl,//pdf链接success(res) {wx.openDocument({ //打开文档filePath: res.tempFilePath,fileType: "pdf",//文档类型showMenu: true,success: function (res) {wx.showToast({title: 打开文档成功,})},fail: function (res) {wx.showToas…

springboot整合xxl-job

文章目录 前言一、xxl-job是什么&#xff1f;二、使用步骤1.下载源码,并部署好2.模仿xxl-job-executor-sample-springboot 自己建立一个服务1 引入xxl-job核心依赖2 创建服务,配置yml3 创建一个配置类,用于读取上述配置,并配置好handel信息4 创建一个执行器的任务类,用于执行真…

阵列模式综合第三部分:深度学习(附源码)

一、前言 这个例子展示了如何设计和训练卷积神经网络&#xff08;CNN&#xff09;来计算产生所需模式的元素权重。 二、介绍 模式合成是阵列处理中的一个重要课题。阵列权重有助于塑造传感器阵列的波束图案&#xff0c;以匹配所需图案。传统上&#xff0c;由于空间信号处理和频…

举例说明什么是卷积神经网络

卷积神经网络&#xff08;Convolutional Neural Network, CNN&#xff09;是一种深度学习模型&#xff0c;主要应用于计算机视觉任务&#xff0c;如图像分类、物体检测等。它通过卷积层、池化层和全连接层等组件来实现对图像的特征提取和分类。 现在我们以一个简单的图像分类任…

Android Compose UI实战练手----Google Bloom登录页

目录 1.概述2.页面展示1.1 亮色主题1.2暗色主题 3.登录页面拆分以及编码实现3.1 登录页面拆分3.2 编码实现3.2.1 LoginPage3.2.2 LoginTitle3.2.3 LoginInoutBox3.2.4 LoginHintWithUnderLine3.2.5 LoginButton 4.源码地址 1.概述 在之前的章节中我们已经介绍了如何实现Google…

什么是网络货运平台?

一、什么是网络货运平台&#xff1f; 网络货运平台是依托互联网平台整合配置运输资源&#xff0c;以承运人身份与托运人签订运输服务合同、承担承运人责任&#xff0c;委托实际承运人完成运输服务的物流平台。它通过互联网形式实现运输过程真实、公平、公正、合法&#xff0c…

logback-spring.xml详解

本文来写说下logback-spring.xml相关的知识与概念 文章目录 概述configuration元素定义上下文名称定义变量appender组件RollingFileAppender配置logger配置root配置ELK的配置输出logback状态数据异步输出日志代码中的日志格式本文小结 概述 对于xml日志文件的配置&#xff0c;大…

前端web入门-移动web-day09

(创作不易&#xff0c;感谢有你&#xff0c;你的支持&#xff0c;就是我前行的最大动力&#xff0c;如果看完对你有帮助&#xff0c;请留下您的足迹&#xff09; 目录 空间转换 空间转换 – 平移 视距 perspective 空间 – 旋转 立体呈现 – transform-style 空间转换…

Ubuntu下编译VTK

1.先安装QT&#xff0c;不知道不装行不行&#xff0c;我们项目需要。 2.去VTK官网下载VTK源码。 3.解压源码。 4.编译需要用cmake-gui&#xff0c;装QT的一般都有&#xff0c;但需要把路径添加到PATH才能用。 5.打开cmake-gui&#xff0c;设置源码路径&#xff0c;编译输出路…

只要你会vue,5分钟学不会 svelte 你来找我

&#x1f33b; 前言 2023年了&#xff0c;国内前端领域基本被Vue、React占领市场&#xff0c;近几年似乎前端技术栈的迭代更新缓慢了下来。 当然近几年也出现了像svelte、solid.js等一些新兴的前端框架&#xff0c;这些框架有很多创新的点&#xff0c;比如svelte相比于vue,re…