随机森林超参数的网格优化(机器学习的精华--调参)

随机森林超参数的网格优化(机器学习的精华–调参)

随机森林各个参数对算法的影响

影响力参数
⭐⭐⭐⭐⭐
几乎总是具有巨大影响力
n_estimators(整体学习能力)
max_depth(粗剪枝)
max_features(随机性)
⭐⭐⭐⭐
大部分时候具有影响力
max_samples(随机性)
class_weight(样本均衡)
⭐⭐
可能有大影响力
大部分时候影响力不明显
min_samples_split(精剪枝)
min_impurity_decrease(精剪枝)
max_leaf_nodes(精剪枝)
criterion(分枝敏感度)

当数据量足够大时,几乎无影响
random_state
ccp_alpha(结构风险)

探索的第一步(学习曲线)

先看看n_estimators的学习曲线(当然我们也可以看一下其他参数的学习曲线)

#参数潜在取值,由于现在我们只调整一个参数,因此参数的范围可以取大一些、取值也可以更密集
Option = [1,*range(5,101,5)]#生成保存模型结果的arrays
trainRMSE = np.array([])
testRMSE = np.array([])
trainSTD = np.array([])
testSTD = np.array([])#在参数取值中进行循环
for n_estimators in Option:#按照当下的参数,实例化模型reg_f = RFR(n_estimators=n_estimators,random_state=83)#实例化交叉验证方式,输出交叉验证结果cv = KFold(n_splits=5,shuffle=True,random_state=83)result_f = cross_validate(reg_f,X,y,cv=cv,scoring="neg_mean_squared_error",return_train_score=True,n_jobs=-1)#根据输出的MSE进行RMSE计算train = abs(result_f["train_score"])**0.5test = abs(result_f["test_score"])**0.5#将本次交叉验证中RMSE的均值、标准差添加到arrays中进行保存trainRMSE = np.append(trainRMSE,train.mean()) #效果越好testRMSE = np.append(testRMSE,test.mean())trainSTD = np.append(trainSTD,train.std()) #模型越稳定testSTD = np.append(testSTD,test.std())

定义画图函数

def plotCVresult(Option,trainRMSE,testRMSE,trainSTD,testSTD):#一次交叉验证下,RMSE的均值与std的绘图xaxis = Optionplt.figure(figsize=(8,6),dpi=80)#RMSEplt.plot(xaxis,trainRMSE,color="k",label = "RandomForestTrain")plt.plot(xaxis,testRMSE,color="red",label = "RandomForestTest")#标准差 - 围绕在RMSE旁形成一个区间plt.plot(xaxis,trainRMSE+trainSTD,color="k",linestyle="dotted")plt.plot(xaxis,trainRMSE-trainSTD,color="k",linestyle="dotted")plt.plot(xaxis,testRMSE+testSTD,color="red",linestyle="dotted")plt.plot(xaxis,testRMSE-testSTD,color="red",linestyle="dotted")plt.xticks([*xaxis])plt.legend(loc=1)plt.show()plotCVresult(Option,trainRMSE,testRMSE,trainSTD,testSTD)

在这里插入图片描述

当绘制学习曲线时,我们可以很容易找到泛化误差开始上升、或转变为平稳趋势的转折点。因此我们可以选择转折点或转折点附近的n_estimators取值,例如20。然而,n_estimators会受到其他参数的影响,例如:

  • 单棵决策树的结构更简单时(依赖剪枝时),可能需要更多的树
  • 单棵决策树训练的数据更简单时(依赖随机性时),可能需要更多的树

因此n_estimators的参数空间可以被确定为range(20,100,5),如果你比较保守,甚至可以确认为是range(15,25,5)。

探索第二步(随机森林中的每一棵决策树)

属性.estimators_,查看森林中所有的树

参数参数含义对应属性属性含义
n_estimators树的数量reg.estimators_森林中所有树对象
max_depth允许的最大深度.tree_.max_depth0号树实际的深度
max_leaf_nodes允许的最大
叶子节点量
.tree_.node_count0号树实际的总节点量
min_sample_split分枝所需最小
样本量
.tree_.n_node_samples0号树每片叶子上实际的样本量
min_weight_fraction_leaf分枝所需最小
样本权重
tree_.weighted_n_node_samples0号树每片叶子上实际的样本权重
min_impurity_decrease分枝所需最小
不纯度下降量
.tree_.impurity
.tree_.threshold
0号树每片叶子上的实际不纯度
0号树每个节点分枝后不纯度下降量

