XGBoost模型 0基础小白也能懂(附代码)

news/2024/12/23 10:21:07/文章来源:https://www.cnblogs.com/Mephostopheles/p/18397154

XGBoost模型 0基础小白也能懂(附代码)

原文链接

啥是XGBoost模型

XGBoost 是 eXtreme Gradient Boosting 的缩写称呼,它是一个非常强大的 Boosting 算法工具包,优秀的性能(效果与速度)让其在很长一段时间内霸屏数据科学比赛解决方案榜首,现在很多大厂的机器学习方案依旧会首选这个模型。

XGBoost 在并行计算效率、缺失值处理、控制过拟合、预测泛化能力上都变现非常优秀。本文我们给大家详细展开介绍 XGBoost,包含「算法原理」和「工程实现」两个方面。

关于 XGBoost 的原理,其作者陈天奇本人有一个非常详尽的Slides做了系统性的介绍。

Boosted Tree

Boosted Tree(提升树)是一种常用的机器学习方法,属于集成学习的一种。它通过将多个弱学习器(通常是决策树)组合起来,以提升整个模型的预测性能。Boosted Tree的核心思想是通过逐步训练多个决策树,每个树都试图修正前一个树的错误,最终得到一个更强大的模型。

模型:假设我们有\(K\)棵树\(\hat{y_i}=\sum_{k=1}^Kf_k(x_i),f_k\in{F}\)\(F\)为包含所有回归树的函数空间。
目标函数:\(Obj=\sum_{i=1}^nl(y_i,\hat{y_i})+\sum_{k=1}^K\Omega(f_k)\)
\(\sum_{i=1}^nl(y_i,\hat{y_i})\)是成本函数
\(\sum_{k=1}^K\Omega(f_k)\)是正则化项,代表树的复杂程度,树越复杂正则化项的值越高(正则化项如何定义我们会在后面详细说)。

当我们讨论决策树或相关的树模型时,通常是启发式的。启发式(heuristic)在机器学习中指的是使用经验法则或近似方法来解决问题,而不保证找到最优解。

Gradient Boosting(如何学习)

在做 GBDT 的时候,我们没有办法使用 SGD(Stochastic Gradient Descent,随机梯度下降),因为它们是树,而非数值向量——也就是说从原来我们熟悉的参数空间变成了函数空间。Gradient Boosting Decision Trees(GBDT)与深度学习或线性模型不同,它的核心不是直接通过参数更新来优化,而是通过构建新的决策树来逐步降低误差。

解决方案:初始化一个预测值,每次迭代添加一个新函数\((f)\)

1)目标函数变换

根据解决方案可以对目标函数进行初步变形

其中constant是常数项,比如\(\Omega(f_1),\Omega(f_2)\)之类的,然后第三行就是考虑平方损失,\(l(y_i,\hat{y_i})=\frac{1}{2}(y_i-\hat{y_i})^2\),代进去就行

所以我们的目的就是找到\(f(t)\)使得目标函数最低。然而,经过上面初次变形的目标函数仍然很复杂,目标函数会产生二次项。引入泰勒公式

这图也多少有点问题,是在还没考虑平方损失的地方引入泰勒公式,然后泰勒公式也有问题,后面两项应该是\(f(x)\)的一阶导数和二阶导数,所以才是\(g_i,h_i\)

再把里面的常数项提取出,和\(f_t\)无关

2)重新定义树

前面已经用\(f_t(x)\)表示一棵树,在本小节,我们重新定义一下树:我们通过叶子结点中的分数向量和将实例映射到叶子结点的索引映射函数来定义树:(有点儿抽象,具体请看下图)

图里有问题,第一个叶子结点权重是+2

3)定义树的复杂程度

其中\(T\)才是叶子节点的个数,\(\gamma\)是控制树的复杂度的参数,树的叶子节点越多,复杂度越高。通过调节
\(\gamma\)可以控制模型的复杂度。后面一堆是 L2 Norm正则化系数

4)重新审视目标函数

