回归树模型 0基础小白也能懂(附代码)

回归树模型 0基础小白也能懂(附代码)

啥是回归树模型

大家在前面的部分学习到了使用决策树进行分类,实际决策树也可以用作回归任务,我们叫作回归树。而回归树的结构还是树形结构,但是属性选择与生长方式和分类的决策树有不同。

要讲回归树,我们一定会提到CART树,CART树全称Classification And Regression Trees,包括分类树与回归树。

CART的特点是:假设决策树是二叉树,内部结点特征的取值为「是」和「否」,右分支是取值为「是」的分支,左分支是取值为「否」的分支。这样的决策树等价于「递归地二分每个特征」,将输入空间(特征空间)划分为有限个单元,并在这些单元上确定预测的概率分布,也就是在输入给定的条件下输出的条件概率分布。

这是人话吗......看半天没看懂,回归树相对于决策树来说用于处理连续型数值的目标变量。也就是说,回归树的预测输出是一个连续的实数值,例如预测房价、温度等,之前学的决策树都是处理离散的,看下面的图吧

设有数据集\(D\),构建回归树的大体思路如下:

  • ① 考虑数据集上所有特征\(j\),遍历每一个特征下可能的取值或者切分点(怎么选取的后面会说),将数据集划分为两部分\(D_1,D_2\)
  • ② 分别计算\(D_1,D_2\)的平方误差和,选择最小的平方误差对应的特征与分割点,生成两个子节点(将数据划分为两部分)。
  • ③ 对上述两个子节点递归调用步骤 ① ②,直到满足停止条件(比如最小样本数,最大数深度之类的)。

回归树构建完成后,就完成了对整个输入空间的划分(即完成了回归树的建立)。将整个输入空间划分为多个子区域,每个子区域输出为该区域内所有训练样本的平均值。我们知道了回归树其实是将输入空间划分为\(M\)个单元,每个区域的输出值是该区域内所有点\(y\)值的平均数。但我们希望构建最有效的回归树:预测值与真实值差异度最小。下面部分我们展开讲讲,回归树是如何生长的。

2.启发式切分与最优属性选择

又是最优属性选择,决策树中是信息增益和基尼系数之类的,那这里会是什么呢?

下面是我们基础的划分思路

RSS(残差平方和,Residual Sum of Squares)是用于衡量分裂质量的一个标准

  • \(y\)为每个训练样本的标签构成的标签向量,向量中的每个元素\(y_i\)对应的是每个样本的标签。
  • \(X\)为特征的集合,\(x_1,x_2,...,x_p\)为第一个特征到第p个特征
  • \(R_1,R_2,...,R_j\)为整个特征空间划分得来的J个不重叠的区域
  • \(\widetilde{y}_{R_j}\) 为划分到第\(j\)个区域\(R_j\)的样本的平均标签值,用这个值作为该区域的预测值,即如果有一个测试样本在测试时落入到该区域,就将该样本的标签值预测为\(\widetilde{y}_{R_j}\)

但是这个最小化和探索的过程,计算量是非常非常大的。我们采用「探索式的递归二分」来尝试解决这个问题。

递归二分

回归树采用的是「自顶向下的贪婪式递归方案」。这里的贪婪,指的是每一次的划分,只考虑当前最优,而不回头考虑之前的划分。

我们再来看看「递归切分」。下方有两个对比图,其中左图是非递归方式切分得到的,而右图是二分递归的方式切分得到的空间划分结果(下一次划分一定是在之前的划分基础上将某个区域一份为二)。

(感觉思路就是不一次性划分完,根据当前现状一步一步来)

回归树总体流程类似于分类树:分枝时穷举每一个特征可能的划分阈值,来寻找最优切分特征和最优切分点阈值,衡量的方法是平方误差最小化。分枝直到达到预设的终止条件(如叶子个数上限)就停止。

但通常在处理具体问题时,单一的回归树模型能力有限且有可能陷入过拟合,我们经常会利用集成学习中的Boosting思想,对回归树进行增强,得到的新模型就是提升树(Boosting Decision Tree),进一步,可以得到梯度提升树(Gradient Boosting Decision Tree,GBDT),再进一步可以升级到XGBoost。通过多棵回归树拟合残差,不断减小预测值与标签值的偏差,从而达到精准预测的目的,会在后面介绍这些高级算法。

