An Empirical Model of Large-Batch Training

news/2024/11/28 5:13:44/文章来源:https://www.cnblogs.com/MTandHJ/p/18570527

目录
  • Gradient Noise Scale

McCandlish S., Kaplan J., Amodei D. and OpenAI Dota Team. An empirical model of large-batch training. 2018.

本文讨论了随着 batch size 改变, sgd-style 的优化器的学习应该怎么调整.

Gradient Noise Scale

  • 考虑如下的优化问题:

    \[\tag{1} \min_{\theta \in \mathbb{R}^D} \quad L(\theta) = \mathbb{E}_{x \sim \rho} [L_x(\theta)], \]

    其中 \(\rho(x)\) 是数据 \(x\) 所服从的分布.

  • 通常来说, 精准地优化 (1) 需要计算整个数据集上的梯度, 这个不太现实, 所以实际中, 通常采用 mini-batch 更新策略:

    \[L_{batch}(\theta) = \frac{1}{B} \sum_{i=1}^B L_{x_i} (\theta), \quad x_i \sim \rho. \]

  • 所对应的, SGD 更新策略为:

    \[\theta_{t + 1} \leftarrow \theta_t - \epsilon \underbrace{\frac{1}{B} \sum_{i=1}^B \nabla_{\theta} L_{x_i} (\theta_t)}_{=: G_{est}}, \]

    其中 \(\epsilon\) 为步长.

  • 进一步假设 (\(G = \nabla_{\theta} L, H = \nabla_{\theta}^2 L\))

    \[L(\theta - \epsilon V) \approx L(\theta) - \epsilon G^T V + \frac{1}{2} \epsilon^2 V^T H V. \]

    容易发现, 此时最优的 \(\epsilon\)

    \[\epsilon_{\max} = \frac{|G|^2}{G^T H G}. \]

  • 对于 mini-batch 的更新情况, 类似有

    \[\begin{array}{ll} \mathbb{E}[L(\theta - \epsilon G_{est})] &= L(\theta) - \epsilon G^T \mathbb{E}[G_{est}] + \frac{1}{2} \epsilon^2 \mathbb{E}[G_{est}^T H G_{est}] \\ &= L(\theta) - \epsilon G^T G + \frac{1}{2} \epsilon^2 \mathbb{E}[G_{est}^T H G_{est}] \\ &= L(\theta) - \epsilon G^T G + \frac{1}{2} \epsilon^2 \mathbb{E}[G^T H G + \frac{\text{tr}(H\Sigma)}{B}], \end{array} \]

    其中

    \[\Sigma = \text{Cov}(\nabla_{\theta} L_x(\theta)). \]

注: 上述第二个等式成立的原因是:
$$
\begin{array}{ll}
\mathbb{E}_x[x^TAx]
&=\mathbb{E}[(A^{1/2} x)^{T} (A^{1/2}x)] \
&=\text{Tr}(\mathbb{E}[(A^{1/2} x)^{T} (A^{1/2}x)]) \
&=\mathbb{E}[\text{Tr}((A^{1/2} x)^{T} (A^{1/2}x))] \
&=\mathbb{E}[\text{Tr}((A^{1/2}x) (A^{1/2} x)^{T} )] \
&=\mathbb{E}[\text{Tr}(A{1/2}xxT A^{1/2})] \
&=\text{Tr}(A^{1/2} \mathbb{E}[xx^T] A^{1/2}) \
&=\text{Tr}(A^{1/2} (\text{Cov}(x, x) + \mathbb{E}[x]\mathbb{E}[x]^T]) A^{1/2}) \
&=\text{Tr}(A \text{Cov}(x, x)) + \mathbb{E}[x]^T A \mathbb{E}[x].
\end{array}
$$

  • 因此, 在这个情况下, 我们有

    \[\epsilon_{opt} (B) = \frac{\epsilon_{\max}}{1 + \mathcal{B}_{noise} / B}, \quad \mathcal{B}_{noise} = \frac{\text{tr}(H\Sigma)}{G^T H G}. \]

    其中 \(\mathcal{B}_{noise}\) 被称之为 noise scale.

  • 所以, 当 \(\mathcal{B}_{noise} \gg B\) 的时候, 增大 batch size \(B\) 应当相应的线性地增大学习率, 当 \(\mathcal{B}_{noise} < B\) 的时候, 再增大 batch size 对于学习率的调节就不需要那么灵敏 (实际上在这种情况下, 这种情况下再继续增大 batch size 所得到的效率的增益是很微弱的):