定义在叶子结点\(j\)中的实例的集合为:\(I_j=\{i|q(x_i)=j\}\),这么定义也是为了能够构建出第三个式子,都写成\(\sum_{j=1}^T\)

同时也会发现上式是\(T\)个独立二次函数的和

5)计算叶子结点的值

搞了一大坨,其实也就是先把值换成\(G_j,H_j\),然后用一元二次方程求一个最优值就完了。

下图是前面公式讲解对应的一个实际例子。

这里再次总结一下,我们已经把目标函数变成了仅与\(G,H,\gamma,\lambda,T\)这五项已知参数有关的函数,把之前的变量\(f_t\)消灭掉了,也就不需要对每一个叶子进行打分了!

那么现在问题来,刚才我们提到,以上这些是假设树结构确定的情况下得到的结果。但是树的结构有好多种,我们应该如何确定呢?

6) 贪婪算法生成树

上一部分中我们假定树的结构是固定的。但是,树的结构其实是有无限种可能的,本小节我们使用贪婪算法生成树:

首先生成一个深度为0的树(只有一个根结点,也叫叶子结点)

对于每棵树的每个叶子结点,尝试去做分裂(生成两个新的叶子结点,原来的叶子结点不再是叶子结点)。在增加了分裂后的目标函数前后变化为(我们希望增加了树之后的目标函数小于之前的目标函数,所以用之前的目标函数减去之后的目标函数):

\(Gain=\frac{1}{2}(\frac{G_L^2}{H_L+\lambda}+\frac{G_R^2}{H_R+\lambda}-\frac{(G_L+G_R)^2}{H_L+H_R+\lambda})-\gamma\)

接下来要考虑的是如何寻找最佳分裂点。

例如,如果\(x_j\)是年龄,当分裂点是\(a\)的时候的增益\(Gain\)是多少?

其实这里对排序后的实例进行从左到右的线性扫描就足以决定特征的最佳分裂点。从左到右依次扫描:一旦数据按照特征值进行了排序,我们从第一个样本开始,依次计算每个可能的分裂点。对于每个分裂点,我们把样本分为“左侧”和“右侧”两个子集,分别计算划分前后目标函数的变化。下面还有别的一些办法

7)如何处理分类型变量

在很多情况下,我们不需要为分类变量设计特殊的处理方式,可以将其转换为one-hot 编码来处理。

\(z_j= \begin{cases} 0& \text{if x is in category y}\\ 1& \text{otherwise} \end{cases} \)

如果有太多的分类的话,矩阵会非常稀疏,算法会优先处理稀疏数据。

8) 修剪和正则化

回顾之前的增益,当训练损失减少的值小于正则化带来的复杂度时,增益有可能会是负数,此时就是模型的简单性和可预测性之间的权衡

XGBoost核心原理归纳解析

铺垫了那么多,总算到这里了。XGBoost 也是一个 Boosting 加法模型,每一步迭代只优化当前步中的子模型。

\(m\)步我们有:\(F_m(x_i)=F_{m-1}(x_i)+f_m(x_i)\)

\(f_m(x_i)\)为当前步的子模型。
\(F_{m-1}(x_i)\)为前\(m-1\)个完成训练且固定了的子模型。

泰勒展开

然后去掉常数,带入复杂度(和之前一样)

1)近似算法

基于性能的考量,XGBoost 还对贪心准则做了一个近似版本,简单说,处理方式是「将特征分位数作为划分候选点」。这样将划分候选点集合由全样本间的遍历缩减到了几个分位数之间的遍历。

展开来看,特征分位数的选取还有 global 和 local 两种可选策略:

精确贪心准则:这是默认的精确算法,遍历所有可能分裂点,找到能最大化增益的点。计算量最大,但分裂效果最优。
Global 近似分裂:使用全体样本的特征分位数进行一次性划分,分裂点在所有节点中复用,计算量大幅减少,适合较大的数据集。
Local 近似分裂:在每个节点分裂前根据当前节点的样本重新计算特征分位数,能够更加灵活适应不同节点的特征分布,适合样本分布差异较大的情况。

