DMS:直接可微的网络搜索方法,最快仅需单卡10分钟 | ICML 2024

news/2024/11/15 15:49:37/文章来源:https://www.cnblogs.com/VincentLee/p/18368785

Differentiable Model ScalingDMS)以直接、完全可微的方式对宽度和深度进行建模,是一种高效且多功能的模型缩放方法。与先前的NAS方法相比具有三个优点:1)DMS在搜索方面效率高,易于使用。2)DMS实现了高性能,可与SOTA NAS方法相媲美。3)DMS是通用的,与各种任务和架构兼容。

来源:晓飞的算法工程笔记 公众号

论文: Differentiable Model Scaling using Differentiable Topk

  • 论文地址:https://arxiv.org/abs/2405.07194

Introduction


  在近年来,像GPTViT这样的大型模型展示了出色的性能。值得注意的是,GPT-4的涌现强调了通过扩展网络来实现人工通用智能(AGI)的重要性。为了支持这个扩展过程,论文引入了一种通用而有效的方法来确定网络在扩展过程中的最佳宽度和深度。

  目前,大多数网络的结构设计仍然依赖于人类专业知识。通常需要大量资源来调整结构超参数,导致很难确定最佳结构。与此同时,神经架构搜索(NAS)方法已经被引入到自动化网络结构设计中。根据搜索策略将NAS方法分为两类:随机搜索方法和基于梯度的方法。

  随机搜索方法需要对大量子网络进行采样以比较性能。然而,这些方法的搜索效率受到样本评估周期的限制,导致性能降低和搜索成本增加。

  与随机搜索方法不同,基于梯度的方法采用梯度下降法来优化结构参数、 提高效率,使其更善于平衡搜索成本和最终性能。然而,一个巨大的挑战依然存在:如何以直接和可微的方式为结构超参数建模?早期的方法一直在努力应对这一挑战,结果导致性能下降、成本增加。具体来说,根据建模策略将先前的方法分为三类:

  1. 多元素选择:在搜索卷积层中的通道数时,将通道数建模为通道选择(比如PaS通过可学习二值卷积生成0/1掩码对通道进行剪枝),如图1 a.1所示。
  2. 单数字选择:在搜索卷积层中的通道数时,将通道数建模为从多个候选数字中的选择一个(比如FBNetV2在一层中学习不同大小的候选卷积的权重),如图1 a.2所示。
  3. 梯度估计topk:尝试直接建模宽度和深度(比如通过梯度估计学习动态的k值以及给每个通道生成重要性分数,随后选择topk通道),如图1 a.3所示。这个可能跟多元素选择有点类似,核心区别是多元素选择是为了生成掩码,而这个则是为了生成动态k值。

  但所有上述策略都无法以直接和完全可微分的方式对结构超参数进行建模。为了解决上述挑战,论文引入了一个完全可微分的topk运算符,可以无缝地以直接和可微分的方式对深度和宽度进行建模。值得注意的是,每个可微分的topk运算符都有一个可学习参数,表示深度或宽度结构超参数,可以基于任务损失和资源约束损失的指导进行优化。与现有的基于梯度的方法相比,论文的方法在优化效率方面表现出色。

  基于可微分topk,论文提出了一种可微分模型缩放(DMS)算法来搜索网络的最佳宽度和深度。为了验证功效和效率,在各种任务中进行了严格的测试,包括视觉任务和NLP任务,以及不同的架构,包括CNNTransformer。由于可微分topk具有高效的搜索效率,DMS在性能或搜索成本方面均优于先前的SOTA方法。

  总的来说,论文的贡献如下:

  1. 引入了可微分的topk运算符,可以以直接和可微分的方式对结构超参数进行建模,因此很容易进行优化。
  2. 基于可微分的topk提出了一种可微分模型缩放(DMS)算法,用于搜索网络的最佳宽度和深度。
  3. 评估了DMS在各种任务和架构上的性能。例如,DMS在搜索过程中只需0.4 GPU天,就能比最先进的zero-shot NAS方法ZiCo表现出1.3%的优势。与性能相当的one-shot NAS方法ScaleNet以及multi-shot NAS方法Amplification相比,DMS所需的搜索成本仅为几十分之一。此外,DMS是一种广泛适用的方法,在COCO数据集上将Yolo-v8-n的提高了2.0%,并提高了裁剪Llama-7B模型样本分类精度。

