【机器学习】单变量线性回归

文章目录

  • 线性回归模型(linear regression model)
  • 损失/代价函数(cost function)——均方误差(mean squared error)
  • 梯度下降算法(gradient descent algorithm)
  • 参数(parameter)和超参数(hyperparameter)
  • 代码实现样例
  • 运行结果

线性回归模型(linear regression model)

  • 线性回归模型:

f w , b ( x ) = w x + b f_{w,b}(x) = wx + b fw,b(x)=wx+b

其中, w w w 为权重(weight), b b b 为偏置(bias)

  • 预测值(通常加一个帽子符号):

y ^ ( i ) = f w , b ( x ( i ) ) = w x ( i ) + b \hat{y}^{(i)} = f_{w,b}(x^{(i)}) = wx^{(i)} + b y^(i)=fw,b(x(i))=wx(i)+b

损失/代价函数(cost function)——均方误差(mean squared error)

  • 一个训练样本: ( x ( i ) , y ( i ) ) (x^{(i)}, y^{(i)}) (x(i),y(i))
  • 训练样本总数 = m m m
  • 损失/代价函数是一个二次函数,在图像上是一个开口向上的抛物线的形状。

J ( w , b ) = 1 2 m ∑ i = 1 m [ f w , b ( x ( i ) ) − y ( i ) ] 2 = 1 2 m ∑ i = 1 m [ w x ( i ) + b − y ( i ) ] 2 \begin{aligned} J(w, b) &= \frac{1}{2m} \sum^{m}_{i=1} [f_{w,b}(x^{(i)}) - y^{(i)}]^2 \\ &= \frac{1}{2m} \sum^{m}_{i=1} [wx^{(i)} + b - y^{(i)}]^2 \end{aligned} J(w,b)=2m1i=1m[fw,b(x(i))y(i)]2=2m1i=1m[wx(i)+by(i)]2

  • 为什么需要乘以 1/2?因为对平方项求偏导后会出现系数 2,是为了约去这个系数。

梯度下降算法(gradient descent algorithm)

  • α \alpha α:学习率(learning rate),用于控制梯度下降时的步长,以抵达损失函数的最小值处。若 α \alpha α 太小,梯度下降太慢;若 α \alpha α 太大,下降过程可能无法收敛。
  • 梯度下降算法:

r e p e a t { t m p _ w = w − α ∂ J ( w , b ) w t m p _ b = b − α ∂ J ( w , b ) b w = t m p _ w b = t m p _ b } u n t i l c o n v e r g e \begin{aligned} repeat \{ \\ & tmp\_w = w - \alpha \frac{\partial J(w, b)}{w} \\ & tmp\_b = b - \alpha \frac{\partial J(w, b)}{b} \\ & w = tmp\_w \\ & b = tmp\_b \\ \} until \ & converge \end{aligned} repeat{}until tmp_w=wαwJ(w,b)tmp_b=bαbJ(w,b)w=tmp_wb=tmp_bconverge

其中,偏导数为

∂ J ( w , b ) w = 1 m ∑ i = 1 m [ f w , b ( x ( i ) ) − y ( i ) ] x ( i ) ∂ J ( w , b ) b = 1 m ∑ i = 1 m [ f w , b ( x ( i ) ) − y ( i ) ] \begin{aligned} & \frac{\partial J(w, b)}{w} = \frac{1}{m} \sum^{m}_{i=1} [f_{w,b}(x^{(i)}) - y^{(i)}] x^{(i)} \\ & \frac{\partial J(w, b)}{b} = \frac{1}{m} \sum^{m}_{i=1} [f_{w,b}(x^{(i)}) - y^{(i)}] \end{aligned} wJ(w,b)=m1i=1m[fw,b(x(i))y(i)]x(i)bJ(w,b)=m1i=1m[fw,b(x(i))y(i)]

参数(parameter)和超参数(hyperparameter)

  • 超参数(hyperparameter):训练之前人为设置的任何数量都是超参数,例如学习率 α \alpha α
  • 参数(parameter):模型在训练过程中创建或修改的任何数量都是参数,例如 w , b w, b w,b

代码实现样例

import numpy as np
import matplotlib.pyplot as plt# 计算误差均方函数 J(w,b)
def cost_function(x, y, w, b):m = x.shape[0] # 训练集的数据样本数cost_sum = 0.0for i in range(m):f_wb = w * x[i] + bcost = (f_wb - y[i]) ** 2cost_sum += costreturn cost_sum / (2 * m)# 计算梯度值 dJ/dw, dJ/db
def compute_gradient(x, y, w, b):m = x.shape[0] # 训练集的数据样本数d_w = 0.0d_b = 0.0for i in range(m):f_wb = w * x[i] + bd_wi = (f_wb - y[i]) * x[i]d_bi = (f_wb - y[i])d_w += d_wid_b += d_bidj_dw = d_w / mdj_db = d_b / mreturn dj_dw, dj_db# 梯度下降算法
def linear_regression(x, y, w, b, learning_rate=0.01, epochs=1000):J_history = [] # 记录每次迭代产生的误差值for epoch in range(epochs):dj_dw, dj_db = compute_gradient(x, y, w, b)# w 和 b 需同步更新w = w - learning_rate * dj_dwb = b - learning_rate * dj_dbJ_history.append(cost_function(x, y, w, b)) # 记录每次迭代产生的误差值return w, b, J_history# 绘制线性方程的图像
def draw_line(w, b, xmin, xmax, title):x = np.linspace(xmin, xmax)y = w * x + b# plt.axis([0, 10, 0, 50]) # xmin, xmax, ymin, ymaxplt.xlabel("X-axis", size=15)plt.ylabel("Y-axis", size=15)plt.title(title, size=20)plt.plot(x, y)# 绘制散点图
def draw_scatter(x, y, title):plt.xlabel("X-axis", size=15)plt.ylabel("Y-axis", size=15)plt.title(title, size=20)plt.scatter(x, y)# 从这里开始执行
if __name__ == '__main__':# 训练集样本x_train = np.array([1, 2, 3, 5, 6, 7])y_train = np.array([15.5, 19.7, 24.4, 35.6, 40.7, 44.8])w = 0.0 # 权重b = 0.0 # 偏置epochs = 10000 # 迭代次数learning_rate = 0.01 # 学习率J_history = [] # # 记录每次迭代产生的误差值w, b, J_history = linear_regression(x_train, y_train, w, b, learning_rate, epochs)print(f"result: w = {w:0.4f}, b = {b:0.4f}") # 打印结果# 绘制迭代计算得到的线性回归方程plt.figure(1)draw_line(w, b, 0, 10, "Linear Regression")plt.scatter(x_train, y_train) # 将训练数据集也表示在图中plt.show()# 绘制误差值的散点图plt.figure(2)x_axis = list(range(0, 10000))draw_scatter(x_axis, J_history, "Cost Function in Every Epoch")plt.show()

