深度学习模型部署(番外3)神经网络不同层的量化方法

神经网络层量化

批归一化层Batch Normalization(BN层)

关于归一化的原理可以看之前的这篇blog:BatchNorm原理与应用
批归一化在推理过程中会被融合到上一层或者下一层中,这种处理方式被称为批归一化折叠。这样可以减少量化,也可以减少属于的运算和读写,提高推理速度。
例如对于全连接层 y = W X + b y = WX+b y=WX+b后的批归一化层:
y = B a t c h N o r m ( W X + b ) = B a t c h N o r m ( W X ) = γ ( W x − μ σ 2 + ϵ ) + β = γ W x σ 2 + ϵ + ( β − γ μ σ 2 + ϵ ) = W ~ X + b ~ \begin{align} y& = BatchNorm(WX+b) \\ & = BatchNorm(WX) \\ & = \gamma(\frac{Wx-\mu }{\sqrt[]{\sigma^2+\epsilon } } )+\beta \\ & = \frac{\gamma Wx}{\sqrt[]{\sigma^2+\epsilon } } +(\beta-\frac{\gamma \mu}{\sqrt[]{\sigma^2+\epsilon } }) \\ & = \widetilde{W}X+\widetilde{b} \end{align} y=BatchNorm(WX+b)=BatchNorm(WX)=γ(σ2+ϵ Wxμ)+β=σ2+ϵ γWx+(βσ2+ϵ γμ)=W X+b
从而就将BN层融入到了全连接层的参数中
W ~ k , : = γ k W k , : σ k 2 + ϵ , b ~ k = β k − γ k μ k σ k 2 + ϵ . \begin{aligned} \widetilde{\mathbf{W}}_{k,:}& =\frac{\boldsymbol{\gamma}_k\mathbf{W}_{k,:}}{\sqrt{\mathbf{\sigma}_k^2+\epsilon}}, \\ \widetilde{\mathbf{b}}_{k}& =\boldsymbol{\beta}_k-\frac{\boldsymbol{\gamma}_k\boldsymbol{\mu}_k}{\sqrt{\mathbf{\sigma}_k^2+\epsilon}}. \end{aligned} W k,:b k=σk2+ϵ γkWk,:,=βkσk2+ϵ γkμk.

激活函数层

一般线性层之后都会跟一个激活函数层,从底层的角度考虑,如果在线性层计算完后将数据从寄存器放回内存,再取出来进行非线性层计算,这种方法需要进行读取,非常浪费时间,那么是否能考虑把激活函数层也进行量化,让其可以和量化过的线性层同用定点运算,这样就可以不用放回再取出了,可以直接接着运行。激活函数的种类有很多,像ReLU这种比较简单的激活函数,很容易量化,但是像sigmoid这种激活函数就很难量化,需要复杂的支持。如果不能量化,我们需要在激活函数前后各加一个量化器,这样对精度的影响非常大,很多新的激活函数带来的精度提升在量化后会降低很多。
在这里插入图片描述

池化层

不同的池化层,量化方法也不同。
对于最大池化,输出就来自输入中的最大值,所以对于activation不需要进行量化。
但是对于平均池化,计算出的平均值不一定是一个整数,所以要对activation进行量化,但是输入和输出的范围是差不多的,所以可以公用一个量化器。

实现

pytorch对于量化提供了三种方案:

  • Eager Mode quantization:自己选择量化,自己选择融合
  • FX Graph Mode Quantization:提供了自动量化,自动评估,但是跟nn.Module的兼容性需要用户自己负责,比第一种方案自动化程度高一些
  • PyTorch 2 Export Quantization:pytorch2.1新引入的量化方案,自动化程度更高,也是pytorch官方推荐新手用的方案。
    相关介绍的链接:pytorch官方文档
    一个简单训练后静态量化的demo:
import torch# define a floating point model where some layers could be statically quantized
class M(torch.nn.Module):def __init__(self):super().__init__()# QuantStub 将float tensor转化为量化表示self.quant = torch.ao.quantization.QuantStub()self.conv = torch.nn.Conv2d(1, 1, 1)self.relu = torch.nn.ReLU()# DeQuantStub 将量化表示转化为float tensorself.dequant = torch.ao.quantization.DeQuantStub()def forward(self, x):# 自己手动指定量化模型中的量化点x = self.quant(x)x = self.conv(x)x = self.relu(x)# 自己决定何时将量化表示转化为float tensorx = self.dequant(x)return xmodel_fp32 = M()# 模型必须设置为eval模式,以便在量化过程中,模型的行为和量化后的行为一致
model_fp32.eval()# 模型量化配置,里面包括了默认的量化配置,可以通过`torch.ao.quantization.get_default_qconfig('x86')`获取
# 对于PC端的量化,推荐使用`x86`,对于移动端的量化,推荐使用`qnnpack`
# 其他的量化配置,比如选择对称量化还是非对称量化,以及MinMax还是L2Norm校准技术,都可以在这里指定
model_fp32.qconfig = torch.ao.quantization.get_default_qconfig('x86')# 手动进行融合,将一些常见的操作融合在一起,以便后续的量化
# 常见的融合包括`conv + relu`和`conv + batchnorm + relu`
model_fp32_fused = torch.ao.quantization.fuse_modules(model_fp32, [['conv', 'relu']])# 准备模型,插入观察者,观察激活张量,观察者用于校准量化参数
model_fp32_prepared = torch.ao.quantization.prepare(model_fp32_fused)# 进行校准,这里输入需要使用代表性的数据,以便观察者能够观察到激活张量的分布,从而计算出量化参数
input_fp32 = torch.randn(4, 1, 4, 4)
model_fp32_prepared(input_fp32)# 将模型转化为量化模型,这里会将权重量化,计算并存储每个激活张量的scale和bias值,以及用量化实现替换关键操作
model_int8 = torch.ao.quantization.convert(model_fp32_prepared)# 运行量化模型,这里的计算都是在int8上进行的
res = model_int8(input_fp32)

