Python 中的机器学习简介:多项式回归

一、说明

        多项式回归可以识别自变量和因变量之间的非线性关系。本文是关于回归、梯度下降和 MSE 系列文章的第三篇。前面的文章介绍了简单线性回归、回归的正态方程和多元线性回归。

二、多项式回归

        多项式回归用于最适合曲线拟合的复杂数据。它可以被视为多元线性回归的子集。

        请注意,X₀ 是偏差的一列;这允许在第一篇文章中讨论的广义公式。使用上述等式,每个“自变量”都可以被视为 X₁ 的指数版本。

        这允许从多元线性回归使用相同的模型,因为只需要识别每个变量的系数。可以创建一个简单的三阶多项式模型作为示例。其等式如下:

        模型、梯度下降和 MSE 的广义函数可用于前面的文章:

# line of best fit
def model(w, X):"""Inputs:w: array of weights | (num features, 1)X: array of inputs  | (n samples, num features)Output:returns the output of X@w | (n samples, 1)"""return torch.matmul(X, w)
# mean squared error (MSE)
def MSE(Yhat, Y):"""Inputs:Yhat: array of predictions | (n samples, 1)Y: array of expected outputs | (n samples, 1)Output:returns the loss of the model, which is a scalar"""return torch.mean((Yhat-Y)**2) # mean((error)^2)
# optimizer
def gradient_descent(w):"""Inputs:w: array of weights | (num features, 1)Global Variables / Constants:X: array of inputs  | (n samples, num features)Y: array of expected outputs | (n samples, 1)lr: learning rate to scale the gradientOutput:returns the updated weights""" n = X.shape[0]return w - (lr * 2/n) * (torch.matmul(-Y.T, X) + torch.matmul(torch.matmul(w.T, X.T), X)).reshape(w.shape)

三、创建数据

        现在,所需要的只是一些用于训练模型的数据。可以使用“蓝图”功能,并且可以添加随机性。这遵循与前面文章相同的方法。蓝图如下所示:

        可以创建大小为 (800, 4) 的训练集和大小为 (200, 4) 的测试集。请注意,除偏差外,每个特征都是第一个特征的指数版本。

import torchtorch.manual_seed(5)
torch.set_printoptions(precision=2)# features
X0 = torch.ones((1000,1))
X1 = (100*(torch.rand(1000) - 0.5)).reshape(-1,1) # generates 1000 random numbers from -50 to 50
X2, X3 = X1**2, X1**3
X = torch.hstack((X0,X1,X2,X3))# normal distribution with a mean of 0 and std of 8
normal = torch.distributions.Normal(loc=0, scale=8)# targets
Y = (3*X[:,3] + 2*X[:,2] + 1*X[:,1] + 5 + normal.sample(torch.ones(1000).shape)).reshape(-1,1)# train, test
Xtrain, Xtest = X[:800], X[800:]
Ytrain, Ytest = Y[:800], Y[800:]

        定义初始权重后,可以使用最佳拟合线绘制数据。

torch.manual_seed(5)
w = torch.rand(size=(4, 1))
w
tensor([[0.83],[0.13],[0.91],[0.82]])
import matplotlib.pyplot as pltdef plot_lbf():"""Output:prints the line of best fit in comparison to the train and test data"""# plot the train and test setsplt.scatter(Xtrain[:,1],Ytrain,label="train")plt.scatter(Xtest[:,1],Ytest,label="test")# plot the line of best fitX1_plot = torch.arange(-50, 50.1,.1).reshape(-1,1) X2_plot, X3_plot = X1_plot**2, X1_plot**3X0_plot = torch.ones(X1_plot.shape)X_plot = torch.hstack((X0_plot,X1_plot,X2_plot,X3_plot))plt.plot(X1_plot.flatten(), model(w, X_plot).flatten(), color="red", zorder=4)plt.xlim(-50, 50)plt.xlabel("$X$")plt.ylabel("$Y$")plt.legend()plt.show()plot_lbf()
图片来源:作者

四、训练模型

        为了部分最小化成本函数,可以使用 5e-11 和 500,000 epoch 的学习率与梯度下降一起使用。

lr = 5e-11
epochs = 500000# update the weights 1000 times
for i in range(0, epochs):# update the weightsw = gradient_descent(w)# print the new values every 10 iterationsif (i+1) % 100000 == 0:print("epoch:", i+1)print("weights:", w)print("Train MSE:", MSE(model(w,Xtrain), Ytrain))print("Test MSE:", MSE(model(w,Xtest), Ytest))print("="*10)plot_lbf()
epoch: 100000
weights: tensor([[0.83],[0.13],[2.00],[3.00]])
Train MSE: tensor(163.87)
Test MSE: tensor(162.55)
==========
epoch: 200000
weights: tensor([[0.83],[0.13],[2.00],[3.00]])
Train MSE: tensor(163.52)
Test MSE: tensor(162.22)
==========
epoch: 300000
weights: tensor([[0.83],[0.13],[2.00],[3.00]])
Train MSE: tensor(163.19)
Test MSE: tensor(161.89)
==========
epoch: 400000
weights: tensor([[0.83],[0.13],[2.00],[3.00]])
Train MSE: tensor(162.85)
Test MSE: tensor(161.57)
==========
epoch: 500000
weights: tensor([[0.83],[0.13],[2.00],[3.00]])
Train MSE: tensor(162.51)
Test MSE: tensor(161.24)
==========
图片来源:作者

        即使有 500,000 个 epoch 和极小的学习率,该模型也无法识别前两个权重。虽然当前的解决方案非常准确,MSE为161.24,但可能需要数百万个epoch才能完全最小化它。这是多项式回归梯度下降的局限性之一。

