第83步 时间序列建模实战:Catboost回归建模

基于WIN10的64位系统演示

一、写在前面

这一期,我们介绍Catboost回归。

同样,这里使用这个数据:

《PLoS One》2015年一篇题目为《Comparison of Two Hybrid Models for Forecasting the Incidence of Hemorrhagic Fever with Renal Syndrome in Jiangsu Province, China》文章的公开数据做演示。数据为江苏省2004年1月至2012年12月肾综合症出血热月发病率。运用2004年1月至2011年12月的数据预测2012年12个月的发病率数据。

二、Catboost回归

(1)参数解读

无论是回归还是分类,CatBoost的大部分参数都是通用的,但任务的不同性质意味着一些参数可能只在一个任务中有意义。

以下是一些关键参数的简要概述:

(a)通用参数:

learning_rate: 学习率,决定了模型每一步的步长。常用的值为0.01, 0.03, 0.1等。

iterations: 树的数量。

depth: 树的深度。

l2_leaf_reg: L2正则化项的系数。

cat_features: 分类特征的列索引列表。

loss_function: 损失函数。对于分类,常见的是Logloss(二分类)或MultiClass(多分类)。对于回归,常见的是RMSE。

border_count: 用于数值特征的分箱数量。较高的值可能会导致过拟合,较低的值可能会导致欠拟合。

verbose: 显示的训练日志的详细程度。

(b)专用于分类的参数:

classes_count: 在多分类任务中,类别的数量。

class_weights: 各类的权重,用于不平衡分类任务。

auto_class_weights: 用于处理类不平衡的自动权重计算方法。

(c)专用于回归的参数:

scale_pos_weight: 用于不平衡的回归任务。

(d)异同点:

相同点: 大部分参数(如learning_rate, depth, l2_leaf_reg等)在回归和分类任务中都是相同的,并且它们的含义和效果也是一致的。

不同点: 损失函数loss_function是根据任务(回归或分类)来确定的。此外,某些参数(如classes_count和class_weights)仅在分类任务中有意义,而scale_pos_weight更倾向于回归任务。

此外,在使用CatBoost时,建议始终查阅其官方文档,因为该库可能会经常更新,新的参数或功能可能会被添加进来。网址如下:

https://catboost.ai/docs/

(2)单步滚动预测

import pandas as pd
import numpy as np
from sklearn.metrics import mean_absolute_error, mean_squared_error
from catboost import CatBoostRegressor
from sklearn.model_selection import GridSearchCV# 读取数据
data = pd.read_csv('data.csv')# 将时间列转换为日期格式
data['time'] = pd.to_datetime(data['time'], format='%b-%y')# 创建滞后期特征
lag_period = 6
for i in range(lag_period, 0, -1):data[f'lag_{i}'] = data['incidence'].shift(lag_period - i + 1)# 删除包含 NaN 的行
data = data.dropna().reset_index(drop=True)# 划分训练集和验证集
train_data = data[(data['time'] >= '2004-01-01') & (data['time'] <= '2011-12-31')]
validation_data = data[(data['time'] >= '2012-01-01') & (data['time'] <= '2012-12-31')]# 定义特征和目标变量
X_train = train_data[['lag_1', 'lag_2', 'lag_3', 'lag_4', 'lag_5', 'lag_6']]
y_train = train_data['incidence']
X_validation = validation_data[['lag_1', 'lag_2', 'lag_3', 'lag_4', 'lag_5', 'lag_6']]
y_validation = validation_data['incidence']# 初始化 CatBoostRegressor 模型
catboost_model = CatBoostRegressor(verbose=0)# 定义参数网格
param_grid = {'iterations': [50, 100, 150],'learning_rate': [0.01, 0.05, 0.1, 0.5, 1],'depth': [4, 6, 8],'loss_function': ['RMSE']
}# 初始化网格搜索
grid_search = GridSearchCV(catboost_model, param_grid, cv=5, scoring='neg_mean_squared_error')# 进行网格搜索
grid_search.fit(X_train, y_train)# 获取最佳参数
best_params = grid_search.best_params_# 使用最佳参数初始化 CatBoostRegressor 模型
best_catboost_model = CatBoostRegressor(**best_params, verbose=0)# 在训练集上训练模型
best_catboost_model.fit(X_train, y_train)# 对于验证集,我们需要迭代地预测每一个数据点
y_validation_pred = []for i in range(len(X_validation)):if i == 0:pred = best_catboost_model.predict([X_validation.iloc[0]])else:new_features = list(X_validation.iloc[i, 1:]) + [pred[0]]pred = best_catboost_model.predict([new_features])y_validation_pred.append(pred[0])y_validation_pred = np.array(y_validation_pred)# 计算验证集上的MAE, MAPE, MSE 和 RMSE
mae_validation = mean_absolute_error(y_validation, y_validation_pred)
mape_validation = np.mean(np.abs((y_validation - y_validation_pred) / y_validation))
mse_validation = mean_squared_error(y_validation, y_validation_pred)
rmse_validation = np.sqrt(mse_validation)# 计算训练集上的MAE, MAPE, MSE 和 RMSE
y_train_pred = best_catboost_model.predict(X_train)
mae_train = mean_absolute_error(y_train, y_train_pred)
mape_train = np.mean(np.abs((y_train - y_train_pred) / y_train))
mse_train = mean_squared_error(y_train, y_train_pred)
rmse_train = np.sqrt(mse_train)print("Train Metrics:", mae_train, mape_train, mse_train, rmse_train)
print("Validation Metrics:", mae_validation, mape_validation, mse_validation, rmse_validation)

