Pytorch模型训练后静态量化并加载int8量化模型推理

目录

  • 一、源码包准备
    • 1.1 源码包获取
    • 1.2 代表性验证集
    • 1.3 Pytorch模型
    • 1.4 推理测试图片
  • 二、环境准备
  • 三、模型转换
    • 3.1 参数修改
    • 3.2 代码
    • 3.3 量化转换结果
    • 3.4 量化前后模型大小对比
  • 四、量化模型推理
    • 4.1 参数修改
    • 4.2 代码
    • 4.3 推理结果
    • 4.4推理时间
  • 五、总结

一、源码包准备

1.1 源码包获取

网站源码包:Pytorch静态量化

教程中配套的源码包获取方法为文章末扫码到公众号中回复关键字:Pytorch模型训练后静态量化。获取下载链接。

下载解压后的样子如下:

在这里插入图片描述

1.2 代表性验证集

有代表行的验证集位于根目录下的data文件夹中,如下:

在这里插入图片描述

1.3 Pytorch模型

在我源码包中已经提供了一个Pytorch模型,位于根目录下的models文件夹中,如下:

在这里插入图片描述

1.4 推理测试图片

推理测试的图片位于根目录下的TestImages文件夹中,如下:

在这里插入图片描述

二、环境准备

下面是我自己的运行环境,仅供参考:

在这里插入图片描述

三、模型转换

在我提供源码包中,转换代码为pat_to_int.py脚本,将Pytorch的float32模型转为int8模型。

3.1 参数修改

使用此脚本需要修改的地方如下:

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

3.2 代码

具体代码如下:

# !/usr/bin/env python3
# coding=utf-8import torch
import os
from pose_estimation import *def evaluate(model, val_data_dir='./data'):             # 定义一个函数evaluate,用于评估模型。函数接收两个参数,一个是模型,另一个是验证数据的目录。box_size = 368                                      # 定义了一些参数,包括框的大小、缩放搜索的比例和步长scale_search = [0.5, 1.0, 1.5, 2.0]param_stride = 8# Predict pictureslist_dir = os.walk(val_data_dir)                    # 使用os.walk函数遍历验证数据目录for root, dirs, files in list_dir:                  # 遍历验证数据目录中的所有文件for f in files:test_image = os.path.join(root, f)print("test image path", test_image)img_ori = cv2.imread(test_image)  # B,G,R order   # 使用cv2.imread函数读取图片。multiplier = [scale * box_size / img_ori.shape[0] for scale in scale_search]       # 计算缩放因子for i, scale in enumerate(multiplier):               # 遍历所有的缩放因子。h = int(img_ori.shape[0] * scale)w = int(img_ori.shape[1] * scale)pad_h = 0 if (h % param_stride == 0) else param_stride - (h % param_stride)pad_w = 0 if (w % param_stride == 0) else param_stride - (w % param_stride)new_h = h + pad_hnew_w = w + pad_wimg_test = cv2.resize(img_ori, (0, 0), fx=scale, fy=scale, interpolation=cv2.INTER_CUBIC)                # 根据缩放因子调整图像大小。img_test_pad, pad = pad_right_down_corner(img_test, param_stride, param_stride)img_test_pad = np.transpose(np.float32(img_test_pad[:, :, :, np.newaxis]), (3, 2, 0, 1)) / 256 - 0.5feed = Variable(torch.from_numpy(img_test_pad))           # 将numpy数组转换为torch张量,并封装为Variableoutput1, output2 = model(feed)                            # 将输入数据传入模型,得到输出print(output1.shape, output2.shape)# loading model
state_dict = torch.load('./models/coco_pose_iter_440000.pth.tar')['state_dict']           # 加载预训练模型# create a model instance
model_fp32 = get_pose_model()                        # 创建一个新的模型实例
model_fp32.load_state_dict(state_dict)               # 将预训练模型的参数加载到新的模型实例中。
model_fp32.float()# model must be set to eval mode for static quantization logic to work
model_fp32.eval()# attach a global qconfig, which contains information about what kind
# of observers to attach. Use 'fbgemm' for server inference and
# 'qnnpack' for mobile inference. Other quantization configurations such
# as selecting symmetric or assymetric quantization and MinMax or L2Norm
# calibration techniques can be specified here.
model_fp32.qconfig = torch.quantization.get_default_qconfig('fbgemm')             # 设置模型的量化配置# Prepare the model for static quantization. This inserts observers in
# the model that will observe activation tensors during calibration.
model_fp32_prepared = torch.quantization.prepare(model_fp32)                      # 准备模型进行静态量化。# calibrate the prepared model to determine quantization parameters for activations
# in a real world setting, the calibration would be done with a representative dataset
evaluate(model_fp32_prepared)                                                     # 使用定义的evaluate函数对准备好的模型进行评估# 在Pytorch进行静态量化是,通常需要一个代表性的数据集来确定激活的量化参数,这个过程通常称为校准。上面的evaluate(model_fo32_prepared)就是校准过程。
# evaluate函数会对验证集中的每一张图片进行预测,这个过程会触发模型中的观察器(observer)来收集激活的统计信息,这些信息会被用来确定量化参数。# Convert the observed model to a quantized model. This does several things:
# quantizes the weights, computes and stores the scale and bias value to be
# used with each activation tensor, and replaces key operators with quantized
# implementations.
model_int8 = torch.quantization.convert(model_fp32_prepared)                       # 将观察到的模型转换为量化模型   # convert函数使用收集到的 统计信息来确定激活的量化参数,并将模型转为量化模型。
print("model int8", model_int8)
# save model
torch.save(model_int8.state_dict(),"./openpose_vgg_quant.pth")                      # 保存量化后的模型

