Qkeras量化模型-直接搭建模型的量化感知训练

qkeras是谷歌的感知训练量化框架,具有一些功能:

1、支持导入keras模型到qkeras模型;

2、支持剪枝和量化,使用tensorflow lite一起配合,简直不要太好用;

3、支持指定量化函数,量化的bit数目、量化的小数点左边bit数、对称性等;

4、容易扩展,可以使用keras的设计准则自定义构建块去扩展keras的函数,并且发展出sota模型;

qkeras的量化:

1、使用keras搭建的模型,可以直接转换成qkeras的量化模型;

2、使用pytorch、mxnet等其他框架搭建的模型,可以手动转换成qkeras搭建的模型;

示例:
官网的实例直接运行会有比较多的问题,这里改成这样的导入方式,解决问题;并且保留最核心的部分,删除无用语句。

from __future__ import absolute_import
from __future__ import division
from __future__ import print_functionimport tensorflow as tf
from qkeras import *
from qkeras.utils import model_save_quantized_weights
import numpy as npnp.random.seed(42)NB_EPOCH = 100
BATCH_SIZE = 64
VERBOSE = 1
NB_CLASSES = 10
OPTIMIZER = tf.keras.optimizers.legacy.Adam(lr=0.0001, decay=0.000025)
VALIDATION_SPLIT = 0.1
train = 1# 1、导入数据;
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
RESHAPED = 784
x_test_orig = x_test
x_train = x_train.astype("float32")
x_test = x_test.astype("float32")
x_train = x_train[..., np.newaxis]
x_test = x_test[..., np.newaxis]
x_train /= 256.0
x_test /= 256.0y_train = tf.keras.utils.to_categorical(y_train, NB_CLASSES)
y_test = tf.keras.utils.to_categorical(y_test, NB_CLASSES)# 2、搭建量化模型;
x = x_in = tf.keras.layers.Input(x_train.shape[1:-1] + (1,), name="input")
x = QConv2D(32, (2, 2), strides=(2,2),kernel_quantizer=quantized_bits(4,0,1),bias_quantizer=quantized_bits(4,0,1),name="conv2d_0_m")(x)
x = QActivation("quantized_relu(4,0)", name="act0_m")(x)
x = QConv2D(64, (3, 3), strides=(2,2),kernel_quantizer=quantized_bits(4,0,1),bias_quantizer=quantized_bits(4,0,1),name="conv2d_1_m")(x)
x = QActivation("quantized_relu(4,0)", name="act1_m")(x)
x = QConv2D(64, (2, 2), strides=(2,2),kernel_quantizer=quantized_bits(4,0,1),bias_quantizer=quantized_bits(4,0,1),name="conv2d_2_m")(x)
x = QActivation("quantized_relu(4,0)", name="act2_m")(x)
x = tf.keras.layers.Flatten()(x)
x = QDense(NB_CLASSES, kernel_quantizer=quantized_bits(4,0,1),bias_quantizer=quantized_bits(4,0,1),name="dense")(x)
x = tf.keras.layers.Activation("softmax", name="softmax")(x)model = tf.keras.Model(inputs=[x_in], outputs=[x])
model.summary()
model.compile(loss="categorical_crossentropy", optimizer=OPTIMIZER, metrics=["accuracy"])# 3、训练模型;
if train:history = model.fit(x_train, y_train, batch_size=BATCH_SIZE,epochs=NB_EPOCH, initial_epoch=1, verbose=VERBOSE,validation_split=VALIDATION_SPLIT)# 4、计算测试准确度;score = model.evaluate(x_test, y_test, verbose=VERBOSE)print("Test score:", score[0])print("Test accuracy:", score[1])# 5、保存量化模型;model_save_quantized_weights(model)#6、打印模型参数shape;for layer in model.layers:for w, weight in enumerate(layer.get_weights()):print(layer.name, w, weight.shape)print_qstats(model)

基本逻辑步骤:

1、导入数据;

2、搭建量化模型;

3、训练模型;

4、计算测试准确度;

5、保存量化模型;

6、打印模型参数shape;

依赖版本:

# This file may be used to create an environment using:
# $ conda create --name <env> --file <this file>
# platform: win-64
ca-certificates=2023.01.10=haa95532_0
certifi=2022.12.7=py37haa95532_0
openssl=1.1.1t=h2bbff1b_0
pip=22.3.1=py37haa95532_0
python=3.7.16=h6244533_0
setuptools=65.6.3=py37haa95532_0
sqlite=3.41.2=h2bbff1b_0
vc=14.2=h21ff451_1
vs2015_runtime=14.27.29016=h5e58377_2
wheel=0.38.4=py37haa95532_0
wincertstore=0.2=py37haa95532_2

训练过程:
在这里插入图片描述

参考:

GitHub - google/qkeras: QKeras: a quantization deep learning library for Tensorflow Keras

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/189880.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

吴恩达《机器学习》8-7:多元分类

