pytorch量化库使用(2)

FX Graph Mode量化模式

训练后量化有多种量化类型(仅权重、动态和静态),配置通过qconfig_mapping ( prepare_fx函数的参数)完成。

FXPTQ API 示例:

import torch
from torch.ao.quantization import (get_default_qconfig_mapping,get_default_qat_qconfig_mapping,QConfigMapping,
)
import torch.ao.quantization.quantize_fx as quantize_fx
import copymodel_fp = UserModel()#
# post training dynamic/weight_only quantization
## we need to deepcopy if we still want to keep model_fp unchanged after quantization since quantization apis change the input model
model_to_quantize = copy.deepcopy(model_fp)
model_to_quantize.eval()
qconfig_mapping = QConfigMapping().set_global(torch.ao.quantization.default_dynamic_qconfig)
# a tuple of one or more example inputs are needed to trace the model
example_inputs = (input_fp32)
# prepare
model_prepared = quantize_fx.prepare_fx(model_to_quantize, qconfig_mapping, example_inputs)
# no calibration needed when we only have dynamic/weight_only quantization
# quantize
model_quantized = quantize_fx.convert_fx(model_prepared)#
# post training static quantization
#model_to_quantize = copy.deepcopy(model_fp)
qconfig_mapping = get_default_qconfig_mapping("qnnpack")
model_to_quantize.eval()
# prepare
model_prepared = quantize_fx.prepare_fx(model_to_quantize, qconfig_mapping, example_inputs)
# calibrate (not shown)
# quantize
model_quantized = quantize_fx.convert_fx(model_prepared)#
# quantization aware training for static quantization
#model_to_quantize = copy.deepcopy(model_fp)
qconfig_mapping = get_default_qat_qconfig_mapping("qnnpack")
model_to_quantize.train()
# prepare
model_prepared = quantize_fx.prepare_qat_fx(model_to_quantize, qconfig_mapping, example_inputs)
# training loop (not shown)
# quantize
model_quantized = quantize_fx.convert_fx(model_prepared)#
# fusion
#
model_to_quantize = copy.deepcopy(model_fp)
model_fused = quantize_fx.fuse_fx(model_to_quantize)

量化堆栈

量化是将浮点模型转换为量化模型的过程。因此,在高层次上,量化堆栈可以分为两部分:1)。量化模型的构建块或抽象 2)。将浮点模型转换为量化模型的量化流程的构建块或抽象

量化模型

量化张量

为了在 PyTorch 中进行量化,我们需要能够用张量表示量化数据。量化张量允许存储量化数据(表示为 int8/uint8/int32)以及量化参数(如比例和 Zero_point)。除了允许以量化格式序列化数据之外,量化张量还允许许多有用的操作,使量化算术变得容易。

PyTorch 支持每张量和每通道的对称和非对称量化。每个张量意味着张量内的所有值都使用相同的量化参数以相同的方式量化。每个通道意味着对于每个维度(通常是张量的通道维度),张量中的值使用不同的量化参数进行量化。这可以减少将张量转换为量化值时的错误,因为异常值只会影响其所在的通道,而不是整个张量。

映射是通过使用转换浮点张量来执行的

 

 

 

请注意,我们确保浮点中的零在量化后表示没有错误,从而确保诸如填充之类的操作不会导致额外的量化误差。

以下是量化张量的几个关键属性:

  • QScheme (torch.qscheme):一个枚举,指定我们量化张量的方式

    • torch.per_tensor_affine

    • torch.per_tensor_对称

    • torch.per_channel_affine

    • torch.per_channel_symmetry

  • dtype (torch.dtype):量化张量的数据类型

    • 火炬.quint8

    • 火炬.qint8

    • 火炬.qint32

    • 火炬.float16

  • 量化参数(根据 QScheme 的不同而变化):所选量化方式的参数

    • torch.per_tensor_affine 的量化参数为

      • 刻度(浮动)

      • 零点(整数)

    • torch.per_channel_affine 的量化参数为

      • per_channel_scales(浮点数列表)

      • per_channel_zero_points(整数列表)

      • 轴(整数)

