【yolov5】onnx的INT8量化engine

GitHub上有大佬写好代码,理论上直接克隆仓库里下来使用

git clone https://github.com/Wulingtian/yolov5_tensorrt_int8_tools.git

然后在yolov5_tensorrt_int8_tools的convert_trt_quant.py 修改如下参数

BATCH_SIZE 模型量化一次输入多少张图片

BATCH 模型量化次数

height width 输入图片宽和高

CALIB_IMG_DIR 训练图片路径,用于量化

onnx_model_path onnx模型路径

engine_model_path 模型保存路径

其中这个batch_size不能超过照片的数量,然后跑这个convert_trt_quant.py

出问题了吧@_@

这是因为tensor的版本更新原因,这个代码的tensorrt版本是7系列的,而目前新的tensorrt版本已经没有了一些属性,所以我们需要对这个大佬写的代码进行一些修改

如何修改呢,其实tensorrt官方给出了一个caffe量化INT8的例子

https://github.com/NVIDIA/TensorRT/tree/master/samples/python/int8_caffe_mnist

如果足够NB是可以根据官方的这个例子修改一下直接实现onnx的INT8量化的

但是奈何我连半桶水都没有,只有一滴水,但是这个例子中的tensorrt版本是新的,于是我尝试将上面那位大佬的代码修改为使用新版的tensorrt

居然成功了??!!

成功量化后的模型大小只有4MB,相比之下的FP16的大小为6MB,FP32的大小为9MB

再看看检测速度,速度和FP16差不太多

但是效果要差上一些了

那肯定不能忘记送上修改的代码,折腾一晚上的结果如下,主要是 util_trt程序

# tensorrt-libimport os
import tensorrt as trt
import pycuda.autoinit
import pycuda.driver as cuda
from calibrator import Calibrator
from torch.autograd import Variable
import torch
import numpy as np
import time
# add verbose
TRT_LOGGER = trt.Logger(trt.Logger.VERBOSE) # ** engine可视化 **# create tensorrt-engine# fixed and dynamic
def get_engine(max_batch_size=1, onnx_file_path="", engine_file_path="",\fp16_mode=False, int8_mode=False, calibration_stream=None, calibration_table_path="", save_engine=False):"""Attempts to load a serialized engine if available, otherwise builds a new TensorRT engine and saves it."""def build_engine(max_batch_size, save_engine):"""Takes an ONNX file and creates a TensorRT engine to run inference with"""with trt.Builder(TRT_LOGGER) as builder, \builder.create_network(1) as network,\trt.OnnxParser(network, TRT_LOGGER) as parser:# parse onnx model fileif not os.path.exists(onnx_file_path):quit('ONNX file {} not found'.format(onnx_file_path))print('Loading ONNX file from path {}...'.format(onnx_file_path))with open(onnx_file_path, 'rb') as model:print('Beginning ONNX file parsing')parser.parse(model.read())assert network.num_layers > 0, 'Failed to parse ONNX model. \Please check if the ONNX model is compatible 'print('Completed parsing of ONNX file')print('Building an engine from file {}; this may take a while...'.format(onnx_file_path))        # build trt enginebuilder.max_batch_size = max_batch_sizeconfig = builder.create_builder_config()config.max_workspace_size = 1 << 20if int8_mode:config.set_flag(trt.BuilderFlag.INT8)assert calibration_stream, 'Error: a calibration_stream should be provided for int8 mode'config.int8_calibrator  = Calibrator(calibration_stream, calibration_table_path)print('Int8 mode enabled')runtime=trt.Runtime(TRT_LOGGER)plan = builder.build_serialized_network(network, config)engine = runtime.deserialize_cuda_engine(plan)if engine is None:print('Failed to create the engine')return None   print("Completed creating the engine")if save_engine:with open(engine_file_path, "wb") as f:f.write(engine.serialize())return engineif os.path.exists(engine_file_path):# If a serialized engine exists, load it instead of building a new one.print("Reading engine from file {}".format(engine_file_path))with open(engine_file_path, "rb") as f, trt.Runtime(TRT_LOGGER) as runtime:return runtime.deserialize_cuda_engine(f.read())else:return build_engine(max_batch_size, save_engine)

唔,convert_trt_quant.py的代码也给一下吧