得益于pytorch,onnxruntime,tensorrt等工具,模型量化以及部署已经变得非常简单,但是有些知识我们还是要学,就如闫令琪老师说的:工具的发展可以简化我们工作流程,但是不能简化我们学习的知识,API是API,知识是知识。
觉得有帮助,请点赞+收藏,thanks

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/513726.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

将ppt里的视频导出来

将ppt的后缀从pptx改为zip 找到【media】里面有存放图片和音频以及视频,看文件名后缀可以找到,mp4的即为视频,直接复制粘贴到桌面即可。 关闭压缩软件把ppt后缀改回,不影响ppt正常使用。

Android Studio开发(一) 构建项目

1、项目创建测试 1.1 前言 Android Studio 是由 Google 推出的官方集成开发环境(IDE),专门用于开发 Android 应用程序。 基于 IntelliJ IDEA: Android Studio 是基于 JetBrains 的 IntelliJ IDEA 开发的,提供了丰富的功能和插件…

大数据开发-Hadoop分布式集群搭建

大数据开发-Hadoop分布式集群搭建 文章目录 大数据开发-Hadoop分布式集群搭建环境准备Hadoop配置启动Hadoop集群Hadoop客户端节点Hadoop客户端节点 环境准备 JDK1.8Hadoop3.X三台服务器 主节点需要启动namenode、secondary namenode、resource manager三个进程 从节点需要启动…

【Bugs】java: 错误: 不支持发行版本 xx

文章目录 报错场景:报错原因:解决方法: 报错场景: IDEA运行Java项目报错,点击运行之后,IDEA在编译代码的时候就出现报错: 报错类型一:java: 错误: 不支持发行版本 21报错类型二&am…

华为数通方向HCIP-DataCom H12-821题库(多选题:81-100)

第81题 在如图所示的网络中,所有的交换机运行RSTP协议,假如SWB的E1接口故障后, RSTP的处理过程时: A、SWB删除MAC地址表中以E1为目的的端口的端口项 B、在所有非边缘转发端口.上向外发送拓扑改变通知( Topology Change Notication ),通知其他交换机网络中出现了拓扑改变 C、重…

Linux第70步_新字符设备驱动的一般模板

1、了解“申请和释放设备号函数” int alloc_chrdev_region(dev_t *dev, unsigned baseminor, unsigned count, const char *name) //注册字符设备驱动 //dev:保存申请到的设备号 //baseminor:次设备号的起始地址 //count:要申请的设备数…

低密度奇偶校验码LDPC(九)——QC-LDPC译码器FPGA全并行设计

往期博文 低密度奇偶校验码LDPC(一)——概述_什么是gallager构造-CSDN博客 低密度奇偶校验码LDPC(二)——LDPC编码方法-CSDN博客 低密度奇偶校验码LDPC(三)——QC-LDPC码概述-CSDN博客 低密度奇偶校验码…

【操作系统概念】 第3章:进程

文章目录 0.前言3.1进程概念3.1.1 进程3.1.2 进程状态3.1.3 进程控制块(PCB) 3.2、进程调度3.2.1 调度队列3.2.2 调度程序3.2.3 上下文切换 3.3 进程操作3.3.1 进程创建3.3.2 进程终止 3.4 进程间通信 0.前言 早期的计算机一次只能执行一个程序。这种程序…

安泰ATA-P1005压电叠堆放大器在纳米定位台驱动中的应用

纳米定位台是一种高精度的微纳米级定位设备,主要用于微纳米加工、显微镜下的样品定位、纳米精度的测量和调试等领域。内置高性能压电陶瓷,运动范围可达500μm,具有体积小、无摩擦、响应速度快、高精度位移等特点,ATA-P1005压电叠堆…

stm32学习笔记:I2C通信外设原理(未完)

软件实现和硬件实现 串口通信为异步时序,用软件实现很麻烦,基本上用硬件实现 而I2C协议通信为同步时序,软件实现简单且灵活,硬件实现比较麻烦,故软件比较常用 但I2C硬件实现功能比较大,执行效率高&#xff…

【Proteus仿真】【Arduino单片机】坐姿矫正提醒器设计

文章目录 一、功能简介二、软件设计三、实验现象联系作者 一、功能简介 本项目使用Proteus8仿真Arduino单片机控制器,使用LCD1602液晶显示模块、HC-SR04超声波模块、蜂鸣器、按键、人体红外传感器等。 主要功能: 系统运行后,LCD1602显示超声…

Linux运维:在线/离线安装Telnet客户端和Telnet服务

Linux运维:在线/离线安装Telnet客户端和Telnet服务 前言1.1 在线安装Telnet1.2 离线安装Telnet1.3 Telnet服务有关的命令 💖The Begin💖点点关注,收藏不迷路💖 前言 Telnet是一种用于远程登录到其他计算机的协议&…