五、正态方程

        作为替代方案,可以使用第二篇文章中的正态方程直接计算优化权重:

def NormalEquation(X, Y):"""Inputs:X: array of input values | (n samples, num features)Y: array of expected outputs | (n samples, 1)Output:returns the optimized weights | (num features, 1)"""return torch.inverse(X.T @ X) @ X.T @ Yw = NormalEquation(Xtrain, Ytrain)
w
tensor([[4.57],[0.98],[2.00],[3.00]])

        正态方程能够立即识别每个权重的正确值,并且每组的MSE比梯度下降时低约100点:

MSE(model(w,Xtrain), Ytrain), MSE(model(w,Xtest), Ytest)
(tensor(60.64), tensor(63.84))

六、结论

        通过实现简单线性、多重线性和多项式回归,接下来的两篇文章将介绍套索和岭回归。这些类型的回归在机器学习中引入了两个重要概念:过拟合和正则化。

 参考文章:

亨特·菲利普斯

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/63686.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【C语言】小游戏-三字棋

大家好,我是深鱼~ 目录 一、游戏介绍 二、文件分装 三、代码实现步骤 1.制作简易游戏菜单 2.初始化棋盘 3.打印棋盘 4.玩家下棋 5.电脑随机下棋 6.判断输赢 7.判断棋盘是否满了 四、完整代码 game.h(相关函数的声明,整个代码要引用的头文件以及宏…

CSS 中的优先级规则是怎样的?

聚沙成塔每天进步一点点 ⭐ 专栏简介⭐内联样式(Inline Styles)⭐ID 选择器(ID Selectors)⭐类选择器、属性选择器和伪类选择器(Class, Attribute, and Pseudo-class Selectors)⭐元素选择器和伪元素选择器…

如何让ES低成本、高性能?滴滴落地ZSTD压缩算法的实践分享

前文分别介绍了滴滴自研的ES强一致性多活是如何实现的、以及如何提升ES的性能潜力。由于滴滴ES日志场景每天写入量在5PB-10PB量级,写入压力和业务成本压力大,为了提升ES的写入性能,我们让ES支持ZSTD压缩算法,本篇文章详细展开滴滴…

hive 字段注释乱码

hive 字段注释乱码: 在mysql中运行: alter table COLUMNS_V2 modify column COMMENT varchar(256) character set utf8;OK

Signal Desktop for Mac(专业加密通讯软件)中文版安装教程

想让您的聊天信息更安全和隐藏吗? Mac版本的Signal Desktop是MACOS上的专业加密通信工具,非常安全。使用信号协议,该协议结合了固定前密钥,双重RATCHES算法和3-DH握手信号,该信号可以确保第三方实体将不会传达您的消息…

LeetCode150道面试经典题--同构字符串(简单)

1.题目 给定两个字符串 s 和 t ,判断它们是否是同构的。如果 s 中的字符可以按某种映射关系替换得到 t ,那么这两个字符串是同构的。每个出现的字符都应当映射到另一个字符,同时不改变字符的顺序。不同字符不能映射到同一个字符上&#xff0c…

VR内容定制 | VR内容中控管理平台可以带来哪些价值?

随着科技的不断发展,虚拟现实(VR)技术已经逐渐渗透到各个领域,其中教育领域也不例外。通过VR技术,学生可以身临其境地参与到各种场景中,获得更加直观、生动的学习体验。为了让教师更好地进行VR教学的设计和管理,提高教…

01_什么是ansible、基本架构、ansible工作机制、Ansible安装、配置主机清单、设置SSH无密码登录等

1.什么是ansible 1.1.基本介绍 1.2.基本架构 1.3.基本特征 1.4.优点 1.5.ansible工作机制 2.Ansible安装 2.1.机器准备 2.2.安装ansible 2.2.1.安装epel源 2.2.2.安装ansible 2.2.3.查看ansible版本 2.2.4.树状结构展示文件夹 2.2.4.1.其中ansible.cfg的内容如下 2.2.4.2.host的…

棒球和垒球的区别·棒球联盟

棒球和垒球的区别 1. 定义和起源 棒球起源于19世纪中叶的美国,最初被认为是一种游戏,而并非体育运动。那时,棒球常常被孩子们用来进行休闲娱乐。在20世纪初,它才开始被纳入体育运动的范畴。 垒球则是棒球的近亲,同样…

构建Docker容器监控系统(Cadvisor +Prometheus+Grafana)

Cadvisor PrometheusGrafana 1.1、Cadvisor产品简介 Cadvisor是Google开源的一款用于展示和分析容器运行状态的可视化工具。通过在主机上运行Cadvisor用户可以轻松的获取到当前主机上容器的运行统计信息,并以图表的形式向用户展示。 1.2、安装docker-ce [rootloc…

Hadoop理论及实践-HDFS读写数据流程(参考Hadoop官网)

NameNode与DataNode回顾 主节点和副本节点通常指的是Hadoop分布式文件系统(HDFS)中的NameNode和DataNode。 NameNode(主节点):NameNode是Hadoop集群中的一个核心组件,它负责管理文件系统的命名空间和元数据…

Ubuntu 20.04 安装 Stable Diffusionn

步骤 1:安装 wget、git、Python3 和 Python3虚拟环境(如果已安装可忽略这步骤) sudo apt install wget git python3 python3-venv步骤 2:克隆 SD 项目到本地 git clone https://github.com/AUTOMATIC1111/stable-diffusion-webu…