深入探讨梯度下降:优化机器学习的关键步骤(三)

文章目录

  • 🍀引言
  • 🍀随机、批量梯度下降的差异
  • 🍀随机梯度下降的实现
  • 🍀随机梯度下降的调试

🍀引言

随机梯度下降是一种优化方法,主要作用是提高迭代速度,避免陷入庞大计算量的泥沼。在每次更新时,随机梯度下降只使用一个样本中的一个例子来近似所有的样本,来调整参数,虽然不是全局最优解,但很多时候是可接受的。

前两篇主要介绍了一下批量梯度下降,本节前部分主要介绍一下随机梯度下降


🍀随机、批量梯度下降的差异

随机梯度下降和批量梯度下降都是常用的优化方法,它们在处理大规模数据集时都有自己的优点和缺点。以下是它们的不同点:

  • 相同点:
    两种方法都用于优化目标函数,通过迭代地更新参数来最小化目标函数。在每一步迭代中,它们都会根据当前参数的梯度来更新参数。

  • 不同点:
    (1)样本的使用方式:在随机梯度下降中,每次迭代只使用**一个样本**来计算梯度;而在批量梯度下降中,每次迭代会使用整个数据集来计算梯度。因此,随机梯度下降在处理大规模数据集时更高效,因为它不需要加载整个数据集到内存中。

    (2)收敛速度:由于随机梯度下降每次只使用一个样本来计算梯度,因此它的收敛速度通常比批量梯度下降更快。但是,随机梯度下降的收敛可能更加波动,因为每次迭代的样本可能不同。

    (3)准确度:批量梯度下降的准确度通常比随机梯度下降更高。因为批量梯度下降会使用整个数据集来计算梯度,因此它的更新更精确。但是,在处理大规模数据集时,批量梯度下降可能会遇到内存不足的问题。

这里可以通过下列图来进行简单的说明
请添加图片描述

上面这种图是批量梯度下降的主要公式,前两篇文章已经介绍了
请添加图片描述
上面的这张图指的就是随机梯度下降的主要公式了,我们可以看到求个符号消失了

🍀随机梯度下降的实现

导入必要的库

import numpy as np

选取100000个数据作为测试数据

m = 100000
x = np.random.random(size=m)
y = x*3+4+np.random.normal(size=m)  # 后面的添加的噪音

注意:后面加了一个噪音目的是使得原有的数据添加一些随机性,省的太假了~
之后我们需要编写两个函数,前一个函数主要是用来计算样本的梯度,后一个函数主要包括计算学习率以及循环判断

def sgd(X_b,y,initial_theta,n_iters,epsilon=1e-8):def learning_rate(i_iter):t0=5t1 = 50return t0/(i_iter+t1)theta = initial_thetai_iter = 1while i_iter<=n_iters:index=np.random.randint(0,len(X_b))x_i = X_b[index]y_i = y[index]gradient = dj_sgd(theta,x_i,y_i)theta = theta-gradient*learning_rate(i_iter)i_iter+=1return theta

注意:在学习率的计算采用模拟退火思想,目的是为了控制参数的变化来影响行为,从而达到更好的优化效果。
请添加图片描述
之后我们需要使用numpy库中的hstack函数在x左侧添加一列

X_b = np.hstack([np.ones((len(x),1)),x])  # 左测增加一列

在添加前,我们需要将x转成矩阵

x = x.reshape(-1,1)

运行结果如下
在这里插入图片描述
之后我们需要设置initial_theta初始值

initial_theta = np.zeros(X_b.shape[1])

前提的准备做完就可以验证了