近似算法的性能与精确贪心算法几乎相同,但大大降低了计算成本。

2)加权分位数

在 XGBoost 中,加权分位数(Weighted Quantile Sketch)用于加速分裂点的寻找过程。加权分位数算法并不是直接根据样本的特征值来划分分位点,而是考虑了样本的二阶导数(Hessian)作为权重,从而更好地平衡分裂点的选择,特别是在近似算法中。

令偏导为0易得\(f_m^*(x_i)=-\frac{g_i}{h_i}\)

3) 列采样与学习率

列采样指的是在构建每棵决策树时,XGBoost 不会使用全部特征,而是随机选择部分特征用于分裂。这种方法源自于随机森林的思想,目的是增加模型的多样性,从而防止过拟合。

学习率在梯度提升树(GBDT)中是一个非常重要的超参数,用于控制每棵树对模型的贡献。学习率可以防止模型更新过快,从而提升模型的稳定性和性能。也叫步长、shrinkage,具体的操作是在每个子模型前(即每个叶节点的回归值上)乘上该系数,不让单颗树太激进地拟合,留有一定空间,使迭代更稳定。XGBoost默认设定为 。

4) 特征缺失与稀疏性

简单说,它的做法是将缺失值和稀疏\(0\)值等同视作缺失值,将其「绑定」在一起,分裂节点的遍历会跳过缺失值的整体。这样大大提高了运算效率。

比如在下面的例子中有六种划分情况,XGBoost 会遍历以上6种情况(3个非缺失值的切分点×缺失值的两个方向),最大的分裂收益就是本特征上的分裂收益

XGBoost工程优化

1)并行列块设计

XGBoost 将每一列特征提前进行排序,以块(Block)的形式储存在缓存中,并以索引将特征值和梯度统计量对应起来,每次节点分裂时会重复调用排好序的块。而且不同特征会分布在独立的块中,因此可以进行分布式或多线程的计算。

2)缓存访问优化

特征值排序后通过索引来取梯度\(g_i,h_i\)会导致访问的内存空间不一致,进而降低缓存的命中率,影响算法效率。为解决这个问题,XGBoost为每个线程分配一个单独的连续缓存区,用来存放梯度信息。

3) 核外块计算

数据量非常大的情形下,无法同时全部载入内存。XGBoost 将数据分为多个 blocks 储存在硬盘中,使用一个独立的线程专门从磁盘中读取数据到内存中,实现计算和读取数据的同时进行。
为了进一步提高磁盘读取数据性能,XGBoost 还使用了两种方法:

① 压缩 block,用解压缩的开销换取磁盘读取的开销。
② 将 block 分散储存在多个磁盘中,提高磁盘吞吐量。

XGBoost vs GBDT

GBDT 是机器学习算法,XGBoost 在算法基础上还有一些工程实现方面的优化。

GBDT 使用的是损失函数一阶导数,相当于函数空间中的梯度下降;XGBoost 还使用了损失函数二阶导数,相当于函数空间中的牛顿法。

正则化:XGBoost 显式地加入了正则项来控制模型的复杂度,能有效防止过拟合。

列采样:XGBoost 采用了随机森林中的做法,每次节点分裂前进行列随机采样。

缺失值:XGBoost 运用稀疏感知策略处理缺失值,GBDT无缺失值处理策略。

并行高效:XGBoost 的列块设计能有效支持并行运算,效率更优。

代码实现

需要先下载xgboost

pip install xgboost

代码如下