看结果:

(3)多步滚动预测-vol. 1

对于Catboost回归,目标变量y_train不能是多列的DataFrame,所以你们懂的。

(4)多步滚动预测-vol. 2

同上。

(5)多步滚动预测-vol. 3

import pandas as pd
import numpy as np
from catboost import CatBoostRegressor  # 导入CatBoostRegressor
from sklearn.model_selection import GridSearchCV
from sklearn.metrics import mean_absolute_error, mean_squared_error# 数据读取和预处理
data = pd.read_csv('data.csv')
data_y = pd.read_csv('data.csv')
data['time'] = pd.to_datetime(data['time'], format='%b-%y')
data_y['time'] = pd.to_datetime(data_y['time'], format='%b-%y')n = 6for i in range(n, 0, -1):data[f'lag_{i}'] = data['incidence'].shift(n - i + 1)data = data.dropna().reset_index(drop=True)
train_data = data[(data['time'] >= '2004-01-01') & (data['time'] <= '2011-12-31')]
X_train = train_data[[f'lag_{i}' for i in range(1, n+1)]]
m = 3X_train_list = []
y_train_list = []for i in range(m):X_temp = X_trainy_temp = data_y['incidence'].iloc[n + i:len(data_y) - m + 1 + i]X_train_list.append(X_temp)y_train_list.append(y_temp)for i in range(m):X_train_list[i] = X_train_list[i].iloc[:-(m-1)]y_train_list[i] = y_train_list[i].iloc[:len(X_train_list[i])]# 模型训练
param_grid = {'iterations': [50, 100, 150],'learning_rate': [0.01, 0.05, 0.1, 0.5, 1],'depth': [4, 6, 8]
}best_catboost_models = []for i in range(m):grid_search = GridSearchCV(CatBoostRegressor(verbose=0), param_grid, cv=5, scoring='neg_mean_squared_error')  # 使用CatBoostRegressorgrid_search.fit(X_train_list[i], y_train_list[i])best_catboost_model = CatBoostRegressor(**grid_search.best_params_, verbose=0)best_catboost_model.fit(X_train_list[i], y_train_list[i])best_catboost_models.append(best_catboost_model)validation_start_time = train_data['time'].iloc[-1] + pd.DateOffset(months=1)
validation_data = data[data['time'] >= validation_start_time]X_validation = validation_data[[f'lag_{i}' for i in range(1, n+1)]]
y_validation_pred_list = [model.predict(X_validation) for model in best_catboost_models]
y_train_pred_list = [model.predict(X_train_list[i]) for i, model in enumerate(best_catboost_models)]def concatenate_predictions(pred_list):concatenated = []for j in range(len(pred_list[0])):for i in range(m):concatenated.append(pred_list[i][j])return concatenatedy_validation_pred = np.array(concatenate_predictions(y_validation_pred_list))[:len(validation_data['incidence'])]
y_train_pred = np.array(concatenate_predictions(y_train_pred_list))[:len(train_data['incidence']) - m + 1]mae_validation = mean_absolute_error(validation_data['incidence'], y_validation_pred)
mape_validation = np.mean(np.abs((validation_data['incidence'] - y_validation_pred) / validation_data['incidence']))
mse_validation = mean_squared_error(validation_data['incidence'], y_validation_pred)
rmse_validation = np.sqrt(mse_validation)
print("验证集:", mae_validation, mape_validation, mse_validation, rmse_validation)mae_train = mean_absolute_error(train_data['incidence'][:-(m-1)], y_train_pred)
mape_train = np.mean(np.abs((train_data['incidence'][:-(m-1)] - y_train_pred) / train_data['incidence'][:-(m-1)]))
mse_train = mean_squared_error(train_data['incidence'][:-(m-1)], y_train_pred)
rmse_train = np.sqrt(mse_train)
print("训练集:", mae_train, mape_train, mse_train, rmse_train)

结果:

三、数据

