最小二乘法原理推导+代码实现[Python]

news/2024/11/19 12:36:53/文章来源:https://www.cnblogs.com/hello-nullptr/p/18347123

0.前言

  • 本文主要介绍了最小二乘法公式推导,并且使用Python语言实现线性拟合。
  • 读者需要具备高等数学、线性代数、Python编程知识。
  • 请读者按照文章顺序阅读。
  • 绘图软件为:geogebra5。

1.原理推导

1.1应用

最小二乘法在购房中的应用通常涉及房价预测和房屋定价方面。这种统计方法通过拟合数据来找到一条最符合实际观测值的直线(或曲线),从而帮助预测房屋的合理市场价格。例如某地的房价与房屋面积大小关系如下图(图1-1)所示。
image
为了方便操作,请读者不要考虑数据是否真实有效,当然这样的房价笔者是不会买。笔者将数据以CSV格式保存,具体数据如下图(1-2)所示。
image

点击查看数据
其中x表示房屋的面积,单位平方米,y表示房屋的价格,单位万元。
x	y
12.3	11.8
14.3	12.7
14.5	13
14.8	11.8
16.1	14.3
16.8	15.3
16.5	13.5
15.3	13.8
17	14
17.8	14.9
18.7	15.7
20.2	18.8
22.3	20.1
19.3	15
15.5	14.5
16.7	14.9
17.2	14.8
18.3	16.4
19.2	17
17.3	14.8
19.5	15.6
19.7	16.4
21.2	19
23.04	19.8
23.8	20
24.6	20.3
25.2	21.9
25.7	22.1
25.9	22.4
26.3	22.6

1.2定义直线方程

image

1.3定义拟合误差

假设房屋面积、房屋价格、预测价格如下图(1-3)所示。
image
此时需要一个函数去衡量房屋预测价格与真实的房屋价格之间的误差,若预测价格和真实价格之间的误差很小,约等于0,则表明该拟合函数预测房屋价格十分准确。具体的误差函数如下所示。
image
有时L又称作损失函数。

1.4梯度下降优化

image
梯度下降思想如下图(图1-4)所示。
image
下面笔者举出一个简单的例子帮助理解。
image
image
image
image
image
比如在x=3这一点,为了使g(x)的值变小,即往山谷方向移动,因此x需要向左移动,即x需要变小,示例图如下(图1-7)所示。
image
例如在x=-1这一点,为了使g(x)值变小,需要x不断变大,往山谷处靠近,示例图如下图(图1-8)所示。
image
在结合x<1、x>1时g'(x)的符号可以总结出以下梯度下降公式。
image
其中上述公式的=是指编程语言中的赋值操作。根据公式不难看出,当x<1时,g'(x)<0,x-g'(x)的值相较于x变大了;当x>1时,g'(x)>0,x-g'(x)的值相较于x变小了,这里非常巧妙,通过不断的计算和赋值,就好像一步一步的走动到山谷,值得注意的是,这里还不算是一小步一小步。
考虑到一种情况,若g'(x)非常大或者非常小,导致赋值后的x太大或太小。例如当x=-1时,计算出来的导数为-9999,那么再执行x=x-g'(x)后x的值为9998,x的值从-1变为9998,这个步子也太大了吧,显然是不合适的,此时可能会出现反复震荡的情况,具体示例如下图(图1-9)所示。

image
为了克服反复震荡的情况,所以要引进学习率。

1.5学习率

对于梯度优化函数x=x-g'(x),引进学习率后如下所示。
image
其中η为学习率,其含义类似于迈步子的力度,力度越大,迈的步子越远,力度越小迈的步子越小,一般情况下,学习率设置为很小(0.001、0.0001)。
例如当x=-1时,g'(x)为-1,η为1000,则x更新后的值为999,若η为0.0001,则x更新后的值为-0.9999,仅仅是挪动了一小小小小步。

1.6更新所有参数

通过上文,相信你已经懂得了梯度更新的原理,那么对于损失函数L来说,怎么进行梯度更新呢?更新公式如下图(图1-10)所示。
image
之前的简单示例是对x求导,这里因为是多元函数,所以要求偏导,只要掌握梯度下降的基本思想,对于二次函数拟合也是类似。

2.代码解释

