Why Transformers Need Adam: A Hessian Perspective

news/2024/9/20 9:41:04/文章来源:https://www.cnblogs.com/MTandHJ/p/18381295

目录
  • 符号说明
  • 所有参数的 Hessian 矩阵
  • Block-wise Hessian
  • 代码

Zhang Y., Chen C., Ding T., Li Z., Sun R. and Luo Z. Why transformers need adam: a hessian perspective. arXiv preprint, 2024.

本文从 Hessian 矩阵的角度回答为什么 Adam 相较于其它方法, 比如 SGD 在 transformer 的训练上格外有效.

符号说明

  • 假设一个网络分为 \(L\) 个 block, 每个 block 有可学习的参数 \(w \in \mathbb{R}^{d_l}\);
  • \(\mathcal{L}\), 损失函数, \(w = [w_1, w_2 ,\ldots, w_L]\) 记为所有的参数;
  • \(\nabla^2 \mathcal{L} (w_l) \in \mathbb{R}^{d_l \times d_l}\), 第 \(l\) 个 block 的参数的所对应的 Hessian 矩阵

所有参数的 Hessian 矩阵

  • 作者考虑 ResNet18, VGG16 在 ImageNet 上的实验, 以及 GPT2 在 OpenWebText 上的实验, ViT-base 在 ImageNet 上的hi眼, BERT 在 Cornell Movie-Dialogs Corpus 上的实验, GPT2-nano 在 English corpus 上的实验.

  • 首先我们观察一下所有参数的 Hessian 上的差距, 从上图 (Adam, SGD 表现差不多, 所以作者只放了一个), 可以发现, 其实不同的模型, 即使一个是 CNN 另一个是 Transformer, 他们训练的时候的参数的 Hessian 矩阵的整体的谱是相差不大的. 所以我们没法直接从这个指标上回答为什么 Adam 会比 SGD 好一点.

Block-wise Hessian

  • 接着, 我们检查每一个 block, 这里的 block 可以简单理解为 PyTorch 自带的分割, 比如 MLP, Query/Key/Value projection, embedding layer 等.

  • 可以很明显地发现, Transformer 的不同 block 的谱 (分布) 相差是很大的, 而 CNN 的则很一致.

  • 进一步, 我们可以计算不同模型的不同 block 的 hessian 的谱间的 Jensen-Shannon 距离, 可以发现, CNN 的模型一致地低, 而 Transformer 模型不同 block 间差异很大.

  • 我们可以这么认为, 因为 transformer 不同 block 差异很大, 所以很难通过设定一个学习率去统一, 所以需要 Adam 这种每个位置单独设定学习的优化器.

  • 作者认为, 这主要和 Transformer 的层次的堆叠不那么具有序列性有关, 一个简单的例子是, MLP-mixer, 它仅由 MLP 组成, 但是运算方式是模仿 Transformer 的, 可以发现, 它的不同 block 间的距离也呈现类似的情况.

  • 上表列出了不同 block 的一个平均 JS 距离.

  • 作者进一步给出了一个二次方程的优化的例子, 并给予了理论分析, 有兴趣的可以回看原文.

代码

[official]

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/787608.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

VL24 边沿检测

这个就是需要对a 进行打一拍last_a<=a; 需要理解的点是打一拍的last_a是落后a一个时钟周期的,也就是对当前时刻使用a时候,此时的last_a是a的上一时刻的值。`timescale 1ns/1ns module edge_detect(input clk,input rst_n,input a,output reg rise,output reg down ); reg …

RE入门第三天---TEA算法

OK,老规矩,先复习一下昨天的内容 ..... 几分钟就复习了,直接开干今天的内容 先找大佬的wp 来源: TEA系列加密解密 | Gruges Blog (g2uge.github.io) 逆向算法之TEA算法 - Sk2rw - 博客园 (cnblogs.com) 一.TEA加密解密简介 在密码学中,微型加密算法(Tiny Encryption Algo…

vue3 控制el-dialog 双向绑定显示隐藏

父组件<Contact v-model:isView="isView" /> 子组件<template><div><el-dialogwidth="400"title="微信二维码":model-value="props.isView"@closed="handleClose"><div class="dialog-div…

Typora使用PicGo自动上传图片

Gitee配置PicGo图床 简介 由于我们使用Markdown写博客时需要上传一些图片,以便于理解。但是md文件不像Word文件一样能承载图片传输,所以我们使用md文件进行多设备协作,或者传输发给其他人的时候,图片的传输成了很大的问题。一般情况下我们可以搭建一个文件服务器,但是这样…

【网络安全C10-2024.8.24】-docker、数据库、Web应用程序安全

1、在docker中分别以后台方式和交互方式启动centos,对比启动后的容器状态,实现退出容器也能保持其运行状态。docker run -d --name centos7-001 centos docker run -it --name centos7-002 centos /bin/bash docker run -d -t --name centos7-003 centos2、在docker并部署DVW…

网络安全C10-2024.8.24-docker、数据库、Web应用程序安全

docker run -d --name centos7-001 centos docker run -it --name centos7-002 centos /bin/bash docker run -d -t --name centos7-003 centos docker pull sagikazarmark/dvwa docker run -d -p 8082:80 -p 33060:3306 --name dvwa sagikazarmark/dvwa

CodeForces VP Record

CodeForces Round 767 (contest 1628) A Meximum Array 考虑二分。二分的时候计算区间 $ \text{mex} $,参考 P4137 Rmq Problem / mex,主席树即可。时间复杂度 $ \Theta(n \log^2 n) $,无需卡常。 B Peculiar Movie Preferences 首先,对于一个合法的回文串,容易证明首尾两…

数据库监控运维方案,保障高性能及高可用

通过构建对关键指标的监控,实现对数据库性能和资源的实时追踪,识别并解决影响的数据库问题,保障数据库的高性能及高可用性,更全面地支持业务及应用的稳定、持续运行。 随着企业对数据高可用的需求日益增长,对于数据库的实时监控和故障自动恢复方案愈发重要。作为关…

VL22 根据状态转移图实现时序电路

和上一题的第一段完全相同,第二段只是根据状态转移有部分改变,本体使用三段式状态机来写,第三段写法和上一题不一样。`timescale 1ns/1nsmodule seq_circuit(input C ,input clk ,input rst_n,output wire Y );…

Datawhale X 李宏毅苹果书 AI夏令营 深度学习01

神经网络的优化,通常我们使用梯度下降的方法对获取最优的参数,已达到优化神经网络的目的。另外,我们也可以对学习率进行调整,通过使用自适应学习率和学习率调度,最后,批量归一化改变误差表面,达到优化的目的。 同样,也会存在优化失败的时候,在收敛在局部极限值或者鞍点…