深入探讨梯度下降:优化机器学习的关键步骤(一)

文章目录

  • 🍀引言
  • 🍀什么是梯度下降?
  • 🍀损失函数
  • 🍀梯度(gradient)
  • 🍀梯度下降的工作原理
  • 🍀梯度下降的变种
    • 🍀随机梯度下降(SGD)
    • 🍀批量梯度下降(BGD)
    • 🍀小批量梯度下降(Mini-Batch GD)
  • 🍀如何选择学习率?
  • 🍀梯度下降的相关数学公式
  • 🍀梯度下降的实现(代码)
  • 🍀总结

🍀引言

在机器学习领域,梯度下降是一种核心的优化算法,它被广泛应用于训练神经网络、线性回归和其他机器学习模型中。本文将深入探讨梯度下降的工作原理,并且进行简单的代码实现


🍀什么是梯度下降?

梯度下降是一种迭代优化算法,旨在寻找函数的局部最小值(或最大值)以最小化(或最大化)一个损失函数。在机器学习中,我们通常使用梯度下降来最小化模型的损失函数,以便训练模型的参数。
这里顺便提一嘴,与梯度下降齐名的梯度上升算法目的是使效用函数最大。


🍀损失函数

在使用梯度下降之前,我们首先需要定义一个损失函数。损失函数是一个用于衡量模型预测值与实际观测值之间差异的函数。通常,我们使用均方误差(MSE)作为回归问题的损失函数,使用交叉熵作为分类问题的损失函数。


🍀梯度(gradient)

梯度是损失函数相对于模型参数的偏导数。它告诉我们如果稍微调整模型参数,损失函数会如何变化。梯度下降算法利用梯度的信息来不断调整参数,以减小损失函数的值。

🍀梯度下降的工作原理

梯度下降的核心思想是沿着损失函数的负梯度方向调整参数,直到达到损失函数的局部最小值。具体来说,梯度下降的步骤如下:

  • 初始化模型参数:首先,随机初始化模型参数或使用某种启发式方法。

  • 计算损失和梯度:使用当前模型参数计算损失函数的值,并计算损失函数相对于参数的梯度。

  • 参数更新:根据梯度的方向和学习率(learning rate)本文我称其为eta,更新模型参数。学习率是一个控制步长大小的超参数,它决定了每次迭代中参数更新的大小。

  • 重复迭代:重复步骤2和3,直到损失函数的值收敛到一个稳定的值,或达到预定的迭代次数。

🍀梯度下降的变种

在梯度下降的基础上,发展出了多种变种算法,以应对不同的问题和挑战。其中一些常见的包括

🍀随机梯度下降(SGD)

随机梯度下降每次只使用一个随机样本来估计梯度,从而加速收敛速度。它特别适用于大规模数据集和在线学习。

🍀批量梯度下降(BGD)

批量梯度下降在每次迭代中使用整个训练数据集来计算梯度。尽管计算开销较大,但通常能够更稳定地收敛到全局最小值。

🍀小批量梯度下降(Mini-Batch GD)

小批量梯度下降综合了SGD和BGD的优点,它使用一个小批量样本来估计梯度,平衡了计算效率和收敛性能。

🍀如何选择学习率?

学习率是梯度下降的关键超参数之一。选择合适的学习率可以加速收敛,但过大的学习率可能导致不稳定的训练过程。通常,我们可以采用以下方法选择学习率:

  • 网格搜索:尝试不同的学习率值,通过验证集的性能来选择最佳值。

  • 学习率衰减:开始时使用较大的学习率,随着训练的进行逐渐减小学习率。

  • 自适应学习率:使用自适应学习率算法,如Adam、Adagrad或RMSprop,它们可以自动调整学习率以适应梯度的变化。

🍀梯度下降的相关数学公式

本人数学不好,这里有说的不清楚的地方还请见谅,谢谢佬~
首先我们通过图像认识一下损失函数
在这里插入图片描述
这里的步长指的是,可能有些人会好奇为啥有一个负号呢?因为对称轴左侧的导数都是负值,这里加一个负号不就正了嘛
在这里插入图片描述

具体推导过程请查看相关佬的文章(哭~)

🍀梯度下降的实现(代码)

首先我们导入我们需要的库

import numpy as np
import matplotlib.pyplot as plt

之后我们需要举一个例子,这里我们采用numpy里面的一个分割函数linspace,同时我们举一个函数的例子

plt_x = np.linspace(-1,6,141)
plt_y = (plt_x-2.5)**2-1

之后我们使用show进行展示一下图像

plt.plot(plt_x,ply_y)
plt.show()

运行结果如下
在这里插入图片描述

上图看起来就是一个普通的曲线,方便我们进行理解

接下来我们需要两个函数,一个为了返回导数,一个为了返回对应的y值

def dj(thera):return 2*(thera-2.5) # 求导
def j(thera)return (thera-2.5)**2-1  # 求对应的值

接下来是梯度下降的关键位置了,这里我们需要初始化两个参数以及一个范围参数,同时设置一个while循环,将前一个thera保存在last_thera中,后一个thera是前一个thera和步长的差值,这里的步长就是梯度个参数eta的乘积,最后使用if函数来终结循环,最终我们将最小值点的值、导数、以及自变量打印出来

eta = 0.1
theta =0.0
epsilon = 1e-8
while True:gradient = dj(theta)last_theta = thetatheta = theta-gradient*eta if np.abs(j(theta)-j(last_theta))<epsilon:breakprint(theta)
print(dj(theta))
print(j(theta))

