本文提纲:
- fx 和 eager 两种量化训练方式介绍
- 量化训练的流程介绍:以 mmdet 的 yolov3 为例
- 常用的精度调优 debug 工具介绍
- 案例分析:模型精度调优经验分享
第一部分:fx 和 eager 两种量化训练方式介绍
首先介绍一下量化训练的原理。
上图为单个神经元的计算,计算形式是加权求和,再经过非线性激活后得到输出,这个输出又可以作为下一个神经元的输入继续运输,所以神经网络的基础运算是矩阵的乘法。如果神经元的计算全部采用 float32 的形式,模型的内存占用和数据搬运都会很占资源。如果用 int8 替换 float32,内存的搬运效率能提高 75%,充分展示了量化的有效性。由于两个 int8 相乘会超出 int8 的表示范围,为了防止溢出,累加器使用 int32 类型的,累加后的结果会再次 requantized 到 int8;
量化的目标就是在尽可能不影响模型精度的情况下降低模型的功耗,实现模型压缩效果,常见的量化方式有后量化训练 PTQ 和量化感知训练 QAT。
量化感知训练其实是一种伪量化的过程,即在训练过程中模拟浮点转定点的量化过程,数据虽然都是表示为 float32,但实际的值会间隔地受到量化参数的限制。具体方法是在某些 op 前插入伪量化节点(fake quantization nodes),伪量化节点有两个作用:
1.在训练时,用以统计流经该 op 的数据的最大最小值,便于在部署量化模型时对节点进行量化
2.伪量化节点参与模型训练的前向推理过程,因此会模型训练中导入了量化损失,但伪量化节点是不参与梯度更新过程的。
上图是模型学习量化损失的示意图, 正常的量化流程是 quantize->mul(int)->dequantize,而伪量化是对原先的 float 先 quantize 到 int,再 dequantize 到 float,这个步骤用于模拟量化过程中 round 操作所带来的误差,用这个误差再去进行前向运算。上图可以比较直观的表示引起误差的原因,从左到右数第 4 个黑点表示一个浮点数,quantize 后映射到 253,dequantize 后取到了第 5 个黑点,这就引起了误差。
地平线基于 PyTorch 开发的 horizon_plugin_pytorch 量化训练工具,同时支持 Eager 和 fx 两种模式。
eager 模式的使用方式建议参考用户手册 -4.2 量化感知训练章节(4.2.2。 快速上手中有完整的快速上手示例,各使用阶段注意事项建议参考 4.2.3。 使用指南)。fx 模式的相关 API 介绍请参考用户手册 -4.2.3.4.2。 主要接口参数说明章节
第二部分:量化训练的流程介绍:以 mmdet 的 yolov3 为例
QAT 流程介绍
准备好浮点模型,加载训好的浮点权重
model = build_detector(cfg.model,train_cfg=cfg.get('train_cfg'),test_cfg=cfg.get('test_cfg'))model.init_weights()# 加载config里的 init_cfg
设置 BPU 架构
set_march(March.BAYES)
算子融合(eager 模式需要,fx 可省略)
# qat: run fuse_module to fuse conv+bn/relu/add opmodel.backbone.fuse_modules()model.neck.fuse_modules()model.bbox_head.fuse_modules()
设置量化配置
- 整个 model 使用默认的 qconfig
- 模型的输出,配置高精度输出
- det 模型 head 输出的 loss 损失函数的 qconfig 设置为 None
# qat: set qconfig for float modelmodel.qconfig = get_default_qat_qconfig()# qat: set default_qat_out_qconfig for last convfor m in model.bbox_head.convs_pred:m.qconfig = get_default_qat_out_qconfig()# qat: set None for loss qconfig, loss should be quantizedmodel.bbox_head.loss_cls.qconfig = Nonemodel.bbox_head.loss_conf.qconfig = Nonemodel.bbox_head.loss_xy.qconfig = Nonemodel.bbox_head.loss_wh.qconfig = None
将浮点模型转换为 qat 模型(示例使用 eager 模式)
qat_model = prepare_qat(model)qat_model.to(torch.device("cuda:1"))
开始 qat 训练
- 可以复用浮点的 train_detector,替换 model 即可
train_detector(qat_model,datasets,cfg,distributed=distributed,validate=(not args.no_validate),timestamp=timestamp,meta=meta)
qat 模型转定点(需要 load 训练好的 qat 模型权重)
quantized_model = convert(qat_model.eval())
deploy_model 和 example_input 准备
deploy_model = DeployModel(quantized_model.backbone, quantized_model.neck,quantized_model.bbox_head).to(torch.device("cuda:1"))example_input = torch.randn(size=(24, 3, 320, 320), device=torch.device("cuda:1"))
Trace 模型构建静态 graph,进行编译
- eval()使 bn、dropout 等处于正确的状态
- 编译只能在 cpu 上做
- check_model 用于检查算子是否能全部跑在 bpu 上,建议提前检查
traced_model = torch.jit.trace(deploy_model.eval(), example_input)traced_model.to(torch.device("cpu"))example_input.to(torch.device("cpu"))check_model(traced_model, example_input, advice=1)compile_model(traced_model, [example_input], opt=0, hbm="model.hbm")
如果 qat 精度不达标,如何插入 calibration?
1. 准备好浮点模型,加载训好的浮点权重
2. 设置BPU架构
3. 算子融合(eager模式需要,fx可省略)
4. 设置model的量化配置
-----------------calib_model-------------------
calib_model = prepare_qat(float_model)
calib_model.eval() # 使bn、dropout等处于正确的状态
set_fake_quantize(calib_model, FakeQuantState.CALIBRATION) # 不进行伪量化操作,仅观测算子输入输出统计量,更新scale
#校准训练(可复用浮点的train_detector,替换model即可)
train_detector(calib_model,datasets,cfg,distributed=distributed,validate=(not args.no_validate),timestamp=timestamp,meta=meta)#校准精度验证
calib_model.eval()
set_fake_quantize(calib_model, FakeQuantState.VALIDATION)
val(calib_model,val_dataloader,device)
-----------此时calib_model里的scale已经更新了-------------------------
qat_model = prepare_qat(float_model)
-----------qat_model加载calib训练好的模型权重,开始qat训练-----------------------------------------------
train_detector(qat_model,datasets,cfg,distributed=distributed,validate=(not args.no_validate),timestamp=timestamp,meta=meta)
伪量化节点(fake quantize)的三种状态:
- CALIBRATION 模式:即不进行伪量化操作,仅观测算子输入输出统计量,更新 scale
- QAT 模式:观测统计量并进行伪量化操作。
- VALIDATION 模式:不会观测统计量,仅进行伪量化操作。
以下常见误操作会导致一些异常现象:
- calibration 之前模型设置为 train()的状态,且未使用
set_fake_quantize
,等于是在跑 QAT 训练; - calibration 之前模型设置为 eval()的状态,且未使用
set_fake_quantize
,会导致 scale 一直处于初始状态,全为 1,calib 不起作用。 - calibration 之前模型设置为 eval()的状态,且正确使用了
set_fake_quantize
,但是在这之后又设置了一遍 model.eval(),这将导致 fake_quant 未处于训练状态,scale 一直处于初始状态,全为 1;
对 mobilenet_v2 模型做 qat 训练的设置
量化节点设置
关键代码:
from horizon_plugin_pytorch.quantization import QuantStubself.quant = QuantStub(scale=1/128) # 一般 pyramid 输入的 Quant 层,需要手动设置 scale=1/128def fuse_modules(self):x = self.quant(x)
# Copyright (c) OpenMMLab. All rights reserved.
import warningsimport torch.nn as nn
from mmcv.cnn import ConvModule
from mmcv.runner import BaseModule
from torch.nn.modules.batchnorm import _BatchNorm
from horizon_plugin_pytorch.quantization import QuantStubfrom ..builder import BACKBONES
from ..utils import InvertedResidual, make_divisible
import torch@BACKBONES.register_module()
class MobileNetV2(BaseModule):arch_settings = [[1, 16, 1, 1], [6, 24, 2, 2], [6, 32, 3, 2],[6, 64, 4, 2], [6, 96, 3, 1], [6, 160, 3, 2],[6, 320, 1, 1]]def __init__(self,widen_factor=1.,out_indices=(1, 2, 4, 7),frozen_stages=-1,conv_cfg=None,norm_cfg=dict(type='BN'),act_cfg=dict(type='ReLU6'),norm_eval=False,with_cp=False,pretrained=None,init_cfg=None):super(MobileNetV2, self).__init__(init_cfg)# qat: model start with Quantization node# and set scale=1/128self.quant = QuantStub(scale=1/128) # 一般pyramid输入的Quant层,需要手动设置scale=1/128self.pretrained = pretrainedassert not (init_cfg and pretrained), \'init_cfg and pretrained cannot be specified at the same time'if isinstance(pretrained, str):warnings.warn('DeprecationWarning: pretrained is deprecated, ''please use "init_cfg" instead')self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)elif pretrained is None:if init_cfg is None:self.init_cfg = [dict(type='Kaiming', layer='Conv2d'),dict(type='Constant',val=1,layer=['_BatchNorm', 'GroupNorm'])]else:raise TypeError('pretrained must be a str or None')self.widen_factor = widen_factorself.out_indices = out_indicesif not set(out_indices).issubset(set(range(0, 8))):raise ValueError('out_indices must be a subset of range'f'(0, 8). But received {out_indices}')if frozen_stages not in range(-1, 8):raise ValueError('frozen_stages must be in range(-1, 8). 'f'But received {frozen_stages}')self.out_indices = out_indicesself.frozen_stages = frozen_stagesself.conv_cfg = conv_cfgself.norm_cfg = norm_cfgself.act_cfg = act_cfgself.norm_eval = norm_evalself.with_cp = with_cpself.in_channels = make_divisible(32 * widen_factor, 8)self.conv1 = ConvModule(in_channels=3,out_channels=self.in_channels,kernel_size=3,stride=2,padding=1,conv_cfg=self.conv_cfg,norm_cfg=self.norm_cfg,act_cfg=self.act_cfg)self.layers = []for i, layer_cfg in enumerate(self.arch_settings):expand_ratio, channel, num_blocks, stride = layer_cfgout_channels = make_divisible(channel * widen_factor, 8)inverted_res_layer = self.make_layer(out_channels=out_channels,num_blocks=num_blocks,stride=stride,expand_ratio=expand_ratio)layer_name = f'layer{i + 1}'self.add_module(layer_name, inverted_res_layer)self.layers.append(layer_name)if widen_factor > 1.0:self.out_channel = int(1280 * widen_factor)else:self.out_channel = 1280layer = ConvModule(in_channels=self.in_channels,out_channels=self.out_channel,kernel_size=1,stride=1,padding=0,conv_cfg=self.conv_cfg,norm_cfg=self.norm_cfg,act_cfg=self.act_cfg)self.add_module('conv2', layer)self.layers.append('conv2')def make_layer(self, out_channels, num_blocks, stride, expand_ratio):"""Stack InvertedResidual blocks to build a layer for MobileNetV2.Args:out_channels (int): out_channels of block.num_blocks (int): number of blocks.stride (int): stride of the first block. Default: 1expand_ratio (int): Expand the number of channels of thehidden layer in InvertedResidual by this ratio. Default: 6."""layers = []for i in range(num_blocks):if i >= 1:stride = 1layers.append(InvertedResidual(self.in_channels,out_channels,mid_channels=int(round(self.in_channels * expand_ratio)),stride=stride,with_expand_conv=expand_ratio != 1,conv_cfg=self.conv_cfg,norm_cfg=self.norm_cfg,act_cfg=self.act_cfg,with_cp=self.with_cp))self.in_channels = out_channelsreturn nn.Sequential(*layers)def _freeze_stages(self):if self.frozen_stages >= 0:for param in self.conv1.parameters():param.requires_grad = Falsefor i in range(1, self.frozen_stages + 1):layer = getattr(self, f'layer{i}')layer.eval()for param in layer.parameters():param.requires_grad = False# qat: do fuse modeldef fuse_modules(self):self.conv1.fuse_modules()for layer_name in self.layers:layer = getattr(self, layer_name)if hasattr(layer, "fuse_modules"):layer.fuse_modules()elif isinstance(layer, nn.Sequential):for m in layer:if hasattr(m, "fuse_modules"):m.fuse_modules()def forward(self, x):"""Forward function."""# qat: qat model start with QuantStubx = self.quant(x)x = self.conv1(x)outs = []for i, layer_name in enumerate(self.layers):layer = getattr(self, layer_name)x = layer(x)if i in self.out_indices:outs.append(x)return tuple(outs)def train(self, mode=True):"""Convert the model into training mode while keep normalization layerfrozen."""super(MobileNetV2, self).train(mode)self._freeze_stages()if mode and self.norm_eval:for m in self.modules():# trick: eval have effect on BatchNorm onlyif isinstance(m, _BatchNorm):m.eval()
算子融合
[7.5.5. 算子融合 — Horizon Open Explorer](https://developer.horizon.ai/api/v1/fileData/horizon_j5_open_explorer_cn_doc/plugin/source/advanced_content/op_fusion.html?highlight=算子融合 算子 融合#)
举个例子:mmcv/cnn/bricks/conv_module.py
class ConvModule(nn.Module):
...
# qat: fuse conv + bn/reludef fuse_modules(self):fuse_list = Noneif self.with_norm:if self.with_activation:fuse_list = ["conv", self.norm_name, "activate"] # conv+bn+reluelse:fuse_list = ["conv", self.norm_name] # conv+bnelse:if self.with_activation:fuse_list = ["conv", "activate"] # conv+reluif fuse_list is not None:torch.quantization.fuse_modules(self,fuse_list,inplace=True,fuser_func=quantization.fuse_known_modules,)
eager 方案麻烦的是,基本每个模块都要手动去设置算子融合
反量化节点设置
mmdetection-master/mmdet/models/dense_heads/yolo_head.py
关键代码:
self.dequant = nn.ModuleList() # 不止1个反量化节点,用list包起来self.dequant.append(DeQuantStub())def fuse_modules(self):pred_map = self.dequant[i](self.convs_pred[i](x))
class YOLOV3Head(BaseDenseHead, BBoxTestMixin):def __init__(self,num_classes,in_channels,out_channels=(1024, 512, 256),anchor_generator=dict(type='YOLOAnchorGenerator',base_sizes=[[(116, 90), (156, 198), (373, 326)],[(30, 61), (62, 45), (59, 119)],[(10, 13), (16, 30), (33, 23)]],strides=[32, 16, 8]),bbox_coder=dict(type='YOLOBBoxCoder'),featmap_strides=[32, 16, 8],one_hot_smoother=0.,conv_cfg=None,norm_cfg=dict(type='BN', requires_grad=True),# qat# act_cfg=dict(type='LeakyReLU', negative_slope=0.1),act_cfg=dict(type='ReLU'),loss_cls=dict(type='CrossEntropyLoss',use_sigmoid=True,loss_weight=1.0),loss_conf=dict(type='CrossEntropyLoss',use_sigmoid=True,loss_weight=1.0),loss_xy=dict(type='CrossEntropyLoss',use_sigmoid=True,loss_weight=1.0),loss_wh=dict(type='MSELoss', loss_weight=1.0),train_cfg=None,test_cfg=None,init_cfg=dict(type='Normal', std=0.01,override=dict(name='convs_pred'))):super(YOLOV3Head, self).__init__(init_cfg)# Check paramsassert (len(in_channels) == len(out_channels) == len(featmap_strides))self.num_classes = num_classesself.in_channels = in_channelsself.out_channels = out_channelsself.featmap_strides = featmap_stridesself.train_cfg = train_cfgself.test_cfg = test_cfgif self.train_cfg:self.assigner = build_assigner(self.train_cfg.assigner)if hasattr(self.train_cfg, 'sampler'):sampler_cfg = self.train_cfg.samplerelse:sampler_cfg = dict(type='PseudoSampler')self.sampler = build_sampler(sampler_cfg, context=self)self.fp16_enabled = Falseself.one_hot_smoother = one_hot_smootherself.conv_cfg = conv_cfgself.norm_cfg = norm_cfgself.act_cfg = act_cfgself.bbox_coder = build_bbox_coder(bbox_coder)self.prior_generator = build_prior_generator(anchor_generator)self.loss_cls = build_loss(loss_cls)self.loss_conf = build_loss(loss_conf)self.loss_xy = build_loss(loss_xy)self.loss_wh = build_loss(loss_wh)self.num_base_priors = self.prior_generator.num_base_priors[0]assert len(self.prior_generator.num_base_priors) == len(featmap_strides)self._init_layers()def _init_layers(self):self.convs_bridge = nn.ModuleList()self.convs_pred = nn.ModuleList()self.dequant = nn.ModuleList() # 不止1个反量化节点,用list包起来for i in range(self.num_levels):conv_bridge = ConvModule(self.in_channels[i],self.out_channels[i],3,padding=1,conv_cfg=self.conv_cfg,norm_cfg=self.norm_cfg,act_cfg=self.act_cfg)conv_pred = nn.Conv2d(self.out_channels[i],self.num_base_priors * self.num_attrib, 1)self.convs_bridge.append(conv_bridge)self.convs_pred.append(conv_pred)self.dequant.append(DeQuantStub())def fuse_modules(self):for m in self.convs_bridge:m.fuse_modules()def forward(self, feats):"""Forward features from the upstream network.Args:feats (tuple[Tensor]): Features from the upstream network, each isa 4D-tensor.Returns:tuple[Tensor]: A tuple of multi-level predication map, each is a4D-tensor of shape (batch_size, 5+num_classes, height, width)."""assert len(feats) == self.num_levelspred_maps = []for i in range(self.num_levels):x = feats[i]x = self.convs_bridge[i](x)pred_map = self.dequant[i](self.convs_pred[i](x))pred_maps.append(pred_map)return tuple(pred_maps),
第三部分:常用的精度调优 debug 工具介绍
工具:集成接口、量化配置检查、模型可视化、相似度对比、统计量、分步量化、异构模型部署 device 检查
第四部分:模型精度调优分享
模型精度调优时常遇到的问题:
-
calib 模型的精度和 float 对齐,quantized 模型的精度损失较大
正常情况下,calib/qat 模型的精度和 quantized 模型的精度损失很小(1%), 如果偏差过大,可能是 calib/qat 的流程不对。
原因:calib 模型伪量化节点的状态不正确,导致 calib 阶段,测试的是 float 模型的精度,而 quantized 阶段,测试的是 calib 模型的精度,所以精度损失本质上还是量化精度的损失。
如何避免:
- 正确设置 calib 训练和评测时的伪量化节点状态。
- 让客户在 calib 的基础上,做 qat, 评测 qat 模型的精度。(客户的数据量大,qat 时间太长,一直没有选择 qat,导致这个问题被暴露出来了)
如何设置正确的 calib 伪量化节点的状态?(fx 和 eager 都是一样的)
http://model.aidi.hobot.cc/api/docs/horizon_plugin_pytorch/latest/html/user_guide/calibration.html
#加载浮点模型权重model.load_state_dict(torch.load("output/toy_experiments/model_moderate_best_soft_float_131892.pth"))set_march(March.BAYES)#校准配置calib_model = prepare_qat_fx(model,{"":default_calib_8bit_fake_quant_qconfig,"module_name":...}).to(device)calib_model.to(device)#校准需要全程开启eval()状态calib_model.eval()#校准的训练阶段,设置伪量化节点模式为 CALIBRATIONset_fake_quantize(calib_model, FakeQuantState.CALIBRATION)train(cfg, calib_model, device, distributed)#校准的评测阶段,设置伪量化节点的模式为 VALIDATIONset_fake_quantize(calib_model, FakeQuantState.VALIDATION)#加载校准的模型权重calib_model.load_state_dict(torch.load("output/toy_experiments/model_moderate_best_soft_calib_118633.pth"))#测试校准的精度run_test(cfg, calib_model, vis=args.vis, eval_score_iou=args.eval_score_iou, eval_all_depths=args.eval_all_depths) # 11.8650
注意:16 行的 train 在评测时,也要设置 FakeQuantState.VALIDATION,不然 scale 不生效,评测的指标也不对
常见问题:
- 数据校准之前模型设置为 train()的状态,且未使用
set_fake_quantize
,等于 caib 阶段是在跑 QAT 训练; - 校准的评测阶段,未设置伪量化节点的模式为 VALIDATION, 实际评测的是 float 模型;
总结 2: 如果做 calib,一定要仔细检查伪量化节点状态和模型状态是否正确,避免不符合预期的结果
2.当量化精度损失超过大,如何调优?
- 使用 model_profiler() 这个集成接口,生成压缩包。
- 检查是否配置高精度输出、是否存在未融合的算子、是否共享 op、是否算子分布过大 int8 兜不住?
- 注意:使用 debug 集成接口时,要保证浮点模型训练到位,并传入真实数据
3.多任务模型的精度调优建议
- qat 调优策略和常规模型一样,ptq+qat
- 如果只有一个 head 精度有损失,可以固定其他部分,单独使用这个 head 的数据做 calib
4.calib 和 qat 流程的正确衔接
calib:
#加载浮点模型权重model.load_state_dict(torch.load("output/toy_experiments/model_moderate_best_soft_float_131892.pth"))set_march(March.BAYES)#校准配置calib_model = prepare_qat_fx(model,{"":default_calib_8bit_fake_quant_qconfig,"module_name":...}).to(device)calib_model.to(device)#校准需要全程开启eval()状态calib_model.eval()#校准的训练阶段,设置伪量化节点模式为 CALIBRATIONset_fake_quantize(calib_model, FakeQuantState.CALIBRATION)train(cfg, calib_model, device, distributed)#校准的评测阶段,设置伪量化节点的模式为 VALIDATIONset_fake_quantize(calib_model, FakeQuantState.VALIDATION)#加载校准的模型权重calib_model.load_state_dict(torch.load("output/toy_experiments/model_moderate_best_soft_calib_118633.pth"))#测试校准的精度run_test(cfg, calib_model, vis=args.vis, eval_score_iou=args.eval_score_iou, eval_all_depths=args.eval_all_depths) # 11.8650
qat:
set_march(March.BAYES)qat_model = prepare_qat_fx(model,{"":default_qat_8bit_fake_quant_qconfig,"module_name":'''}).to(device)qat_model.to(device)#加载校准模型权重qat_model.load_state_dict(torch.load("output/toy_experiments/model_moderate_best_soft_calib_118633.pth"))#训练阶段,保证模型处于model.train()状态,这样伪量化节点也处于qat模式train(cfg, qat_model, device, distributed)
5.检查 conv 高精度输出
方式 1:查看 qconfig_info.txt,重点关注 DeQuantStub 附近的 conv 是不是 float32 输出
qconfig_info.txt
方式 2:打印 qat_model 的最后一层,查看该层是否有 (activation_post_process): FakeQuantize
高精度的 conv:
(1): ConvModule2d((0): Conv2d(64, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)(weight_fake_quant): FakeQuantize(fake_quant_enabled=tensor([1], dtype=torch.uint8), observer_enabled=tensor([1], dtype=torch.uint8), quant_min=-128, quant_max=127, dtype=qint8, qscheme=torch.per_channel_symmetric, ch_axis=0, scale=tensor([1., 1., 1.]), zero_point=tensor([0, 0, 0])(activation_post_process): MovingAveragePerChannelMinMaxObserver(min_val=tensor([]), max_val=tensor([])))))
)
int8 的 conv
(0): ConvModule2d((0): ConvReLU2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)(weight_fake_quant): FakeQuantize(fake_quant_enabled=tensor([1], dtype=torch.uint8), observer_enabled=tensor([1], dtype=torch.uint8), quant_min=-128, quant_max=127, dtype=qint8, qscheme=torch.per_channel_symmetric, ch_axis=0, scale=tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]), zero_point=tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])(activation_post_process): MovingAveragePerChannelMinMaxObserver(min_val=tensor([]), max_val=tensor([])))(activation_post_process): FakeQuantize(fake_quant_enabled=tensor([1], dtype=torch.uint8), observer_enabled=tensor([1], dtype=torch.uint8), quant_min=-128, quant_max=127, dtype=qint8, qscheme=torch.per_tensor_symmetric, ch_axis=-1, scale=tensor([1.]), zero_point=tensor([0])(activation_post_process): MovingAverageMinMaxObserver(min_val=tensor([]), max_val=tensor([]))))
6.检查共享 op
打开 qconfig_info.txt,后面标有(n)的就是共享的
特殊情况:layernorm 在 QAT 阶段是多个小量化算子拼接而成,module 的重复调用,也会产生大量 op 共享的问题
解决办法: 将 layernorm 替换为 batchnorm,测试了 float 精度,没有下降。
7.检查未融合的算子
打开 qconfig_info.txt,全局搜 BatchNorm2d 和 ReLU,如果前面有 conv,那就是没做算子融合
可以融合的算子:
- conv+bn
- conv+relu
- conv+add
- conv+bn+relu
- conv+bn+add
- conv+bn+relu+add
8.检查数据分布特别大的算子
打开 float 模型的统计量分布,一般是 model0_statistic.txt
有两个表,第一个表是按模型结构排列的;第二个表是按数据分布范围排列的
拖到第二个表,看前几行是那些 op
可以看到很多 conv 的分布很异常,使用的是 int8 量化
解决办法:
- 检查这些 conv 后面是否有 bn,添加 bn 后,数据能收敛一些
- 如果结构上已经加了 bn,数据分布还大,可以配置 int16 量化
- int16 调这两个接口,default_qat_16bit_fake_quant_qconfig 和 default_calib_16bit_fake_quant_qconfig
- 中间算子的写法和高精度输出类似 model.xx.qconfig = default_qat_16bit_fake_quant_qconfig ()