3.3 量化转换结果

运行上面脚本后,会在根目录下得到一个openpose_vgg_quant.pth量化后的模型,如下:

在这里插入图片描述

3.4 量化前后模型大小对比

模型从量化前的199M缩减到量化后的50M,模型大小缩减为原来的四分之一。

在这里插入图片描述

四、量化模型推理

在我提供的源码包中,推理脚本为量化模型推理脚本为evaluate_model.py文件。将加载前一步转换得到的int8模型进行推理。

4.1 参数修改

在这里插入图片描述
在这里插入图片描述

4.2 代码

加载In8模型的代码为:

# Load int8 model
# 加载int8模型不能和之前加载float32模型一样,需要将模型通过prepare() , convert()操作转成量化模型,然后load_state_dict加载进模型。
state_dict = torch.load('./openpose_vgg_quant.pth')
model_fp32 = get_pose_model()                                                             # 创建一个新的模型实例。
model_fp32.qconfig = torch.quantization.get_default_qconfig('fbgemm')                     # 设置模型的量化配置。这里使用的是fbgemm,它是Facebook为服务器端优化的8位整数量化库。
model_fp32_prepared = torch.quantization.prepare(model_fp32)                              # 准备模型进行静态量化。这个步骤会插入观察器到模型中,用于收集需要量化的张量的统计信息。
model_int8 = torch.quantization.convert(model_fp32_prepared)                              # 将准备好的模型转换为量化模型。这个步骤会使用收集到的统计信息来确定量化参数,并将模型中的浮点运算替换为量化运算。
model_int8.load_state_dict(state_dict)                                                    # 将加载的状态字典加载到量化模型中。这个步骤会将保存的参数值赋给模型。
model = model_int8                                                                        # 将量化模型赋值给model
model.eval()start_time = time.time()
# Predict pictures
test_image = './TestImages/test1.jpg'
img_ori = cv2.imread(test_image) # B,G,R ordermultiplier = [scale * box_size / img_ori.shape[0] for scale in scale_search]heatmap_avg = torch.zeros((len(multiplier), 19, img_ori.shape[0], img_ori.shape[1]))
paf_avg = torch.zeros((len(multiplier), 38, img_ori.shape[0], img_ori.shape[1]))

4.3 推理结果

运行上面脚本后,会输出如下结果,并将输出结果自动保存到根目录下的ResultImages文件夹中:

在这里插入图片描述

在这里插入图片描述
在这里插入图片描述

4.4推理时间

此模型的推理时间包括前处理和后处理,测试环境为:Nvidia GeForce RTX 3050。

量化前的推理时间为5.7s,量化后的推理时间为3.4s。

在这里插入图片描述

五、总结

以上就是Pytorch模型训练后静态量化并加载int8量化模型推理的详细过程。

总结不易,多多支持,谢谢!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/499668.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

什么是Sectigo?如何优惠申请?

Sectigo,全球领先的SSL/TLS证书提供商,以其卓越的安全性能和广泛的认可度赢得了业界的一致好评。我们的证书不仅能加密您的网站通信,确保敏感信息传输过程中的绝对安全,还能显著提升您的网站信誉,让访客一眼就能识别出…

价格战打响!阿里云服务器和腾讯云服务器价格对比

2024年阿里云服务器和腾讯云服务器价格战已经打响,阿里云服务器优惠61元一年起,腾讯云服务器62元一年,2核2G3M、2核4G、4核8G、8核16G、16核32G、16核64G等配置价格对比,阿腾云atengyun.com整理阿里云和腾讯云服务器详细配置价格表…

什么是VR紧急情况模拟|消防应急虚拟展馆|VR游戏体验馆加盟

VR紧急情况模拟是利用虚拟现实(Virtual Reality,简称VR)技术来模拟各种紧急情况和应急场景的训练和演练。通过VR技术,用户可以身临其境地体验各种紧急情况,如火灾、地震、交通事故等,以及应对这些紧急情况的…

第三百七十四回

文章目录 1. 概念介绍2. 实现方法2.1 基本用法2.2 特殊用法 3. 示例代码4. 内容总结 我们在上一章回中介绍了"分享三个使用TextField的细节"相关的内容,本章回中将介绍如何让Text组件中的文字自动换行.闲话休提,让我们一起Talk Flutter吧。 1.…

leetcode刷题(剑指offer) 46.全排列

46.全排列 给定一个不含重复数字的数组 nums ,返回其 所有可能的全排列 。你可以 按任意顺序 返回答案。 示例 1: 输入:nums [1,2,3] 输出:[[1,2,3],[1,3,2],[2,1,3],[2,3,1],[3,1,2],[3,2,1]]示例 2: 输入&#…