过拟合与正则化

过拟合问题处理

(1)约束控制树的过度生长
限制树的深度:当达到设置好的最大深度时结束树的生长。
分类误差法:当树继续生长无法得到客观的分类误差减小,就停止生长。
叶子节点最小数据量限制:一个叶子节点的数据量过小,树停止生长。

(2)剪枝
约束树生长的缺点就是提前扼杀了其他可能性,过早地终止了树的生长,我们也可以等待树生长完成以后再进行剪枝,即所谓的后剪枝,而后剪枝算法主要有以下几种:
Reduced-Error Pruning(REP,错误率降低剪枝)。
Pesimistic-Error Pruning(PEP,悲观错误剪枝)。
Cost-Complexity Pruning(CCP,代价复杂度剪枝)。
Error-Based Pruning(EBP,基于错误的剪枝)。

正则化

剪枝的目标是找到使得以下表达式最小的子树\(T_a\)

\(T_a=RSS+\alpha|T|\)

  • 其中\(\alpha\)是正则化项的系数,可以通过交叉验证去选择。
  • \(|T|\)是回归树叶子节点的个数(即树的复杂度)

代码实现

# 导入必要的库
import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import fetch_california_housing  # 加载加州房价数据集
from sklearn.model_selection import train_test_split  # 用于划分训练集和测试集
from sklearn.tree import DecisionTreeRegressor  # 使用回归树
from sklearn.metrics import mean_squared_error, r2_score  # 用于评估模型性能# 1. 加载加州房价数据集
data = fetch_california_housing()  # 加载加州房价数据
X = data.data  # 特征矩阵(包含了多个影响房价的因素,如人口密度、纬度、经度等)
y = data.target  # 目标变量(房价,单位为千美元)# 2. 划分训练集和测试集
# 我们将数据集分为训练集和测试集,70%用于训练模型,30%用于测试模型的表现
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)# 3. 创建回归树模型
# max_depth=5 限制了树的最大深度为5,防止过拟合
# random_state=42 确保每次运行代码时模型的结果是可重复的
regressor = DecisionTreeRegressor(max_depth=5, random_state=42)# 4. 训练模型
# fit() 函数用于训练模型,使其学习训练集中的特征与房价之间的关系
regressor.fit(X_train, y_train)# 5. 进行预测
# 使用训练好的模型对训练集和测试集进行预测
y_pred_train = regressor.predict(X_train)  # 对训练集的预测结果
y_pred_test = regressor.predict(X_test)  # 对测试集的预测结果# 6. 评估模型# 计算测试集的均方误差(MSE)
# MSE 衡量模型预测值与实际值之间的平均误差,数值越小表示预测越准确
mse_test = mean_squared_error(y_test, y_pred_test)
print(f"Mean Squared Error (Test): {mse_test:.2f}")# 计算训练集的均方误差(MSE)
# 可以用来评估模型是否在训练集上过拟合
mse_train = mean_squared_error(y_train, y_pred_train)
print(f"Mean Squared Error (Train): {mse_train:.2f}")# 计算R²得分
# R² 是决定系数,衡量模型对数据的拟合程度,1.0表示完全拟合,0表示无法拟合
r2_test = r2_score(y_test, y_pred_test)
r2_train = r2_score(y_train, y_pred_train)
print(f"R² Score (Test): {r2_test:.2f}")
print(f"R² Score (Train): {r2_train:.2f}")# 7. 可视化回归树的预测结果(实际值 vs. 预测值)
# 我们绘制散点图来展示测试集上的实际房价与预测房价的对比
plt.scatter(y_test, y_pred_test)
plt.xlabel('Actual Prices')  # 横轴是实际的房价
plt.ylabel('Predicted Prices')  # 纵轴是模型预测的房价
plt.title('Actual vs Predicted Prices')  # 图表标题
plt.show()

结果如下

看下对角线发现很多店落在右下角,看来预测的结果还是低估了房价。