Related Work


Stochastic Search Methods

  随机搜索方法通常通过采样和评估的循环过程进行操作。在每一步中,它们会对具有不同构的模型进行采样,然后对其进行评估。这种策略非常灵活,可以处理连续和离散的搜索空间。然而,它的一个显著缺点搜索效率低下,导致资源消耗高和性能不理想。具体而言,基于随机搜索的方法可以分为三种:

  1. multi-shot NAS:需要训练多个模型,这非常耗时,如EfficientNet用了1714TPU天来进行搜索。
  2. one-shot NAS:需要训练一个庞大的超网络,也需要大量资源,如ScaleNet用了379GPU天来训练一个超网络。
  3. zero-shot NAS:通过消除训练任何模型来减少成本,但其性能尚未达到所期望的标准。

Gradient-based Methods

  基于梯度的结构搜索方法使用梯度下降来探索模型的结构,这些方法一般比随机搜索方法更高效。基于梯度的方法的关键在于如何使用可学习参数来建模结构超参数并计算其梯度,理想情况下,可学习参数应直接建模结构超参数并且其梯度应以完全可微的方式计算。然而,先前的方法在建模网络的宽度和深度时往往难以同时满足这两个条件,可以将它们分为三类:(1)多元素选择,(2)单数字选择和(3)梯度估计topk。前两类间接地建模结构超参数,而第三类不可微分,需要进行梯度估计。

  为了提高结构搜索的优化效率,论文引入了一种新的可微分的topk方法,可以直接建模宽度和深度,并且是完全可微分的。从实验结果来看,论文的方法更加高效和有效。

Method


Differentiable Top-k

  假设存在一个由 \(k\) 表示的结构超参数,表示元素的数量,比如卷积层中的 \(k\) 个通道或网络阶段中的 \(k\) 个残差块。 \(k\) 的最大值为\(N\),使用 \({\mathbf{c}} \in \mathbb{R}^N\)来表示元素的重要性,其中较大的值表示更高的重要性。可微分topk方法的目标是输出一个软掩码 \({\mathbf{m}} \in 0,1^N\),代表具有前 \(k\) 个重要分数的选定元素。

topk运算符使用可学习的参数 \(a\) 作为阈值,选择那些重要性值大于 \(a\) 的元素。 \(a\) 能够直接建模元素数量 \(k\),因为 \(k\) 可以看作是 \(a\) 的一个函数,其中\(k=\sum^N\_{i=1}{1c\_i>a}\)\(1A\) 是一个指示函数,如果 \(A\) 为真则等于1,否则等于0\(c\_i\) 表示第 \(i\) 个元素的重要性。将topk表示为一个函数 \(f\),如下所示:

\[\begin{align} m\_i = f(a) \approx \begin{cases} 1 & \text{if } c\_i > a \ 0 & \text{otherwise} \end{cases} \end{align} \]

  在先前的方法中, \(f\) 通常是一个分段函数,不平滑也不可微分,且 \(a\) 的梯度是通过估计计算得出的。论文认为采用相对于 \(a\) 完全可微分的 \(f\) 的最大挑战是重要性分数分布不均匀。具体来说,不均匀的分布导致重要性值排序中的两个相邻元素之间的差异较大。假设每次迭代时通过固定值更新 \(a\),当前后元素的重要性差异很大时,则需要许多步才能使 \(a\) 跨越这两个元素。当差异很小时, \(a\) 可以在一步内跨越许多元素。因此,在元素重要性不均匀时,以完全可微分的方式优化 \(a\) 是非常困难。

  为了解决这个挑战,论文采用了一种重要性归一化过程,将不均匀分布的重要性强制转换为均匀分布的值,使得topk函数在可微分的情况下变得平滑且易于优化。总结起来,可微分topk有两个步骤:重要性归一化和软掩码生成。

