【论文笔记】MetaPruning: Meta Learning for Automatic Neural Network Channel Pruning

Abstract

提出了一种新颖的元学习方法,用于自动剪枝非常深的神经网络。首先训练一个称为PruningNet的元网络,该网络能够针对目标网络生成权重参数,以生成任何剪枝结构。使用简单的随机结构抽样方法来训练PruningNet。然后,应用进化过程来搜索表现良好的剪枝网络。这种搜索非常高效,因为权重是由经过训练的PruningNet直接生成的,在搜索时不需要任何微调。通过为目标网络训练单个PruningNet,可以在几乎没有人为参与的情况下,在不同约束条件下搜索各种剪枝网络。与最先进的剪枝方法相比,在MobileNet V1/V2和ResNet上展示了更优越的性能。

github仓库

1 Introduction

典型的剪枝方法包含三个阶段:训练大型过度参数化网络、剪枝不太重要的权重或通道、微调或重新训练剪枝后的网络。第二阶段是关键。它通常执行迭代分层剪枝和快速微调或权重重建以保持精度。

最近的一项研究发现,无论是否继承原始网络中的权重,剪枝网络都可以达到相同的精度。这一发现表明,通道剪枝的本质是找到良好的剪枝结构——逐层通道数。

然而,详尽地寻找最佳剪枝结构在计算上是令人望而却步的。考虑一个 10 10 10层的网络,每层包含 32 32 32个通道。分层通道数的可能组合可以是 3 2 10 32^{10} 3210

受到最近的神经架构搜索(NAS)的启发,特别是One-Shot模型,以及HyperNetwork中的权重预测机制,本文提出训练一个PruningNet,它可以为所有候选修剪网络结构生成权重,这样就可以通过评估验证数据的准确性来搜索性能良好的结构,这是非常高效的。

在这里插入图片描述

图1:MetaPruning有两步。1)训练PruningNet。每次迭代,随机生成一个网络编码向量(network encoding vectors),并对应生成剪枝后的网络(Pruned Network)。PruningNet将网络编码向量作为输入并生成Pruned Network的权重。2)搜索最佳的Pruned Network。通过改变网络编码向量构建了许多剪枝网络,并使用PrunedNet预测的权重来评估它们在验证数据上的优劣。搜索时无需进行微调。

为了训练PruningNet,使用随机结构采样。PruningNet使用相应的网络编码向量生成剪枝网络的权重,即每层的通道数。通过随机输入不同的网络编码向量,PruningNet逐渐学习为各种修剪结构生成权重。训练结束后,通过进化搜索方法来搜索性能良好的剪枝网络,该方法可以灵活地结合计算FLOP或硬件延迟等各种约束。此外,通过确定每一层或每一阶段的通道来直接搜索最佳剪枝网络,可以在捷径中剪枝通道而无需额外的努力,这在以前的通道剪枝解决方案中很少得到解决。
本文将这个方法称为MetaPruning。

本文贡献分为4点:

  • 提出了一种元学习方法,MetaPruning,用于通道剪枝。方法的核心是学习一个元网络Pruning Net,它为各种修建结构生成权重。通过单个经过训练的PruningNet,可以在不同约束下搜索各种剪枝网络。
  • 与传统的剪枝方法相比,MetaPruning跳出繁琐的超参数调整,能够根据所需的指标直接进行优化。
  • 与其他AutoML方法相比,MetaPruning可以轻松地在搜索所需结构时强制实施约束,而无需手动调整强化学习超参数。
  • 元学习能够毫不费力地修剪类似ResNet结构的快捷连接中的通道,这并非易事,因为快捷连接中的通道影响不止一层。

3 Methodology

本节引入了元学习方法,用于自动修剪深度神经网络中的通道,修剪后的网络可以轻松满足各种约束。

