Pytorch导出FP16 ONNX模型

一般Pytorch导出ONNX时默认都是用的FP32,但有时需要导出FP16的ONNX模型,这样在部署时能够方便的将计算以及IO改成FP16,并且ONNX文件体积也会更小。想导出FP16的ONNX模型也比较简单,一般情况下只需要在导出FP32 ONNX的基础上调用下model.half()将模型相关权重转为FP16,然后输入的Tensor也改成FP16即可,具体操作可参考如下示例代码。这里需要注意下,当前Pytorch要导出FP16的ONNX必须将模型以及输入Tensor的device设置成GPU,否则会报很多算子不支持FP16计算的提示。

import torch
from torchvision.models import resnet50def main():export_fp16 = Trueexport_onnx_path = f"resnet50_fp{16 if export_fp16 else 32}.onnx"device = torch.device("cuda:0")model = resnet50()model.eval()model.to(device)if export_fp16:model.half()with torch.inference_mode():dtype = torch.float16 if export_fp16 else torch.float32x = torch.randn(size=(1, 3, 224, 224), dtype=dtype, device=device)torch.onnx.export(model=model,args=(x,),f=export_onnx_path,input_names=["image"],output_names=["output"],dynamic_axes={"image": {2: "width", 3: "height"}},opset_version=17)if __name__ == '__main__':main()

通过Netron可视化工具可以看到导出的FP16 ONNX的输入/输出的tensor类型都是float16
在这里插入图片描述

并且通过对比可以看到,FP16的ONNX模型比FP32的文件更小(48.6MB vs 97.3MB)。
在这里插入图片描述
大多数情况可以按照上述操作进行正常转换,但也有一些比较头大的场景,因为你永远无法知道拿到的模型会有多奇葩,例如下面示例:
错误导出FP16 ONNX示例

import torch
import torch.nn as nn
import torch.nn.functional as Fclass MyModel(nn.Module):def __init__(self) -> None:super().__init__()self.conv = nn.Conv2d(3, 1, kernel_size=3, stride=2, padding=1)def forward(self, x):x = self.conv(x)kernel = torch.tensor([[0.1, 0.1, 0.1],[0.1, 0.1, 0.1],[0.1, 0.1, 0.1]], dtype=torch.float32, device=x.device).reshape([1, 1, 3, 3])x = F.conv2d(x, weight=kernel, bias=None, stride=1)return xdef main():export_fp16 = Trueexport_onnx_path = f"my_model_fp{16 if export_fp16 else 32}.onnx"device = torch.device("cuda:0")model = MyModel()model.eval()model.to(device)if export_fp16:model.half()with torch.inference_mode():dtype = torch.float16 if export_fp16 else torch.float32x = torch.randn(size=(1, 3, 224, 224), dtype=dtype, device=device)model(x)torch.onnx.export(model=model,args=(x,),f=export_onnx_path,input_names=["image"],output_names=["output"],dynamic_axes={"image": {2: "width", 3: "height"}},opset_version=17)if __name__ == '__main__':main()

执行以上代码后会报如下错误信息:

