GBDT模型 0基础小白也能懂(附代码)

GBDT模型 0基础小白也能懂(附代码)

原文链接

啥是GBDT

GBDT(Gradient Boosting Decision Tree),全名叫梯度提升决策树,是一种迭代的决策树算法,又叫 MART(Multiple Additive Regression Tree),它通过构造一组弱的学习器(树),并把多颗决策树的结果累加起来作为最终的预测输出。该算法将决策树与集成思想进行了有效的结合。

Gradient Boosting里的boosting是啥?

Boosting方法训练基分类器时采用串行的方式,各个基分类器之间有依赖。它的基本思路是将基分类器层层叠加,每一层在训练的时候,对前一层基分类器分错的样本,给予更高的权重(并根据前一个基分类器的表现计算误差。对于分类正确的样本,它们的权重保持不变或减少;对于分类错误的样本,算法会增加它们的权重)。测试时,根据各层分类器的结果的加权得到最终结果。

Bagging 与 Boosting 的串行训练方式不同,Bagging 方法在训练过程中,各基分类器之间无强依赖,可以进行并行训练。

GBDT详解

所有弱分类器的结果相加等于预测值。
每次都以当前预测为基准,下一个弱分类器去拟合误差函数对预测值的残差(预测值与真实值之间的误差)。
GBDT的弱分类器使用的是树模型。

实际工程实现里,GBDT 是计算负梯度,用负梯度近似残差,而不是像这样简单相减

1)GBDT与负梯度近似残差

回归任务下,GBDT在每一轮的迭代时对每个样本都会有一个预测值,此时的损失函数为均方差损失函数:

可以看出,当损失函数选用「均方误差损失」时,每一次拟合的值就是(真实值-预测值),即残差。

2)GBDT训练过程

我们来借助1个简单的例子理解一下 GBDT 的训练过程。假定训练集只有4个人\(A,B,C,D\),他们的年龄分别是\((14,16,24,26)\)。身份分别是高一学生,高三学生,应届毕业生,已工作两年。为了按照特征预测年龄

先用回归树训练后看结果,里面分的节点只是例子,没啥具体含义,先按购物金额分一下,再按上网时间或者上网时段分一下

接下来改用 GBDT 来训练。由于样本数据少,我们限定叶子节点最多为2即每棵树都只有一个分枝),并且限定树的棵树为2。先按照购物金额来分出一棵树:

上图中的树很好理解:\(A,B\) 年龄较为相近,\(C,D\) 年龄较为相近,被分为左右两支,每支用平均年龄作为预测值。

  • 我们计算残差(即「实际值」-「预测值」),所以\(A\)的残差 \(15-1=14\)
  • 这里 \(A\) 的「预测值」是指前面所有树预测结果累加的和,在当前情形下前序只有一棵树,所以直接是 \(15\),其他多树的复杂场景下需要累加计算作为 \(A\) 的预测值。

那么到这里预测完成,接下来就是要用一个弱分类器(树)去拟合误差函数对预测值的残差(预测值与真实值之间的误差)

上图中的树就是残差学习的过程了,里面提问和回答同样也只是例子:

  • \(A,B,C,D\) 的值换作残差 \(-1,1,-1,1\),再构建一棵树学习,这棵树只有两个值 \(1\)\(-1\),直接分成两个节点:\(A,C\) 在左边,\(B,D\) 在右边。
  • 这棵树学习残差,在我们当前这个简单的场景下,已经能保证预测值和实际值(上一轮残差)相等了。
  • 我们把这棵树的预测值累加到第一棵树上的预测结果上,就能得到真实年龄,这个简单例子中每个人都完美匹配,得到了真实的预测值。

最终的预测过程是这样的:

  • \(A\):高一学生,购物较少,经常问学长问题,真实年龄 14 岁,预测年龄 15-1
  • \(B\):高三学生,购物较少,经常被学弟提问,真实年龄 16 岁,预测年龄 15+1
  • \(C\):应届毕业生,购物较多,经常问学长问题,真实年龄 24 岁,预测年龄 25-1
  • \(D\):工作两年员工,购物较多,经常被学弟提问,真实年龄 26 岁,预测年龄 25+1

