机器学习优化算法

news/2025/3/31 11:07:08/文章来源:https://www.cnblogs.com/Skienz/p/18797707

优化算法——SGD、Momentum、Adagrad、RMSprop、Adam、AdamW

  • 统一数学表达:设损失函数为$\mathcal{L}(\theta) $,学习率为$\eta$。
    • 每次迭代仅使用一个随机小批量(mini-batch)数据计算梯度
    • 从训练集中采样包含小批量$m$个样本${x{(1)},\cdots,x{(m)}}$,其对应的目标为${y{(1)},\cdots,y{(m)}}$。则用于计算的梯度$\displaystyle g=\frac{1}{m}\sum_{i=1}^m\nabla_\theta \mathcal{L} (f(x{(i)};\theta),y)$。
  • 本文中出现的数学表达式中的参数是单个元素,当参数为矩阵时,对矩阵中的每个元素进行相同的更新操作。比如$g$是矩阵,则$g^2=g\odot g$。

1. SGD

1.1 基本概念

  • 随机梯度下降,stochastic gradient descent。

  • 更新公式$\theta_{t+1}\leftarrow\theta_t-\eta\cdot g$。

  • PyTorch中调用方法

# (params: _params_t, lr: float, momentum: float = ..., dampening: float = ..., weight_decay: float = ..., nesterov: bool = ...) -> Noneoptimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)

1.2 Case study

  • 更新不稳定
  • 设$J(x,y)=x2+9y2$,初始$(x_0,y_0)=(2,2)$,设置学习率$\eta=0.1$。假设用mini-batch求出来的梯度就是理论值,则更新公式为$(x,y)=(x-\eta\nabla_xJ,y-\eta\nabla_yJ)=(0.8x,-0.8y)$。根据SGD优化到最低点$(0,0)$:$(2,2)\rightarrow(1.6,-1.6)\rightarrow(1.28,1.28)\rightarrow(1.024,-1.024)\rightarrow\cdots$。由此看出梯度大时导致收敛不稳定,产生震荡。

2 Momentum

  • 动量法
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9)

2.1 指数加权平均

day 观测值$\theta$ 代替$\theta$的估计值$v$
1 100 30
2 107 53.1
3 105 68.67
4 110 81.069
5 126 94.5483
6 120 102.18381
7 130
  • 有什么用:刻画数据变化的趋势。比如上述表格中观测值虽然不是单调递增的,但其整体的趋势是增长的,因此用$v$来刻画该增长趋势。

  • 现在希望预测第7天的值,可以想到用加权平均$\displaystyle v_7=\frac{1}{7}\sum_{i=1}^{7}\theta_i$。但实际上越近期的数据权重应该更大,因此用指数加权平均:$v_t=\beta v_{t-1}+(1-\beta)\theta_t$。

  • 初始$v_0=0,\beta=0.7$。填入表格。可以发现当时间序列较短时,最早几天的$v$值很小,不准确。时间序列够长,$v_0$的权重越小,即影响越小。

  • 修正:$\displaystyle v_t{correct}=\frac{v_t}{1-\betat}$。

2.2 基本概念

  • 引入历史梯度加权平均,在梯度方向一致时加速收敛,减少SGD中的震荡,从而更加稳定。
  • 数学表示:
    • 速度更新(累计梯度):$v_t=\gamma v_{t-1}+(1-\gamma)\cdot g$。其中$\gamma$为动量系数,一般设置为0.9。
    • 参数更新:$\theta_t=\theta_{t-1}-\eta\cdot v_t$。

3 Adagrad

  • Adaptive Gradient Algorithm。自适应学习率优化算法,根据参数的历史梯度动态调整学习率,尤其适用于稀疏数据和高维优化问题。

  • 数学表达:

    • 累计梯度平方和:$G=G+g^2$。初始化$G=0$。
    • 参数更新:$\displaystyle \theta_{t+1}=\theta_t-\frac{\eta}{\sqrt{G+\varepsilon}}\cdot g$。其中$\eta$是全局学习率(超参数),$\varepsilon$是防止除0的很小的数。
  • 问题:累计越来越大,导致后期收敛缓慢。

3.1 优化:RMSprop

  • Root Mean Square Propagation。
  • PyTorch调用:
optimizer = torch.optim.RMSprop(model.parameters(), lr=0.01, momentum=0.9)
  • 与Adagrad唯一不同的地方:
    • 累计梯度平方换成了指数加权平均:$G=\beta\cdot G+(1-\beta)\cdot g^2$。
  • 但后期容易在小范围内产生震荡。

