回归树模型 0基础小白也能懂(附代码)
啥是回归树模型
大家在前面的部分学习到了使用决策树进行分类,实际决策树也可以用作回归任务,我们叫作回归树。而回归树的结构还是树形结构,但是属性选择与生长方式和分类的决策树有不同。
要讲回归树,我们一定会提到CART树,CART树全称Classification And Regression Trees,包括分类树与回归树。
CART的特点是:假设决策树是二叉树,内部结点特征的取值为「是」和「否」,右分支是取值为「是」的分支,左分支是取值为「否」的分支。这样的决策树等价于「递归地二分每个特征」,将输入空间(特征空间)划分为有限个单元,并在这些单元上确定预测的概率分布,也就是在输入给定的条件下输出的条件概率分布。
这是人话吗......看半天没看懂,回归树相对于决策树来说用于处理连续型数值的目标变量。也就是说,回归树的预测输出是一个连续的实数值,例如预测房价、温度等,之前学的决策树都是处理离散的,看下面的图吧
设有数据集\(D\),构建回归树的大体思路如下:
- ① 考虑数据集上所有特征\(j\),遍历每一个特征下可能的取值或者切分点(怎么选取的后面会说),将数据集划分为两部分\(D_1,D_2\)
- ② 分别计算\(D_1,D_2\)的平方误差和,选择最小的平方误差对应的特征与分割点,生成两个子节点(将数据划分为两部分)。
- ③ 对上述两个子节点递归调用步骤 ① ②,直到满足停止条件(比如最小样本数,最大数深度之类的)。
回归树构建完成后,就完成了对整个输入空间的划分(即完成了回归树的建立)。将整个输入空间划分为多个子区域,每个子区域输出为该区域内所有训练样本的平均值。我们知道了回归树其实是将输入空间划分为\(M\)个单元,每个区域的输出值是该区域内所有点\(y\)值的平均数。但我们希望构建最有效的回归树:预测值与真实值差异度最小。下面部分我们展开讲讲,回归树是如何生长的。
2.启发式切分与最优属性选择
又是最优属性选择,决策树中是信息增益和基尼系数之类的,那这里会是什么呢?
下面是我们基础的划分思路
RSS(残差平方和,Residual Sum of Squares)是用于衡量分裂质量的一个标准
- \(y\)为每个训练样本的标签构成的标签向量,向量中的每个元素\(y_i\)对应的是每个样本的标签。
- \(X\)为特征的集合,\(x_1,x_2,...,x_p\)为第一个特征到第p个特征
- \(R_1,R_2,...,R_j\)为整个特征空间划分得来的J个不重叠的区域
- \(\widetilde{y}_{R_j}\) 为划分到第\(j\)个区域\(R_j\)的样本的平均标签值,用这个值作为该区域的预测值,即如果有一个测试样本在测试时落入到该区域,就将该样本的标签值预测为\(\widetilde{y}_{R_j}\)
但是这个最小化和探索的过程,计算量是非常非常大的。我们采用「探索式的递归二分」来尝试解决这个问题。
递归二分
回归树采用的是「自顶向下的贪婪式递归方案」。这里的贪婪,指的是每一次的划分,只考虑当前最优,而不回头考虑之前的划分。
我们再来看看「递归切分」。下方有两个对比图,其中左图是非递归方式切分得到的,而右图是二分递归的方式切分得到的空间划分结果(下一次划分一定是在之前的划分基础上将某个区域一份为二)。
(感觉思路就是不一次性划分完,根据当前现状一步一步来)
回归树总体流程类似于分类树:分枝时穷举每一个特征可能的划分阈值,来寻找最优切分特征和最优切分点阈值,衡量的方法是平方误差最小化。分枝直到达到预设的终止条件(如叶子个数上限)就停止。
但通常在处理具体问题时,单一的回归树模型能力有限且有可能陷入过拟合,我们经常会利用集成学习中的Boosting思想,对回归树进行增强,得到的新模型就是提升树(Boosting Decision Tree),进一步,可以得到梯度提升树(Gradient Boosting Decision Tree,GBDT),再进一步可以升级到XGBoost。通过多棵回归树拟合残差,不断减小预测值与标签值的偏差,从而达到精准预测的目的,会在后面介绍这些高级算法。
过拟合与正则化
过拟合问题处理
(1)约束控制树的过度生长
限制树的深度:当达到设置好的最大深度时结束树的生长。
分类误差法:当树继续生长无法得到客观的分类误差减小,就停止生长。
叶子节点最小数据量限制:一个叶子节点的数据量过小,树停止生长。
(2)剪枝
约束树生长的缺点就是提前扼杀了其他可能性,过早地终止了树的生长,我们也可以等待树生长完成以后再进行剪枝,即所谓的后剪枝,而后剪枝算法主要有以下几种:
Reduced-Error Pruning(REP,错误率降低剪枝)。
Pesimistic-Error Pruning(PEP,悲观错误剪枝)。
Cost-Complexity Pruning(CCP,代价复杂度剪枝)。
Error-Based Pruning(EBP,基于错误的剪枝)。
正则化
剪枝的目标是找到使得以下表达式最小的子树\(T_a\)
\(T_a=RSS+\alpha|T|\)
- 其中\(\alpha\)是正则化项的系数,可以通过交叉验证去选择。
- \(|T|\)是回归树叶子节点的个数(即树的复杂度)
代码实现
# 导入必要的库
import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import fetch_california_housing # 加载加州房价数据集
from sklearn.model_selection import train_test_split # 用于划分训练集和测试集
from sklearn.tree import DecisionTreeRegressor # 使用回归树
from sklearn.metrics import mean_squared_error, r2_score # 用于评估模型性能# 1. 加载加州房价数据集
data = fetch_california_housing() # 加载加州房价数据
X = data.data # 特征矩阵(包含了多个影响房价的因素,如人口密度、纬度、经度等)
y = data.target # 目标变量(房价,单位为千美元)# 2. 划分训练集和测试集
# 我们将数据集分为训练集和测试集,70%用于训练模型,30%用于测试模型的表现
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)# 3. 创建回归树模型
# max_depth=5 限制了树的最大深度为5,防止过拟合
# random_state=42 确保每次运行代码时模型的结果是可重复的
regressor = DecisionTreeRegressor(max_depth=5, random_state=42)# 4. 训练模型
# fit() 函数用于训练模型,使其学习训练集中的特征与房价之间的关系
regressor.fit(X_train, y_train)# 5. 进行预测
# 使用训练好的模型对训练集和测试集进行预测
y_pred_train = regressor.predict(X_train) # 对训练集的预测结果
y_pred_test = regressor.predict(X_test) # 对测试集的预测结果# 6. 评估模型# 计算测试集的均方误差(MSE)
# MSE 衡量模型预测值与实际值之间的平均误差,数值越小表示预测越准确
mse_test = mean_squared_error(y_test, y_pred_test)
print(f"Mean Squared Error (Test): {mse_test:.2f}")# 计算训练集的均方误差(MSE)
# 可以用来评估模型是否在训练集上过拟合
mse_train = mean_squared_error(y_train, y_pred_train)
print(f"Mean Squared Error (Train): {mse_train:.2f}")# 计算R²得分
# R² 是决定系数,衡量模型对数据的拟合程度,1.0表示完全拟合,0表示无法拟合
r2_test = r2_score(y_test, y_pred_test)
r2_train = r2_score(y_train, y_pred_train)
print(f"R² Score (Test): {r2_test:.2f}")
print(f"R² Score (Train): {r2_train:.2f}")# 7. 可视化回归树的预测结果(实际值 vs. 预测值)
# 我们绘制散点图来展示测试集上的实际房价与预测房价的对比
plt.scatter(y_test, y_pred_test)
plt.xlabel('Actual Prices') # 横轴是实际的房价
plt.ylabel('Predicted Prices') # 纵轴是模型预测的房价
plt.title('Actual vs Predicted Prices') # 图表标题
plt.show()
结果如下
看下对角线发现很多店落在右下角,看来预测的结果还是低估了房价。
Mean Squared Error (Test): 0.52
Mean Squared Error (Train): 0.49
R² Score (Test): 0.60
R² Score (Train): 0.63
emm准度一般,也在情理之中,R²得分越接近1,模型的预测效果越好。