量化和反量化

模型的输入和输出都是浮点张量,但量化模型中的激活是量化的,因此我们需要运算符在浮点和量化张量之间进行转换。

  • 量化(浮点 -> 量化)

    • torch.quantize_per_tensor(x, 尺度, 零点, dtype)

    • torch.quantize_per_channel(x, 尺度, Zero_points, 轴, dtype)

    • torch.quantize_per_tensor_dynamic(x,dtype,reduce_range)

    • 到(火炬.float16)

  • 反量化(量化 -> 浮点)

    • quantized_tensor.dequantize() - 在 torch.float16 张量上调用 dequantize 会将张量转换回 torch.float

    • 火炬.反量化(x)

量化运算符/模块

  • 量化算子是以量化Tensor为输入,输出量化Tensor的算子。

  • 量化模块是执行量化操作的 PyTorch 模块。它们通常是为线性和卷积等加权运算定义的。

量化引擎

当执行量化模型时,qengine (torch.backends.quantized.engine) 指定使用哪个后端来执行。重要的是要确保qengine在量化激活和权重的取值范围方面与量化模型兼容。

量化流程

观察者和 FakeQuantize

  • 观察者是 PyTorch 模块,用于:

    • 收集张量统计信息,例如通过观察者的张量的最小值和最大值

    • 并根据收集的张量统计数据计算量化参数

  • FakeQuantize 是 PyTorch 模块,用于:

    • 模拟网络中张量的量化(执行量化/反量化)

    • 它可以根据观察者收集的统计数据计算量化参数,也可以学习量化参数

查询配置

  • QConfig 是 Observer 或 FakeQuantize Module 类的命名元组,可以使用 qscheme、dtype 等进行配置。它用于配置应如何观察操作员

    • 算子/模块的量化配置

      • 不同类型的 Observer/FakeQuantize

      • 数据类型

      • q方案

      • quant_min/quant_max:可用于模拟较低精度的张量

    • 目前支持激活和权重的配置

    • 我们根据为给定运算符或模块配置的 qconfig 插入输入/权重/输出观察器

一般量化流程

一般来说,流程如下

  • 准备

    • 根据用户指定的 qconfig 插入 Observer/FakeQuantize 模块

  • 校准/训练(取决于训练后量化或量化感知训练)

    • 允许观察者收集统计数据或 FakeQuantize 模块来学习量化参数

  • 转变

    • 将校准/训练模型转换为量化模型

量化有不同的模式,它们可以分为两种方式:

就我们应用量化流程的位置而言,我们有:

  1. Post Training Quantization(训练后应用量化,量化参数根据样本校准数据计算)

  2. 量化感知训练(在训练过程中模拟量化,以便使用训练数据与模型一起学习量化参数)

就我们如何量化运算符而言,我们可以:

  • 仅权重量化(仅权重静态量化)

  • 动态量化(权重静态量化,激活动态量化)

  • 静态量化(权重和激活都是静态量化的)

我们可以在同一量化流程中混合不同的量化运算符方式。例如,我们可以进行具有静态和动态量化运算符的训练后量化。

量化支持矩阵

 

量化定制

虽然提供了观察者根据观察到的张量数据选择比例因子和偏差的默认实现,但开发人员可以提供自己的量化函数。量化可以选择性地应用于模型的不同部分,或者针对模型的不同部分进行不同的配置。

我们还为conv1d()conv2d()、 conv3d()Linear()的每通道量化提供支持。

量化工作流程通过在模型的模块层次结构中添加(例如,将观察者添加为 .observer子模块)或替换(例如,转换nn.Conv2d为 nn.quantized.Conv2d)子模块来工作。这意味着该模型nn.Module在整个过程中保持基于常规的实例,因此可以与 PyTorch API 的其余部分一起使用。

量化自定义模块 API