image

3.完整代码

点击查看代码
import numpy as np
import matplotlib.pyplot as plt
plt.switch_backend('TkAgg')
plt.rcParams['font.sans-serif'] = ['SimHei'] # 用来正常显示中文标签SimHei
plt.rcParams['axes.unicode_minus'] = False # 用来正常显示负号class LinerRegression():#初始化类#learning_rate:学习率                  浮点数#end_error:相邻两轮损失函数之间的差值    浮点数#max_it:最大迭代次数                   整形def __init__(self, learning_rate=0.00001, end_error=0.01,max_it=1000):self.f_lr = learning_rate       # 学习率self.f_end_error=end_error      # 前一轮loss与当前轮loss之差小于等于end_error时结束迭代self.f_diff=2147483647          # 保存前一轮loss与当前轮loss之差self.f_w=np.random.normal(0,1)  # y=wx+b中的wself.f_b=np.random.uniform(0,1) # y=wx+b中的bself.i_max_iterator=max_it      # 最大迭代次数,当end_error与max_iterator其一满足便会停止迭代#获得w和b的偏导数def get_partial_derivative(self):f_p_w =(2/len(self.arr_x)) * np.sum( (self._f(self.f_w,self.arr_x,self.f_b)-self.arr_y)*self.arr_x  )f_p_b =(2/len(self.arr_x)) * np.sum(  self._f(self.f_w,self.arr_x,self.f_b)-self.arr_y              )return f_p_w, f_p_bdef standardize(self):  # 标准化f_mu =    np.mean(self.arr_x)f_sigma = np.std(self.arr_x)self.arr_x= (self.arr_x - f_mu) / f_sigmadef fit(self, x, y):self.arr_x = xself.arr_y = y#标准化# self.standardize()#当前损失值f_origin_loss=self.get_loss(y_true=self.arr_y,y_pred=self._f(self.f_w,self.arr_x,self.f_b))i_it_cnt=0#迭代次数while self.f_diff>self.f_end_error and i_it_cnt<self.i_max_iterator:self.next_step()#更新w b#学习后的损失值f_cur_loss=self.get_loss(y_true=self.arr_y,y_pred=self._f(self.f_w,self.arr_x,self.f_b))#损失值之差f_diff=f_origin_loss-f_cur_lossself.f_diff=f_diffi_it_cnt+=1print("第{}次训练,w={:.2f},b={:.2f},loss={:.2f},diff={}".format(i_it_cnt,self.f_w,self.f_b,f_cur_loss,f_diff))print("训练结果函数式:y={:.2f}x+{:.2f}".format(self.f_w,self.f_b))#绘制结果图plt.scatter(self.arr_x, self.arr_y)arr_new_x = np.linspace(10,28,28-10+1)arr_new_y = self.f_w * arr_new_x + self.f_bplt.plot(arr_new_x, arr_new_y,'r--')plt.show()#一元线性函数#w:斜率   浮点数#x:自变量 整形/浮点型/整形数组/浮点型数组#b:截距   浮点数#返回值:  整形/浮点型/整形数组/浮点型数组def _f(self, w, x, b):return w*x+bdef predict(self, new_x):"""预测"""y_pred = self._f(self.f_w, new_x, self.f_b)return y_preddef get_loss(self, y_true, y_pred):"""损失y_true:[x,x,x,x,x]  <class 'numpy.ndarray'>y_pred:[x,x,x,x,x]  <class 'numpy.ndarray'>"""return (1/len(y_true))*np.sum((y_pred-y_true)**2)def next_step(self):"""梯度学习,往前走"""d_w, d_b = self.get_partial_derivative()self.f_w = self.f_w - self.f_lr * d_wself.f_b = self.f_b - self.f_lr * d_bif __name__ == '__main__':train = np.loadtxt('./Datasets/白话机器学习/线性回归.csv',delimiter=',', dtype='float', skiprows=1)x = train[:,0]y = train[:,1]lg=LinerRegression(learning_rate=1e-5,end_error=1e-3,max_it=1e3)lg.fit(x,y)print("x=21时,预测为{}".format(lg.predict(new_x=np.array([21]))))

4.运行结果

控制台输出:
image
拟合结果:
image
损失(代码中没有):
image

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/779439.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