# 导入所需的库
import xgboost as xgb
from sklearn.model_selection import train_test_split
from sklearn.datasets import load_diabetes  # 替换为 load_diabetes
from sklearn.metrics import mean_squared_error# 1. 加载糖尿病数据集
# 这个数据集包含442个样本,10个特征,用于预测一个连续目标变量
diabetes = load_diabetes()
X, y = diabetes.data, diabetes.target  # X是特征数据,y是标签(目标变量)# 2. 划分数据集为训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)# 3. 将数据转换为 DMatrix 格式
dtrain = xgb.DMatrix(X_train, label=y_train)
dtest = xgb.DMatrix(X_test, label=y_test)# 4. 设置 XGBoost 模型的超参数
params = {'objective': 'reg:squarederror',  # 回归任务使用的目标函数,平方误差'max_depth': 3,                   # 决策树的最大深度,控制模型的复杂度'eta': 0.05,                       # 学习率,控制每棵树对整体模型的贡献'eval_metric': 'rmse' ,            # 评估指标,使用均方根误差(RMSE)'lambda': 2,                        # L2 正则化项,防止过拟合'alpha': 0.5   # L1 正则化项
}# 5. 设定训练轮数
num_round = 200  # 训练的轮数,即构建多少棵树# 6. 定义评估数据集
evals = [(dtrain, 'train'), (dtest, 'eval')]  # (数据集, 数据集名称)# 7. 训练 XGBoost 模型,加入 early_stopping_rounds早停机制,防止过拟合
bst = xgb.train(params, dtrain, num_round, evals, early_stopping_rounds=10)# 8. 使用训练好的模型对测试集进行预测
y_pred = bst.predict(dtest)# 9. 评估模型性能
mse = mean_squared_error(y_test, y_pred)
print(f"Mean Squared Error: {mse}")# 10. 保存训练好的模型
bst.save_model('xgboost_model.json')# 11. 加载已保存的模型
loaded_bst = xgb.Booster()
loaded_bst.load_model('xgboost_model.json')# 12. 使用加载的模型进行预测
y_pred_loaded = loaded_bst.predict(dtest)
mse_loaded = mean_squared_error(y_test, y_pred_loaded)
print(f"Mean Squared Error from loaded model: {mse_loaded}")

结果如下