综上,GBDT 需要将多棵树的得分累加得到最终的预测得分,且每轮迭代,都是在现有树的基础上,增加一棵新的树去拟合前面树的预测值与真实值之间的残差。

梯度提升 vs 梯度下降

下面我们来对比一下「梯度提升」与「梯度下降」。这两种迭代优化算法,都是在每1轮迭代中,利用损失函数负梯度方向的信息,更新当前模型,只不过:

梯度提升(比如前面的GBDT):通过构建多个弱学习器,在函数空间中逼近最优解,不需要模型参数化,利用损失函数的负梯度逐步优化模型。

梯度下降(比如线性回归,逻辑回归):通过参数化的模型,利用损失函数的梯度更新参数,最终找到使损失最小的参数。

优缺点

随机森林 vs GBDT

何时使用哪个模型?

使用随机森林:如果你需要快速构建一个可用的模型,并且数据量较大,可以考虑使用随机森林。它对参数的敏感性较低,容易调参,适用于初步探索。
使用GBDT:如果你在追求更高的模型精度,尤其是在较复杂的数据集上,GBDT 通常表现更好,但需要更多的时间来调整超参数。

代码实现

还是用加州房价数据集

# 导入必要的库
import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import fetch_california_housing
from sklearn.model_selection import train_test_split
from sklearn.ensemble import GradientBoostingRegressor
from sklearn.metrics import mean_squared_error, r2_score# 1. 加载加州房价数据集
data = fetch_california_housing()
X = data.data  # 特征矩阵
y = data.target  # 目标变量(房价)# 2. 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)# 3. 创建GBDT回归模型
# n_estimators: 基分类器的数量
# learning_rate: 每个基分类器的学习率
# max_depth: 决策树的最大深度
gbdt_regressor = GradientBoostingRegressor(n_estimators=100, learning_rate=0.1, max_depth=3, random_state=42)# 4. 训练模型
gbdt_regressor.fit(X_train, y_train)# 5. 进行预测
y_pred_train = gbdt_regressor.predict(X_train)
y_pred_test = gbdt_regressor.predict(X_test)# 6. 评估模型# 计算测试集的均方误差
mse_test = mean_squared_error(y_test, y_pred_test)
print(f"Mean Squared Error (Test): {mse_test:.2f}")# 计算训练集的均方误差
mse_train = mean_squared_error(y_train, y_pred_train)
print(f"Mean Squared Error (Train): {mse_train:.2f}")# 计算R²得分
r2_test = r2_score(y_test, y_pred_test)
r2_train = r2_score(y_train, y_pred_train)
print(f"R² Score (Test): {r2_test:.2f}")
print(f"R² Score (Train): {r2_train:.2f}")# 7. 可视化GBDT的预测结果(实际值 vs. 预测值)
plt.scatter(y_test, y_pred_test, alpha=0.5)
plt.xlabel('Actual Prices')
plt.ylabel('Predicted Prices')
plt.title('Actual vs Predicted Prices (GBDT)')
plt.show()

结果如下

Mean Squared Error (Test): 0.29
Mean Squared Error (Train): 0.26
R² Score (Test): 0.78
R² Score (Train): 0.81

和我们之前的决策树相比明显好了很多,毕竟这里数据规模不大。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/792173.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

二项式定理(二项式展开)

