人工智能|机器学习——强大的 Scikit-learn 可视化让模型说话

一、显示 API 简介

使用 utils.discovery.all_displays 查找可用的 API。

Sklearn 的utils.discovery.all_displays可以让你看到哪些类可以使用。

from sklearn.utils.discovery import all_displays
displays = all_displays()
displays

Scikit-learn (sklearn) 总是会在新版本中添加 "Display "API,因此这里可以了解你的版本中有哪些可用的 API 。例如,在我的 Scikit-learn 1.4.0 中,就有这些类:

[('CalibrationDisplay', sklearn.calibration.CalibrationDisplay),('ConfusionMatrixDisplay',sklearn.metrics._plot.confusion_matrix.ConfusionMatrixDisplay),('DecisionBoundaryDisplay',sklearn.inspection._plot.decision_boundary.DecisionBoundaryDisplay),('DetCurveDisplay', sklearn.metrics._plot.det_curve.DetCurveDisplay),('LearningCurveDisplay', sklearn.model_selection._plot.LearningCurveDisplay),('PartialDependenceDisplay',sklearn.inspection._plot.partial_dependence.PartialDependenceDisplay),('PrecisionRecallDisplay',sklearn.metrics._plot.precision_recall_curve.PrecisionRecallDisplay),('PredictionErrorDisplay',sklearn.metrics._plot.regression.PredictionErrorDisplay),('RocCurveDisplay', sklearn.metrics._plot.roc_curve.RocCurveDisplay),('ValidationCurveDisplay',sklearn.model_selection._plot.ValidationCurveDisplay)]

二、显示决策边界

使用 inspection.DecisionBoundaryDisplay 显示决策边界

如果使用 Matplotlib 来绘制,会很麻烦:

  • 使用 np.linspace 设置坐标范围;

  • 使用 plt.meshgrid 计算网格;

  • 使用 plt.contourf 绘制决策边界填充;

  • 然后使用 plt.scatter 绘制数据点。

现在,使用 inspection.DecisionBoundaryDisplay 可以简化这一过程:

from sklearn.inspection import DecisionBoundaryDisplay
from sklearn.datasets import load_iris
from sklearn.svm import SVC
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import StandardScaler
import matplotlib.pyplot as pltiris = load_iris(as_frame=True)
X = iris.data[['petal length (cm)', 'petal width (cm)']]
y = iris.targetsvc_clf = make_pipeline(StandardScaler(), SVC(kernel='linear', C=1))
svc_clf.fit(X, y)display = DecisionBoundaryDisplay.from_estimator(svc_clf, X, grid_resolution=1000,xlabel="Petal length (cm)",ylabel="Petal width (cm)")
plt.scatter(X.iloc[:, 0], X.iloc[:, 1], c=y, edgecolors='w')
plt.title("Decision Boundary")
plt.show()

使用 DecisionBoundaryDisplay 绘制三重分类模型。

请记住,Display 只能绘制二维数据,因此请确保数据只有两个特征或更小的维度。

三、概率校准

要比较分类模型,使用 calibration.CalibrationDisplay 进行概率校准,概率校准曲线可以显示模型预测的可信度。

CalibrationDisplay使用的是模型的 predict_proba。如果使用支持向量机,需要将 probability 设为 True:

from sklearn.calibration import CalibrationDisplay
from sklearn.model_selection import train_test_split
from sklearn.datasets import make_classification
from sklearn.ensemble import HistGradientBoostingClassifierX, y = make_classification(n_samples=1000,n_classes=2, n_features=5,random_state=42)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)
proba_clf = make_pipeline(StandardScaler(), SVC(kernel="rbf", gamma="auto", C=10, probability=True))
proba_clf.fit(X_train, y_train)CalibrationDisplay.from_estimator(proba_clf, X_test, y_test)hist_clf = HistGradientBoostingClassifier()
hist_clf.fit(X_train, y_train)ax = plt.gca()
CalibrationDisplay.from_estimator(hist_clf,X_test, y_test,ax=ax)
plt.show()

CalibrationDisplay.

四、显示混淆矩阵

在评估分类模型和处理不平衡数据时,需要查看精确度和召回率。使用 metrics.ConfusionMatrixDisplay绘制混淆矩阵(TP、FP、TN 和 FN)。