SpringBoot项目中HTTP请求体只能读一次?试试这方案

问题描述 在基于Spring开发Java项目时,可能需要重复读取HTTP请求体中的数据,例如使用拦截器打印入参信息等,但当我们重复调用getInputStream()或者getReader()时,通常会遇到类似以下的错误信息:大体的意思是当前request的getInputStream()已经被调用过了。那为什么会出现这…

类与类之间的基本关系

类与类之间的基本关系 类与类之间的六种关系 一、继承关系继承指的是一个类(称为子类、子接口)继承另外的一个类(称为父类、父接口)的功能,并可以增加它自己的新功能的能力。在Java中继承关系通过关键字extends明确标识,在设计时一般没有争议性。在UML类图设计中,继承用…

js 将十进制字符串转换成4字节的字节数组

函数function convertToHexArrays(input) {// 通过制表符分割输入字符串const numbers = input.split(\t);// 用于存储结果的数组const result = [];for (let num of numbers) {// 将字符串转换为数字const value = parseInt(num);// 创建一个 4 字节的 ArrayBufferconst buffe…

超异构计算杂谈

超异构计算杂谈 在这一节中要从更远的视角来看看计算机架构发展的黄金 10 年,主要将围绕异构计算和超异构来展开。在开始具体内容前,非常推荐观看以下两个视频:计算机架构的新黄金时代:A New Golden Age for Computer Architecture 编译器的黄金时代:The Golden Age of Co…

转发wsa和安卓模拟器网络

adb连接上设备后, 执行 执行端口转发 adb forward tcp:6789 tcp:888`就可以了, 把设备的8888端口转发到本机6789, 本机postman之类直接访问 127.0.0.1:6789即可 其他笔记:连接wsa: adb connect 127.0.0.1:58526 连接安卓模拟器: adb connect 127.0.0.1:58526 安装app adb -s 1…

09HTML+CSS

完成小兔鲜儿商城界面1 <!DOCTYPE html>2 <html lang="en">3 4 <head>5 <meta charset="UTF-8">6 <meta name="viewport" content="width=device-width, initial-scale=1.0">7 <!-- 提升…

macos上安装esp-idf v4.2版本

参考 https://docs.espressif.com/projects/esp-idf/en/release-v5.0/esp32/get-started/linux-macos-setup.html 安装 Prerequisites brew install cmake ninja dfu-utilgit下载idf 4.2版本并安装 git clone -b release/v4.2 --recursive https://github.com/espressif/esp-id…

VS设置 LLVM-Clang 编译器进行编译C++项目

在VS中默认的C++编译器一般为 MSVC 编译器,可以根据自己的需要将其设置为 LLVM-Clang 编译器。主要有两种方案: 1)直接使用 Visual Studio Installer来自动下载对应的 Clang 编译器和构建工具,后续无需再进行配置,便可直接使用。 2)使用自己编译或者单独下载的 LLVM-Clan…

记一次微信聊天记录导出工具的折腾

小记微信聊天记录选择性导出工具: WechatExporter 的使用目前的微信app(iOS端 v8.0.46)聊天记录中, 允许用户基于图片/视频进行筛选 单个或者少量保存到本机没啥问题 但是如果你量很大, 不好意思, 有批量操作功能, 但是我不支持全选, 因为我批量操作单次最多只支持 9 个文件 就…

《加缪情书集》-1944

用直白的话语,短句子,热烈表达感情。写很具体的细节打动人全文背诵,谢谢 【PS:加缪和玛丽亚这种不被世俗赞同的感情是不是可以直接拿来用...?】分手后

当你用bing搜索张云杰时

首页会跳出:总结一下:(张杰自称)张云杰现实中是完完全全的废物。打开张云杰相关的图片可以看到:只能说气质相符!

洛谷P3842 线段——题解

洛谷P3842题解传送锚点摸鱼环节 [TJOI2007] 线段 题目描述 在一个 \(n \times n\) 的平面上,在每一行中有一条线段,第 \(i\) 行的线段的左端点是\((i, L_{i})\),右端点是\((i, R_{i})\)。 你从 \((1,1)\) 点出发,要求沿途走过所有的线段,最终到达 \((n,n)\) 点,且所走的路…