基于BatchNorm的模型剪枝【详解+代码】

文章目录

    • 1、BatchNorm(BN)
    • 2、L1与L2正则化
      • 2.1 L1与L2的导数及其应用
      • 2.2 论文核心点
    • 3、模型剪枝的流程

  • ICCV经典论文,通俗易懂!论文题目:Learning Efficient Convolutional Networks through Network Slimming
  • 卷积后能得到多个特征图,这些图一定都重要吗?
  • 训练模型的时候能否加入一些策略,让权重参数体现出主次之分?
  • 以上这两点就是论文的核心,先看论文再看源码其实并不难!

如下图所示,每个conv-layer会被计算相应的channel scaling factors,然后根据channel scaling factors筛选conv-layer,达到模型瘦身的作用,图中的1.170,0.001,0.290等就是下面我们将要介绍的学习参数 γ \gamma γ 值,

在这里插入图片描述

1、BatchNorm(BN)

Network slimming,就是利用BN层中的缩放因子 γ \gamma γ
整体感觉就是一个归一化操作,但是BN中还额外引入了两个可训练的参数: γ \gamma γ β \beta β

BN的公式:
x ^ ( k ) = γ ⋅ x ( k ) − E [ x ( k ) ] V a r [ x ( k ) ] + β \hat x^{(k)}=\gamma \cdot \frac{x^{(k)}-E[x^{(k)}]}{\sqrt{Var[x^{(k)}]}}+\beta x^(k)=γVar[x(k)] x(k)E[x(k)]+β

  • 如果训练时候输入数据的分布总是改变,网络模型还能学的好吗?
    • 不能,网络刚开始学起来会很差,而且还容易导致过拟合,
  • 对于卷积层来说,它的输入可不是只有原始输入数据
    • 而是卷积层+BN层+relu层输出的数据,如果输入只来自卷积层,那么数据不在同一个分布内,网络刚开始学起来会很差,而且还容易导致过拟合
  • 以sigmoid为例,如果不经过BN层,很多输出值越来也偏离,导致模型收敛越来越难!
    在这里插入图片描述

A、BN的作用

  • BN要做的就是把越来越偏离的分布给他拉回来!
  • 再重新规范化到均值为0方差为1的标准正态分布
  • 这样能够使得激活函数在数值层面更敏感,训练更快
  • 有一种感觉:经过BN后,把数值分布强制分布在了非线性函数的线性区域中,而图像本身是非线性的,所以这是一个缺陷,所以就引入了 γ \gamma γ 参数,

B、BatchNorm参数

  • 如果都是线性的了,神经网络还有意义吗?
  • BN另一方面还需要保证一些非线性,对规范化后的结果再进行变换
  • 这两个参数是训练得到的: y ( k ) = γ x ^ ( k ) + β ( k ) y^{(k)} = \gamma \hat x^{(k)} + \beta ^{(k)} y(k)=γx^(k)+β(k)
  • 感觉就是从正态分布进行一些改变,拉动一下,变一下形状!

图中的1.170,0.001,0.290等就是学习参数 γ \gamma γ 值, γ \gamma γ 值越大则说明该特征层越重要,越小则不重要,

在这里插入图片描述

2、L1与L2正则化

如果学习到的 γ \gamma γ 值是1.17,1.16,1.15等,那如何筛选比较重要的 γ \gamma γ 值呢?使用L1正则化就可以实现筛选比较重要的 γ \gamma γ 值,

  • 论文中提出:训练时使用L1正则化能对参数进行稀疏作用,
  • L1:对权重参数稀疏与特征选择,会对一些权重参数稀疏化接近于0,
  • L2:平滑特征,会对权重参数都接近于0,

L1正则化: J ( θ → ) = 1 2 ∑ i = 1 m ( h θ ~ ( x ( i ) ) − y ( i ) ) 2 + λ ∑ j = 1 n ∣ θ j ∣ J\big(\overrightarrow{\theta}\big)= \frac{1}{2}\sum_{i=1}^m\big(h_{\widetilde{\theta}}(x^{(i)})-y^{(i)}\big)^2+\lambda \sum_{j=1}^n|\theta_j| J(θ )=21i=1m(hθ (x(i))y(i))2+λj=1nθj

L2正则化: J ( θ → ) = 1 2 ∑ i = 1 m ( h θ ~ ( x ( i ) ) − y ( i ) ) 2 + λ ∑ j = 1 n θ j 2 J\big(\overrightarrow{\theta}\big)= \frac{1}{2}\sum_{i=1}^m\big(h_{\widetilde{\theta}}(x^{(i)})-y^{(i)}\big)^2+\lambda \sum_{j=1}^n\theta_j^2 J(θ )=21i=1m(hθ (x(i))y(i))2+λj=1nθj2

其中 h θ ~ ( x ( i ) ) h_{\widetilde{\theta}}(x^{(i)}) hθ (x(i))是预测值, y ( i ) y^{(i)} y(i)是标签值,

2.1 L1与L2的导数及其应用

L1的导数:

L1求导后为:sign( θ \theta θ),相当于稳定前进,都为 ± 1 \pm 1 ±1;所以迭代次数够多,有些特征层权重 θ \theta θ 最后可以学成0了,所以L1可以做稀疏化,

在这里插入图片描述

L2的导数:

L2求导为:θ,梯度下降过程越来越慢,相应的权重参数都接近0,起到平滑的作用,