目录引入正题延伸 引入 首先有一个广为人知的结论: \[(a+b)^2=a^2+2ab+b^2 \]那么,如何求 \((a+b)^3\) 呢?手算,如下: \[\begin{aligned} (a+b)^3 &= (a+b)\times(a+b)^2\\ &=(a+b)\times(a^2+2ab+b^2)\\ &=[a\times(a^2+2ab+b^2)]+[b\times(a^2+2ab+b^2)]\\ …

机器视觉检测的速度六大影响因素

物料处理时间 材料处理时间是指待检测材料暴露在图像采集介质前面,以便能够充分聚焦在材料上以获取图像的时间。在工业环境中,材料通常位于装配线或传送带上。相机是固定的或可移动的,放置在装配线的某个点。当材料进入相机的焦点区域时,材料处理时间开始,当材料完全聚焦时…

BigDecimal使用注意的地方

BigDecimal 是 Java 中的一个类,这个相信大家都是知道的。它的作用就是可以表示任意精度的十进制数,BigDecimal 提供了精确的数字运算,适用于需要高精度计算的场景。 一、浮点精度 我们先来看一个例子:compareTo 方法比较中,a.compareTo(b) 返回:-1: a小于b0: a等于b1: a…

2024-09-04:用go语言,给定一个长度为n的数组 happiness,表示每个孩子的幸福值,以及一个正整数k,我们需要从这n个孩子中选出k个孩子。 在筛选过程中,每轮选择一个孩子时,所有尚未选

2024-09-04:用go语言,给定一个长度为n的数组 happiness,表示每个孩子的幸福值,以及一个正整数k,我们需要从这n个孩子中选出k个孩子。 在筛选过程中,每轮选择一个孩子时,所有尚未选中的孩子的幸福值都会减少 1。需要注意的是,幸福值不能降低到负数,只有在其为正数时才能…

初探编译链接原理

这篇博文由一个 bug 引出了编译链接的整个过程。我们可以看到一个源代码文件最终变成一个可执行文件中间经历了编译和链接两个过程,编译过程又分为预编译,编译,和汇编;预编译阶段主要处理#开头的代码,编译则是进行一些语法分析和优化,最终生成汇编代码,而汇编则是生成机…

canvas版本的五子棋

代码:<!Doctype html> <html lang="zh_cn"><head><meta http-equiv="Content-Type" content="text/html; charset=utf-8" /><title>五子棋</title><meta name="Keywords" content="&quo…

vue router路径重复时报错

参考——  https://blog.csdn.net/zz00008888/article/details/119566375 报错: Avoided redundant navigation to current location: "/Eee". NavigationDuplicated: Avoided redundant navigation to current location: "/Eee".在router的index下添加…

CUDA Toolkit常见安装问题一览

CUDA Toolkit常见安装问题一览关注TechLead,复旦博士,分享云服务领域全维度开发技术。拥有10+年互联网服务架构、AI产品研发经验、团队管理经验,复旦机器人智能实验室成员,国家级大学生赛事评审专家,发表多篇SCI核心期刊学术论文,阿里云认证的资深架构师,上亿营收AI产品…

如何从Docker镜像中提取恶意文件

原创 Bypass当发生容器安全事件时,需要从容器或镜像中提取恶意文件进行分析和处理。 本文主要介绍3种常见的方法: (1) 从运行的容器中复制文件 首先,需要从镜像运行启动一个容器,然后,使用docker cp命令从容器中提取文件到宿主机。 docker run -d --name test test:v1.0 …

网站内容页无法访问!内页无法正常访问

内页无法正常访问原因:通常是因为伪静态设置不正确。 解决方法:检查伪静态规则是否正确配置。 确认服务器是否开启了rewrite模块。 生成正确的.htaccess文件,并放置在根目录下。扫码添加技术【解决问题】专注中小企业网站建设、网站安全12年。熟悉各种CMS,精通PHP+MYSQL、H…

PHP配置文件(php.ini)中上传文件大小限制怎么调整

上传文件大小限制原因:PHP配置文件(php.ini)中上传文件大小限制过小。 解决方法:修改php.ini文件中的upload_max_filesize和post_max_size。iniupload_max_filesize = 20M post_max_size = 20M扫码添加技术【解决问题】专注中小企业网站建设、网站安全12年。熟悉各种CMS,精…