机器学习 sklearn 中的超参数搜索方法

✅作者简介:人工智能专业本科在读,喜欢计算机与编程,写博客记录自己的学习历程。
🍎个人主页:小嗷犬的个人主页
🍊个人网站:小嗷犬的技术小站
🥭个人信条:为天地立心,为生民立命,为往圣继绝学,为万世开太平。


本文目录

    • 超参数搜索
    • 默认参数
    • GridSearchCV
    • RandomizedSearchCV
    • HalvingGridSearchCV
    • HalvingRandomSearchCV


超参数搜索

在建模时模型的超参数往往会对精度造成一定影响,而设置和调整超参数的取值,往往称为调参

在实践中调参往往依赖人工来进行设置调整范围,然后使用机器在超参数范围内进行搜索,找到最优的超参数组合。

在 sklearn 中,提供了四种超参数搜索方法:

  • GridSearchCV
  • RandomizedSearchCV
  • HalvingGridSearchCV
  • HalvingRandomSearchCV

默认参数

为了方便起见,我们先定义一个默认参数的模型,用于后续的超参数搜索。

# 导入相关库
import random
import numpy as np
import pandas as pd
from sklearn import datasets
from sklearn.experimental import enable_halving_search_cv
from sklearn.model_selection import (train_test_split,GridSearchCV,RandomizedSearchCV,HalvingGridSearchCV,HalvingRandomSearchCV,
)
from sklearn.ensemble import RandomForestRegressor# 设置随机种子
seed = 1
random.seed(seed)
np.random.seed(seed)# 加载数据集
data = datasets.load_diabetes()
X = pd.DataFrame(data.data, columns=data.feature_names)
y = pd.Series(data.target)# 划分数据集
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=seed)# 使用默认参数模型进行分类并评分
reg = RandomForestRegressor(random_state=seed)
reg.fit(X_train, y_train)
print(round(reg.score(X_test, y_test), 6))

最终默认参数的模型在测试集上的 R 2 R^2 R2 分数约为 0.269413

GridSearchCV

GridSearchCV 是一种网格搜索超参数的方法,它会遍历所有的超参数组合,然后评估模型的性能,最终选择性能最好的一组超参数。

# 设置超参数搜索范围
param_grid = {"max_depth": [2, 4, 5, 6, 7],"min_samples_leaf": [1, 2, 3],"min_weight_fraction_leaf": [0, 0.1],"min_impurity_decrease": [0, 0.1, 0.2]
}# 使用 GridSearchCV 进行超参数搜索
reg = GridSearchCV(RandomForestRegressor(random_state=seed),param_grid,cv=5,n_jobs=-1,
)
reg.fit(X_train, y_train)# 输出模型在测试集上的 R 方分数
print(round(reg.score(X_test, y_test), 6))# 输出最优超参数组合
print(reg.best_params_)

在这个超参数搜索空间中,一共有 5 * 3 * 2 * 3 = 90 种超参数组合。
最终超参数搜索后的模型在测试集上的 R 2 R^2 R2 分数约为 0.287919
最优超参数组合为:

{'max_depth': 4,'min_impurity_decrease': 0.2,'min_samples_leaf': 1,'min_weight_fraction_leaf': 0
}

RandomizedSearchCV

RandomizedSearchCV 是一种随机搜索超参数的方法,它的使用方法与 GridSearchCV 类似,但是它不会遍历所有的超参数组合,而是在超参数的取值范围内随机选择一组超参数进行训练,然后评估模型的性能,最终选择性能最好的一组超参数。

# 使用 RandomizedSearchCV 进行超参数搜索
reg = RandomizedSearchCV(RandomForestRegressor(random_state=seed),param_grid,cv=5,n_jobs=-1,n_iter=20,  # 设置迭代次数random_state=seed,
)
reg.fit(X_train, y_train)# 输出模型在测试集上的 R 方分数
print(round(reg.score(X_test, y_test), 6))# 输出最优超参数组合
print(reg.best_params_)

RandomizedSearchCV 一共进行了 20 次迭代,即尝试了 20 组超参数组合。
最终超参数搜索后的模型在测试集上的 R 2 R^2 R2 分数约为 0.26959
最优超参数组合为:

{'min_weight_fraction_leaf': 0,'min_samples_leaf': 1,'min_impurity_decrease': 0.1,'max_depth': 6
}

HalvingGridSearchCV

HalvingGridSearchCVGridSearchCV 类似,但在迭代的过程中采用减半超参数搜索空间的方法,以此来减少超参数搜索的时间。

在搜索的最开始,HalvingGridSearchCV 使用很少的数据样本来在完整的超参数搜索空间中进行搜索,筛选其中最优的超参数,之后再增加数据进行进一步筛选。

# 使用 HalvingGridSearchCV 进行超参数搜索
reg = HalvingGridSearchCV(RandomForestRegressor(random_state=seed),param_grid,cv=5,n_jobs=-1,random_state=seed,
)
reg.fit(X_train, y_train)# 输出模型在测试集上的 R 方分数
print(round(reg.score(X_test, y_test), 6))# 输出最优超参数组合
print(reg.best_params_)

最终超参数搜索后的模型在测试集上的 R 2 R^2 R2 分数约为 0.287919
最优超参数组合为:

{'max_depth': 4,'min_impurity_decrease': 0.2,'min_samples_leaf': 1,'min_weight_fraction_leaf': 0
}

可以看到,HalvingGridSearchCV 得到的最优超参数组合与 GridSearchCV 得到的最优超参数组合相同。