注: 作者在 Appendix D.1 中证明了训练速度和训练样本所满足的一个等式关系 (但是其中的证明我没有推过去).

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/841716.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

算法网关视频分析网关拍照检测高空抛物检测算法:守护城市安全的“天眼”

高空抛物,一个看似微不足道的行为,实则隐藏着巨大的安全隐患。随着城市化进程的加快,高层建筑如雨后春笋般拔地而起,高空抛物现象也随之增多,给人们的生活带来了严重的威胁。从烟头、饮料瓶到花盆、垃圾,这些被随意抛掷的物品,一旦从高空落下,其破坏力不容小觑。 为了有…

jquery仿PPT幻灯片特效插件ppt.js

ppt.js是一款jquery仿PPT幻灯片特效插件。该jquery插件基于jquery来显示图片翻页效果,可全屏显示,以及自定义图片的宽度和高度。演示 下载使用方法 在页面中引入jquery和ppt.js文件,以及字体图标文件iconic和插件样式文件ppt.css。<link rel="stylesheet" hr…

性能指标详解

一、监听器中的插件 @gc - Active Threads Over Timeip 活动线程时间 @gc - AutoStop Listener 自动停止侦听器 @gc - Bytes Throughput Over Timejp 字节吞吐量随时间变化 @gc -Composite Graph 综合图 @gc - Connect Times Over Timejp 连接时间 @gc -Console Status Loggerj…

leetcode78 子集

leetcode78 子集思路:深度优先搜索回溯 分析此类问题可以先用树形结构模拟代码逻辑。那么根据这个解答树,首先我们的回溯搜索函数应该由这么几部分组成将搜索获得的答案加入到res中。 for循环遍历搜索下一个元素(比如在初始列表为空的时候,第一位可以选1,2,3显然需要通过…

bootstrap模态窗口美化特效

这是一款bootstrap模态窗口美化特效。该特效在原生bootstrap模态窗口的基础上,通过添加自定义的CSS样式,制作出效果非常炫酷的模态窗口。演示 下载使用方法 在页面中引入下面的文件。<link rel="stylesheet" href="http://jrain.oscitas.netdna-cdn.com…

重温经典,一网万游:在线红白机FC游戏平台:webgame.one

还记得小时候守在电视机前,手握红白机手柄,沉浸在《魂斗罗》紧张刺激的战斗、《超级马里奥兄弟》奇妙的冒险世界,或是与小伙伴一起在《坦克大战》里并肩作战的美好时光吗?那些经典的 FC 游戏,承载着我们童年最纯真的快乐与回忆。如今,有一个名为 https://webgame.one 的在…

明火识别视频分析服务器烟雾识别小区住宅智慧消防场景方案

随着城市化进程的加快和科技的不断进步,燃气安全和消防安全已成为城市安全管理的重要组成部分。为了响应国家政策的号召,提升城镇燃气安全水平,以及加强高层民用建筑的消防安全管理,迫切需要一套科学、高效的技术解决方案来应对当前的挑战。 本文将详细介绍如何利用明火识别…

vue2 数据导入excel

1、安装 npm install xlsx一、前端<el-uploadstyle="display: inline-block"actionaccept=".xlsx, .xls":auto-upload="false":show-file-list="false":on-change="handleUpload"><el-button type="primary&q…

小迪安全第10天HTTP数据包

请求包:request 回显包:response (1)请求方式:post get get:提交请求 post:向指定资源提交内容,登录/上传文件 •get:向特定资源发出请求(请求指定页面信息,并返回实体主体); •post:向指定资源提交数据进行处理请求(提交表单、上传文件),又可能导致新的资源的…

《安富莱嵌入式周报》第346期:开源2GHz带宽,12bit分辨率,3.2Gsps采样率示波,开源固件安全分析器, 开源口袋电源,开源健康测量,FreeCAD

周报汇总地址:http://www.armbbs.cn/forum.php?mod=forumdisplay&fid=12&filter=typeid&typeid=104 视频: https://www.bilibili.com/video/BV1TYBhYKECK/目录: 1、开源2GHz带宽,12bit分辨率,3.2Gsps采样率示波器 2、开源嵌入式固件安全分析器 3、TI分享的8…