随机梯度算法应用动画场景

https://github.com/lilipads/gradient_descent_viz


#include <math.h>
#include <iostream>

namespace Function {
    enum FunctionName {
        local_minimum,  // 许多小坑
        global_minimum,    // 大坑
        saddle_point,   // 起伏山路
        ecliptic_bowl, // 起伏大坑
        hills,  // 大坑 + 小坑 + 山谷
        plateau // 大坑 + 波纹
    };
}

class Point { 
public:
    Point() : x(0.), z(0.) {}
    Point(double x1, double z1) : x(x1), z(z1) {}
    double x; 
    double z;
};

const float kBallRadiusPerGraph = 24.63;

class GradientDescent {
public:
    GradientDescent();
    virtual ~GradientDescent() {}

    double learning_rate = 0.001;
    static Function::FunctionName function_name;

    // simple getters and setters
    Point position() { return p; }
    void setStartingPosition(double x, double z) { starting_p.x = x; starting_p.z = z; }
    bool isConverged() { return is_converged; };
    double gradX() { return grad.x; };
    double gradZ() { return grad.z; };
    Point gradPosition() { return grad; };
    Point delta() { return m_delta; }


    // core methods
    static double f(double x, double z);
    Point takeGradientStep();
    void resetPositionAndComputeGradient();

protected:
    Point p; // current position
    Point starting_p; // starting position
    Point m_delta; // movement in each direction after a gradient step
    Point grad; // gradient at the current position
    bool is_converged = false;

    void setPositionAndComputeGradient(double x, double z);
    void computeGradient();
    virtual void updateGradientDelta() = 0;
    virtual void resetState() {}
};


class VanillaGradientDescent : public GradientDescent {
public:
    VanillaGradientDescent() {}

protected:
    void updateGradientDelta();
};


class Momentum : public GradientDescent {
public:
    Momentum() {}

    double decay_rate = 0.8;

protected:
    void updateGradientDelta();
};

class AdaGrad : public GradientDescent {
public:
    AdaGrad() {}
    Point gradSumOfSquared() { return grad_sum_of_squared; }

protected:
    void updateGradientDelta();
    void resetState();

private:
    Point grad_sum_of_squared;
};

class RMSProp : public GradientDescent {
public:
    RMSProp() {}

    double decay_rate = 0.99;
    Point decayedGradSumOfSquared() { return decayed_grad_sum_of_squared; }

protected:
    void updateGradientDelta();
    void resetState();

private:
    Point decayed_grad_sum_of_squared;
};

class Adam : public GradientDescent {
public:
    Adam() {}

    double beta1 = 0.9;
    double beta2 = 0.999;
    Point decayedGradSum() { return decayed_grad_sum; }
    Point decayedGradSumOfSquared() { return decayed_grad_sum_of_squared; }

protected:
    void updateGradientDelta();
    void resetState();

private:
    Point decayed_grad_sum;
    Point decayed_grad_sum_of_squared;
};


const double kDivisionEpsilon = 1e-12;
const double kFiniteDiffEpsilon = 1e-12;
const double kConvergenceEpsilon = 1e-2;

Function::FunctionName GradientDescent::function_name = Function::local_minimum;


GradientDescent::GradientDescent()
{
    resetPositionAndComputeGradient();
}


double GradientDescent::f(double x, double z) {
    switch (function_name) {
    case Function::local_minimum: {
        z *= 1.4;
        return -2 * exp(-((x - 1) * (x - 1) + z * z) / .2) -
            6. * exp(-((x + 1) * (x + 1) + z * z) / .2) +
            x * x + z * z;
    }
    case Function::global_minimum: {
        return x * x + z * z;
    }
    case Function::saddle_point: {
        return sin(x) + z * z;
    }
    case Function::ecliptic_bowl: {
        x /= 2.;
        z /= 2.;
        return -exp(-(x * x + 5 * z * z)) + x * x + 0.5 * z * z;
    }
    case Function::hills: {
        z *= 1.4;
        return  2 * exp(-((x - 1) * (x - 1) + z * z) / .2) +
            6. * exp(-((x + 1) * (x + 1) + z * z) / .2) -
            2 * exp(-((x - 1) * (x - 1) + (z + 1) * (z + 1)) / .2) +
            x * x + z * z;
    }
    case Function::plateau: {
        x *= 10;
        z *= 10;
        double r = sqrt(z * z + x * x) + 0.01;
        return -sin(r) / r + 0.01 * r * r;
    }
    }
    return 0.;
}