from sklearn.datasets import fetch_openml
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import ConfusionMatrixDisplaydigits = fetch_openml('mnist_784', version=1)
X, y = digits.data, digits.target
rf_clf = RandomForestClassifier(max_depth=5, random_state=42)
rf_clf.fit(X, y)ConfusionMatrixDisplay.from_estimator(rf_clf, X, y)
plt.show()

五、Roc 和 Det 曲线

因为经常并列评估Roc 和 Det 曲线,因此把metrics.RocCurveDisplay 和 metrics.DetCurveDisplay两个图表放在一起。

  • RocCurveDisplay比较模型的 TPR 和 FPR。对于二分类,希望 FPR 低而 TPR 高,因此左上角是最佳位置。Roc 曲线向这个角弯曲。

由于 Roc 曲线停留在左上角附近,右下角是空的,因此很难看到模型差异。

  • 使用 DetCurveDisplay 绘制一条带有 FNR 和 FPR 的 Det 曲线。它使用了更多空间,比 Roc 曲线更清晰。Det 曲线的最佳点是左下角。

from sklearn.metrics import RocCurveDisplay
from sklearn.metrics import DetCurveDisplayX, y = make_classification(n_samples=10_000, n_features=5,n_classes=2, n_informative=2)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42,stratify=y)classifiers = {"SVC": make_pipeline(StandardScaler(), SVC(kernel="linear", C=0.1, random_state=42)),"Random Forest": RandomForestClassifier(max_depth=5, random_state=42)
}fig, [ax_roc, ax_det] = plt.subplots(1, 2, figsize=(10, 4))
for name, clf in classifiers.items():clf.fit(X_train, y_train)RocCurveDisplay.from_estimator(clf, X_test, y_test, ax=ax_roc, name=name)DetCurveDisplay.from_estimator(clf, X_test, y_test, ax=ax_det, name=name)

六、调整阈值

在数据不平衡的情况下,希望调整召回率和精确度。可以使用使用 metrics.PrecisionRecallDisplay 调整阈值

  • 对于电子邮件欺诈,需要高精确度。

  • 而对于疾病筛查,则需要高召回率来捕获更多病例。

那么可以调整阈值,但调整多少才合适呢?因此可以使用metrics.PrecisionRecallDisplay 来绘制相关图表。

from xgboost import XGBClassifier
from sklearn.datasets import load_wine
from sklearn.metrics import PrecisionRecallDisplaywine = load_wine()
X, y = wine.data[wine.target<=1], wine.target[wine.target<=1]
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3,stratify=y, random_state=42)xgb_clf = XGBClassifier()
xgb_clf.fit(X_train, y_train)PrecisionRecallDisplay.from_estimator(xgb_clf, X_test, y_test)
plt.show()

这表明可以按照 Scikit-learn 的设计绘制模型,就像这里的 xgboost

七、回归模型评估

Scikit-learn 的 metrics.PredictionErrorDisplay 绘制残差图可以帮助评估回归模型。

from sklearn.svm import SVR
from sklearn.metrics import PredictionErrorDisplayrng = np.random.default_rng(42)
X = rng.random(size=(200, 2)) * 10
y = X[:, 0]**2 + 5 * X[:, 1] + 10 + rng.normal(loc=0.0, scale=0.1, size=(200,))reg = make_pipeline(StandardScaler(), SVR(kernel='linear', C=10))
reg.fit(X, y)fig, axes = plt.subplots(1, 2, figsize=(8, 4))
PredictionErrorDisplay.from_estimator(reg, X, y, ax=axes[0], kind="actual_vs_predicted")
PredictionErrorDisplay.from_estimator(reg, X, y, ax=axes[1], kind="residual_vs_predicted")
plt.show()

图表展示预测值与实际值的比较,左图适合线性回归。然而,并非所有数据都是完全线性的,因此,请参考右图。右图展示了实际值与预测值的差异,即残差图。残差图的香蕉形状暗示我们的数据可能不适合线性回归。考虑将核函数从"线性" 转换为 "rbf" ,残差图会更好。

reg = make_pipeline(StandardScaler(), SVR(kernel='rbf', C=10))

八、绘制学习曲线

