Adafactor Adaptive Learning Rates with Sublinear Memory Cost

news/2024/11/15 18:00:11/文章来源:https://www.cnblogs.com/MTandHJ/p/18408312

目录
  • 符号说明
  • Adafactor
    • Factored Second Moment Estimation
    • No Momentum
    • Out-of-Date Second Moment Estimator
    • 算法
  • 代码

Shazeer N. and Stern M. Adafactor: Adaptive learning rates with sublinear memory cost. ICML, 2018.

本文介绍了一种 memory-efficient 的优化器: Adafactor.

符号说明

  • \(x\), parameters;
  • \(W \subset x\), a linear transformation, $ \in \mathbb{R}^{n \times m}$

Adafactor

下面, 我们一步步介绍 Adafactor 对于 Adam 的修改.

Factored Second Moment Estimation

  • 一般的 Adam 的更新流程如下:

  • 一个比较重要的点是 一阶和二阶 的动量估计, 这个估计导致了 Adam 至少需要 2x 的参数的缓存.

  • 假设对于 linear transformation \(W \in \mathbb{R}^{n \times m}\), 它所对应的二阶动量为: \(V \in \mathbb{R}^{n \times m}\), 作者希望将他分解成两个低秩矩阵: \(R \in \mathbb{R}^{n \times k}, S \in \mathbb{R}^{k \times m}\), 使得

    \[V \approx RS. \]

  • 由于 \(V\) 是非负的, 所以作者更倾向于 nonnegative matrix factorization, 并利用泛化的 KL 散度—— I-divergence:

    \[d(p, q) = p \log \frac{p}{q} - p + q \]

    作为度量.

  • 作者希望 \(R, S\) 能够满足:

    \[\min_{R \in \mathbb{R}^{n \times k}, S \in \mathbb{R}^{k \times m}} \quad \sum_{i=1}^n \sum_{j=1}^m d(V_{ij}, [RS]_{ij}) \\ s.t. \quad R_{ij} \ge 0, \quad S_{ij} \ge 0. \]

  • 特别的, 作者证明了, 在 \(k=1\) 的情况下, 一定有:

    \[RS = V1_m 1_n^T V / 1_n^T V 1_m, \quad 1_{\ell} := (1, \ldots, 1) \in \mathbb{R}^{\ell} \]

    成立. 于是, 在这种情况下, 不失一般性的, 可以领:

    \[R = V 1_m, C = 1^T V. \]

  • 于是, 作者给出了如下的 \(V_t\) 的更新方案:

    \[G_t = \nabla f_t(W_{t-1}) \\ R_t = \beta_2 R_{t-1} + (1 - \beta_2) (G_t^2 1_m) \\ C_t = \beta_2 C_{t-1} + (1 - \beta_2) (\mathbf{1}_n^T G_t^2) \\ \hat{V}_t = (R_t C_t / 1_n^T R_t) / (1 - \beta_2^t) \\ W_t = W_{t-1} - \alpha G_t / (\sqrt{\hat{V}_t} + \epsilon). \]

No Momentum

  • 为了进一步降低一阶动量的缓存, 作者直接令 \(\beta_1 = 0\), 即移除了一阶动量.

Out-of-Date Second Moment Estimator

  • 作者认为, 当模型变化特别快的时候, 二阶矩的估计很容易过时:

  • 如上图所示, 当我们用一个较大的 \(\beta_2\), 如果没有 warm-up (即模型缓慢更新) 阶段, 效果是特别差的.

  • 为了验证这一点, 作者统计:

    \[\text{RMS}(U_t) = \text{RMS}_{x \in X} (u_{xt}) = \sqrt{\text{Mean}_{x \in X} (\frac{g_{xt}^2}{\hat{v}_{xt}} )}. \]

    作者认为, 如果训练是稳定的, \(\text{RMS}(U_t) \approx 1\), 既然 Adam 的一个假设是:

    \[\mathbb{E}[\hat{v}] = \mathbb{E}[g^2]. \]

  • 如上图所示, \(\beta_2\) 取得比较大的时候, 结果并不是这样的. 于是:

    \[U_t= G_t / \sqrt{\hat{V}_t} \\ \hat{U}_t = U_t / \max(1, RMS(U_t) / d) \\ W_t = W_{t-1} - \alpha_t \hat{U}_t. \]

    即 Adafactor 会手动校准.

算法

  • Adafactor 对于 matrix:

  • Adafactor 对于 vector:

  • 默认的参数设置:

注: \(\rho\) 是人为设置的相对步长, 这里不多赘述了.

代码

[pytorch-optimizer]

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/795494.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Oracle 19c OCP 认证考试 082 题库(第19题)- 2024年修正版

【优技教育】Oracle 19c OCP 082题库(Q 19题)- 2024年修正版 考试科目:1Z0-082 考试题量:90 通过分数:60% 考试时间:150min 本文为(CUUG 原创)整理并解析,转发请注明出处,禁止抄袭及未经注明出处的转载。 原文地址:http://www.cuug.com.cn/ocp/082kaoshitiku/3822886061…

如何用 Helm 安装指定版本的 GitLab Runner?

本分分享如何使用 Helm 来在 Kubernetes 集群上安装极狐GitLab Runner。整体步骤分为:Helm 的安装、vaules.yaml 文件的配置、Runner 的安装、Runner 的测试。 极狐GitLab 为 GitLab 在中国的发行版,中文版本对中国用户更友好。极狐GitLab 支持一键私有化部署,可以在 ubuntu…

安装程序在安装此程序包时遇到了错误2503

原文链接:https://blog.csdn.net/sisi_new/article/details/139180294 安装程序在安装此程序包时遇到错误2503通常是由于安装权限不足造成的解决方案如下:1.修改TEMP文件夹的权限:进入“C:\Windows\Temp”路径,右键单击Temp文件夹选择“属性”,在“安全”选项卡中修改权限…

安全测试工具(1)- Burp Suite Pro的安装教程

啥是Burp Suite 用于攻击web 应用程序的集成平台 程序员必备技能,不仅可以拿来做渗透测试、漏洞挖掘还能帮助程序员调试程序 Bug 它包含了许多Burp工具,这些不同的burp工具通过协同工作,有效的分享信息,支持以某种工具中的信息为基础供另一种工具使用的方式发起攻击。这些工…

高等数学 1.6 极限存在准则 两个重要极限

目录第一个准则第一个重要极限第二个准则第二个重要极限柯西(Cauchy)极限存在准则 第一个准则 准则Ⅰ:如果数列 \(\{ x_n \}\) ,\(\{ y_n \}\) 及 \(\{ z_n \}\) 满足下列条件: (1)从某项起,即 \(\exists n_0 \in \mathbb{N}_+\) ,当 \(n > n_0\) 时,有 \[y_n \le…

socket套接字通信---win和linux互通(1)

一、Windows下的网络调试工具-NetAssist 下载页面 下载后无需安装,解压缩就是个exe的执行文件。双击打开就可使用 软件界面二、linux下的网络调试工具 nc(netcat) 1、当前系统 $ cat /proc/version Linux version 6.6.47-current-x86 (build@armbian) (gcc (Ubuntu 11.4.0-1u…

第一次编程作业

这个作业属于哪个课程 计科22级34班这个作业要求在哪里 个人项目这个作业的目标 1.设计一个查重算法。2. 了解并学习项目的PSP表格3. 学习如何运用github进行代码管理4. 学习使用性能分析工具,分析代码性能5. 学习如何进行单元测试我的github仓库链接:https://github.com/zfi…

mysql 拼接字段

select spot_position,req_line,CONCAT(spot_position,-,req_line) from pdm_qc_apply where req_qctype != 2;结果展示:

Origin2024中绘制多因子分组柱状图,直观展示不同组别内的数据变化!

当我们需要对比多组平行数据时,采用Origin多因子分组柱状图,不仅可以直接的对比多组数据,同时还能够直观展示各个指标因子的数据变化及趋势操作步骤: 1、先打开Origin2024软件,然后在Book1中输入如下示例数据: 2、第一步,绘制分组柱形图图表,选中所有数据:3、点击菜单…

floorplan-reconsturtion-based-plane-triangle

一个iter算25s, 每个epoch31个iter,480个epoch需要 2531480/3600/24 = 4.3(天)改用30个epoch,训练5个小时Loss曲线

易基因:Adv Sci:ACE等揭示产前不良环境暴露通过DNA羟甲基化变化介导子代自闭症|国人佳作

大家好,这里是专注表观组学十余年,领跑多组学科研服务的易基因。 自闭症谱系障碍(Autism spectrum disorder,ASD)是一种神经发育障碍,以社交沟通障碍和刻板行为为主要特征。许多研究证明,妊娠期暴露于环境毒素会导致儿童中ASD患病率快速增长。1-硝基芘(1-Nitropyrene,…