从零开始学习线性回归:理论、实践与PyTorch实现

文章目录

  • 🥦介绍
  • 🥦基本知识
  • 🥦代码实现
  • 🥦完整代码
  • 🥦总结

🥦介绍

线性回归是统计学和机器学习中最简单而强大的算法之一,用于建模和预测连续性数值输出与输入特征之间的关系。本博客将深入探讨线性回归的理论基础、数学公式以及如何使用PyTorch实现一个简单的线性回归模型。

🥦基本知识

线性回归的数学基础
线性回归的核心思想是建立一个线性方程,它表示了自变量(输入特征)与因变量(输出)之间的关系。这个线性方程通常表示为:
y=b0+b1x1+b2x2+…+bpxp
y=b0​+b1​x1​+b2​x2​+…+bp​xp​

其中, y y y 是因变量, x 1 , x 2 , … , x p x_1, x_2, \ldots, x_p x1,x2,,xp 是自变量, b 0 , b 1 , b 2 , … , b p b_0, b_1, b_2, \ldots, b_p b0,b1,b2,,bp 是模型的参数, p p p 是特征的数量。我们的目标是找到最佳的参数值,以最小化模型的误差。

损失函数
为了找到最佳参数,我们需要定义一个损失函数来度量模型的性能。在线性回归中,最常用的损失函数是均方误差(MSE),它表示了模型预测值与实际值之间的平方差的平均值:
在这里插入图片描述
其中, n n n 是样本数量, y i y_i yi 是实际值, y ^ i \hat{y}_i y^i 是模型的预测值。

梯度下降优化
为了最小化损失函数,我们使用梯度下降算法。梯度下降通过计算损失函数相对于参数的梯度,并迭代地更新参数,以减小损失。更新规则如下:
在这里插入图片描述
其中, b j b_j bj 是第 j j j个参数, α \alpha α 是学习率, ∂ ∂ b j M S E \frac{\partial}{\partial b_j} MSE bjMSE 是损失函数对参数 b j b_j bj的偏导数。

🥦代码实现

如果你想知道实现线性回归的大体步骤,下图可以充分进行说明
在这里插入图片描述

  • 准备数据
  • 设计模型(计算) y ^ i \hat{y}_i y^i
  • 构造损失和优化器
  • 训练周期(前向,反向 ,更新)

本节还是以刘二大人的视频讲解为例,结尾会设置传送门

class LinearModel(torch.nn.Module):def __init__(self):super(LinearModel, self).__init__() # 调用父类的构造函数self.linear = torch.nn.Linear(1, 1)  # 参数详情下图展示def forward(self, x):y_pred = self.linear(x)   # x代表输入样本的张量return y_pred
model = LinearModel()

所以模型类都要继承Module,此类主要包含两个函数一个是构造函数(初始化对象时调用),另一个是前向计算

好奇的小伙伴会思考为何没有反向(backward),这是因为Module会帮你进行,但是如果后期自己有更高效的方法可以自行设置。
在这里插入图片描述

  • 第一个参数 in_features:这是输入特征的数量。在这里,表示我们的模型只有一个输入特征。如果你有多个输入特征,你可以将这个参数设置为输入特征的数量。

  • 第二个参数 out_features:这是输出特征的数量。这表示我们的模型将生成一个输出。在线性回归中,通常只有一个输出,因为我们试图预测一个连续的数值。

  • 第三个参数:意思是要不要偏置量。默认true

通常情况下特征代表列,比如我们有一个n×2的y和一个n×3的x,那么我们需要一个3×2的权重,有的书中会在两边做转置,但无论咋样目的都是为了让这个矩阵乘法成立

criterion = torch.nn.MSELoss(size_average=False)  # 使用均方误差损失 
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)  # 使用随机梯度下降优化器

在这里插入图片描述
在这里插入图片描述
model.parameters() 用于告诉优化器哪些参数需要在训练过程中进行更新,这包括模型的权重和偏置项等。在线性回归示例中,模型的参数包括权重和偏置项。

优化器的选择有许多大家可以都试试看看
在这里插入图片描述

之后就进行训练了

for epoch in range(1000):y_pred = model(x_data)loss = criterion(y_pred, y_data) print(epoch, loss.item())optimizer.zero_grad()   # 归零loss.backward()  # 反向optimizer.step()  # 更新
print('w = ', model.linear.weight.item()) 
print('b = ', model.linear.bias.item())
x_test = torch.Tensor([[4.0]]) 
y_test = model(x_test)
print('y_pred = ', y_test.data)

🥦完整代码

x_data = torch.Tensor([[1.0], [2.0], [3.0]]) 
y_data = torch.Tensor([[2.0], [4.0], [6.0]])
class LinearModel(torch.nn.Module):def __init__(self):super(LinearModel, self).__init__() self.linear = torch.nn.Linear(1, 1)def forward(self, x):y_pred = self.linear(x) return y_pred
model = LinearModel()
criterion = torch.nn.MSELoss(size_average=False)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
for epoch in range(1000):y_pred = model(x_data)loss = criterion(y_pred, y_data) print(epoch, loss.item())optimizer.zero_grad() loss.backward()optimizer.step()
print('w = ', model.linear.weight.item()) 
print('b = ', model.linear.bias.item())
x_test = torch.Tensor([[4.0]]) 
y_test = model(x_test)
print('y_pred = ', y_test.data)
predicted = model(x_data).detach().numpy()
plt.scatter(x_data, y_data, label='Original data')
plt.plot(x_data, predicted, label='Fitted line', color='r')
plt.legend()
plt.show()