学习曲线主要研究模型的泛化效果和训练测试数据之间的差异或偏差。接下来,使用 model_selection.LearningCurveDisplay 绘制学习曲线,并比较了决策树分类器和梯度提升分类器在不同训练数据下的表现。

from sklearn.tree import DecisionTreeClassifier
from sklearn.ensemble import GradientBoostingClassifier
from sklearn.model_selection import LearningCurveDisplayX, y = make_classification(n_samples=1000, n_classes=2, n_features=10,n_informative=2, n_redundant=0, n_repeated=0)tree_clf = DecisionTreeClassifier(max_depth=3, random_state=42)
gb_clf = GradientBoostingClassifier(n_estimators=50, max_depth=3, tol=1e-3)train_sizes = np.linspace(0.4, 1.0, 10)
fig, axes = plt.subplots(1, 2, figsize=(10, 4))
LearningCurveDisplay.from_estimator(tree_clf, X, y,train_sizes=train_sizes,ax=axes[0],scoring='accuracy')
axes[0].set_title('DecisionTreeClassifier')
LearningCurveDisplay.from_estimator(gb_clf, X, y,train_sizes=train_sizes,ax=axes[1],scoring='accuracy')
axes[1].set_title('GradientBoostingClassifier')
plt.show()

从图中可以看出,虽然基于树的 GradientBoostingClassifier 在训练数据上保持了良好的准确性,但其在测试数据上的泛化能力与 DecisionTreeClassifier 相比并无明显优势。

九、可视化参数调整

为了改善泛化效果差的模型,可以尝试通过调整正则化参数来提高性能。传统的方法是使用 "GridSearchCV" 或 "Optuna" 等工具来实现模型调整,然而这些方法只能找出整体表现最佳的模型,且调整过程并不直观。如果需要调整特定参数以测试其对模型的影响,建议使用 model_selection.ValidationCurveDisplay 来直观地观察模型在参数变化时的表现。

from sklearn.model_selection import ValidationCurveDisplay
from sklearn.linear_model import LogisticRegressionparam_name, param_range = "C", np.logspace(-8, 3, 10)
lr_clf = LogisticRegression()ValidationCurveDisplay.from_estimator(lr_clf, X, y,param_name=param_name,param_range=param_range,scoring='f1_weighted',cv=5, n_jobs=-1)
plt.show()

十、讨论

尝试过所有这些显示后,我必须承认一些遗憾:

  • 最大的遗憾是这些 API 大多数缺乏详细的教程,这可能也是与 Scikit-learn 的详尽文档相比不为人知的原因。

  • 这些应用程序接口散布在不同的软件包中,因此很难从一个地方引用它们。

  • 代码仍然非常基础。通常需要将其与 Matplotlib 的 API 搭配使用才能完成工作。一个典型的例子是 "DecisionBoundaryDisplay",在绘制决策边界后,还需要使用 Matplotlib 来绘制数据分布。

  • 它们很难扩展。除了一些验证参数的方法外,很难用工具或方法来简化模型的可视化过程;最终需要重写了很多东西。

这些 API 希望得到更多关注,并且随着版本升级,可视化 API 也能更易用。

在机器学习中,用可视化方式解释模型与训练模型同样重要。