void GradientDescent::computeGradient() {
    // use finite difference method
    grad.x = (f(p.x + kFiniteDiffEpsilon, p.z) -
        f(p.x - kFiniteDiffEpsilon, p.z)) / (2 * kFiniteDiffEpsilon);

    grad.z = (f(p.x, p.z + kFiniteDiffEpsilon) -
        f(p.x, p.z - kFiniteDiffEpsilon)) / (2 * kFiniteDiffEpsilon);
}

void GradientDescent::resetPositionAndComputeGradient() {
    is_converged = false;
    m_delta = Point(0, 0);
    resetState();
    setPositionAndComputeGradient(starting_p.x, starting_p.z);
}


void GradientDescent::setPositionAndComputeGradient(double x, double z) {
    /* set position and dirty gradient */

    p.x = x;
    p.z = z;
    computeGradient();
}

Point GradientDescent::takeGradientStep() {
    /* take a gradient step. return the new position
     * side effects:
     * - update delta to the step just taken
     * - update position to new position.
     * - update grad to gradient of the new position
     */

    if (abs(gradX()) < kConvergenceEpsilon &&
        abs(gradZ()) < kConvergenceEpsilon) {
        is_converged = true;
    }
    if (is_converged) return p;

    updateGradientDelta();
    setPositionAndComputeGradient(p.x + m_delta.x, p.z + m_delta.z);
    return p;
}

void VanillaGradientDescent::updateGradientDelta() {
    m_delta.x = -learning_rate * grad.x;
    m_delta.z = -learning_rate * grad.z;
}

void Momentum::updateGradientDelta() {
    /* https://en.wikipedia.org/wiki/Stochastic_gradient_descent#Momentum */

    m_delta.x = decay_rate * m_delta.x - learning_rate * grad.x;
    m_delta.z = decay_rate * m_delta.z - learning_rate * grad.z;
}


void AdaGrad::updateGradientDelta() {
    /* https://en.wikipedia.org/wiki/Stochastic_gradient_descent#AdaGrad */

    grad_sum_of_squared.x += pow(grad.x, 2);
    grad_sum_of_squared.z += pow(grad.z, 2);
    m_delta.x = -learning_rate * grad.x / (sqrt(grad_sum_of_squared.x) + kDivisionEpsilon);
    m_delta.z = -learning_rate * grad.z / (sqrt(grad_sum_of_squared.z) + kDivisionEpsilon);
}


void AdaGrad::resetState() {
    grad_sum_of_squared = Point(0, 0);
}


void RMSProp::updateGradientDelta() {
    /* https://en.wikipedia.org/wiki/Stochastic_gradient_descent#RMSProp */

    decayed_grad_sum_of_squared.x *= decay_rate;
    decayed_grad_sum_of_squared.x += (1 - decay_rate) * pow(grad.x, 2);
    decayed_grad_sum_of_squared.z *= decay_rate;
    decayed_grad_sum_of_squared.z += (1 - decay_rate) * pow(grad.z, 2);
    m_delta.x = -learning_rate * grad.x / (sqrt(decayed_grad_sum_of_squared.x) + kDivisionEpsilon);
    m_delta.z = -learning_rate * grad.z / (sqrt(decayed_grad_sum_of_squared.z) + kDivisionEpsilon);
}


void RMSProp::resetState() {
    decayed_grad_sum_of_squared = Point(0, 0);
}