Importance Normalization

  根据以下方式,通过将所有元素的重要性映射到从01的均匀分布的值来对所有元素的重要性归一化:

\[\begin{align} & c_i' = \frac{1}{N}\sum^N_{j=1}{1c\_i>c\_j}. \end{align} \]

  归一化后的元素重要性用 \({\mathbf{c}}'\) 表示。 \(1A\) 是与上面相同的指示函数, \({\mathbf{c}}\) 中的任意两个元素通常是不同的。值得注意的是,虽然 \(\mathbf{c}'\)01之间均匀分布,但 \(\mathbf{c}\) 可以遵循任何分布。

  直观地说, \(c'\_i\) 表示 \({\mathbf{c}}\) 中值小于 \(c\_i\) 的部分。此外,可学习的阈值 \(a\) 也变得有意义,表示元素的剪枝比例。 \(k\) 可以通过 \(k=\lfloor(1-a)N\rceil\) 计算,其中 \(\lfloor \, \rceil\) 是一个取整函数。 \(a\) 限制在 \(0,1\) 的范围内,其中 \(a=0\) 表示不剪枝, \(a=1\) 表示剪枝所有元素。

Soft Mask Generation

  在归一化之后,可以使用基于相对大小的剪枝比例 \(a\) 和归一化元素重要性\({\mathbf{c}}'\)的平滑可微函数轻松生成软掩码${\mathbf{m}} $。

\[\begin{align} & m\_i = f(a)= \text{Sigmoid}(\lambda({\mathbf{c}}'\_i- a)) = \frac{1}{1+e^{-\lambda({\mathbf{c}}'\_i - a)}}. \end{align} \]

  论文添加了一个超参数 \(\lambda\)来控制从公式3到硬掩码生成函数的逼近程度。当 \(\lambda\)趋近于无穷大时,公式3接近于硬掩码生成函数(根据固定阈值 \(a\) 直接得出0/1)。通常将 \(\lambda\)设置为\(N\),因为当\(c'\_i>a+3/N\)\(c'\_i<a-3/N\)时,$|(m_i-\lfloor m_i \rceil)|<0.05 $。这意味着除了重要性值接近剪枝比例的六个元素外,其他元素的掩码接近于01,近似误差小于0.05。因此, \(\lambda=N\)足以逼近topk的硬掩码生成函数。

  公式3的前向和反向图分别如图2(a) 和图2(b) 所示,可以观察到以下两点:

  1. topk直接使用可学习的剪枝比例 \(a\) 来建模元素数量 \(k\),并在前向过程中生成极化的软掩码\({\mathbf{m}}\),以完美模拟剪枝后的模型。
  2. 可微分topk完全可微分,并且能够稳定地进行优化。\(a\)相对于\(m_i\)的梯度为 $\frac{\partial m_i}{\partial a} = -\lambda(1-m_i)m_i $。我们的__topk直观地检测模糊区域中\(0.05<m\_i<0.95\)的掩码梯度。请注意,图2__(b) 描述的是 \(\frac{\partial m\_i}{\partial a}\)的值,而不是 \(a\) 的总梯度,\(a\) 的总梯度为 \(\sum_{i=1}^{N}{\frac{\partial task\_loss}{\partial m\_i}\frac{\partial m\_i}{\partial a}}+\frac{\partial resource\_loss}{\partial a}\)

Element Evaluation

  由于元素重要性会被归一化后再进行掩码生成,所以不限制元素重要性的分布,可以通过多种方法来量化元素重要性,例如L1-norm 等。论文以滑动平均方式实现了Taylor importance,具体如下所示:

\[\begin{align} c^{t+1}_i = c^t\_i \times decay + (m^t\_i \times g_{i})^2 \times (1-decay). \end{align} \]

  在这里,\(t\)表示训练步骤,\(g\_i\)\(m\_i\) 相对于训练损失的梯度,\(Decay\)是衰减率,\(c\_i^0\)的初始值设为零,衰减率设为0.99。请注意,元素的重要性不是通过梯度下降来更新的。通过利用Taylor importance,可以高效且稳定地估计元素的重要性。

Differentiable Model Scaling

  依靠可微分topk,论文提出了可微分模型缩放(Differentiable Model ScalingDMS)来优化网络的宽度和深度。DMS有三种基于基于训练的模型剪枝的流水线变体,如表1所示。

  • \(\text{DMS}\_{\text{p}}\)

\(\text{DMS}\_{\text{p}}\) 是基于训练的模型剪枝流水线,由预训练阶段、搜索阶段和重新训练阶段组成:

  1. 预训练阶段用于预训练一个超网络,通常需要大量时间和资源。
  2. 搜索阶段在特定资源约束下搜索超网络的最优宽度和深度,由于论文方法具有较高的搜索效率,因此搜索阶段只使用了大约1/10或更少的重新训练轮数。
  3. 在重新训练阶段,重新对已经进行了搜索的模型进行训练。与SOTA剪枝方法进行比较时,使用这个流水线。
  4. \(\text{DMS}\_{\text{np}}\)

\(\text{DMS}_{\text{np}}\) 是论文中默认和最常用的流水线。预训练阶段的高本占据了总成本的大部分,这是__NAS__和剪方法在实际应用中面临的一个重大障碍。为克服这个问题,从 \(\text{DMS}_{\text{}}\) 中去除了预训练阶段,直接从随机初始化超网络开始搜索。通过增加超网络大小,\(\text{DMS}_{\text{np}}\) 在性能和效率上都超过了 \(\text{DMS}_{\text{p}}\),并且比其他NAS方法更加高效。

  • \(\text{DMS}\_{\text{p-}}\)

\(\text{DMS}_{\text{p-}}\) 用于快速比较不同搜索方法。与 \({\text{DMS}_\text{p}}\) 相比,它只优化结构参数,不对搜索到的模型进行重新训练。利用现有的预训练超网络,也能输出合理的结果。此外,它只需要数百次迭代,在单个RTX3090上花费不到10分钟就可以搜索出一个模型。

  • Search Space

  如图1(b)所示,论文的搜索空间涵盖了网络的宽度和深度,这是模型扩展最关键的结构超参数。为了表示这些维度,使用了可微分的topk方法。在网络中,宽度通常涵盖了卷积层中的通道维度、全连接层中的特征维度等。关于深度,论文专注于具有残差连接的网络,并搜索每个阶段中的块数。具体来说,将可微分topk的软掩码合并到残差连接中,使得每个块可以表示为 \(x\_{i+1}=x\_i+f(x\_i)\times m\_i\)

  此外,对于结构超参数 \(x\),在范围 \(1, x\_{max}\) 内以步长1进行搜索,而大多数先前的NAS方法则在范围 \(x_{min}, x_{max}\) 内以步长32进行搜索。这个搜索空间是由人类专家设计的一个较好的子空间,然而论文的搜索空间更加通用且成本最低。实验结果显示,论文的方法在精细搜索空间上可以达到更好的性能,这个空间更难搜索,而先前的方法在粗粒度搜索空间上的表现较好,这种空间更容易搜索。

  • Resource Constraint Loss

  为了确保网络遵循特定的资源约束,在优化过程中加入了一个额外的组件,称为Resource Constraint Loss。因此,整体的损失函数为:

\[\begin{align} loss= & loss_{task}+\lambda_{resource}\times loss\_{resource}. \ & loss\_{resource}=\begin{cases} \log(\frac{r_{c}}{r_{t}}) & \text{if } r_{c} > r_{t} \ 0 & \text{otherwise} \end{cases}. \end{align} \]

  在这里, \(loss_{task}\) 表示任务损失。 \(loss_{resource}\) 表示额外的资源约束损失, \(\lambda\_{resource}\) 作为其权重系数。 \(r\) 表示当前资源消耗水平,根据可学习参数的不同topk操作符进行计算。 \(r\_t\) 代表目标资源耗水平,由用户指定。由于topk是完全可微分的,可学习结构参数可以在任务损失和资源约束失的指导下进行优化。

Experiment


如果本文对你有帮助,麻烦点个赞或在看呗

更多内容请关注 微信公众号【晓飞的算法工程笔记】

work-life balance.

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/784143.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

ArchLinux配置OpenCV C++环境

本文将简单介绍在 ArchLinux 中安装 OpenCV C++ 库并运行一个简单的 OpenCV 程序的过程。 参考:https://github.com/donaldssh/Install-OpenCV 我的环境最新的 ArchLinux KDE Plasma 6 桌面环境 OpenCV 4.10.0 clang 18.1.8 gcc 14.2.1安装 安装以下包: sudo pacman -S hdf5 …

聊聊如何利用ingress-nginx实现应用层容灾

前言 容灾是一种主动的风险管理策略,旨在通过构建和维护异地的冗余系统,确保在面临灾难性事件时,关键业务能够持续运作,数据能够得到保护,从而最大限度地减少对组织运营的影响和潜在经济损失。因此容灾的重要性不言而喻,今天的话题主要是聊下如何利用ingress-nginx实现应…

一个超全的go工具库Lancet

文档官网 https://www.golancet.cn 安装 使用 go1.18 及以上版本的用户,建议安装 v2.x.x。 因为 v2.x.x 应用 go1.18 的泛型重写了大部分函数。 go get github.com/duke-git/lancet/v2使用 go1.18 以下版本的用户,必须安装 v1.x.x。目前最新的 v1 版本是 v1.4.1。 go get git…

SHELL之变量

一、脚本 1、shell组成 #!脚本声明(使用哪种解释器) # 注释信息 脚本内容注意: 如果直接将解释器路径写死在脚本里,可能在某些系统就会存在找不到解释器的兼容性问题,所以可以使用: #!/bin/env 解释器 #!/bin/env bash2、执行脚本方式 sh -x 脚本文件路径 source 脚…

NSSCFT [SWPUCTF 2022 新生赛]ez_ez_php

NSSCFT [SWPUCTF 2022 新生赛]ez_ez_php进入之后就看见一段php代码,那就直接开始代码审计<?php error_reporting(0); if (isset($_GET[file])) {if ( substr($_GET["file"], 0, 3) === "php" ) {//截取字符串前三个字符,并与php做判断echo "Nic…

同一工程中的低复位

mark: 低复位和高复位在同一工程中一定要吻合。

序列容器

序列容器 序列容器以线性序列的方式存储元素。它没有对元素进行排序,元素的顺序和存储它们的顺序相同。5 种标准的序列容器,每种容器都具有不同的特性:array<T,N> (数组容器) :长度固定的序列,有 N 个 T 类型的对象,不能增加或删除元素。 vector<T> (向量容器…

Antd-React-TreeSelect前端搜索过滤

Antd-React-TreeSelect前端搜索过滤,antd本事是带有搜索的功能,但是在开发过程中发现自带的搜索功能与我们要使用的搜索过滤还是差了好多,在一些时候搜索为了迎合需要不得不这么操作,那么该操作结合了antd官方的搜索操作,因而在看了网上的一些操作后还是与需求不符合,最后…

方法的三种调用形式

在《可以调用Null的实例方法吗?》一文中,我谈到.NET方法的三种调用形式,现在我们就来着重聊聊这个话题。具体来说,这里所谓的三种方法调用形式对应着三种IL指令:Call、CallVirt和Calli。在《可以调用Null的实例方法吗?》一文中,我谈到.NET方法的三种调用形式,现在我们就…

关于隐藏Selenium绕过检测

. 浏览器指纹识别:网站通常通过浏览器指纹识别来检测访问者的身份。浏览器指纹是浏览器在访问网站时提供的一组信息,包括浏览器类型、版本、插件、用户代理字符串、屏幕分辨率、语言设置、操作系统等。 当你使用 Selenium 或其他自动化工具时,某些指纹信息可能会暴露自动化工…

依赖倒置原则

一、前言 依赖倒置原则也称依赖倒转原则(Dependence Inversion Principle) 看官方定义 高层模块不应该依赖底层模块,二者都应该依赖其抽象 抽象不应该依赖细节,细节应该依赖抽象 依赖倒置的中心思想是面向接口编程 如果你了解点设计模式,应该理解上面的话,但是如果不了解…