Eager 模式和 FX 图形模式量化 API 都为用户提供了一个钩子,以指定以自定义方式量化的模块,并使用用户定义的逻辑进行观察和量化。用户需要指定:

  1. 源 fp32 模块的 Python 类型(模型中存在)

  2. 被观察模块的Python类型(由用户提供)。该模块需要定义一个from_float函数,该函数定义如何从原始 fp32 模块创建观察到的模块。

  3. 量化模块的Python类型(由用户提供)。该模块需要定义一个from_observed函数,该函数定义如何从观察到的模块创建量化模块。

  4. 描述上述 (1)、(2)、(3) 的配置,传递给量化 API。

然后框架将执行以下操作:

  1. 在准备模块交换期间,它将使用 (2) 中类的from_float函数将 (1) 中指定类型的每个模块转换为 (2) 中指定的类型。

  2. 在转换模块交换期间,它将使用 (3) 中类的from_observed函数将 (2) 中指定类型的每个模块转换为(3) 中指定的类型。

目前,要求ObservedCustomModule将具有单个 Tensor 输出,并且框架(而不是用户)将在该输出上添加观察者。观察者将作为自定义模块实例的属性存储在activation_post_process键下。未来可能会放宽这些限制。

自定义 API 示例:

import torch
import torch.ao.nn.quantized as nnq
from torch.ao.quantization import QConfigMapping
import torch.ao.quantization.quantize_fx# original fp32 module to replace
class CustomModule(torch.nn.Module):def __init__(self):super().__init__()self.linear = torch.nn.Linear(3, 3)def forward(self, x):return self.linear(x)# custom observed module, provided by user
class ObservedCustomModule(torch.nn.Module):def __init__(self, linear):super().__init__()self.linear = lineardef forward(self, x):return self.linear(x)@classmethoddef from_float(cls, float_module):assert hasattr(float_module, 'qconfig')observed = cls(float_module.linear)observed.qconfig = float_module.qconfigreturn observed# custom quantized module, provided by user
class StaticQuantCustomModule(torch.nn.Module):def __init__(self, linear):super().__init__()self.linear = lineardef forward(self, x):return self.linear(x)@classmethoddef from_observed(cls, observed_module):assert hasattr(observed_module, 'qconfig')assert hasattr(observed_module, 'activation_post_process')observed_module.linear.activation_post_process = \observed_module.activation_post_processquantized = cls(nnq.Linear.from_float(observed_module.linear))return quantized#
# example API call (Eager mode quantization)
#m = torch.nn.Sequential(CustomModule()).eval()
prepare_custom_config_dict = {"float_to_observed_custom_module_class": {CustomModule: ObservedCustomModule}
}
convert_custom_config_dict = {"observed_to_quantized_custom_module_class": {ObservedCustomModule: StaticQuantCustomModule}
}
m.qconfig = torch.ao.quantization.default_qconfig
mp = torch.ao.quantization.prepare(m, prepare_custom_config_dict=prepare_custom_config_dict)
# calibration (not shown)
mq = torch.ao.quantization.convert(mp, convert_custom_config_dict=convert_custom_config_dict)
#
# example API call (FX graph mode quantization)
#
m = torch.nn.Sequential(CustomModule()).eval()
qconfig_mapping = QConfigMapping().set_global(torch.ao.quantization.default_qconfig)
prepare_custom_config_dict = {"float_to_observed_custom_module_class": {"static": {CustomModule: ObservedCustomModule,}}
}
convert_custom_config_dict = {"observed_to_quantized_custom_module_class": {"static": {ObservedCustomModule: StaticQuantCustomModule,}}
}
mp = torch.ao.quantization.quantize_fx.prepare_fx(m, qconfig_mapping, torch.randn(3,3), prepare_custom_config=prepare_custom_config_dict)
# calibration (not shown)
mq = torch.ao.quantization.quantize_fx.convert_fx(mp, convert_custom_config=convert_custom_config_dict)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/631.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【MySQL数据库 | 第十九篇】SQL性能分析工具

目录 前言: SQL执行频率: 慢查询日志: profile: profile各个指令: 总结: 前言: 本篇我们将为大家讲解SQL性能的分析工具,而只有熟练的掌握了性能分析的工具,才可以更…

C语言学习(二十九)---内存操作函数