void Adam::updateGradientDelta() {
    /* https://en.wikipedia.org/wiki/Stochastic_gradient_descent#Adam */

    // first moment (momentum)
    decayed_grad_sum.x *= beta1;
    decayed_grad_sum.x += (1 - beta1) * grad.x;
    decayed_grad_sum.z *= beta1;
    decayed_grad_sum.z += (1 - beta1) * grad.z;
    // second moment (rmsprop)
    decayed_grad_sum_of_squared.x *= beta2;
    decayed_grad_sum_of_squared.x += (1 - beta2) * pow(grad.x, 2);
    decayed_grad_sum_of_squared.z *= beta2;
    decayed_grad_sum_of_squared.z += (1 - beta2) * pow(grad.z, 2);

    m_delta.x = -learning_rate * decayed_grad_sum.x /
        (sqrt(decayed_grad_sum_of_squared.x) + kDivisionEpsilon);
    m_delta.z = -learning_rate * decayed_grad_sum.z /
        (sqrt(decayed_grad_sum_of_squared.z) + kDivisionEpsilon);
}


void Adam::resetState() {
    decayed_grad_sum_of_squared = Point(0, 0);
    decayed_grad_sum = Point(0, 0);
}


void print(std::string title, const Point& pt) {
    double scrollOffset = 2000;
    double yOffset = scrollOffset / kBallRadiusPerGraph;
    yOffset = 0;
    std::cout << title << "(" << pt.x << ", " << GradientDescent::f(pt.x, pt.z) + yOffset << ", " << pt.z << ")" << std::endl;
    
}

void test() {
    // 可用于滑动后缓动慢停止
    VanillaGradientDescent  vanillaGradientDescent;
    vanillaGradientDescent.learning_rate = 0.01; //速度
    vanillaGradientDescent.function_name = Function::global_minimum; // 场景
    vanillaGradientDescent.setStartingPosition(100, 100);  // 顶部起始位置
    vanillaGradientDescent.resetPositionAndComputeGradient();   // 更新梯度和当前位置
    print("pt", vanillaGradientDescent.position());
    print("gradient", vanillaGradientDescent.gradPosition());
    print("delta", vanillaGradientDescent.delta());
    for (size_t i = 0; i < 60; i++)
    {   // 60次
        vanillaGradientDescent.takeGradientStep();
        print("pt", vanillaGradientDescent.position());
        print("gradient", vanillaGradientDescent.gradPosition());
        print("delta", vanillaGradientDescent.delta());
    }

    // 可用于滑动后反弹多次逐渐停止
    Momentum  momentum;
    momentum.learning_rate = 0.01; //速度
    momentum.decay_rate = 0.9;  // 摩擦系数, 1表示理想情况,无摩擦
    momentum.setStartingPosition(100, 100);  // 顶部起始位置
    momentum.resetPositionAndComputeGradient();   // 更新梯度和当前位置
    print("pt", momentum.position());
    print("gradient", momentum.gradPosition());
    print("delta", momentum.delta());
    for (size_t i = 0; i < 60; i++)
    {   // 200次
        momentum.takeGradientStep();
        print("pt", momentum.position());
        print("gradient", momentum.gradPosition());
        print("delta", momentum.delta());
    }
}


创作不易,小小的支持一下吧!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/701595.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

基于GWO灰狼优化的CNN-GRU-Attention的时间序列回归预测matlab仿真

目录 1.算法运行效果图预览 2.算法运行软件版本 3.部分核心程序 4.算法理论概述 4.1卷积神经网络&#xff08;CNN&#xff09;在时间序列中的应用 4.2 GRU网络 4.3 注意力机制&#xff08;Attention&#xff09; 4.4 GWO优化 5.算法完整程序工程 1.算法运行效果图预览…

【Android踩坑】 Constant expression required

gradle 8&#xff0c;报错 Constant expression required&#xff1a;意思是case语句后面要跟常量 解决1 单击switch语句&#xff0c;键盘按下altenter&#xff0c;将switch-case语句替换为if-else语句(或者手动修改) 解决2 在gradle.properties中添加 android.nonFinalRes…

IP代理网络协议介绍