HalvingRandomSearchCV

HalvingRandomSearchCVHalvingGridSearchCV 类似,都是逐步增加样本数量,减少超参数组合,但是 HalvingRandomSearchCV 每次生成的超参数组合是随机的。

# 使用 HalvingRandomSearchCV 进行超参数搜索
reg = HalvingRandomSearchCV(RandomForestRegressor(random_state=seed),param_grid,cv=5,n_jobs=-1,random_state=seed,
)
reg.fit(X_train, y_train)# 输出模型在测试集上的 R 方分数
print(round(reg.score(X_test, y_test), 6))# 输出最优超参数组合
print(reg.best_params_)

最终超参数搜索后的模型在测试集上的 R 2 R^2 R2 分数约为 0.26959
最优超参数组合为:

{'min_weight_fraction_leaf': 0,'min_samples_leaf': 1,'min_impurity_decrease': 0.1,'max_depth': 6
}

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/256001.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【Docker二】docker网络模式、网络通信、数据管理

目录 一、docker网络模式: 1、概述 2、docker网络实现原理: 3、docker的网络模式: 3.1、bridge模式: 3.2、host模式: 3.3、container模式: 3.4、none模式: 3.5、自定义网络模式&#xf…

c语言指针详解下

指针下 1 指针与字符串 int main01(){//指针与字符串char a[] "helloworld";//定义了一个字符数组,字符数组内容为helloworld\0//定义一个指针用来保存数组首元素的地址char * p a;printf("%s\n",p);//%s打印一个字符串,要的是首个字符的地址printf(…

MySql概述及其性能说明

MySQL是一种开源的关系型数据库管理系统,由瑞典MySQL AB公司开发,现属于Oracle公司。MySQL是最流行的开源数据库之一,被广泛地应用于Web开发中。MySQL提供了一个高度稳定可靠的数据存储解决方案,同时也可以很容易地跨平台运行。My…

软件中提示找不到msvcp140.dll无法继续执行代码,运行打开软件怎么弄

今天打开CAD提示找不到msvcp140.dll,这是一个很常见的问题,可能是由于系统缺少这个重要的动态链接库文件导致的。本文将介绍五个解决方法,以及msvcp140.dll文件的作用和丢失原因。 一、msvcp140.dll文件的作用 msvcp140.dll是Microsoft Vis…

一对多聊天

服务端 import java.io.*; import java.net.*; import java.util.ArrayList; public class Server{public static ServerSocket server_socket;public static ArrayList<Socket> socketListnew ArrayList<Socket>(); public static void main(String []args){try{…

基于互一致性学习的半监督医学图像分割

Mutual consistency learning for semi-supervised medical image segmentation 基于互一致性学习的半监督医学图像分割背景贡献半监督学习 其它缓解过拟合的方法实验方法损失函数Thinking 基于互一致性学习的半监督医学图像分割 Medical Image Analysis 81 (2022) 102530 背…

Spring AOP带你了解整个流程,让面试官只能仰望

文章目录 一&#xff0c;介绍二&#xff0c;什么是JDK动态代理以及CGLIB代理三&#xff0c;源码流程图小结 一&#xff0c;介绍 提示&#xff1a;解析 A[“JavaConfig”] --> B[“EnableAspectJAutoProxy”]&#xff1a; 在Spring配置中&#xff0c;启用AspectJ自动代理功能…

【尘缘送书第五期】Java程序员:学习与使用多线程

目录 1 多线程对于Java的意义2 为什么Java工程师必须掌握多线程3 Java多线程使用方式4 如何学好Java多线程5 参与方式 摘要&#xff1a;互联网的每一个角落&#xff0c;无论是大型电商平台的秒杀活动&#xff0c;社交平台的实时消息推送&#xff0c;还是在线视频平台的流量洪峰…

【C语言快速学习基础篇】之二控制语句、循环语句、隐式转换

文章目录 一、控制语句1.1、for循环1.2、while循环1.3、注意&#xff1a;for循环和while循环使用上面等同1.4、do while循环1.4.1while条件成立时1.4.2、while条件不成立时 C语言介绍 C语言是一门面向过程的计算机编程语言&#xff0c;与C、C#、Java等面向对象编程语言有所不同…

【微服务】springboot整合quartz使用详解

目录 一、前言 二、quartz介绍 2.1 quartz概述 2.2 quartz优缺点 2.3 quartz核心概念 2.3.1 Scheduler 2.3.2 Trigger 2.3.3 Job 2.3.4 JobDetail 2.4 Quartz作业存储类型 2.5 适用场景 三、Cron表达式 3.1 Cron表达式语法 3.2 Cron表达式各元素说明 3.3 Cron表达…

配置BFD状态与接口状态联动示例

1、BFD检测IP链路。 在IP链路上建立BFD会话&#xff0c;利用BFD检测机制快速检测故障。BFD检测IP链路支持单跳检测和多跳检测&#xff1a; BFD单跳检测是指对两个直连系统进行IP连通性检测&#xff0c;“单跳”是IP链路的一跳。 BFD多跳检测是指BFD可以检测两个系统间的任意路…

UDP通讯

本章节主要讲解的是TCP和UDP两种通信方式它们都有着自己的优点和缺点 这两种通讯方式不通的地方就是TCP是一对一通信 UDP是一对多的通信方式 接下来会一一讲解 UDP通信 主要的方向是一对多通信方式 UDP通信就是一下子可以通信多个对象&#xff0c;这就是UDP对比TCP的优势&am…