可以通过对上述属性的调用查看当前模型每一棵树的各个属性,对我们对于参数范围的选择给予帮助。

正戏开始(网格搜索)

import numpy as np
import pandas as pd
import sklearn
import matplotlib as mlp
import matplotlib.pyplot as plt
import time #计时模块time
from sklearn.ensemble import RandomForestRegressor as RFR
from sklearn.model_selection import cross_validate, KFold, GridSearchCV
# 定义RMSE函数
def RMSE(cvresult,key):return (abs(cvresult[key])**0.5).mean()
#导入波士顿房价数据
data = pd.read_csv(r"D:\Pythonwork\datasets\House Price\train_encode.csv",index_col=0)
#查看数据
data.head()
X = data.iloc[:,:-1]
y = data.iloc[:,-1]

在这里插入图片描述

Step 1.建立benchmark

# 定义回归器
reg = RFR(random_state=83)
# 进行5折交叉验证
cv = KFold(n_splits=5,shuffle=True,random_state=83)result_pre_adjusted = cross_validate(reg,X,y,cv=cv,scoring="neg_mean_squared_error",return_train_score=True,verbose=True,n_jobs=-1)
#分别查看训练集和测试集在调参之前的RMSE
RMSE(result_pre_adjusted,"train_score")
RMSE(result_pre_adjusted,"test_score")

结果分别是
11177.272008319653
30571.26665524217

Step 2.创建参数空间

param_grid_simple = {"criterion": ["squared_error","poisson"], 'n_estimators': [*range(20,100,5)], 'max_depth': [*range(10,25,2)], "max_features": ["log2","sqrt",16,32,64,"auto"], "min_impurity_decrease": [*np.arange(0,5,10)]}

Step 3.实例化用于搜索的评估器、交叉验证评估器与网格搜索评估器

#n_jobs=4/8,verbose=True
reg = RFR(random_state=1412,verbose=True,n_jobs=-1)
cv = KFold(n_splits=5,shuffle=True,random_state=83)
search = GridSearchCV(estimator=reg,param_grid=param_grid_simple,scoring = "neg_mean_squared_error",verbose = True,cv = cv,n_jobs=-1)

Step 4.训练网格搜索评估器

#=====【TIME WARNING: 15mins】=====#   当然博主的电脑比较慢
start = time.time()
search.fit(X,y)
print(time.time() - start)

Step 5.查看结果

search.best_estimator_# 直接使用最优参数进行建模
ad_reg = RFR(n_estimators=85, max_depth=23, max_features=16, random_state=83)cv = KFold(n_splits=5,shuffle=True,random_state=83)
result_post_adjusted = cross_validate(ad_reg,X,y,cv=cv,scoring="neg_mean_squared_error",return_train_score=True,verbose=True,n_jobs=-1)
#查看调参后的结果
RMSE(result_post_adjusted,"train_score")
RMSE(result_post_adjusted,"test_score")

得出结果
11000.81099038192
28572.070208366855

调参前后对比

#默认值下随机森林的RMSE
xaxis = range(1,6)
plt.figure(figsize=(8,6),dpi=80)
#RMSE
plt.plot(xaxis,abs(result_pre_adjusted["train_score"])**0.5,color="green",label = "RF_pre_ad_Train")
plt.plot(xaxis,abs(result_pre_adjusted["test_score"])**0.5,color="green",linestyle="--",label = "RF_pre_ad_Test")
plt.plot(xaxis,abs(result_post_adjusted["train_score"])**0.5,color="orange",label = "RF_post_ad_Train")
plt.plot(xaxis,abs(result_post_adjusted["test_score"])**0.5,color="orange",linestyle="--",label = "RF_post_ad_Test")
plt.xticks([1,2,3,4,5])
plt.xlabel("CVcounts",fontsize=16)
plt.ylabel("RMSE",fontsize=16)
plt.legend()
plt.show()

在这里插入图片描述
不难发现,网格搜索之后的模型过拟合程度减轻,且在训练集与测试集上的结果都有提高,可以说从根本上提升了模型的基础能力。我们还可以根据网格的结果继续尝试进行其他调整,来进一步降低模型在测试集上的RMSE。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/453301.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

1.0 Hadoop 教程