链接:https://pan.baidu.com/s/1EFaWfHoG14h15KCEhn1STg?pwd=q41n

提取码:q41n

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/131011.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

简单好用的CHM文件阅读器 CHM Viewer Star最新 for mac

CHM Viewer Star 是一款适用于 Mac 平台的 CHM 文件阅读器软件&#xff0c;支持本地和远程 CHM 文件的打开和查看。它提供了直观易用的界面设计&#xff0c;支持多种浏览模式&#xff0c;如书籍模式、缩略图模式和文本模式等&#xff0c;并提供了丰富的功能和工具&#xff0c;如…

数字人解决方案——ER-NeRF实时对话数字人模型训练与项目部署

前言 1、算法概述 ER-NeRF是基于NeRF用于生成数字人的方法&#xff0c;可以达到实时生成的效果。具体来说&#xff0c;为了提高动态头部重建的准确性&#xff0c;ER-NeRF引入了一种紧凑且表达丰富的基于NeRF的三平面哈希表示法&#xff0c;通过三个平面哈希编码器剪枝空的空间…

STM32:GPIO模拟SPI驱动ADS8361

ADS8361是TI公司开发的一款模拟量输入芯片。ADS8361有四种工作模式&#xff0c;本文主要针对模式三进行通信驱动。官方方案使用两路SPI来通信&#xff0c;一路SPI Master&#xff0c;一路SPI Slave。我在使用STM32主控芯片的两路SPI进行通信的时候&#xff0c;发现只有SPI Mast…

ES6 class类的静态方法static有什么用

在项目中&#xff0c;工具类的封装经常使用静态方法。 // amap.jsimport AMapLoader from amap/amap-jsapi-loader; import { promiseLock } from triascloud/utils; /*** 高德地图初始化工具*/ class AMapHelper {static getAMap window.AMap? window.AMap: promiseLock(AM…

【1】MongoDB的安装以及连接

今天是2023年10月11日&#xff0c;MongoDB最新版本是7.0.2 最近闲着没事学习一下MongoDB这个NoSQL数据库&#xff0c;有时间就顺手记录一下我学习的笔记吧~ 学习笔记来自黑马程序员《MongoDB基础入门到高级进阶&#xff0c;一套搞定mongodb》 配套资料&#xff1a;点此资料链接…

AMEYA360分享:村田电子搭载了Onsemi公司IoT设备专用IC的新Bluetooth® Low Energy模块开始量产

近年来&#xff0c;所有远程监控、远程控制的用例均要求具备可无线连接的电池驱动IoT设备&#xff0c;而长寿命电池与安全的数据通信功能是其关键。为此&#xff0c;在IoT边缘设备的设计方面&#xff0c;最大的课题是要提高功率效率和安全性。 Type 2EG由于无线与内置微处理器两…

vscode 连接ubuntu git下载缓慢

在ubuntu20.04下载&#xff1a; git clone https://github.com/introlab/rtabmap.git src/rtabmap 挂掉情况 export https_proxyhttp://10.10.10.176:7890export http_proxyhttp://10.10.10.176:7890 其中 10.10.10.176是我本机的ip地址&#xff0c;7890是我的代理后几位 如…

idea compile项目正常,启动项目的时候build失败,报“找不到符号”等问题

1、首先往上找&#xff0c;看能不能找到如下报错信息 You aren’t using a compiler supported by lombok, so lombok will not work and has been disabled. 2、这种问题属于lombok编译失败导致&#xff0c;可能原因是依赖jar包没有更新到最新版本 3、解决方案 1&#xff09…

什么是强缓存、协商缓存?

为了减少资源请求次数,加快资源访问速度,浏览器会对资源文件如图片、css文件、js文件等进行缓存,而浏览器缓存策略又分为强缓存和协商缓存,什么是强缓存?什么是协商缓存?两者之间的区别又是什么?接下来本文就带大家深入了解这方面的知识。 强缓存 所谓强缓存,可以理解…

使用postman 调用 Webservice 接口

1. 先在浏览器地址栏 访问你的webService地址 地址格式: http://127.0.0.1:8092/xxxx/ws(这个自己的决定)/xxxxXccv?wsdl 2. post man POST 访问wwebService接口 地址格式: http://127.0.0.1:8092/xxxx/ws(这个自己的决定)/xxxxXccv <soapenv:Envelope xmlns:soapenv…

Excel 快速分析

文章目录 格式化图表汇总计数 表超级表 迷你图 快捷键: Ctrl Q 先选中数据, 再按快捷键或快速分析按钮. 格式化 查看规则: 前提是先在表中添加某种规则, 再全选该表, 这样在查看规则时才会显示出这个规则. 图表 汇总 计数 表 超级表 迷你图

Android 项目增加 res配置

main.res.srcDirs "src/main/res_test" build->android->sourceSets