在上一节内容中,我们学习了有关字符串操作的函数,其中分为了限制长度和不限制长度两种方式,虽然上节内容已经在很大程度上有助于程序的实现,但是其有一个致命的缺陷,聪明的你一定已经猜到了吧,对的&#xf…

一面、二面、三面有什么区别?

很多公司面试都分一面、二面、三面甚至更多,大家可能会好奇,为什么要面这么多面,每一面又有啥区别呢? 首先我来回答下为什么要这么多面,最核心的是最后3点: 如果光是一个人面,担心会看走眼&…

Python强类型编程

Python是一门强类型的动态类型语言,具体如下特性: 可以动态构造脚本执行、修改函数、对象类型结构、变量类型但不允许类型不匹配的操作 第一个例子体现动态性:用字符串直接执行代码,动态构建了一个函数并执行,甚至给…

IDEA创建一个Servlet项目(tomcat10)

一、创建maven项目 org.apache.maven.archetypes:maven-archetype-webapp 二、增加Servlet依赖 tomcat9及以前依赖 <!--加入servlet依赖&#xff08;servlet的jar&#xff09;--><dependency><groupId>javax.servlet</groupId><artifactId>ja…

设计模式-抽象工厂模式

抽象工厂模式 1、抽象工厂模式简介2、具体实现 1、抽象工厂模式简介 抽象工厂模式(Abstract Factory Pattern)在工厂模式尚添加了一个创建不同工厂的抽象接口(抽象类或接口实现)&#xff0c;该接口可叫做超级工厂。在使用过程中&#xff0c;我们首先通过抽象接口创建不同的工厂…

NoSQL之 Redis配置与优化

文章目录 一、关系数据库与非关系型数据库关系型数据库&#xff1a;非关系型数据库关系型数据库和非关系型数据库区别&#xff1a;非关系型数据库产生背景 二、Redis简介Redis 具有以下几个优点&#xff1a;使用场景&#xff1a;哪些数据适合放入缓存中Redis为什么这么快 三、R…

4、数据库操作语句:聚合函数

目录 1、定义 2、常用的聚合函数 1&#xff09;Avg/sum&#xff1a;只适用于数值类型的字段&#xff08;或变量&#xff09;。 2&#xff09;Max/min:适用于数值类型、字符串类型、日期时间类型的字段&#xff08;或变量&#xff09; 3&#xff09;Count&#xff1a; ①作…

SpringBoot自定义starter之接口日志输出

文章目录 前言文章主体1 项目全部源码2 项目结构介绍3 starter 的使用3.1 配置文件 application,yml的内容3.2 启动类3.3 控制器类 4 测试结果 结语 前言 本文灵感来源是一道面试题。 要求做一个可以复用的接口日志输出工具&#xff0c;在使用时引入依赖&#xff0c;即可使用。…

OpenCV学习笔记 | ROI区域选择提取 | Python

摘要 ROI区域是指图像中我们感兴趣的特定区域&#xff0c;OpenCV提供了一些函数来选择和提取ROI区域&#xff0c;我们可以使用OpenCV的鼠标事件绑定函数&#xff0c;然后通过鼠标操作在图像上绘制一个矩形框&#xff0c;该矩形框即为ROI区域。本文将介绍代码的实现以及四个主要…

Vue中如何进行游戏开发与游戏引擎集成?

Vue中如何进行游戏开发与游戏引擎集成&#xff1f; Vue.js是一款流行的JavaScript框架&#xff0c;它的MVVM模式和组件化开发思想非常适合构建Web应用程序。但是&#xff0c;如果我们想要开发Web游戏&#xff0c;Vue.js并不是最合适的选择。在本文中&#xff0c;我们将介绍如何…

edge自带断网游戏

在没有网络时你会不会很无聊&#xff1f;博主告诉你一个edge浏览器自带的断网小游戏&#xff0c;让你在断网时也能玩游戏&#xff01; 网址&#xff1a; 打开edge://surf这个断网游戏网站即可游玩&#xff1a; 作弊码既隐藏模式&#xff1a; 输入microsoft&#xff08;意思就…