Hadoop 是一个开源的分布式计算和存储框架,由 Apache 基金会开发和维护。 Hadoop 为庞大的计算机集群提供可靠的、可伸缩的应用层计算和存储支持,它允许使用简单的编程模型跨计算机群集分布式处理大型数据集,并且支持在单台计算机到几千台计…

一站式SpringBoot学习平台:让编程变得轻松有趣!

介绍:Spring Boot是一个开源的Java框架,旨在简化Spring应用程序的开发和部署过程。 Spring Boot由Pivotal团队设计并推出,它的核心优势在于极大地简化了传统Spring应用的初始搭建和开发流程。具体来说,Spring Boot的主要特点包括&…

27. 云原生流量治理之kubesphere灰度发布

云原生专栏大纲 文章目录 灰度发布介绍灰度发布策略KubeSphere中恢复发布策略蓝绿部署金丝雀发布流量镜像 灰度发布实战部署自制应用金丝雀发布创建金丝雀发布任务测试金丝雀发布情况 蓝绿部署创建蓝绿部署测试蓝绿部署情况 流量镜像创建流量进行任务测试流量镜像情况 灰度发布…

PHP安装后错误处理

一:问题 安装PHP后提示错误如下 二:解决 1:Warning: Module mysqli already loaded in Unknown on line 0解决 原因:通过php.ini配置文件开启mysqli扩展的时候,开启了多次 解决:将php.ini配置文件中多个…

飞讯SRM项目为广州某五金机械制造公司带来显著效益

前言 飞讯工业互联为广州某五金机械制造公司提供的SRM供应商管理解决方案已实现整体上线。在此过程中,得到了该客户采购、仓库、财务等部门领导与员工的鼎力支持和协作,确保了项目的顺利进行。项目背景 在项目启动初期,飞讯项目团队深入实地…

算法练习-环形链表(思路+流程图+代码)

难度参考 难度:中等 分类:链表 难度与分类由我所参与的培训课程提供,但需要注意的是,难度与分类仅供参考。且所在课程未提供测试平台,故实现代码主要为自行测试的那种,以下内容均为个人笔记,旨在…

ubuntu20.04安装最新版nginx

前言 记录下ubuntu服务器安装nginx 步骤 安装最新版本的 Nginx 可以通过添加 Nginx 官方的软件仓库并更新软件包信息。以下是在 Ubuntu 20.04 上安装最新版本 Nginx 的步骤: 添加 Nginx 软件仓库: 打开终端并执行以下命令: sudo sh -c echo…

GPT原始论文:Improving Language Understanding by Generative Pre-Training论文翻译

1 摘要 自然语理解包括文本蕴含、问题回答、语义相似性评估和文档分类等一系列多样化的任务。尽管大量未标注的文本语料库很丰富,但用于学习这些特定任务的标注数据却很稀缺,这使得基于区分性训练的模型难以充分发挥作用。我们展示了通过在多样化的未标…

《幻兽帕鲁》解锁基地和工作帕鲁数量上限

帕鲁私服的游戏参数通常可通过配置文件 PalWorldSettings.ini 来进行修改,然而这个配置文件有个别参数对游戏不生效,让人很是头疼。没错!我说的就是终端最大的帕鲁数量! 其实还有另外一种更加高级的参数修改方式,那就…

CTFshow web(php特性 105-108)

web105 <?php /* # -*- coding: utf-8 -*- # Author: Firebasky # Date: 2020-09-16 11:25:09 # Last Modified by: h1xa # Last Modified time: 2020-09-28 22:34:07 */ highlight_file(__FILE__); include(flag.php); error_reporting(0); $error你还想要flag嘛&…

spring boot打完jar包后使用命令行启动,提示.jar 中没有主清单属性

在对springBoot接口中间件开发完毕后&#xff0c;本地启动没有任何问题&#xff0c;在使用package命令打包也没异常&#xff0c;打完包后使用命令行&#xff1a;java -jar xxx.jar启动发现报异常&#xff1a;xxx.jar 中没有主清单属性&#xff0c;具体解决方法如下&#xff1a;…

让cgteamwork自动为Houdini载入相机,角色道具的abc文件

一 需求 最近接到个需求&#xff1a;在创建EFX文件时&#xff0c;自动加载动画出的缓存abc文件相机&#xff0c; 不用手动一个个的载入&#xff0c;还容易出错 ABC文件自动导入到Houndini里 二 过程/效果 在CGTeamwork里打开对应的镜头&#xff0c;下面的文件列表显示相机和角…