https://github.com/lilipads/gradient_descent_viz
#include <math.h>
#include <iostream>
namespace Function {
enum FunctionName {
local_minimum, // 许多小坑
global_minimum, // 大坑
saddle_point, // 起伏山路
ecliptic_bowl, // 起伏大坑
hills, // 大坑 + 小坑 + 山谷
plateau // 大坑 + 波纹
};
}
class Point {
public:
Point() : x(0.), z(0.) {}
Point(double x1, double z1) : x(x1), z(z1) {}
double x;
double z;
};
const float kBallRadiusPerGraph = 24.63;
class GradientDescent {
public:
GradientDescent();
virtual ~GradientDescent() {}
double learning_rate = 0.001;
static Function::FunctionName function_name;
// simple getters and setters
Point position() { return p; }
void setStartingPosition(double x, double z) { starting_p.x = x; starting_p.z = z; }
bool isConverged() { return is_converged; };
double gradX() { return grad.x; };
double gradZ() { return grad.z; };
Point gradPosition() { return grad; };
Point delta() { return m_delta; }
// core methods
static double f(double x, double z);
Point takeGradientStep();
void resetPositionAndComputeGradient();
protected:
Point p; // current position
Point starting_p; // starting position
Point m_delta; // movement in each direction after a gradient step
Point grad; // gradient at the current position
bool is_converged = false;
void setPositionAndComputeGradient(double x, double z);
void computeGradient();
virtual void updateGradientDelta() = 0;
virtual void resetState() {}
};
class VanillaGradientDescent : public GradientDescent {
public:
VanillaGradientDescent() {}
protected:
void updateGradientDelta();
};
class Momentum : public GradientDescent {
public:
Momentum() {}
double decay_rate = 0.8;
protected:
void updateGradientDelta();
};
class AdaGrad : public GradientDescent {
public:
AdaGrad() {}
Point gradSumOfSquared() { return grad_sum_of_squared; }
protected:
void updateGradientDelta();
void resetState();
private:
Point grad_sum_of_squared;
};
class RMSProp : public GradientDescent {
public:
RMSProp() {}
double decay_rate = 0.99;
Point decayedGradSumOfSquared() { return decayed_grad_sum_of_squared; }
protected:
void updateGradientDelta();
void resetState();
private:
Point decayed_grad_sum_of_squared;
};
class Adam : public GradientDescent {
public:
Adam() {}
double beta1 = 0.9;
double beta2 = 0.999;
Point decayedGradSum() { return decayed_grad_sum; }
Point decayedGradSumOfSquared() { return decayed_grad_sum_of_squared; }
protected:
void updateGradientDelta();
void resetState();
private:
Point decayed_grad_sum;
Point decayed_grad_sum_of_squared;
};
const double kDivisionEpsilon = 1e-12;
const double kFiniteDiffEpsilon = 1e-12;
const double kConvergenceEpsilon = 1e-2;
Function::FunctionName GradientDescent::function_name = Function::local_minimum;
GradientDescent::GradientDescent()
{
resetPositionAndComputeGradient();
}
double GradientDescent::f(double x, double z) {
switch (function_name) {
case Function::local_minimum: {
z *= 1.4;
return -2 * exp(-((x - 1) * (x - 1) + z * z) / .2) -
6. * exp(-((x + 1) * (x + 1) + z * z) / .2) +
x * x + z * z;
}
case Function::global_minimum: {
return x * x + z * z;
}
case Function::saddle_point: {
return sin(x) + z * z;
}
case Function::ecliptic_bowl: {
x /= 2.;
z /= 2.;
return -exp(-(x * x + 5 * z * z)) + x * x + 0.5 * z * z;
}
case Function::hills: {
z *= 1.4;
return 2 * exp(-((x - 1) * (x - 1) + z * z) / .2) +
6. * exp(-((x + 1) * (x + 1) + z * z) / .2) -
2 * exp(-((x - 1) * (x - 1) + (z + 1) * (z + 1)) / .2) +
x * x + z * z;
}
case Function::plateau: {
x *= 10;
z *= 10;
double r = sqrt(z * z + x * x) + 0.01;
return -sin(r) / r + 0.01 * r * r;
}
}
return 0.;
}
void GradientDescent::computeGradient() {
// use finite difference method
grad.x = (f(p.x + kFiniteDiffEpsilon, p.z) -
f(p.x - kFiniteDiffEpsilon, p.z)) / (2 * kFiniteDiffEpsilon);
grad.z = (f(p.x, p.z + kFiniteDiffEpsilon) -
f(p.x, p.z - kFiniteDiffEpsilon)) / (2 * kFiniteDiffEpsilon);
}
void GradientDescent::resetPositionAndComputeGradient() {
is_converged = false;
m_delta = Point(0, 0);
resetState();
setPositionAndComputeGradient(starting_p.x, starting_p.z);
}
void GradientDescent::setPositionAndComputeGradient(double x, double z) {
/* set position and dirty gradient */
p.x = x;
p.z = z;
computeGradient();
}
Point GradientDescent::takeGradientStep() {
/* take a gradient step. return the new position
* side effects:
* - update delta to the step just taken
* - update position to new position.
* - update grad to gradient of the new position
*/
if (abs(gradX()) < kConvergenceEpsilon &&
abs(gradZ()) < kConvergenceEpsilon) {
is_converged = true;
}
if (is_converged) return p;
updateGradientDelta();
setPositionAndComputeGradient(p.x + m_delta.x, p.z + m_delta.z);
return p;
}
void VanillaGradientDescent::updateGradientDelta() {
m_delta.x = -learning_rate * grad.x;
m_delta.z = -learning_rate * grad.z;
}
void Momentum::updateGradientDelta() {
/* https://en.wikipedia.org/wiki/Stochastic_gradient_descent#Momentum */
m_delta.x = decay_rate * m_delta.x - learning_rate * grad.x;
m_delta.z = decay_rate * m_delta.z - learning_rate * grad.z;
}
void AdaGrad::updateGradientDelta() {
/* https://en.wikipedia.org/wiki/Stochastic_gradient_descent#AdaGrad */
grad_sum_of_squared.x += pow(grad.x, 2);
grad_sum_of_squared.z += pow(grad.z, 2);
m_delta.x = -learning_rate * grad.x / (sqrt(grad_sum_of_squared.x) + kDivisionEpsilon);
m_delta.z = -learning_rate * grad.z / (sqrt(grad_sum_of_squared.z) + kDivisionEpsilon);
}
void AdaGrad::resetState() {
grad_sum_of_squared = Point(0, 0);
}
void RMSProp::updateGradientDelta() {
/* https://en.wikipedia.org/wiki/Stochastic_gradient_descent#RMSProp */
decayed_grad_sum_of_squared.x *= decay_rate;
decayed_grad_sum_of_squared.x += (1 - decay_rate) * pow(grad.x, 2);
decayed_grad_sum_of_squared.z *= decay_rate;
decayed_grad_sum_of_squared.z += (1 - decay_rate) * pow(grad.z, 2);
m_delta.x = -learning_rate * grad.x / (sqrt(decayed_grad_sum_of_squared.x) + kDivisionEpsilon);
m_delta.z = -learning_rate * grad.z / (sqrt(decayed_grad_sum_of_squared.z) + kDivisionEpsilon);
}
void RMSProp::resetState() {
decayed_grad_sum_of_squared = Point(0, 0);
}
void Adam::updateGradientDelta() {
/* https://en.wikipedia.org/wiki/Stochastic_gradient_descent#Adam */
// first moment (momentum)
decayed_grad_sum.x *= beta1;
decayed_grad_sum.x += (1 - beta1) * grad.x;
decayed_grad_sum.z *= beta1;
decayed_grad_sum.z += (1 - beta1) * grad.z;
// second moment (rmsprop)
decayed_grad_sum_of_squared.x *= beta2;
decayed_grad_sum_of_squared.x += (1 - beta2) * pow(grad.x, 2);
decayed_grad_sum_of_squared.z *= beta2;
decayed_grad_sum_of_squared.z += (1 - beta2) * pow(grad.z, 2);
m_delta.x = -learning_rate * decayed_grad_sum.x /
(sqrt(decayed_grad_sum_of_squared.x) + kDivisionEpsilon);
m_delta.z = -learning_rate * decayed_grad_sum.z /
(sqrt(decayed_grad_sum_of_squared.z) + kDivisionEpsilon);
}
void Adam::resetState() {
decayed_grad_sum_of_squared = Point(0, 0);
decayed_grad_sum = Point(0, 0);
}
void print(std::string title, const Point& pt) {
double scrollOffset = 2000;
double yOffset = scrollOffset / kBallRadiusPerGraph;
yOffset = 0;
std::cout << title << "(" << pt.x << ", " << GradientDescent::f(pt.x, pt.z) + yOffset << ", " << pt.z << ")" << std::endl;
}
void test() {
// 可用于滑动后缓动慢停止
VanillaGradientDescent vanillaGradientDescent;
vanillaGradientDescent.learning_rate = 0.01; //速度
vanillaGradientDescent.function_name = Function::global_minimum; // 场景
vanillaGradientDescent.setStartingPosition(100, 100); // 顶部起始位置
vanillaGradientDescent.resetPositionAndComputeGradient(); // 更新梯度和当前位置
print("pt", vanillaGradientDescent.position());
print("gradient", vanillaGradientDescent.gradPosition());
print("delta", vanillaGradientDescent.delta());
for (size_t i = 0; i < 60; i++)
{ // 60次
vanillaGradientDescent.takeGradientStep();
print("pt", vanillaGradientDescent.position());
print("gradient", vanillaGradientDescent.gradPosition());
print("delta", vanillaGradientDescent.delta());
}
// 可用于滑动后反弹多次逐渐停止
Momentum momentum;
momentum.learning_rate = 0.01; //速度
momentum.decay_rate = 0.9; // 摩擦系数, 1表示理想情况,无摩擦
momentum.setStartingPosition(100, 100); // 顶部起始位置
momentum.resetPositionAndComputeGradient(); // 更新梯度和当前位置
print("pt", momentum.position());
print("gradient", momentum.gradPosition());
print("delta", momentum.delta());
for (size_t i = 0; i < 60; i++)
{ // 200次
momentum.takeGradientStep();
print("pt", momentum.position());
print("gradient", momentum.gradPosition());
print("delta", momentum.delta());
}
}