深度学习--模型优化--模型的剪枝--92

目录
  • 1. 模型压缩
  • 2. 神经网络剪枝
  • 4. 非结构化剪枝
  • 4. Pruning neurons结构化剪枝

1. 模型压缩

目的:使得模型体积更小,模型推理速度更快

评估指标:

Compression Ratio
压缩率 = 总参数量 / 非0参数量
原始网络参数量 / 优化后的网络模型中非0参数量

脱水前的重量 / 脱水后的重量

Theoretical Speedup
速度 = 总FLOPS / 非0 FLOPS
脱水前的浮点数运算量 / 脱水后的浮点数运算量

数据,模型,硬件等维度压缩和加速模型的方法。
1.压缩已有的网络,包含:张量分解,模型剪枝,模型量化;
(针对既有模型)
2.构建新的小型网络,包含:知识蒸馏,紧凑网络设计;(针
对新模型)

2. 神经网络剪枝

我们人从小也都经历过神经网络的剪枝

剪枝一词在决策树算法中,其实就已经听过。使用剪枝顾名思义就是将网络模型变得更加简单,顺带减少过拟合。

在训练期间删除连接
密集张量将变得稀疏(用零填充)
可以通过结构化块删除连接

好处:
减少过拟合
稀疏性优势
文件中有大量的0,如果有适当的稀疏张量表示方法,模型二进制文件尺寸减小。
模型更小,可以减少内存带宽消耗量。
对于特定模式的稀疏模型,可以开发优化算子,实现加速推理。

剪枝与dropout的区别:
剪枝并非 Dropout / Dropconnect
剪枝是根据权重的绝对值来选择去掉部分连接。
dropout是训练过程中随机丢弃某些神经元,让每个神经元都能学到知识。

剪枝改变权重张量,不改变激活张量
激活张量是实际推理神经元的输出值,
权重张量是已经落盘的神经元之间的连接权重,推理过程是不变的

不同粒度的剪枝:

非结构化剪枝:数量,把连接上的参数置0即为剪枝
结构化剪枝:改变结构,即网络层输出元素个数,比如卷积核的减少会影响特征图数量。
可以很精准的删除掉每一个非互相依赖的参数,或者一次删除更大的部分。 越细粒度(非结构化的)剪枝,就会越精准,但是同时更难去加速inference。
另一方面,一次删除更大的部分(结构化剪枝)会没那么精准,但是使得稀疏阵计算更容易。所以剪枝的粒度对于精度和速度是个 trade-off !

剪枝的方式:
One-shot Pruning
Iterative Pruning
Automatic Gradual Pruning

模型剪枝来压缩模型的流程:
1 构建模型,并训练得到一个模型
2 使用训练期间对应的剪枝API训练模型
3 生成的权重将包含许多零值
4 使用文件压缩库 gzip、bzip等进行压缩

哪些参数需要被剪枝?Magnitude 幅度也就是绝对值大小
Weight Magnitude Pruning --正向传播
Gradient Magnitude Pruning --反向传播

权重大小剪枝:
其实很简单,根据权重大小剪枝已经被证明很有效。它简单的根据粒度(weights/group/kernel/filter)计算L1正则项,然后根据保留参数的比例,去删除那些数值小的。

梯度大小剪枝:
根据梯度大小剪枝,仅有的不同是我们要用参数权重乘上对应的梯度,然后再计算L1正则项

剪枝的比率影响精度:

哪些层的参数更容易被剪掉:
因为卷积层(conv)中的参数相比全连接层(fc)来说,参数量少,所以卷积层参数的压缩比没有全连接层参数的压缩比大。换句话说,就是卷积层参数更加敏感,剪掉对准确率影响相对更大。
越靠后的卷积层或卷积层之后的那些全连接层往往参数越容易被剪掉。

4. 非结构化剪枝

根据连接重要性判断是否裁剪掉连接

W矩阵举例

numpy.percentile 是 NumPy 库中的一个函数,用于计算给定数组中的百分位数。百分位数是指在一个数据集中,某个百分比位置的值。例如,第 50 百分位数(即 50% 分位数)就是中位数。

import numpy as np# 创建一个示例数组
data = np.array([1, 2, 3, 4, 5])# 计算第 50 百分位数(中位数)
median = np.percentile(data, 50)
print(median)  # 输出: 3.0# 计算多个百分位数
percentiles = np.percentile(data, [25, 50, 75])
print(percentiles)  # 输出: [2. 3. 4.]

numpy.percentile 函数在数据分析和统计学中非常有用,常用于计算数据的分布情况、异常值检测、数据归一化等。

可以做的事情,是对每一层都增加一个变量存储mask矩阵,mask矩阵就根据希望保留的参数比例来存储。
mask掩码矩阵是会随着迭代调整的,直到mask稳定下来,就完成真正的剪枝。训练时Loss中依然可以加上L1正则项。