在这里插入图片描述

2.2 论文核心点

以BN中的 γ \gamma γ 为切入点,即 γ \gamma γ 越小,其对应的特征图越不重要,
为了使得 γ \gamma γ 能有特征选择的作用,引入L1正则来控制 γ \gamma γ

L = ∑ ( x , y ) l ( f ( x , W ) , y ) + λ ∑ γ ∈ Γ g ( γ ) L=\sum_{(x,y)}l\big(f(x,W),y\big)+\lambda\sum_{\gamma \in \Gamma}g(\gamma) L=(x,y)l(f(x,W),y)+λγΓg(γ)

其中 l ( f ( x , W ) , y ) l\big(f(x,W),y\big) l(f(x,W),y)是loss损失函数, γ \gamma γ 是BN中的参数 γ \gamma γ

3、模型剪枝的流程

训练-剪枝-再训练,整体流程如下图所示,

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/462122.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

如何写一个其他人可以使用的GitHub Action

前言 在GitHub中,你肯定会使用GitHub Actions自动部署一个项目到GitHub Page上,在这个过程中总要使用workflows工作流,并在其中使用action,在这个使用的过程中,总会好奇怎么去写一个action呢,所以&#xff…

亲测解决vscode的debug用不了、点了没反应

这个问题在小虎登录vscode同步了设置后出现,原因是launch文件被修改或删除。解决方法是重新添加launch。 坏境配置 win11 + vscode 解决方法 Ctrl + shift + P,搜索debug添加配置: 选择python debugger。 结果生成了一个文件在当前路径: launch内容: {// Use Int…

【Java八股面试系列】JVM-垃圾回收

目录 垃圾回收 堆空间的基本结构 内存分配和回收原则 分代收集机制 Minor GC 流程 空间分配担保 老年代 大对象直接进入老年代 长期存活的对象将进入老年代 GC的区域 对象存活判定算法 引用计数法 可达性分析算法 finalize() 字符串常量判活 类判活 垃圾回收算…

网络原理——数据链路层

以太网是数据链路层的核心协议 1. 以太网数据帧的组成部分 帧起始符(Preamble):8字节的连续数据0xAA,标识一个新数据帧的开始,用于同步收发双方的时钟。 目的MAC地址(Destination MAC Address&#xff09…

微软.NET6开发的C#特性——委托和事件

我是荔园微风,作为一名在IT界整整25年的老兵,看到不少初学者在学习编程语言的过程中如此的痛苦,我决定做点什么,下面我就重点讲讲微软.NET6开发人员需要知道的C#特性,然后比较其他各种语言进行认识。 C#经历了多年发展…

Lua 教程

Lua 教程 (今天又又又开新坑啦) Lua 教程 手册简介 Lua 是一种轻量小巧的脚本语言,用标准C语言编写并以源代码形式开放。 手册说明 Lua是什么? Lua 是一个小巧的脚本语言。是巴西里约热内卢天主教大学(Pontifical Catholic University of Rio de …

flink反压及解决思路和实操

1. 反压原因 反压其实就是 task 处理不过来,算子的 sub-task 需要处理的数据量 > 能够处理的数据量,比如: 当前某个 sub-task 只能处理 1w qps 的数据,但实际上到来 2w qps 的数据,但是实际只能处理 1w 条&#…

年-月-日的输入方法

大家对于输入的函数一定有所认识&#xff0c;比如c中位于 #include <iostream> 中的 cin 函数&#xff0c;这个函数输入单个十分好用&#xff0c;但是对于年月日这种较为复杂的就行不通了&#xff0c;就只能输入最前面的一个 那怎么输入像这样的年月日呢 答案就是用 scan…

FPGA_ip_Rom

一 理论 Rom存储类ip核&#xff0c;Rom是只读存储器的简称&#xff0c;是一种只能读出事先存储数据的固态半导体存储器。 特性&#xff1a; 一旦储存资料&#xff0c;就无法再将之改变或者删除&#xff0c;且资料不会因为电源关闭而消失。 单端口Rom: 双端口rom: 二 Rom ip核…

[word] word中页眉怎么设置与上一节不同 #笔记#笔记#经验分享

word中页眉怎么设置与上一节不同 word中页眉怎么设置与上一节不同 1、首先打开一个文档&#xff0c;点击上方的命令栏&#xff0c;找到“页眉”指令。 2、点击编辑&#xff0c;输入页眉的文字&#xff0c;输入完成之后&#xff0c;会看到两页的页眉是一样的。 3、在“页面布局…

Packet Tracer - Configure IOS Intrusion Prevention System (IPS) Using the CLI

Packet Tracer - 使用CLI配置IOS入侵防御系统&#xff08;IPS&#xff09; 地址表 目标 启用IOS入侵防御系统&#xff08;IPS&#xff09;。 配置日志记录功能。 修改IPS签名规则。 验证IPS配置。 背景/场景 您的任务是在R1上启用IPS&#xff0c;扫描进入192.168.1.0网络…

在 Next 中, ORM 框架 Prisma 使用

Prisma 介绍 Prisma 是一个 ORM 框架&#xff0c;主要用于 Node.js 或 TypeScript 作为后端开发的应用&#xff0c;主要有三部分组成&#xff1a; Prisma Client&#xff1a;自动生成且类型安全的查询构建器&#xff0c;适用于 Nodex.js 和 TS&#xff1b;Prisma Migrate: 迁…