运行结果如下
在这里插入图片描述

在这里插入图片描述

🥦总结

在本篇博客中,我们使用PyTorch实现了一个简单的线性回归模型,并使用随机生成的数据对其进行了训练和可视化。线性回归是一个入门级的机器学习模型,但它为理解模型训练和预测的基本概念提供了一个很好的起点。

请添加图片描述

挑战与创造都是很痛苦的,但是很充实。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/130878.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

JVM上篇之类加载子系统

目录 类加载子系统 内存结构 类的生命周期 类的加载过程 加载 加载class文件方式 连接 验证 验证阶段 准备 解析 初始化 类加载器 介绍 作用 分类 引导类加载器 自定义类加载器 ClassLoader 获取ClassLoader途径 双亲委派机制 介绍 执行流程 好处 打破…

常见的Web安全漏洞(2021年9月的OWASP TOP 10)

聊Web安全漏洞,就不得不提到OWASP TOP10。开放式Web应用程序安全项目(OpenWeb Application Security Project,OWASP)是一个开源的、非营利的组织,主要提供有关Web应用程序的实际可行、公正透明、有社会效益的信息&…

Jenkins 执行远程shell脚本部署jar文件问题起不来

如图:最开始的时候没有加: source /etc/profile 这一行, run.sh里面的java -jar xxxx.jar 一直执行不来。 一开始以为是Jenkins执行退出后会kill一切它启动的进程,所以加了在run.sh里面加了export BUILD_IDdontKillMe&#xff0…

42. QT中开发Android配置QFtp功能时遇到的编译问题

1. 说明 此问题仅适用在QT中开发Android程序时,需要适用QFtp功能的情况。一般情况下,如果开发的是Windows或者Linux系统下的程序,可能不会出现该问题。 2. 问题 【Android】在将QFtp的相关代码文件加入到项目中后,编译项目时会…

Python 中的 set 集合类型是可迭代的吗?

当我们运行以下代码时会报错。 a {1, 2, 4, 3, 4} for i in range(len(a)):print(a[i]) 所以我之前一直以为 set 类型是不可迭代的,后来发现这里的报错问题是:set object is not subscriptable,也就是说 set 是不可以通过下标来访问的。因为…

uniapp:swiper-demo效果

单元格轮播 <swiper class"swiper1" :circular"true" :autoplay"true" interval"3000" previous-margin"195rpx" next-margin"195rpx"><swiper-item v-for"(item,index) in 5" :key"inde…

华为、小鹏大定爆单,智驾苦尽甘来,车主终于愿意买单

‍作者|德新 编辑|王博 国庆假期结束&#xff0c;车圈的最大热点事件&#xff0c;当属问界M7卖爆&#xff0c;上市不到一个月时间内&#xff0c;狂揽5万张大定订单。 在华为手机强势回归&#xff0c;改款问界M7大热的高光之下&#xff0c;还有一个重要趋势值得关注&#xff1…

AIGC|利用大语言模型实现智能私域问答助手

随着ChatGPT的爆火&#xff0c;最近大家开始关注到大语言模型&#xff08;LLM&#xff09;这个领域。像雨后春笋一样&#xff0c;国内外涌现出了很多LLM。作为开发者&#xff0c;我们通常会关注LLM各自擅长的领域和能力&#xff0c;然后思考如何利用它们的能力来解决某个场景或…

大数据——Spark Streaming

是什么 Spark Streaming是一个可扩展、高吞吐、具有容错性的流式计算框架。 之前我们接触的spark-core和spark-sql都是离线批处理任务&#xff0c;每天定时处理数据&#xff0c;对于数据的实时性要求不高&#xff0c;一般都是T1的。但在企业任务中存在很多的实时性的任务需求&…

Docker 网络访问原理解密

How Container Networking Works: Practical Explanation 这篇文章讲得非常棒&#xff0c;把docker network讲得非常清晰。 分为三个部分&#xff1a; 1&#xff09;docker 内部容器互联。 2&#xff09;docker 容器 访问 外部root 网络空间。 3&#xff09;外部网络空间…

读书笔记:多Transformer的双向编码器表示法(Bert)-4

多Transformer的双向编码器表示法 Bidirectional Encoder Representations from Transformers&#xff0c;即Bert&#xff1b; 第二部分 探索BERT变体 从本章开始的诸多内容&#xff0c;以理解为目标&#xff0c;着重关注对音频相关的支持&#xff08;如果有的话&#xff09;…

解读非托管流动性协议Hover: 差异化、层次化的全新借贷体系

“Hover 是 DeFi 借贷赛道的另辟蹊径者&#xff0c;除了在自身机制&#xff08;借贷模型、治理体系&#xff09;上进行创新获得内生动力外&#xff0c;背靠日渐繁荣的 Kava、Cosmos 生态进一步获得外生动力&#xff0c;发展潜力俱佳” 与 DEX 类似&#xff0c;借贷也是 DeFi 世…