随机梯度下降法
- 前言
- 正文
- 代码实现
- 可运行代码
- 结果
前言
随机梯度下降法 (Stochastic Gradient Descent,SGD) 是一种梯度下降法的变种,用于优化损失函数并更新模型参数。与传统的梯度下降法不同,SGD每次只使用一个样本来计算梯度和更新参数,而不是使用整个数据集。这种随机性使得SGD在大型数据集上更加高效,因为它在每次迭代中只需要处理一个样本。
以下是关于随机梯度下降法的详细描述:
- 初姶化参数:与梯度下降法类似,首先需要初始化模型的参数,通常使用随机的初始值。
- 选代过程:
- 对于每个训练样本 i i i :
- 计算损失函数关于当前参数的梯度,即 ∇ f i ( θ ) \nabla f_i(\theta) ∇fi(θ) ,其中 f i ( θ ) f_i(\theta) fi(θ) 是针对第 i i i 个样本的损失。
- 使用计算得到的梯度来更新模型参数: θ = θ − η ⋅ ∇ f i ( θ ) \theta=\theta-\eta \cdot \nabla f_i(\theta) θ=θ−η⋅∇fi(θ) ,其中 η \eta η 是学习率。
- 重复迭代: 重复以上过程,直到达到预定的迭代次数或满足停止条件(例如梯度的范数足够小)。
相比于传统的梯度下降法,SGD的优点包括:
- 高效:特别适用于大型数据集,因为每次迭代只使用一个样本。
- 在线学习: 可以用于在线学习,即在接收到新数据时立即更新模型。
然而,由于随机性的引入,SGD的参数更新可能会更加不稳定,因此学习率的选择变得尤为重要。为了解决这个问题,有一些SGD的变种,如Mini-batch SGD,它在每次迭代中使用小批量的样本来计算梯度。这样可以在保持高效性的同时减小参数更新的方差。
正文
对于给出的函数 f ( x ) f(x) f(x) :
f ( x ) = x ( 1 ) 2 + x ( 2 ) 2 − 2 ⋅ x ( 1 ) ⋅ x ( 2 ) + sin ( x ( 1 ) ) + cos ( x ( 2 ) ) f(x)=x(1)^2+x(2)^2-2 \cdot x(1) \cdot x(2)+\sin (x(1))+\cos (x(2)) f(x)=x(1)2+x(2)2−2⋅x(1)⋅x(2)+sin(x(1))+cos(x(2))
- 初始化参数: 随机选择初始参数 x x x ,通常使用某种随机的初始值。
- 选择学习率: 选择一个适当的学习率 η \eta η ,这是一个重要的超参数,影响着参数更新的步长。
- 设置迭代次数和停止条件: 确定迭代次数的上限或设置停止条件,例如当梯度的范数小于某个容许误差时停止迭代。
- 随机梯度下降选代:
- 对于每次迭代 t t t ,从训练集中随机选择一个样本 i i i 。
- 计算该样本的梯度: ∇ f i ( x ( t ) ) \nabla f_i\left(x^{(t)}\right) ∇fi(x(t))
- 使用梯度更新参数: x ( t + 1 ) = x ( t ) − η ⋅ ∇ f i ( x ( t ) ) x^{(t+1)}=x^{(t)}-\eta \cdot \nabla f_i\left(x^{(t)}\right) x(t+1)=x(t)−η⋅∇fi(x(t))
- 检查是否满足停止条件。如果满足,停止迭代;否则,继续下一次迭代。
- 输出结果: 输出最终的参数 x x x ,以及在最优点的目标函数值 f ( x ) f(x) f(x) 。
代码实现
可运行代码
% 定义目标函数
f = @(x) x(1)^2 + x(2)^2 - 2*x(1)*x(2) + sin(x(1)) + cos(x(2));% 定义目标函数的梯度
grad_f = @(x) [2*x(1) - 2*x(2) + cos(x(1)); 2*x(2) - 2*x(1) - sin(x(2))];% 设置参数
learning_rate = 0.01;
max_iterations = 1000;
tolerance = 1e-6;% 初始化起始点
x = [0; 0];% 随机梯度下降
for iteration = 1:max_iterations% 随机选择一个样本i = randi(2);% 计算梯度gradient = grad_f(x);% 更新参数x = x - learning_rate * gradient;% 检查收敛性if norm(gradient) < tolerancebreak;end
end% 显示结果
fprintf('Optimal solution: x = [%f, %f]\n', x(1), x(2));
fprintf('Optimal value of f(x): %f\n', f(x));
fprintf('Number of iterations: %d\n', iteration);