[0]	train-rmse:76.08309	eval-rmse:71.75905
[1]	train-rmse:74.34324	eval-rmse:70.47408
[2]	train-rmse:72.66427	eval-rmse:69.24759
[3]	train-rmse:71.10664	eval-rmse:68.09809
[4]	train-rmse:69.63498	eval-rmse:67.14668
[5]	train-rmse:68.24045	eval-rmse:66.09854
[6]	train-rmse:66.93042	eval-rmse:64.91738
[7]	train-rmse:65.73304	eval-rmse:64.08775
[8]	train-rmse:64.58640	eval-rmse:63.26052
[9]	train-rmse:63.51304	eval-rmse:62.49745
[10]	train-rmse:62.44810	eval-rmse:61.64759
[11]	train-rmse:61.51387	eval-rmse:60.96222
[12]	train-rmse:60.61767	eval-rmse:60.32972
[13]	train-rmse:59.77722	eval-rmse:59.74329
[14]	train-rmse:59.01348	eval-rmse:59.13121
[15]	train-rmse:58.24704	eval-rmse:58.55106
[16]	train-rmse:57.57392	eval-rmse:58.15165
[17]	train-rmse:56.92761	eval-rmse:57.68188
[18]	train-rmse:56.33319	eval-rmse:57.37781
[19]	train-rmse:55.72582	eval-rmse:56.97001
[20]	train-rmse:55.14420	eval-rmse:56.45029
[21]	train-rmse:54.61096	eval-rmse:55.97904
[22]	train-rmse:54.12594	eval-rmse:55.57225
[23]	train-rmse:53.68383	eval-rmse:55.39305
[24]	train-rmse:53.24822	eval-rmse:55.01127
[25]	train-rmse:52.85214	eval-rmse:54.85699
[26]	train-rmse:52.43814	eval-rmse:54.49904
[27]	train-rmse:52.07004	eval-rmse:54.42905
[28]	train-rmse:51.68191	eval-rmse:54.25354
[29]	train-rmse:51.28268	eval-rmse:54.09452
[30]	train-rmse:50.94229	eval-rmse:54.06703
[31]	train-rmse:50.58475	eval-rmse:53.88010
[32]	train-rmse:50.24739	eval-rmse:53.74475
[33]	train-rmse:49.97042	eval-rmse:53.49905
[34]	train-rmse:49.65855	eval-rmse:53.41597
[35]	train-rmse:49.38190	eval-rmse:53.34692
[36]	train-rmse:49.07203	eval-rmse:53.32202
[37]	train-rmse:48.81472	eval-rmse:53.22084
[38]	train-rmse:48.57124	eval-rmse:53.24058
[39]	train-rmse:48.33730	eval-rmse:53.13983
[40]	train-rmse:47.97171	eval-rmse:53.05406
[41]	train-rmse:47.75619	eval-rmse:52.87405
[42]	train-rmse:47.43067	eval-rmse:52.80852
[43]	train-rmse:47.18844	eval-rmse:52.70296
[44]	train-rmse:46.96694	eval-rmse:52.61260
[45]	train-rmse:46.79053	eval-rmse:52.58588
[46]	train-rmse:46.58746	eval-rmse:52.51602
[47]	train-rmse:46.38476	eval-rmse:52.50433
[48]	train-rmse:46.15591	eval-rmse:52.44922
[49]	train-rmse:46.00542	eval-rmse:52.36981
[50]	train-rmse:45.84480	eval-rmse:52.27445
[51]	train-rmse:45.63700	eval-rmse:52.23794
[52]	train-rmse:45.49250	eval-rmse:52.25740
[53]	train-rmse:45.31208	eval-rmse:52.16836
[54]	train-rmse:45.15374	eval-rmse:52.22044
[55]	train-rmse:45.00284	eval-rmse:52.15072
[56]	train-rmse:44.87677	eval-rmse:52.04112
[57]	train-rmse:44.71921	eval-rmse:52.08482
[58]	train-rmse:44.55626	eval-rmse:52.02783
[59]	train-rmse:44.41483	eval-rmse:52.09304
[60]	train-rmse:44.27997	eval-rmse:52.03098
[61]	train-rmse:44.15710	eval-rmse:52.08378
[62]	train-rmse:44.00683	eval-rmse:52.02136
[63]	train-rmse:43.84878	eval-rmse:52.06178
[64]	train-rmse:43.74180	eval-rmse:52.06495
[65]	train-rmse:43.59775	eval-rmse:52.08875
[66]	train-rmse:43.44009	eval-rmse:52.20317
[67]	train-rmse:43.29717	eval-rmse:52.14245
[68]	train-rmse:43.10437	eval-rmse:52.15464
[69]	train-rmse:43.00768	eval-rmse:52.17011
[70]	train-rmse:42.87951	eval-rmse:52.11852
[71]	train-rmse:42.79951	eval-rmse:52.21249
[72]	train-rmse:42.66769	eval-rmse:52.22331
Mean Squared Error: 2727.2736118611274
Mean Squared Error from loaded model: 2727.2736118611274

train-rmse是训练集上的预测值与真实值之间的误差。eval-rmse是模型在测试集上的 RMSE

分析下早停机制下最后的数据,42.66769 表示在训练集上,模型的预测误差为 42.67。RMSE 越低,表示模型在训练集上拟合得越好。52.22 说明模型在测试集上的预测误差明显高于训练集,表明模型可能存在一定的过拟合问题,模型在训练集上表现良好,但在新数据(测试集)上的泛化能力不如在训练集上的表现。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/793662.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

编程技术开发105本经典书籍推荐分享

最近整理了好多的技术书籍,对于提高自己能力来说还是很有用的,当然要有选择的看,不然估计退休了都不一定看得完,分享给需要的同学。 编程技术开发105本经典书籍推荐:https://zhangfeidezhu.com/?p=753 分享截图本文来自博客园,作者:张飞的猪,转载请注明原文链接:http…

ArcMap批量附色操作,并保存mxd

ArcMap批量附色操作,并保存mxd 1、对单文件操作 1、保存当前ArcMap中打开的shp文件为mxd文件 打开label_shp_root中的任意一个shp文件夹保存成mxd文件2、对当前在arcmap中打开的shp文件应用color配色 color配色是手动设置好一个shp文件夹的配色方案并保存成mxd文件应用color.m…