4 ✅Adam

  • Adam = RMSprop + Momentum。对学习率(步长)不敏感,建议默认0.001。
  • $s,r$初始化均为0;$\beta_1=0.9,\beta_2=0.999$。数学表达:
    • 一阶矩估计(Monmentum部分):$s=\beta_1s+(1-\beta_1)\cdot g$。
    • 二阶矩估计(RMSprop部分):$r=\beta_2r+(1-\beta_2)\cdot g^2$。
    • 修正:$\displaystyle\hat{s}=\frac{s}{1-\beta_1t},\hat{r}=\frac{r}{1-\beta_2t}$。
    • 更新:$\displaystyle \theta\leftarrow\theta-\frac{\eta}{\sqrt{\hat{r}}+\varepsilon}\hat{s}$。
  • PyTorch调用:
optimizer = torch.optim.Adam(params, lr=learning_rate, weight_decay=weight_decay)

4.1 AdamW

  • W:weight decay,权重衰减系数$\lambda$。达到泛化。
    • 不是L2正则化(L2 Regularization):$\displaystyle \mathcal{L}=\mathcal{L}(\theta)+\frac{\lambda}{2}\left | \theta \right |^2$。因为修改了损失函数。
  • 唯一不同:$\theta\leftarrow\theta-\frac{\eta}{\sqrt{\hat{r}}+\varepsilon}\hat{s}-{\color{Red} \lambda\cdot \eta\theta} $。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/907058.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

数码管静态显示

前言 目标 控制LED数码管,静态显示数字 原理 51 单片机的 LED 数码管有8个 每个数码管又由 8 个数码段组成选择要点亮的 LED 数码管的位置, 一共8个位置点亮特定 LED 数码管的数码段, 通过不同的组合,从而显示出想要的字符效果图参考资料 [4-1]静态数码管显示 位码 一共是8个…

上线Steam好评如潮!《刺客信条:影》真被玩家骂到逆袭了?

发售前被喷成筛子,上线后直接真香? 最近Steam玩家圈被《刺客信条:影》彻底刷屏了!虽然预告片公布时因为黑人武士主角、历史细节争议被疯狂吐槽,但游戏上线后却上演大型打脸现场——Steam好评率飙到77%,首周销量直接冲进全球热销榜TOP3,连日本玩家都直呼“忍者跑图太带感…

Gitee DevSecOps:构建智能化军工软件工厂,突破版本管理瓶颈

在军工软件研发向工业化转型的背景下,“软件工厂”模式成为提升研发效率与资源优化配置的核心路径。然而,传统版本管理方法难以应对大规模、跨团队的协同开发需求,导致依赖关系混乱、版本变更失控等问题,严重制约项目交付效率。Gitee DevSecOps平台基于软件工厂的标准化、流…

行政管理系统推荐几个比较好的?

之前写过一篇关于行政管理资料的,指路>> HR猫姐:公司行政究竟是干什么的?这份1000+行政资料收好! 这篇就分享一个我们团队现在正在用的行政管理系统吧——戳此自取模板>> 简道云行政管理系统下面来详细介绍下我们现在主要在用的几个功能: 01 应付/应收合同管理…

【Java】【XXL-job】自己的项目调度任务中心

之前,我们已经学习了xxl-job的入门:https://www.cnblogs.com/luyj00436/p/18780550 。这里的任务执行,调用的是demo。 那么我们自己的项目,如果使用xxl-job?自己的项目,相当于执行器,只要把自己的项目,仿造xxl-job-executor-sample-springboot,即可。 步骤新建Springb…

【Vue】自定义滚动条

<!-- 滚动条开始 --><div class="custom-scrollbar-container"><!-- 添加左右箭头按钮 --><div class="scroll-arrow left-arrow" @click="scrollBy(-100)"><i class="iconfont"style="transform: ro…

重庆软航NTKO WebOffice控件在谷歌Chrome 133版提示扩展已停用解决方案!

NTKO WebOffice‌是重庆软航公司的一款能够在浏览器中直接编辑Microsoft Office、WPS、金山电子表等文档的控件,支持Word、Excel等多种文档格式。该控件能够在IE、Chrome等浏览器中运行,并支持强制痕迹保留、禁止拷贝、模版套红、全文批注等功能‌。 但是软航NTKO WebOffice‌…

5个关键步骤优化IPD流程实施效果

IPD(Integrated Product Development)流程即集成产品开发流程,是一套产品开发的模式、理念与方法。它强调将产品开发视为一个完整的流程,涵盖从市场需求分析、产品规划、设计开发到生产制造、上市销售等各个环节,旨在通过跨部门的团队协作,高效、高质量地推出满足市场需求…