STM32F103学习笔记(七) PWR电源管理(原理篇)

目录 1. PWR电源管理简介 2. STM32F103的PWR模块概述 2.1 PWR模块的基本工作原理 2.2 电源管理的功能和特点 3. PWR模块的常见应用场景 4. 常见问题与解决方案 1. PWR电源管理简介 PWR(Power)模块是STM32F103系列微控制器中的一个重要组成部分&…

C语言 变量

变量其实只不过是程序可操作的存储区的名称。C 中每个变量都有特定的类型,类型决定了变量存储的大小和布局,该范围内的值都可以存储在内存中,运算符可应用于变量上。 变量的名称可以由字母、数字和下划线字符组成。它必须以字母或下划线开头…

导出数据库表结构到文档中

导出效果: 完整代码: Controller层: import io.swagger.annotations.Api; import io.swagger.annotations.ApiOperation; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.web.bind.annotatio…

Android Stdio Execution failed for task ‘:app:compileDebugKotlin‘ 报错解决

具体报错信息如下: compileDebugJavaWithJavac task (current target is 1.8) and compileDebugKotlin task (current target is 17)jvm target compatibility should be set to the same Java version.很显然,这是一个版本冲突问题,compile…

云上攻防-云服务篇弹性计算服务器云数据库实例元数据控制角色AK控制台接管

知识点: 1、云服务-弹性计算服务器-元数据&SSRF&AK 2、云服务-云数据库-外部连接&权限提升 章节点: 云场景攻防:公有云,私有云,混合云,虚拟化集群,云桌面等 云厂商攻防:阿里云&am…

Tomcat服务部署

1、安装jdk、设置环境变量并测试 第一步:安装jdk 在部署 Tomcat 之前必须安装好 jdk,因为 jdk 是 Tomcat 运行的必要环境。 1. #关闭防火墙 systemctl stop firewalld systemctl disable firewalld setenforce 02. #将安装 Tomcat 所需软件包传到/opt…

90%电商APP已沦落至无人下载,用户主观意愿——是真正实用性价值!

90%电商APP已沦落至无人下载,用户主观意愿——是真正实用性价值! 文丨微三云营销总监胡佳东,点击上方“关注”,为你分享市场商业模式电商干货。 - 引言:在互联网发展的大时代下,似乎每个月都有新的APP出现…

Linux Shell脚本练习(一)

一、 Linux下执行Shell脚本的方式: 1、用shell程序执行脚本: a、根据你的shell脚本的类型,选择shell程序,常用的有sh,bash,tcsh等 b、程序的第一行#!/bin/bash里面指明了shell类型的,比如#!/…

Programming Abstractions in C阅读笔记:p306-p307

《Programming Abstractions in C》学习第75天,p306-p307总结,总计2页。 一、技术总结 1.Quicksort algorithm(快速排序) 由法国计算机科学家C.A.R(Charles Antony Richard) Hoare(东尼.霍尔)在1959年开发(develop), 1961年发表…

Windows下使用C++调用海康威视SDK获取实时视频流进行检测

目录 准备海康威视的SDK官网下载下载后解压 Vs 2022创建项目创建32位的环境 将相关文件复制到工程目录下工程配置海康威视SDK配置包含目录配置库目录将dll文件添加到环境中在附加依赖项添加如下内容 工程配置OpenCV配置工程配置包含目录配置库目录 测试 准备海康威视的SDK 官网…

位段 详解

目录 位段的声明位段的内存分配位段的跨平台问题 位段的声明 位段的声明和结构是类似的,有两个不同: 位段的成员必须是 int、unsigned int 或signed int位段的成员名后边有一个冒号和一个数字 例如,A是一个位段类型: struct A…

程序员的金三银四求职宝典

目录 简介: 1.准备简历: 2.强调技术能力: 3.建立个人品牌: 4.提前准备面试: 5.关注招聘信息渠道: 6.提前与内推: 7.心态调整: 结论: 简介: 金三银四是…

老卫带你学---leetcode刷题(130. 被围绕的区域)

130. 被围绕的区域 问题 给你一个 m x n 的矩阵 board ,由若干字符 ‘X’ 和 ‘O’ ,找到所有被 ‘X’ 围绕的区域,并将这些区域里所有的 ‘O’ 用 ‘X’ 填充。 示例 1: 输入:board [[“X”,“X”,“X”,“X”]…

基于相位的运动放大:如何检测和放大难以察觉的运动(02/2)

目录 一、说明二、算法三、准备处理四、高斯核五、带通滤波器六、复杂的可操纵金字塔七、最终预处理步骤八、执行处理九、金字塔的倒塌十、可视化结果十一、结论 一、说明 日常物体会产生人眼无法察觉的微妙运动。在视频中,这些运动的幅度小于一个像素,…

2月28日做题总结(C/C++真题)

今天是2月28日,做题第三天。道阻且长,行则将至;行而不辍,则未来可期! 第一题 static char a[2]{1,2,3};说法是否正确? A---正确 B---错误 正确答案:B 解析:数组定义时&#xf…