运行结果

在这里插入图片描述
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/458538.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

服装设计公司,如何用钉钉实现企业数字化成功转型?

钉钉作为数字化工作平台,为某服装设计公司实现了组织管理的数字化转型,构建了一站式的工作平台。通过钉钉赋能,有利于企业推进组织架构、员工沟通、产品运营和客户服务等方面的数字化、智能化转型。 借助钉钉平台,该服设公司轻松实…

常见的 MIME(媒体)类型速查

一、简介 MIME(Multipurpose Internet Mail Extensions)多用途互联网邮件扩展类型,是设定某种扩展名的文件用一种应用程序来打开的方式类型,当该扩展名文件被访问的时候,浏览器会自动使用指定应用程序来打开。多用于指定一些客户端自定义的文…

使用Qt创建项目 Qt中输出内容到控制台 设置窗口大小和窗口标题 Qt查看说明文档

按windows键,找到Qt Creator ,打开 一.创建带模板的项目 新建项目 设置项目路径QMainWindow是带工具栏的窗口。 QWidget是无工具栏的窗口。 QDuakig是对话框窗口。创建好的项目如下: #include "widget.h"// 构造函数&#xff…

VSCode开发常用扩展记录

1、Chinese 2、document this 可以自动为ts和js文件生成jsDoc注释 3、ESLint 能够查找并修复js代码中的问题 4、koroFileHeader 5、Prettier 代码格式化

numa网卡绑定

#概念 参考:https://www.jianshu.com/p/0f3b39a125eb(opens new window) chip:芯片,一个cpu芯片上可以包含多个cpu core,比如四核,表示一个chip里4个core。 socket:芯片插槽,颗,跟…

高速接口PCB布局指南(五)高速差分信号布线(三)

高速接口PCB布局指南(五)高速差分信号布线(三) 1.表面贴装器件焊盘不连续性缓解2.信号线弯曲3.高速信号建议的 PCB 叠层设计4.ESD/EMI 注意事项5.ESD/EMI 布局规则 tips:资料主要来自网络,仅供学习使用。 …

《Git 简易速速上手小册》第8章:保护你的代码(2024 最新版)

文章目录 8.1 使用 .gitignore 优化你的仓库8.1.1 基础知识讲解8.1.2 重点案例:为 Python 项目配置 .gitignore8.1.3 拓展案例 1:使用全局 .gitignore8.1.4 拓展案例 2:忽略已经被跟踪的文件 8.2 管理敏感数据8.2.1 基础知识讲解8.2.2 重点案…

顺序图(Sequence Diagram)

也叫时序图、序列图 一、定义 顺序图是用来描述对象自身及对象间信息传递顺序的视图。 二、要素 活动者,对象,生命线,控制焦点,消息(同步消息,异步消息,返回消息,自关联消息) 1、 活动者 活动者发出情况或者接收系统的服务。 2、 对象 对象是特定行为与属性的集合。 表…

win10系统连接WiFi,输入正确密码,但还是提示错误

情况 电信宽带 mac和小米手机都可以连上wifi dell上的windows输入正确的密码还是提示错误 解决办法 根据路由器上的终端配置进入网页进行配置,我的是192.168.1.1,账户:useradmin 修改无线网络设置中的加密方式,由Mixed WPA2/WPA-PSK改为W…

深度学习(15)--PyTorch构建卷积神经网络

目录 一.PyTorch构建卷积神经网络(CNN)详细流程 二.graphviz torchviz使PyTorch网络可视化 2.1.可视化经典网络vgg16 2.2.可视化自己定义的网络 一.PyTorch构建卷积神经网络(CNN)详细流程 卷积神经网络(Convolutional Neural Networks)是一种深度学…

go 版本 LeeCode 刷题 在线

https://books.halfrost.com/leetcode/ChapterFour/0001~0099/0001.Two-Sum/ 参考 https://github.com/anzhihe/learning/tree/master/shell/book/abs-3.9.1_cn

LabVIEW动平衡测试与振动分析系统

LabVIEW动平衡测试与振动分析系统 介绍了利用LabVIEW软件和虚拟仪器技术开发一个动平衡测试与振动分析系统。该系统旨在提高旋转机械设备的测试精度和可靠性,通过精确测量和分析设备的振动数据,以识别和校正不平衡问题,从而保证机械设备的高…