import numpy as np
import torch
import torch.nn as nn
import util_trt
import glob,os,cv2BATCH_SIZE = 1
BATCH = 79
height = 640
width = 640
CALIB_IMG_DIR = '/content/drive/MyDrive/yolov5/ikunData/images'
onnx_model_path = "runs/train/exp4/weights/FP32.onnx"
def preprocess_v1(image_raw):h, w, c = image_raw.shapeimage = cv2.cvtColor(image_raw, cv2.COLOR_BGR2RGB)# Calculate widht and height and paddingsr_w = width / wr_h = height / hif r_h > r_w:tw = widthth = int(r_w * h)tx1 = tx2 = 0ty1 = int((height - th) / 2)ty2 = height - th - ty1else:tw = int(r_h * w)th = heighttx1 = int((width - tw) / 2)tx2 = width - tw - tx1ty1 = ty2 = 0# Resize the image with long side while maintaining ratioimage = cv2.resize(image, (tw, th))# Pad the short side with (128,128,128)image = cv2.copyMakeBorder(image, ty1, ty2, tx1, tx2, cv2.BORDER_CONSTANT, (128, 128, 128))image = image.astype(np.float32)# Normalize to [0,1]image /= 255.0# HWC to CHW format:image = np.transpose(image, [2, 0, 1])# CHW to NCHW format#image = np.expand_dims(image, axis=0)# Convert the image to row-major order, also known as "C order":#image = np.ascontiguousarray(image)return imagedef preprocess(img):img = cv2.resize(img, (640, 640))img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)img = img.transpose((2, 0, 1)).astype(np.float32)img /= 255.0return imgclass DataLoader:def __init__(self):self.index = 0self.length = BATCHself.batch_size = BATCH_SIZE# self.img_list = [i.strip() for i in open('calib.txt').readlines()]self.img_list = glob.glob(os.path.join(CALIB_IMG_DIR, "*.jpg"))assert len(self.img_list) > self.batch_size * self.length, '{} must contains more than '.format(CALIB_IMG_DIR) + str(self.batch_size * self.length) + ' images to calib'print('found all {} images to calib.'.format(len(self.img_list)))self.calibration_data = np.zeros((self.batch_size,3,height,width), dtype=np.float32)def reset(self):self.index = 0def next_batch(self):if self.index < self.length:for i in range(self.batch_size):assert os.path.exists(self.img_list[i + self.index * self.batch_size]), 'not found!!'img = cv2.imread(self.img_list[i + self.index * self.batch_size])img = preprocess_v1(img)self.calibration_data[i] = imgself.index += 1# example onlyreturn np.ascontiguousarray(self.calibration_data, dtype=np.float32)else:return np.array([])def __len__(self):return self.lengthdef main():# onnx2trtfp16_mode = Falseint8_mode = True print('*** onnx to tensorrt begin ***')# calibrationcalibration_stream = DataLoader()engine_model_path = "runs/train/exp4/weights/int8.engine"calibration_table = 'yolov5_tensorrt_int8_tools/models_save/calibration.cache'# fixed_engine,校准产生校准表engine_fixed = util_trt.get_engine(BATCH_SIZE, onnx_model_path, engine_model_path, fp16_mode=fp16_mode, int8_mode=int8_mode, calibration_stream=calibration_stream, calibration_table_path=calibration_table, save_engine=True)assert engine_fixed, 'Broken engine_fixed'print('*** onnx to tensorrt completed ***\n')if __name__ == '__main__':main()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/169621.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

智能一体化管网水位监测仪怎么样?

城市排水管网是城市正常运行的关键环节&#xff0c;这是地上和地下通道的连接点&#xff0c;一旦出现问题便会影响城市生命线建设的工程进展。在复杂的地下管道内想要了解水位数据&#xff0c;对于政府部门来讲是一个管理难题。如果可以采取智能产品在其中发挥作用&#xff0c;…

threejs(12)-着色器打造烟雾水云效果

一、自己封装水波纹效果 src/main/main01.js import * as THREE from "three";import { OrbitControls } from "three/examples/jsm/controls/OrbitControls"; import gsap from "gsap"; import * as dat from "dat.gui"; import ver…

江西开放大学引领学习新时代:电大搜题助力学子迈向成功