在IP代理页面上&#xff0c;存在HTTP/HTTPS/Socks5三种协议。它们都是客户端与服务器之间交互的协议。 HTTP HTTP又称之为超文本传输协议&#xff0c;在因特网使用范围广泛。它是一种请求/响应模型&#xff0c;客户端向服务器发送请求&#xff0c;服务器解析请求后对客户端作出…

meshlab: pymeshlab合并多个物体模型并保存(flatten visible layers)

一、关于环境 请参考&#xff1a;pymeshlab遍历文件夹中模型、缩放并导出指定格式-CSDN博客 二、关于代码 本文所给出代码仅为参考&#xff0c;禁止转载和引用&#xff0c;仅供个人学习。 本文所给出的例子是https://download.csdn.net/download/weixin_42605076/89233917中的…

Centos 安装jenkins 多分支流水线部署前后端项目

1、安装jenkins 1.1 安装jdk 要求&#xff1a;11及以上版本 yum install yum install java-11-openjdk 1.2 安装jenkins 导入镜像 sudo wget -O /etc/yum.repos.d/jenkins.repo https://pkg.jenkins.io/redhat-stable/jenkins.repo出现以下错误 执行以下命令 sudo yum …

iview(viewUI) span-method 表格实现将指定列的值相同的行合并单元格

效果图是上面这样的&#xff0c;将第一列的名字一样的合并在一起&#xff1b; <template><div class"table-wrap"><Table stripe :columns"columns" :data"data" :span-method"handleSpan"></Table></div&…

喜大普奔!VMware Workstation Pro 17.5 官宣免费!

Broadcom 已经正式收购 VMware&#xff0c;【VMware中国】官方公众号已于3月11日更名为【VMware by Broadcom中国】。 13日傍晚&#xff0c;该公众号发表推文 V风拂面&#xff0c;好久不见 - 来自VMware 中国的问候 &#xff0c;意味着 VMware 带着惊喜和美好的愿景再次归来。 …

新书速览|MATLAB科技绘图与数据分析

提升你的数据洞察力&#xff0c;用于精确绘图和分析的高级MATLAB技术。 本书内容 《MATLAB科技绘图与数据分析》结合作者多年的数据分析与科研绘图经验&#xff0c;详细讲解MATLAB在科技图表制作与数据分析中的使用方法与技巧。全书分为3部分&#xff0c;共12章&#xff0c;第1…

IPSSL证书:为特定IP地址通信数据保驾护航

IPSSL证书&#xff0c;顾名思义&#xff0c;是专为特定IP地址设计的SSL证书。它不仅继承了传统SSL证书验证网站身份、加密数据传输的基本功能&#xff0c;还特别针对通过固定IP地址进行通信的场景提供了强化的安全保障。在IP地址直接绑定SSL证书的模式下&#xff0c;它能够确保…

分析 vs2019 cpp20 规范的 STL 库模板 function ,源码注释并探讨几个问题

&#xff08;1 探讨一&#xff09;第一个尝试弄清的问题是父类模板与子类模板的模板参数的对应关系&#xff0c;如下图&#xff1a; 我们要弄清的问题是创建 function 对象时&#xff0c;传递的模板参数 _Fty , 传递到其父类 _Func_class 中时 &#xff0c;父类的模板参数 _Ret…

k8s的整体架构及其内部工作原理,以及创建一个pod的原理

一、k8s整体架构 二、k8s的作用&#xff0c;为什么要用k8s&#xff0c;以及服务器的发展历程 1、服务器&#xff1a;缺点容易浪费资源&#xff0c;且每个服务器都要装系统&#xff0c;且扩展迁移成本高 2、虚拟机很好地解决了服务器浪费资源的缺点&#xff0c;且部署快&#x…

短视频语音合成:成都鼎茂宏升文化传媒公司

短视频语音合成&#xff1a;技术革新与创意融合的新篇章 随着科技的飞速发展&#xff0c;短视频已经成为人们生活中不可或缺的一部分。在这个快速变化的时代&#xff0c;短视频语音合成技术正逐渐崭露头角&#xff0c;以其独特的魅力和广泛的应用前景&#xff0c;吸引了众多创…