Linux 忘记密码

最近需要搞几台虚拟机,之前的vm密码进不去 找了几个方法 不是很贴切, Centos7 重启页面 e--> grub ,在linux16 行 修改ro 为rw,最后加上 init=/bin/sh F10 或Ctrl+x进入这里的 rw代替 进os之后 mount -o remount,rw /passwd touch /.autorelabel exec /sbin/init

使用 `Roslyn` 分析器和修复器对.cs源代码添加头部注释

之前写过两篇关于Roslyn源生成器生成源代码的用例,今天使用Roslyn的代码修复器CodeFixProvider实现一个cs文件头部注释的功能, 代码修复器会同时涉及到CodeFixProvider和DiagnosticAnalyzer, 实现FileHeaderAnalyzer 首先我们知道修复器的先决条件是分析器,比如这里,如果要对代…

线性dp:LeetCode516 .最长回文子序列

LeetCode516 .最长回文子序列 题目叙述: 力扣题目链接(opens new window) 给你一个字符串 s ,找出其中最长的回文子序列,并返回该序列的长度。 子序列定义为:不改变剩余字符顺序的情况下,删除某些字符或者不删除任何字符形成的一个序列。 示例 1: 输入:s = "bbbab&…

202409071506,开始写代码,从0开始 验证基本架子

由于视频教程里面 用的VS2105 所以 照抄。开发环境是VS2015 ,WIN10. VS2015 在今天看来是一个很古老的开发环境了,估计都很难找到安装包。(各种安装包:https://www.cnblogs.com/zjoch/p/5694013.html) 用:vs2015.ent_chs.iso (3.88 GB (4,172,560,384 字节))这个安装…

PR出现冲突无法直接解决

举例:存在p-dev 分支,申请合入 master 分支,产生pr 无法直接自动将pr 合入到master中 需要在本地解决 解决:git checkout p-dev,切换分支dev git pull ,更新到最新的 git merge origin master, 此时会出现冲突,通过vscode 或者smartgit 去解决 解决完冲突的文件,需要…

彻底理解字节序

1.基本理论计算机发送数据从内存低地址开始. 计算机接收数据的保存从低地址开始.2.非数值型网络数据传输如上图例子所示,发送端发送了四个字节内容,分别为0x12,0x34,0x56,0x78,假设这四个字节不表示数值例如unsigned int,而是图片内容数据。发送端从低内存地址开始发送四个…

跳跃表

概述 跳跃表(SkipList)是链表加多级索引组成的数据结构。链表的数据结构的查询复条度是 O(N)。为了提高查询效率,可以在链表上加多级索引来实现快速查询。跳跃表不仅能提高搜索性能。也能提高插入和删除操作的性能。索引的层数也叫作跳跃表的高度查找 在跳跃表的结构中会首先从…

Docker 镜像的分层概念

来更深入地理解镜像的概念40.镜像的分层概念 来更深入地理解镜像的概念 ‍ 镜像的分层 镜像,是一种轻量级、可执行的独立软件包,它包含运行某个软件所需的所有内容,我们把应用程序和配置依赖打包好形成一个可交付的运行环境(包括代码、运行时需要的库、环境变量和配置文件等…

prometheus学习笔记之kube-state-metrics

一、kube-state-metrics简介Kube-state-metrics:通过监听 API Server 生成有关资源对象的状态指标,比如 Deployment、Node、Pod,需要注意的是 kube-state-metrics 只是简单的提供一个 metrics 数据, 并不会存储这些指标数据, 所以我 们可以使用 Prometheus 来抓取这些数据然…

事务发件箱模式在 .NET 云原生开发中的应用(基于Aspire)

原文:Transactional Outbox in .NET Cloud Native Development via Aspire 作者:Oleksii Nikiforov总览 这篇文章提供了使用 Aspire、DotNetCore.CAP、Azure Service Bus、Azure SQL、Bicep 和 azd 实现 Outbox 模式的示例。源代码: https://github.com/NikiforovAll/cap-as…