/src/ATen/native/cudnn/Conv_v8.cpp:80.)return F.conv2d(input, weight, bias, self.stride,
Traceback (most recent call last):File "/home/wz/my_projects/py_projects/export_fp16/example.py", line 47, in <module>main()File "/home/wz/my_projects/py_projects/export_fp16/example.py", line 36, in mainmodel(x)File "/home/wz/miniconda3/envs/torch2.0.1/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_implreturn forward_call(*args, **kwargs)File "/home/wz/my_projects/py_projects/export_fp16/example.py", line 17, in forwardx = F.conv2d(x, weight=kernel, bias=None, stride=1)RuntimeError: Input type (torch.cuda.HalfTensor) and weight type (torch.cuda.FloatTensor) should be the same

简单来说就是在推理过程中遇到两种不同类型的数据要计算,torch.cuda.HalfTensor(FP16) 和torch.cuda.FloatTensor(FP32)。遇到这种情况一般常见有两种解法:

  • 一种是找到数据类型与我们预期不一致的地方,然后改成我们要想的dtype,例如上面示例是将kernel的dtype写死成了torch.float32,我们可以改成torch.float16或者写成x.dtype(这种会比较通用,会根据输入的Tensor类型自动切换)。这种方法有个弊端,如果代码里写死dtype的位置很多,改起来会比较头大。
  • 另一种是使用torch.autocast上下文管理器,该上下文管理器能够实现推理过程中自动进行混合精度计算,例如遇到能进行float16/bfloat16计算的场景会自动切换。具体使用方法可以查看官方文档。下面示例代码就是用torch.autocast上下文管理器来做自动转换。
import torch
import torch.nn as nn
import torch.nn.functional as Fclass MyModel(nn.Module):def __init__(self) -> None:super().__init__()self.conv = nn.Conv2d(3, 1, kernel_size=3, stride=2, padding=1)def forward(self, x):x = self.conv(x)kernel = torch.tensor([[0.1, 0.1, 0.1],[0.1, 0.1, 0.1],[0.1, 0.1, 0.1]], dtype=torch.float32, device=x.device).reshape([1, 1, 3, 3])x = F.conv2d(x, weight=kernel, bias=None, stride=1)return xdef main():export_fp16 = Trueexport_onnx_path = f"my_model_fp{16 if export_fp16 else 32}.onnx"device = torch.device("cuda:0")model = MyModel()model.eval()model.to(device)if export_fp16:model.half()with torch.autocast(device_type="cuda", dtype=torch.float16):with torch.inference_mode():dtype = torch.float16 if export_fp16 else torch.float32x = torch.randn(size=(1, 3, 224, 224), dtype=dtype, device=device)model(x)torch.onnx.export(model=model,args=(x,),f=export_onnx_path,input_names=["image"],output_names=["output"],dynamic_axes={"image": {2: "width", 3: "height"}},opset_version=17)if __name__ == '__main__':main()

使用上述代码能够正常导出ONNX模型,并且使用Netron可视化后可以看到导出的FP16 ONNX模型是符合预期的。
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/615808.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

FebHost:给你注册法国.FR域名的8大理由

如果您的企业与法国有联系&#xff0c;或者您的目标受众是法国人&#xff0c;那么拥有 .fr 域名可以成为您的战略资产。以下是您可以考虑选择 .fr 域名的几个原因&#xff1a; 地理定位&#xff1a; 如果您的企业面向法国受众&#xff0c;或以与法国或法国境内某一特定地区的联…

每日一题---OJ题: 环形链表 II

片头 嗨! 小伙伴们,大家好! 我们又见面啦,在上一篇中,我们学习了环形链表I, 今天我们继续来打boss,准备好了吗? Ready Go ! ! ! emmm,同样都是环形链表,有什么不一样的地方呢? 肯定有, 要不然也不会一个标记为"简单" ,一个标记为"中等"了,哈哈哈哈哈 …

全新4.0版本圈子社交论坛系统 ,可打包小程序,于TP6+uni-app 全开源 可打包小程序app uniapp前端+全开源+独立版

简述 首先 圈子系统的核心是基于共同的兴趣或爱好将用户聚集在一起&#xff0c;这种设计使得用户能够迅速找到与自己有共同话题和兴趣的人。 其次 圈子系统提供了丰富的社交功能&#xff0c;如发帖、建圈子、发活动等&#xff0c;并且支持小程序授权登录、H5和APP等多种形式…

最齐全,最简单的免费SSL证书获取方法——实现HTTPS访问

一&#xff1a;阿里云 优势&#xff1a;大平台&#xff0c;在站长中知名度最高&#xff0c;提供20张免费单域名SSL证书 缺点&#xff1a;数量有限&#xff0c;并且只有单域名证书&#xff0c;通配符以及多域名没有免费版本。并且提供的单域名证书只有三个月的期限。 二&#…

【蓝桥杯】蓝桥杯算法复习(五)

&#x1f600;大家好&#xff0c;我是白晨&#xff0c;一个不是很能熬夜&#x1f62b;&#xff0c;但是也想日更的人✈。如果喜欢这篇文章&#xff0c;点个赞&#x1f44d;&#xff0c;关注一下&#x1f440;白晨吧&#xff01;你的支持就是我最大的动力&#xff01;&#x1f4…

Leetcode刷题之合并两个有序数组

Leetcode刷题之合并两个有序数组 一、题目描述二、题目解析 一、题目描述 给你两个按 非递减顺序 排列的整数数组 nums1 和 nums2&#xff0c;另有两个整数 m 和 n &#xff0c;分别表示 nums1 和 nums2 中的元素数目。 请你 合并 nums2 到 nums1 中&#xff0c;使合并后的数…

使用Vivado Design Suite进行功率优化

功率优化是一个可选步骤&#xff0c;它通过使用时钟门控来优化动态功率。它既可以在Project模式下使用&#xff0c;也可以在Non-Project模式下使用&#xff0c;并且可以在逻辑优化之后或布局之后运行&#xff0c;以减少设计中的功率需求。功率优化包括Xilinx的智能时钟门控解决…

【零基础学数据结构】双向链表

1.双向链表的概念 1.1头节点 1.2带头双向循环链表 注意&#xff1a; 哨兵位创建后&#xff0c;首尾连接自己 1.3双链表的初始化 // 双向链表的初始化 void ListInit(ListNode** pphead) {// 给双链表创建一个哨兵位*pphead ListBuyNode(-1); } 2.双向链表的打印 // 双向…

HarmonyOS开发实例:【app帐号管理】

应用帐号管理 介绍 本示例选择应用进行注册/登录&#xff0c;并设置帐号相关信息&#xff0c;简要说明应用帐号管理相关功能。效果图如下&#xff1a; 效果预览 使用说明参考鸿蒙文档&#xff1a;qr23.cn/AKFP8k点击或者转到。 1.首页面选择想要进入的应用&#xff0c;首次进…

Redis 之集群模式

一 集群原理 集群&#xff0c;即Redis Cluster&#xff0c;是Redis 3.0开始引入的分布式存储方案。 集群由多个节点(Node)组成&#xff0c;Redis的数据分布在这些节点中。 集群中的节点分为主节点和从节点&#xff1a;只有主节点负责读写请求和集群信息的维护&#xff1b;从…

接口测试用例编写和接口测试模板

一、简介 接口测试区别于传统意义上的系统测试&#xff0c;下面介绍接口测试用例和接口测试报告。 二、接口测试用例模板 功能测试用例最重要的两个因素是测试步骤和预期结果&#xff0c;接口测试属于功能测试&#xff0c;所以同理。接口测试的步骤中&#xff0c;最重要的是将…

项目管理软件评测:选择合适软件是关键

在过去&#xff0c;中小企业项目管理沿用的是office全家桶&#xff0c;用到后面项目由简单变复杂&#xff0c;项目资源越来越庞大&#xff0c;项目成员越来越多&#xff0c;项目管理问题日益凸显。好用的项目管理软件是化解问题的好方法&#xff0c;好用的项目管理软件是什么样…