本文介绍了当前版本 scikit-learn 中的各种绘图 API,利用这些 API,可以简化一些 Matplotlib 代码,缓解学习曲线,并简化模型评估过程。由于篇幅有限,未对每个 API 进行详细介绍。如果有兴趣,可以查看 [官方文档:https://scikit-learn.org/stable/visualizations.html?ref=dataleadsfuture.com] 了解更多详情。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/675767.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

小滴大课训练营-微服务架构-海量数据商用短链平台项目大课

小滴大课训练营-微服务架构-海量数据商用短链平台项目大课【2023最新升级版】 一、海量数据大课架构课程简介 海量数据项目大课&#xff0c;是小滴课堂准备耗时近1年的综合性实战项目&#xff0c;从0~1开发商用短链平台&#xff1b; 从产品需求文档定制出发&#xff0c;掌握核…

orbslam2基础

目录 一、 内容概要二、 orbslam2基础介绍三 、 orbslam2安装3.1 安装依赖3.2 安装orbslam23.3 下载Kitee数据集 四、 进行ORBSLAM2仿真五、 心得体会六、 参考链接 一、 内容概要 orbslam2基础介绍orbslam2安装orbslam2使用案例&#xff1a;orbslam2kitti数据集序列图像 二、…

ESP8266固件烧写

概述 因为手上有块闲置的ESP8266开发板&#xff0c;想着拿来倒腾一下WIFI探针&#xff0c;倒腾了一阵测试成功&#xff0c;博文记录用以备忘 硬件 ESP8266 NodeMCU 环境 Windows 11 步骤 1.下载esp32_win32_msys2_environment_and_toolchain-20181001.zip 2.下载xtensa…

scikit-learn实现单因子线性回归模型

1.是什么&#xff1a; 针对机器学习提供了数据预处理&#xff0c;分类&#xff0c;回归等常见算法的框架 2.基于scikit-learn求解线性回归的问题&#xff1a; 2.1.求解a&#xff0c;b对新数据进行预测&#xff1a; 2.2评估模型表现&#xff08;y和y’的方差MSE&#xff09;…

为什么会查询不到DNS信息?怎么排查?

DNS&#xff08;域名系统&#xff09;是将域名转换为相应 IP 地址的关键系统。查询 DNS 信息具有重要作用&#xff0c;通过查询 DNS 信息&#xff0c;我们可以知道域名对应的 IP 地址&#xff0c;这是最主要的信息&#xff0c;使设备能与目标服务器进行通信&#xff1b;其次是域…

Innodb实现的索引

概念 一种用于提高数据库查询性能的有序的数据结构。通过使用索引&#xff0c;数据库引擎可以快速定位到存储表中的特定数据&#xff0c;而不必逐行遍历整个表。在处理大量数据的时候可以显著加快数据检索的速度。 通过索引列队数据进行排序&#xff0c;降低数据排序的成本&a…

商家制作微信小程序有什么好处?微信小程序的制作有哪些步骤和流程

微信小程序全面指南 微信小程序是微信生态系统中一项革命性的功能&#xff0c;为希望与庞大的微信用户群体互动的企业提供了独特的融合便捷性和功能性的体验。本全面指南深入探讨了微信小程序的世界&#xff0c;强调了其重要性、工作原理以及实际用例&#xff0c;特别是针对企…

Llama 3 模型上下文长度扩展至1048K

AI苏妲己&#xff1a; Gradient AI 近日宣布&#xff0c;通过其创新的渐进式训练方法&#xff0c;成功将 Llama 3 系列模型的上下文长度扩展至超过100万令牌。 Llama 3 模型最初设计用于处理8000个令牌的上下文长度&#xff0c;这大约相当于6000字或10页文档。为了适应更长的…

LLMs之GPT4ALL:GPT4ALL的简介、安装和使用方法、案例应用之详细攻略

LLMs之GPT4ALL&#xff1a;GPT4ALL的简介、安装和使用方法、案例应用之详细攻略 目录 GPT4ALL的简介 0、新功能 1、特点 2、功能 3、技术报告 GPT4ALL的安装和使用方法 1、安装 2、使用方法 GPT4ALL的案例应用 LLMs之LLaMA3&#xff1a;基于GPT4ALL框架对LLaMA-3实现…

python turtle

名字动画 #SquareSpiral1.py import turtle t turtle.Pen() turtle.bgcolor("black")my_nameturtle.textinput("输入你的姓名","你的名字&#xff1f;") colors["red","yellow","purple","blue"] for…

开源代码分享(28)-含分布式光伏的配电网集群划分和集群电压协调控制

参考文献&#xff1a; [1] Chai Y , Guo L , Wang C ,et al.Network Partition and Voltage Coordination Control for Distribution Networks With High Penetration of Distributed PV Units[J].IEEE Transactions on Power Systems, 2018:3396-3407.DOI:10.1109/TPWRS.2018…

oracle 清理 trace 和 alert 日志文件

某天,发现磁盘空间被占满了&#xff0c;继续查询发现是 oracle 的日志文件占满了磁盘空间 其中: trace文件有35G, alert 有23G 目录地址是: diag/rdbms/orcl/orcl/trace, diag/rdbms/orcl/orcl/alert 都是在 oracle 目录下的 diag 目录内部 # 可以使用 以下命令对目录大小进行排…