在机器学习领域&#xff0c;经常会遇到不止两个类别的分类问题。这时&#xff0c;需要使用多类分类技术。本文将深入探讨多类分类&#xff0c;并结合学习内容中的示例&#xff0c;了解神经网络在解决这类问题时的应用。 一、理解多类分类 多类分类问题是指当目标有多个类别时…

【STM32】RTC(实时时钟)

1.RTC简介 本质&#xff1a;计数器 RTC中断是外部中断&#xff08;EXTI&#xff09; 当VDD掉电的时候&#xff0c;Vbat可以通过电源--->实时计时 STM32的RTC外设&#xff08;Real Time Clock&#xff09;&#xff0c;实质是一个 掉电 后还继续运行的定时器。从定时器的角度…

在c#中如何将多个点位(Point)转换为多边形(Polygon)并装换为shp图层

&#x1f47b;如图&#xff0c;我现在有一组经纬度点位Point&#xff0c;接下来我们将他装换为多边形Polygon格式 &#x1f47b;使用QGIS > 图层 > 添加图层 > 添加分隔文本图层 > 打开这个csv点位文件 &#x1f47b;打开后如左下图&#xff0c;csv文件中的四个点位…

VS 将 localhost访问改为ip访问

项目场景&#xff1a; 使用vs进行本地调试时需要多人访问界面,使用ip访问报错 问题描述 vs通过ip访问报错 虚拟机或其它电脑不能正常打开 原因分析&#xff1a; 原因是vs访问规则默认是iis,固定默认启动地址是localhost 解决方案&#xff1a; 1.vs项目启动之后会出现这个 右…

BUUCTF 被偷走的文件 1

BUUCTF:https://buuoj.cn/challenges 题目描述&#xff1a; 一黑客入侵了某公司盗取了重要的机密文件&#xff0c;还好管理员记录了文件被盗走时的流量&#xff0c;请分析该流量&#xff0c;分析出该黑客盗走了什么文件。 密文&#xff1a; 下载附件&#xff0c;解压得到一个…

vue3安装vue-router

环境 node 18.14.2 yarn 1.22.19 windows 11 vite快速创建vue项目 参考 安装vue-touter 官网 yarn add vue-router4src下新建router文件夹&#xff0c;该文件夹下新建index.ts // router/index.ts 文件 import { createRouter, createWebHashHistory, RouterOptions, Ro…

Java中利用OpenCV进行人脸识别

OpenCV 概述 ​ OpenCV&#xff08;Open Source Computer Vision Library&#xff09;是一个开源计算机视觉库&#xff0c;它提供了丰富的工具和算法&#xff0c;用于处理图像和视频数据。该库由一系列高效的计算机视觉算法组成&#xff0c;涵盖了许多领域&#xff0c;包括目…

解决requests库中的期限处理问题:从404到异常再到修复

目录 引言 一、了解HTTP 404错误 二、问题分析 三、解决方法 1、控制请求频率 2. 使用代理服务器 3、异常处理与重试机制 4、修复问题源头 5、联系目标网站管理员 四、总结 引言 在利用Python的requests库进行网络爬虫或API请求时&#xff0c;我们有时会遇到“HTTP …

机器学习笔记 - 使用 PyTorch 的多任务学习和 HydraNet

一、HydraNet简述 特斯拉使用了一个模型可以解决他们正在处理的每一项可能的任务。 例如:物体检测、道路曲线估计、深度估计、3D重建、视频分析、物体追踪、ETC等等。 以下是在 NVIDIA GPU 上以 3 种不同配置运行的 2 个计算机视觉模型的基准测试。 在第一个配置中,我…

【LLM之基座】qwen 14b-4int 部署踩坑

由于卡只有24G&#xff0c;qwen14b 原生需要 30GB&#xff0c;按照官方团队的说法&#xff0c;他们用的量化方案是基于AutoGPTQ的&#xff0c;而且根据评测&#xff0c;量化之后的模型效果在几乎没有损失的情况下&#xff0c;显存降低到13GB&#xff0c;妥妥穷狗福音&#xff0…

【数据结构】图的简介(图的逻辑结构)

一.引例&#xff08;哥尼斯堡七桥问题&#xff09; 哥尼斯堡七桥问题是指在哥尼斯堡市&#xff08;今属俄罗斯&#xff09;的普雷格尔河&#xff08;Pregel River&#xff09;中&#xff0c;是否可以走遍每座桥一次且仅一次&#xff0c;最后回到起点的问题。这个问题被认为是图…

【Proteus仿真】【Arduino单片机】DS1302时钟

文章目录 一、功能简介二、软件设计三、实验现象联系作者 一、功能简介 本项目使用Proteus8仿真Arduino单片机控制器&#xff0c;使用PCF8574、LCD1602液晶、DS1302等。 主要功能&#xff1a; 系统运行后&#xff0c;LCD1602显示时间日期。 二、软件设计 /* 作者&#xff1a;…