%%time 
sgd(X_b,y,initial_theta,n_iters=m//4)

运行结果如下
在这里插入图片描述
返回的值,分别近似截距和系数


我们可以将代码再优化一下

def sgd(X_b, y, initial_theta, n_iters, epsilon=1e-8):def learning_rate(i_iter):t0 = 5t1 = 50return t0 / (i_iter + t1)theta = initial_theta  # 初始化模型参数m = len(X_b)  # 样本数量for cur_iter in range(n_iters):  # 迭代n_iters次,每轮迭代看一遍整个样本random_indexs = np.random.permutation(m)  # 随机打乱样本的顺序,用于随机梯度下降X_random = X_b[random_indexs]  # 打乱后的特征数据y_random = y[random_indexs]  # 打乱后的标签数据for i in range(m):  # 遍历每个样本# 使用学习率learning_rate(cur_iter*m+i)来更新模型参数theta,通过梯度dj_sgd计算theta = theta - learning_rate(cur_iter * m + i) * dj_sgd(theta, X_random[i], y_random[i])return theta  # 返回优化后的模型参数

这个函数使用了随机梯度下降算法来更新模型参数,通过不断地随机选择一个样本进行参数更新,逐渐优化模型以适应训练数据。学习率随着迭代次数变化,初始较大然后逐渐减小,以有利于收敛到最优解。


🍀随机梯度下降的调试

首先还是做前期的准备

import numpy as np
X = np.random.random(size=(1000,10))
X_b = np.hstack([np.ones((len(X),1)),X])
true_theta = np.arange(1,12,dtype='float') # 这里代表有11个特征值(10个系数,1个截距)
y = X_b.dot(true_theta) + np.random.normal(size=len(X))

之后我们分别才有两种方法进行调试
首先是dj_math

这个函数用于计算线性回归中的成本函数(通常是均方误差)相对于参数 theta 的梯度,采用了矢量化的方法。这是数学公式:

在这里插入图片描述

  • X_b 是包含偏置项的特征矩阵(通常是原始特征矩阵的一列加上全部为 1 的列)。
  • y 是目标向量。
  • theta 是待更新的参数向量。
  • m 是训练样本的数量。
def dj_math(theta,X_b,y):return X_b.T.dot(X_b.dot(theta)-y)*2./len(X_b)

其次是dj_debug

这个函数使用数值逼近方法来计算成本函数相对于参数的梯度。它通过轻微地扰动每个参数 theta[i] 并测量成本函数 j 的变化来估计梯度。这是数学公式:

在这里插入图片描述

  • theta 是参数向量。
  • X_b 是包含偏置项的特征矩阵。
  • y 是目标向量。
  • i 是被扰动的参数的索引。
  • epsilon 是用于扰动的小值。
def dj_debug(theta,X_b,y):res=np.empty(len(theta))epsilon = 0.01for i in range(len(theta)):theta1 = theta.copy()theta2 = theta.copy()theta1[i] +=epsilontheta2[i] -=epsilonres[i] = (j(theta1,X_b,y)-j(theta2,X_b,y))/(2*epsilon)return res

这种数值逼近通常用于调试和验证梯度计算的正确性,特别是在梯度下降等基于梯度的优化算法中,有助于优化参数 theta 的训练过程

完整代码如下

def j(theta,X_b,y):try:return np.sum((X_b.dot(theta)-y)**2)/len(X_b)except:return float('inf')def dj_math(theta,X_b,y):return X_b.T.dot(X_b.dot(theta)-y)*2./len(X_b)def dj_debug(theta,X_b,y):res=np.empty(len(theta))epsilon = 0.01for i in range(len(theta)):theta1 = theta.copy()theta2 = theta.copy()theta1[i] +=epsilontheta2[i] -=epsilonres[i] = (j(theta1,X_b,y)-j(theta2,X_b,y))/(2*epsilon)return resdef gradient_descent(dj,X_b,y,eta,initial_theta,n_iters=1e4,epsilon=1e-8):theta = initial_thetai_iter = 1while i_iter<n_iters:last_theta = thetatheta =theta- eta*dj(theta,X_b,y)if abs(j(theta,X_b,y)-j(last_theta,X_b,y))<epsilon:breaki_iter+=1return theta

可以分别进行测试一下,显然前者更快一点
在这里插入图片描述

请添加图片描述

挑战与创造都是很痛苦的,但是很充实。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/103984.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

一篇文章教会你什么是高度平衡二叉搜索(AVL)树

高度平衡二叉搜索树 AVL树的概念1.操作2.删除3.搜索4.实现描述 AVL树的实现1.AVL树节点的定义2.AVL树的插入3.AVL树的旋转3.1 新节点插入较高右子树的右侧---右右:左单旋3.2 新节点插入较高左子树的左侧---左左:右单旋3.3 新节点插入较高左子树的右侧---左右&#xff1a;先左单…

滑动谜题 -- BFS

滑动谜题 输入&#xff1a;board [[4,1,2],[5,0,3]] 输出&#xff1a;5 解释&#xff1a; 最少完成谜板的最少移动次数是 5 &#xff0c; 一种移动路径: 尚未移动: [[4,1,2],[5,0,3]] 移动 1 次: [[4,1,2],[0,5,3]] 移动 2 次: [[0,1,2],[4,5,3]] 移动 3 次: [[1,0,2],[4,5,3]…

hive指定字段插入数据,包含了分区表和非分区表

1、建表 语句如下&#xff1a; CREATE EXTERNAL TABLE ods_lineitem_full (l_shipdate date,l_orderkey bigint,l_linenumber int,l_partkey int,l_suppkey int,l_quantity decimal(15, 2),l_extendedprice decimal(15, 2),l_discount de…

java开发之个微机器人的二次开发

简要描述&#xff1a; 修改好友备注 请求URL&#xff1a; http://域名地址/modifyRemark 请求方式&#xff1a; POST 请求头Headers&#xff1a; Content-Type&#xff1a;application/jsonAuthorization&#xff1a;login接口返回 参数&#xff1a; 参数名必选类型说…

C++之红黑树

红黑树 红黑树的概念红黑树的性质红黑树结点的定义红黑树的插入红黑树的验证红黑树与AVL树的比较 红黑树的概念 红黑树&#xff0c;是一种二叉搜索树&#xff0c;但在每个结点上增加一个存储位表示结点的颜色&#xff0c;可以是Red或Black。 通过对任何一条从根到叶子的路径上…

Spark 6:Spark SQL DataFrame

SparkSQL 是Spark的一个模块, 用于处理海量结构化数据。 SparkSQL是用于处理大规模结构化数据的计算引擎 SparkSQL在企业中广泛使用&#xff0c;并性能极好 SparkSQL&#xff1a;使用简单、API统一、兼容HIVE、支持标准化JDBC和ODBC连接 SparkSQL 2014年正式发布&#xff0c;当…

Support for password authentication was removed on August 13, 2021 解决方案

打开你的github&#xff0c;Setting 点击Developer settings。 点击generate new token 按照需要选择scope 生成token&#xff0c;以后复制下来。 给git设置token样式的remote url git remote set-url origin https://你的tokengithub.com/你的git用户名/仓库名称.git然后就可…

【List篇】使用Arrays.asList生成的List集合,操作add方法报错

早上到公司&#xff0c;刚到工位&#xff0c;测试同事就跑来说"功能不行了&#xff0c;报服务器异常了&#xff0c;咋回事";我一脸蒙&#xff0c;早饭都顾不上吃&#xff0c;要来了测试账号复现了一下&#xff0c;然后仔细观察测试服务器日志&#xff0c;发现报了一个…

Redis缓存预热、缓存雪崩、缓存击穿、缓存穿透

文章目录 Redis缓存预热、缓存雪崩、缓存击穿、缓存穿透一、缓存预热1、问题排查2、解决方案&#xff08;1&#xff09;准备工作&#xff08;2&#xff09;实施&#xff08;3&#xff09;总结 二、缓存雪崩1、解决方案 三、缓存击穿1、解决方案&#xff08;1&#xff09;互斥锁…

Mavan进阶之多模块(聚合)

文章目录 Maven 多模块&#xff08;聚合&#xff09;非父子关系的多模块项目 Maven 多模块&#xff08;聚合&#xff09; Maven 继承和聚合是 2 个独立的概念。工程与工程之间可能毫无关系&#xff0c;也可能是继承关系&#xff0c;也可能是聚合关系&#xff0c;也可能既是继承…

C#学习 - 初识类与名称空间

类&#xff08;class&#xff09;& 名称空间&#xff08;namespace&#xff09; 类是最基础的 C# 类型&#xff0c;是一个数据结构&#xff0c;是构成程序的主体 名称空间以树型结构组织类 using System; //前面的using就是引用名称空间 //相当于C语言的 #include <..…

maven聚合工程的创建

父工程&#xff1a; parent-project 子工程&#xff1a;project-child project-child2 project-child3 创建父工程 将src目录删除了(在父工程中的src目录是没有用的&#xff09; 创建子工程 右击父工程------new------module 聚合工程创建完之后 在父工程的pom文件中 …