江西开放大学&#xff08;简称江西电大&#xff09;一直以来致力于为学子提供灵活便捷的学习服务。近年来&#xff0c;携手电大搜题微信公众号&#xff0c;江西开放大学以其卓越的教学质量和创新的教学手段&#xff0c;为广大学子开启了一扇通向成功的大门。 作为一家知名的远…

深入了解JVM和垃圾回收算法

1.什么是JVM&#xff1f; JVM是Java虚拟机&#xff08;Java Virtual Machine&#xff09;的缩写&#xff0c;是Java程序运行的核心组件。JVM是一个虚拟的计算机&#xff0c;它提供了一个独立的运行环境&#xff0c;可以在不同的操作系统上运行Java程序。 2.如何判断可回收垃圾…

远程电脑未连接显示器时分辨率太小的问题处理

背景&#xff1a;单位电脑显示器坏了&#xff0c;使用笔记本通过向日葵远程连接&#xff0c;发现分辨率只有800*600并且不能修改&#xff0c;网上找了好久找到了处理方法这里记录一下&#xff0c;主要用到的是一个虚拟显示器软件usbmmidd_v2 1)下载usbmmidd_v2 2&#xff09;…

js案例:打地鼠游戏(打灰太狼)

效果预览图 游戏规则 当灰太狼出现的时候鼠标左键点击灰太狼加10分&#xff0c;小灰灰出现的时候鼠标左键点小灰灰击减10分&#xff0c;不点击不减分不加分。 整体思路 1.把获取背景图片中每个地洞的位置&#xff0c;把所有位置放到一个数组中。 2.封装随机数函数&#xff0c;随…

RSA 2048位算法的主要参数N,E,P,Q,DP,DQ,Qinv,D分别是什么意思 哪个是通常所说的公钥与私钥 -安全行业基础篇5

非对称加密算法RSA 在RSA 2048位算法中&#xff0c;常见的参数N、E、P、Q、DP、DQ、Qinv和D代表以下含义&#xff1a; N&#xff08;Modulus&#xff09;&#xff1a;模数&#xff0c;是两个大素数P和Q的乘积。N的长度决定了RSA算法的安全性。 E&#xff08;Public Exponent&a…

vite基础学习笔记:14.路由跳转(二)携带query参数

说明&#xff1a;自学做的笔记和记录&#xff0c;如有错误请指正 1. 路由跳转&#xff08;携带query参数&#xff09; &#xff08;1&#xff09;第一层路由&#xff08;点击卡片路由跳转至新页面-携带query参数&#xff09; 知识点&#xff1a; query传参对应的是path和qu…

移动医疗科技:开发互联网医院系统源码

在这个数字化时代&#xff0c;互联网医院系统成为了提供便捷、高效医疗服务的重要手段。本文将介绍利用移动医疗科技开发互联网医院系统的源码&#xff0c;为医疗行业的数字化转型提供有力支持。 智慧医疗、互联网医院这一类平台可以通过线上的形式进行部分医疗服务&#xff…

数据结构与算法C语言版学习笔记(3)-线性表的链式结构:链表

提示&#xff1a;文章写完后&#xff0c;目录可以自动生成&#xff0c;如何生成可参考右边的帮助文档 文章目录 前言&#xff1a;回顾顺序表的优缺点&#xff1a;为什么要引入链式结构的线性表&#xff1f; 一、什么是链表&#xff1f;二、链表的分类①为什么要设置头节点&…

【计网 传输层概述】 中科大郑烇老师笔记 (十)

目录 0 引言1 概述1.1 传输服务和协议1.2 传输层 vs 网络层1.3 Internet传输层协议 TCP和UDP 2 多路复用、解复用2.1 UDP的多路复用2.2 TCP的多路复用 3 UDP3.1 概述3.2 UDP报文段3.3 拓展&#xff1a;TCP报文段 &#x1f64b;‍♂️ 作者&#xff1a;海码007&#x1f4dc; 专栏…

Apipost-Helper:IDEA中的类postman工具

今天给大家推荐一款IDEA插件&#xff1a;Apipost-Helper-2.0&#xff0c;写完代码IDEA内一键生成API文档&#xff0c;无需安装、打开任何其他软件&#xff1b;写完代码IDEA内一键调试&#xff0c;无需安装、打开任何其他软件&#xff1b;生成API目录树&#xff0c;双击即可快速…