公式化通道剪枝问题:
( c 1 , c 2 , ⋯ , c l ) ∗ = argmin ⁡ c 1 , c 2 , ⋯ , c l L ( A ( c 1 , c 2 , ⋯ , c l ; w ) s.t.  C < constraint (c_1,c_2,\cdots,c_l)^*={\underset{c_1,c_2,\cdots,c_l}{\operatorname{arg min}}}\ \mathcal{L}(\mathcal{A}(c_1,c_2,\cdots,c_l;w)\ \text{s.t.} \ \mathcal{C}<\text{constraint} (c1,c2,,cl)=c1,c2,,clargmin L(A(c1,c2,,cl;w) s.t. C<constraint
A \mathcal{A} A为剪枝前的网络。尝试找到剪枝后的网络,从第一层到第 l l l层具有 ( c 1 , c 2 , ⋯ , c l ) (c_1,c_2,\cdots,c_l) (c1,c2,,cl)个通道,使得权重被训练后具有最小的损失,同时使成本 C \mathcal{C} C满足约束(FLOP或延迟)。
为了实现这一目标,提出构建一个PruningNet,一种元网络,可以仅通过评估验证数据来快速获得所有潜在修剪网络结构的优点。然后可以应用任何搜索方法,即本文中的进化算法,来搜索最佳剪枝网络。

3.1 PruningNet training

通道剪枝并非易事,因为通道中的分层依赖性使得剪枝一个通道可能会显着影响后续层,从而降低整体精度。以前的方法试图将通道剪枝问题分解为逐层剪枝不重要通道的子问题或添加稀疏正则化。

考虑整体剪枝网络结构来执行信道剪枝任务,有利于寻找信道剪枝的最优解,并且可以解决捷径剪枝问题。然而,获得最佳剪枝网络并不简单,考虑到一个10层且每层包含32个通道的小型网络,可能的剪枝网络结构的组合是巨大的。

受最近工作的启发,该工作表明剪枝留下的权重与剪枝后的网络结构相比并不重要,这鼓励直接找到最佳剪枝后的网络结构。从这个意义上说,可以直接预测最佳剪枝网络,而无需迭代确定重要的权重过滤器。为了实现这一目标,构建了一个元网络 PruningNet,为各种修剪后的网络结构提供合理的权重,以对其性能进行排名。

PruningNet是一个元网络,它以网络编码向量 ( c 1 , c 2 , ⋯ , c l ) (c_1,c_2,\cdots,c_l) (c1,c2,,cl)作为输入,并输出剪枝网络的权重:
W = P r u n i n g N e t ( c 1 , c 2 , ⋯ , c l ) W = PruningNet(c_1,c_2,\cdots,c_l) W=PruningNet(c1,c2,,cl)
在这里插入图片描述

图2:提出的PruningNet随机训练方法。在每次迭代中,随机化一个网络编码向量。 PruningNet通过将向量作为输入来生成权重。剪枝网络是根据向量构建的。裁剪 PruningNet生成的权重以匹配 Pruned Network中的输入和输出通道。通过在每次迭代中改变网络编码向量,PruningNet可以学习为各种修剪网络生成不同的权重。

PruningNet块由两个全连接层组成。在前向传递中,PruningNet将网络编码向量(即每层的通道数)作为输入,并生成权重矩阵。同时,构造剪枝网络,每层的输出通道宽度等于网络编码向量中的元素。生成的权重矩阵被裁剪以匹配剪枝网络中输入和输出通道的数量,如图2所示。给定一批输入图像,可以使用生成的权重计算剪枝网络的损失。

在向后传递中,不是更新Pruned Networks中的权重,而是计算PruningNet中权重的梯度。由于PruningNet中全连接层的输出与Pruned Network中前一个卷积层的输出之间的重塑操作以及卷积操作也是可微的,因此可以轻松计算PruningNet中权重的梯度由链式法则。 PruningNet是端到端可训练的。PruningNet与Pruned Network连接的详细结构如图3所示。

在这里插入图片描述
图3:(a)PruningNet与Pruned Network连接的网络结构。PruningNet和Pruned Network通过网络编码向量和小批量图像的输入进行联合训练。(b)对PruningNet块生成的权重矩阵进行重塑和裁剪操作。

为了训练PruningNet,提出了随机结构采样。在训练阶段,网络编码向量是通过在每次迭代时随机选择每层的通道数来生成的。通过不同的网络编码,构建不同的Pruned Network,并由PruningNet提供相应的权重。通过使用不同的编码向量进行随机训练,PruningNet学会预测各种不同修剪网络的合理权重。

3.2 Pruned-Network search

PruningNet训练完成后,可以通过将网络编码输入PruningNet,生成相应的权重并对验证数据进行评估来获得每个潜在剪枝网络的准确率。
由于网络编码向量数量巨大,无法一一列举。为了找出约束下高精度的剪枝网络,使用进化搜索,它能够轻松地合并任何软或硬约束。

在MetaPruning中使用的进化算法中,每个剪枝网络在每一层中都用一个通道数向量进行编码,称为剪枝网络的基因。在硬约束下,首先随机选择一些基因,通过评估得到相应剪枝网络的准确率。然后选择准确度最高的前k个基因来产生突变和交叉的新基因。突变是通过随机改变基因中一定比例的元素来进行的。交叉意味着随机重组两个亲本基因中的基因以产生后代。可以通过消除不合格的基因来轻松地强制执行约束。通过进一步重复top k选择过程和新基因生成过程多次迭代,可以获得满足约束的基因,同时达到最高的准确率。具体算法参见Algorithm 1。

Algorithm 1: Evolutionary Search Algorithm

超参数:人口规模 P \mathcal{P} P,突变数 M \mathcal{M} M,交叉数 S \mathcal{S} S,最大迭代次数 N \mathcal{N} N
输入:PruningNet,限制 C \mathcal{C} C
输出:具有最高准确率的基因 G top \mathcal{G}_{\text{top}} Gtop

g 0 = Random ( P ) , s.t.  C ; \mathcal{g}_0=\text{Random}(\mathcal{P}), \text{s.t.}\ \mathcal{C}; g0=Random(P),s.t. C;
G topK = ∅ \mathcal{G}_{\text{topK}}=\emptyset GtopK=
for i = 0 : N i=0:\mathcal{N} i=0:N do:

{ G i , accuracy } = Inference ( P r u n i n g N e t ( G i ) ) \{\mathcal{G}_i,\text{accuracy}\}=\text{Inference}(PruningNet(\mathcal{G}_i)) {Gi,accuracy}=Inference(PruningNet(Gi))
G topK , accuracy topK = TopK ( { G i , accuracy } ) \mathcal{G}_{\text{topK}}, \text{accuracy}_\text{topK}=\text{TopK}(\{\mathcal{G}_i,\text{accuracy}\}) GtopK,accuracytopK=TopK({Gi,accuracy})
G mutation = Mutation ( G topK , M ) , s.t.  C \mathcal{G}_{\text{mutation}}=\text{Mutation}(\mathcal{G}_{\text{topK}},\mathcal{M}),\ \text{s.t.}\ \mathcal{C} Gmutation=Mutation(GtopK,M), s.t. C
G crossover = Crossover ( G topK , S ) , s.t.  C \mathcal{G}_{\text{crossover}}=\text{Crossover}(\mathcal{G}_{\text{topK}},\mathcal{S}),\ \text{s.t.}\ \mathcal{C} Gcrossover=Crossover(GtopK,S), s.t. C
G i = G mutation + G crossover \mathcal{G}_i=\mathcal{G}_\text{mutation}+\mathcal{G}_\text{crossover} Gi=Gmutation+Gcrossover
end for
G top1 , accuracy top1 = Top1 ( { G N , accuracy } ) \mathcal{G}_\text{top1},\text{accuracy}_\text{top1}=\text{Top1}(\{\mathcal{G}_\mathcal{N},\text{accuracy}\}) Gtop1,accuracytop1=Top1({GN,accuracy})
return G top1 \mathcal{G}_{\text{top1}} Gtop1

4 Experimental Results

在本节中,展示了提出的MetaPruning方法的有效性。首先解释实验设置并介绍如何在MobileNet V1、V2和 ResNet上应用MetaPruning,它可以很容易地推广到其他网络结构。其次,将的结果与统一的剪枝基线以及最先进的通道剪枝方法进行比较。第三,可视化通过MetaPruning获得的修剪网络。最后,进行消融研究以阐述方法中权重预测的效果。

4.1 Experiment settings

所提出的MetaPruning非常有效。因此在ImageNet 2012分类数据集上进行所有实验是可行的。

MetaPruning方法由两个阶段组成。在第一阶段,PruningNet是通过随机结构采样从头开始训练的,与正常训练网络一样需要 1 4 \frac{1}{4} 41数量的epochs。进一步延长PruningNet训练在获得的Pruned Net中几乎没有产生最终的精度增益。在第二阶段,使用进化搜索算法来找到最佳的修剪网络。通过PruningNet预测所有PrunedNet的权重,搜索时无需微调或重新训练,这使得进化搜索非常高效。在8个Nvidia 1080Ti GPU上推断PrunedNet只需几秒钟。然后从头开始训练从搜索中获得的最佳PrunedNet。对于两个阶段的训练过程,使用标准数据增强策略来处理输入图像。对于MobileNets的实验,采用与相同的训练方案;对于ResNet,采用中的训练方案。所有实验的输入图像分辨率均设置为224×224。

在训练时,将原始训练图像分成子验证数据集和子训练数据集。子验证数据集包含从训练图像中随机选择的50000张图像,每个1000类别有50张图像,而剩余的图像则组成子训练数据集。在子训练数据集上训练PruningNet,并在搜索阶段评估剪枝网络在子验证数据集上的性能。在搜索时,使用20000张子训练图像重新计算BatchNorm层中的运行均值和运行方差,以正确推断剪枝网络的性能,这仅需几秒钟时间。在获得最佳剪枝网络后,将剪枝网络从头开始在原始训练数据集上进行训练,并在测试数据集上进行评估。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/595839.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Scala第十八章节(Iterable集合、Seq集合、Set集合、Map集合以及统计字符个数案例)

Scala第十八章节 章节目标 掌握Iterable集合相关内容.掌握Seq集合相关内容.掌握Set集合相关内容.掌握Map集合相关内容.掌握统计字符个数案例. 1. Iterable 1.1 概述 Iterable代表一个可以迭代的集合, 它继承了Traversable特质, 同时也是其他集合的父特质. 最重要的是, 它定…

题目:串变换(蓝桥OJ 4360)

问题描述&#xff1a; 解题思路&#xff1a; 题目说可以挑选任意个操作&#xff0c;因此我们枚举全部的子集。题目说以任意顺序执行&#xff0c;因此我们枚举每种子集的全排列。如果存在一种子集的一种排列可以使s变成t就返回yes并结束&#xff0c;反之&#xff0c;遍历完全部没…

Java 哈希表

一、哈希表的由来 我们的java程序通过访问数据库来获取数据&#xff0c;但是当我们对数据库所查询的信息进行大量分析后得知&#xff0c;我们要查询的数据满足二八定律&#xff0c;一般数据库的数据基本存储在磁盘当中。这使得每次查询数据将变得无比缓慢。为此我们可以将经常…

leetcode代码记录(第一个出现两次的字母

目录 1. 题目&#xff1a;2. 我的代码&#xff1a;小结&#xff1a; 1. 题目&#xff1a; 给你一个由小写英文字母组成的字符串 s &#xff0c;请你找出并返回第一个出现 两次 的字母。 注意&#xff1a; 如果 a 的 第二次 出现比 b 的 第二次 出现在字符串中的位置更靠前&…

【随笔】Git 高级篇 -- 分离 HEAD(十一)

&#x1f48c; 所属专栏&#xff1a;【Git】 &#x1f600; 作  者&#xff1a;我是夜阑的狗&#x1f436; &#x1f680; 个人简介&#xff1a;一个正在努力学技术的CV工程师&#xff0c;专注基础和实战分享 &#xff0c;欢迎咨询&#xff01; &#x1f496; 欢迎大…

【unity小技巧】unity最完美的CharacterController 3d角色控制器,实现移动、跳跃、下蹲、奔跑、上下坡、物理碰撞效果,复制粘贴即用

最终效果 文章目录 最终效果前言为什么使用CharacterControllerSimpleMove和Move如何选择&#xff1f;1. SimpleMove2. Move 配置CharacterController参数控制相机移动跳跃方式一方式二 下蹲处理下坡抖动问题实现奔跑和不同移速控制完整代码补充&#xff0c;简单版本 实现物理碰…

很详细的单应矩阵分解R、t过程

很详细的单应矩阵分解R、t过程 附赠自动驾驶学习资料和量产经验&#xff1a;链接 已有多种方法将单应矩阵H分解为R、t&#xff0c;在《Deeper understanding of the homography decomposition for vision-based control》一文中介绍了三种方法&#xff1a; O. Faugeras and F.…

Docker实战教程 第1章 Linux快速入门

2-1 Linux介绍 为什么要学Linux 三个不得不学习 课程需要&#xff1a;Docker开发最好在Linux环境下。 开发需要&#xff1a;作为一个后端程序员&#xff0c;是必须要掌握Linux的&#xff0c;这是找工作的基础门槛。 运维需要&#xff1a;在服务器端&#xff0c;主流的大型服…

【操作系统】STM32-操作系统——持续更新

【操作系统】STM32-操作系统——持续更新 文章目录 前言一、ucosii二、freertos1.介绍2.移植 总结 前言 提示&#xff1a;以下是本篇文章正文内容&#xff0c;下面案例可供参考 一、ucosii UCOSII移植到STM32F103C8T6上之移植记录&#xff08;一&#xff09; UCOSII移植到ST…

C++之类和对象(上)

目录 1.面向过程和面向对象初步认识 2.类的引入 3.类的定义 4.类的访问限定符及封装 4.1访问限定符 4.2 类的两种定义方式 第一种&#xff1a; 第二种&#xff1a; 4.3封装 5.类的实例化 6.类对象模型 1.面向过程和面向对象初步认识 C语言是面向过程的&#xff0c;…

京东云服务器地域和可用区选择方法,多因素考虑攻略

京东云服务器地域如何选择&#xff1f;根据地理位置就近选择地域。京东云主机地域支持北京、宿迁、上海和广州&#xff0c;华北地区用户选择北京地域&#xff0c;华东地区用户可以选择上海或宿迁地区&#xff0c;南方用户选择广州地域。云服务器吧yunfuwuqiba.com整理京东云主机…

IP地址获取不到的原因是什么?

在数字化时代的今天&#xff0c;互联网已成为我们日常生活和工作中不可或缺的一部分。而IP地址&#xff0c;作为互联网通信的基础&#xff0c;其重要性不言而喻。然而&#xff0c;有时我们可能会遇到IP地址获取不到的问题&#xff0c;这会给我们的网络使用带来诸多不便。那么&a…