运行结果如下
在这里插入图片描述
这里我们也可以使用列表来看看到底进行了多少次thera的循环

eta = 0.1
theta =0.0
epsilon = 1e-8
theta_history = [theta]
while True:gradient = dj(theta)last_theta = thetatheta = theta-gradient*eta theta_history.append(theta)if np.abs(j(theta)-j(last_theta))<epsilon:breakprint(theta)
print(dj(theta))
print(j(theta))len(theta_history)

运行结果如下

在这里插入图片描述
还可以绘制图像进行直观查看

plt.plot(plt_x,plt_y)
plt.plot(theta_history,[(i-2.5)**2-1 for i in theta_history],color='r',marker='*')
plt.show()

运行结果如下
在这里插入图片描述
这样的话就很直观了吧~

🍀总结

本节只介绍梯度下降的简单实现,下节继续学习此法中eta参数的调节

请添加图片描述

挑战与创造都是很痛苦的,但是很充实。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/94270.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

C++--动态规划其他问题

1.一和零 力扣&#xff08;LeetCode&#xff09;官网 - 全球极客挚爱的技术成长平台 给你一个二进制字符串数组 strs 和两个整数 m 和 n 。 请你找出并返回 strs 的最大子集的长度&#xff0c;该子集中 最多 有 m 个 0 和 n 个 1 。 如果 x 的所有元素也是 y 的元素&#xff0…

PlumeLog查不到日志

一 问题&#xff1a; PlumeLog查不到日志&#xff0c;记录遇到的情况 二 场景 1. 输入不全

AI人员打闹监测识别算法

AI人员打闹监测识别算法通过yolopython网络模型框架算法&#xff0c; AI人员打闹监测识别算法能够准确判断出是否有人员进行打闹行为&#xff0c;算法会立即发出预警信号。Yolo算法&#xff0c;其全称是You Only Look Once: Unified, Real-Time Object Detection&#xff0c;其…

【LeetCode】《LeetCode 101》第十二章:字符串

文章目录 12.1 字符串比较242 . 有效的字母异位词&#xff08;简单&#xff09;205. 同构字符串&#xff08;简单&#xff09;647. 回文子串&#xff08;中等&#xff09;696 . 计数二进制子串&#xff08;简单&#xff09; 12.2 字符串理解224. 基本计算器&#xff08;困难&am…

Axure RP软件安装包分享(附安装教程)

目录 一、软件简介 二、软件下载 一、软件简介 Axure RP是一款专业的原型设计工具&#xff0c;它能够帮助用户创建高保真度的交互式原型。 Axure RP具有以下特点&#xff1a; 强大的交互设计功能&#xff1a;Axure RP提供了丰富的交互设计工具&#xff0c;用户可以通过拖拽和…

Mysql 事物与存储引擎

MySQL事务 MySQL 事务主要用于处理操作量大&#xff0c;复杂度高的数据。比如说&#xff0c;在人员管理系统中&#xff0c; 要删除一个人员&#xff0c;即需要删除人员的基本资料&#xff0c;又需要删除和该人员相关的信息&#xff0c;如信箱&#xff0c; 文章等等。这样&#…

POSTGRESQL WAL 日志问题合集之WAL 如何解析

开头还是介绍一下群&#xff0c;如果感兴趣polardb ,mongodb ,mysql ,postgresql ,redis 等有问题&#xff0c;有需求都可以加群群内有各大数据库行业大咖&#xff0c;CTO&#xff0c;可以解决你的问题。加群请加 liuaustin3微信号 &#xff0c;在新加的朋友会分到3群 &#xf…

VBA快速插入签名(位置不固定)

实例需求&#xff1a;Excel中的多页表格如下图所示&#xff0c;其中包含多个“受益人签字”&#xff0c;其位置不固定&#xff0c;现在需要在其后插入签名图片。 签名图片为透明背景的PNG文件&#xff08;左上角方框内的部分&#xff09;&#xff0c;图片文件属性信息如下图所示…

李宏毅hw1_covid19预测_代码研读+想办法降低validation的loss(Kaggle目前用不了)

1.考虑调整这个neural network的结构尝试让这个loss降低 &#xff08;1&#xff09;Linear(inputdim,64) - ReLU-Linear(64,1), loss0.7174 &#xff08;2&#xff09;Linear(inputdim,64) - ReLU-Linear(64,64) -ReLU-Linear(64,1),loss 0.6996 &#xff08;3&#xff09;这…

linux内网yum源服务器搭建

1.nginx: location / {root /usr/local/Kylin-Server-V10-SP3-General-Release-2303-X86_64;autoindex on;autoindex_localtime on;autoindex_exact_size off; } 注:指定到镜像的包名 2.修改yum源地址 cd /etc/yum.repos.d/vim kylin_x86_64.repo 注: --enabled设置为1 3.重…

Java面向对象:非访问修饰符、继承

继承的概念&#xff1a; 继承是java面向对象编程技术的一块基石&#xff0c;因为它允许创建分等级层次的类。 继承就是子类继承父类的特征和行为&#xff0c;使得子类对象&#xff08;实例&#xff09;具有父类的实例域和方法&#xff0c;或子类从父类继承方法&#xf…

基于Open3D的点云处理16-特征点匹配

点云配准 将点云数据统一到一个世界坐标系的过程称之为点云配准或者点云拼接。&#xff08;registration/align&#xff09; 点云配准的过程其实就是找到同名点对&#xff1b;即找到在点云中处在真实世界同一位置的点。 常见的点云配准算法: ICP、Color ICP、Trimed-ICP 算法…