误剪的参数是否可以恢复?
值得注意的是mask矩阵,它在正向传播和计算loss的时候肯定是要用到的,在反向传播的时候,被剪枝剪掉的参数,是否要进行调参呢?
可以有一定的概率sigma(iter)让它进行调参,其实随着iter迭代次数增加,可以发现概率就会越大。这样的话下次得到的mask矩阵就有可能是不同的。这有助于恢复一些被错误剪掉的连接参数。

4. Pruning neurons结构化剪枝

和非结构化剪枝很类似,并且训练时Loss中依然可以加上L1正则项。用卷积神经网络举例,根据卷积核的L1 Norm判断是否裁剪掉卷积核,具体就是计算卷积核们的绝对值之和,进行排序,剪掉和最小的核以及对应的特征图。



Layer_i 和 Layer{i+1} 之间的卷积核减少,是不影响
Layer{i+2} 层 tensor 形状的

减少B中间那些使得生成结果C的数值几乎不变的通道,间接的剪枝了前面的卷积核

基于Batch Normalization缩放因子进行剪枝
基于BN中的缩放因子γ来对不重要的通道进行裁剪,间接的剪枝了前面的卷积核。

BN:

sparsity:
往往最重要的参数就是模型最终稀疏度sparsity,它表示最终模型中有多少是0,80%就是最终80%的参数为0。含有多少水分

开始步数、结束步数,用于控制什么时候开始和结束剪枝训练,把开始步数设置为大于0的值可以给模型一定时间先收敛,然后再用优
化技术,一般来说这样的效果会比较好。

剪枝效果:
50%-70%左右的稀疏性,准确率降低幅度并不大
独立于量化技巧,通常与量化配合效果不错
可以通过微调尝试不同的参数组合。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/736738.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

thinkphp6 使用FFMpeg获取视频信息

1.本地安装 FFMpeg,官网下载地址:https://ffmpeg.org/download.html#build-windows 解压后,把文件夹放到自定义目录,添加系统变量 2.安装依赖,composer.json 添加"php-ffmpeg/php-ffmpeg": "^0.19.0",3.封装class类<?php namespace app\api\cont…

JDK导入Lets Encrypt根证书

项目在调用https接口时报错:PKIX path building failed: sun.security.provider.certpath.SunCertPathBuilderException: unable to find valid certification path to requested target 原因可能是更新换新证书后,HTTPS 域名的公钥证书不在 JDK/JRE 的证书库中,被Java认为是…

element plus 日历组件默认中文样式,配置日期周一为周起始日

element ui 或者 plus 其实都是西方的展示方式,日立组件的周日视为每一周的开始日期,我们则是周日为每周的最后一天。那咱们要改成周一为每周的开始日期,如下图:elementui 是可以直接属性配置的,element plus不得行,但是配置下面代码到main.ts就可以了~ import ElementPl…

Codeforces Global Round 26 A~C2

惹啊啊啊啊,这场做得我发昏,最近总感觉不在状态,但还是再在冲击1600-1800的题目. A. Strange Splitting---------------------------------题解--------------------------------------------------- 给你一个数组,让你自己构造一个RB字符串让R位置的数组中的数字的最大值-…

搭建工程之一 eclipse 中基于 maven 的 webapp工程能基于tomcat运行

一、背景作为开发人员,开发的web(运行在tomcat 容器中)希望能够在本地开发工具(eclipse)中运行调试,加快开发测试进度。 二、操作步骤 1、创建maven工程 在 eclipse 上右键,选择"New"---"Other" --- "Maven" ---- " Maven Project &…

【PythonGIS】基于Geopandas和Shapely计算矢量面最短路径

在GIS进行空间分析时经常会需要计算最短路径,我也是最近在计算DPC的时候有这方面的需求,刚开始直接是用面的中心点求得距离,但其对不规则或空洞面很不友好。所以今天跟大家分享一下基于Geopandas和Shapely计算矢量面最短路径,这里的最短即点/边的最短!​ 在GIS进行…

ls 设置颜色

1 查看别名对应的真实命令 2 设置颜色 格式: alias 别名=命令 示例 3 取消颜色 示例

主键Id自增,如何获取Id(Dapper)

这里用的是Dapper,以前用EF的时候好像有用到过db.savechanges(). 但是项目中没有这个,所以用以下的方法去获取id 背景:涉及到多表入库,需要获取主表的Id,所以用到了这个(timeFields 可以忽略)/// <summary>/// 单个添加/// </summary>/// <typeparam name…

JDK、Tomcat、Maven配置

一、JDK安装及配置 1.下载地址:https://www.oracle.com/java/technologies/downloads/2.下载后直接本地安装,选择路径默认即可,类似如下路径:C:/Program Files/Java/jdk_1.8.0_301 3. 配置环境变量路径:程序->计算机->右键->属性->高级系统设置->高级->…

VSCode + Qt + QMake 开发编译环境搭建

鉴于Qt官方IDE太过难用,VSCode+各种插件功能强大,遂采用VSCode来写Qt项目; 本博客在 Windows 平台进行指导操作,Mac、Linux 平台配置方式类似,学习其本质就可。前置准备VSCode,最新版本即可 本地 Qt 环境,版本随意,本文主要针对较老版本使用Qmake构建系统的项目环境变…