Mean Squared Error (Test): 0.52
Mean Squared Error (Train): 0.49
R² Score (Test): 0.60
R² Score (Train): 0.63

emm准度一般,也在情理之中,R²得分越接近1,模型的预测效果越好。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/792016.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

如何通过API接口实现库存的精准掌控

https://img2024.cnblogs.com/blog/3506472/202409/3506472-20240904105309327-1011277110.png在电子商务的快速发展中,库存管理已成为衡量企业运营效率的关键指标。随着消费者对快速配送和商品可用性的期望不断提高,电商企业必须找到更智能、更高效的库存管理方法。电商库存…

manim边学边做--曲线类

manim中曲线,除了前面介绍的圆弧类曲线,也可以绘制任意的曲线。 manim中提供的CubicBezier模块,可以利用三次贝塞尔曲线的方式绘制任意曲线。 关于贝塞尔曲线的介绍,可以参考:https://en.wikipedia.org/wiki/B%C3%A9zier_curve。 本文主要介绍贝塞尔曲线和两种带箭头的曲线…

adb获取手机电池信息

1、获取手机电池信息adb shell dumpsys battery字段说明Current Battery Service state:AC powered: true #交流供电USB powered: false #usb供电Wireless powered: false #无线供电Max charging current: 75000 #最大充电电流Max charging volt…

在pycharm中使用copilot

一、注册、获取使用权限 什么双密码验证、学生验证的过程就不重复了,按网上的教程来就行。 需要注意的是,Github学生认证通过之后,并不是能够立马使用copilot,得等三天copilot的免费使用权限才会批下来。 二、在pycharm中使用copilot 1、安装插件、登录Github等,按照网上的…

若依项目pom文件添加jar包已依赖报红,dependency not found,提示找不到jar包

原因很简单,因为我写在了父项目的pom文件中,写在了 里面。这里只是对依赖的版本进行管理。点击查看代码<!-- 依赖声明 --><dependencyManagement><dependencies>正确的做法应该是在子项目中的pom文件中引入对应依赖,在父项目的pom文件中填上对应的依赖版…

0 JavaScript高级程序设计(第4版)【JS红宝书】【详细思维导图】【持续更新】

ProcessOn访问链接 JavaScript高级程序设计(第4版)阅读路线图,涵盖:基本知识进阶内容BOM和DOMJavascript APIJavaScript设计模式和实践策略ProcessOn访问链接本文来自博客园,作者:muling9955,转载请注明原文链接:https://www.cnblogs.com/muling-blog/p/18395904

架构师备考的一些思考

前言 之前的python-pytorch的系列文章还没有写完,只是写到卷积神经网络。因为我报名成功了系统架构师的考试,所以决定先备考,等考完再继续写。 虽然架构师证书不能证明技术水平,但在现实生活中的某些情况下是有意义的。考试虽然无聊,但有些考题还是蛮有意思的。 思考 看了…

vissim检测路段通过车辆数-cnblog

vissim记录(vs4.3) 目录vissim记录(vs4.3)1.数据收集点设置2. 数据采集配置设置3. 结果查看设置数据收集点,进行截面数据统计1.数据收集点设置2. 数据采集配置设置我是五岔口,就设置了五组 还没完继续配置,选择要采集的数据。如果只需要统计通过车辆数,则只选择number veh3…

OpenTelemetry 实战:gRPC 监控的实现原理

前言最近在给 opentelemetry-java-instrumentation 提交了一个 PR,是关于给 gRPC 新增四个 metrics:rpc.client.request.size: 客户端请求包大小 rpc.client.response.size:客户端收到的响应包大小 rpc.server.request.size:服务端收到的请求包大小 rpc.server.response.si…

ICMAN液位检测方案

非接触式液位检测提醒方案TA是什么? ICMAN液位检测是基于双通道比较电容式液位检测原理,来判断容器中是否有液体或者液体是否达到一定高度。 有什么用? ICMAN液位检测可以实现非接触式检测,起到高低、不同液位提醒、缺水提醒、溢水提醒等作用,让我们的生产生活更加安全、便…