模型的泛化性能度量:方法、比较与实现

news/2025/3/29 10:53:12/文章来源:https://www.cnblogs.com/wang_yb/p/18792713

在机器学习领域,模型的泛化性能度量是评估模型在未知数据上表现的关键环节。

通过合理的性能度量,不仅能了解模型的优劣,还能为模型的优化和选择提供科学依据。

本文将深入探讨泛化性能度量的重要性、各种度量方法、它们之间的区别与适用场景,并通过scikit-learn代码示例来展示如何实现这些度量方法。

1. 为什么要做泛化性能度量

模型的最终目标是在面对新数据时能够准确、稳定地进行预测或分类。

然而,在训练过程中,模型可能会出现过拟合(对训练数据拟合得过于紧密,导致在新数据上表现不佳)或欠拟合(未能充分学习数据特征)等问题。

泛化性能度量能够帮助我们:

  1. 客观评估模型优劣:通过量化的指标,准确判断模型在未知数据上的表现,避免主观臆断。
  2. 指导模型优化:明确模型的不足之处,为调整模型参数、选择更合适的算法提供方向。
  3. 比较不同模型:在多个模型之间进行公平、科学的比较,选出最适合特定任务的模型。
  4. 提前预警问题:及时发现模型可能存在的过拟合或欠拟合倾向,采取相应措施加以解决。

2. 度量泛化性能的方法

2.1. 错误率和精度

错误率(Error Rate)是指分类错误的样本数量占样本总数的比例。它直观地反映了模型预测出错的频率。

计算公式:$ \text{Error Rate} = \frac{\text{错误样本数}}{\text{总样本数}} \times 100% $

假设在 100 个测试样本中,模型错误分类了 10 个样本,那么错误率10/100 = 0.1

精度(Accuracy)是指分类正确的样本数量占样本总数的比例,与错误率相对应,反映了模型预测正确的概率。

计算公式:$ \text{Accuracy} = 1 - \text{Error Rate} $

在上述 100 个测试样本中,模型正确分类了 90 个样本,精度为 90/100 = 0.9

错误率精度是分类问题的重要指标,它们能够快速给出模型整体的错误情况和正确率。

它们适用于各类别样本分布均衡的情况。

sckit-learn库中有对应的错误率和精度的计算函数,直接使用即可:

from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier
from sklearn.metrics import accuracy_score# 生成一个二分类数据集
X, y = make_classification(n_samples=1000, n_features=20, n_classes=2, random_state=42)# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42
)# 训练一个决策树分类器
clf = DecisionTreeClassifier(random_state=42)
clf.fit(X_train, y_train)# 获取预测结果
y_pred = clf.predict(X_test)
y_proba = clf.predict_proba(X_test)[:, 1]# 计算错误率,精度
error_rate = 1 - accuracy_score(y_test, y_pred)
accuracy = accuracy_score(y_test, y_pred)print(f"错误率: {error_rate:.2f}")
print(f"精度: {accuracy:.2f}")# 输出结果:
'''
错误率: 0.14
精度: 0.86
'''

2.2. 查准率,查全率和 F1

查准率(Precision)关注的是模型预测为正类的样本中,实际真正为正类的比例,它强调预测结果的可靠性

计算公式:$ \text{Precision} = \frac{TP}{TP+FP} $

查全率(Recall)衡量的是实际正类样本中,被模型正确预测为正类的比例,它关注的是模型对正类样本的覆盖能力

计算公式:$ \text{Recall} = \frac{TP}{TP+FN} $

F1 分数查准率查全率的调和平均数,综合考虑了两者的关系,提供了一个平衡的指标。

计算公式:$ F1 = 2 \times \frac{\text{Precision} \times \text{Recall}}{\text{Precision} + \text{Recall}} $

假设在某个二分类问题中,模型预测出 50 个正例,其中 40 个是真正的正例,实际正例总数为 60 个。

那么,

  • 查准率 = 40 / 50 = 0.8
  • 查全率 = 40 / 60 ≈ 0.6667
  • F1 分数 = 2 * (0.8 * 0.6667) / (0.8 + 0.6667) ≈ 0.7273

在处理不平衡数据集或对正类样本的预测准确性有特殊要求的任务中,查准率查全率和** F1 分数**能更全面地评估模型性能。

例如在医疗诊断中,高查全率意味着尽可能多地检测出患病个体,而高查准率则确保被诊断为患病的个体确实是真正的患者。

这三种指标在sckit-learn库中也有对应的方法:

from sklearn.metrics import (precision_score,recall_score,f1_score,
)# 计算查准率,查全率和F1
precision = precision_score(y_test, y_pred)
recall = recall_score(y_test, y_pred)
f1 = f1_score(y_test, y_pred)print(f"查准率: {precision:.2f}")
print(f"查全率: {recall:.2f}")
print(f"F1 分数: {f1:.2f}")# 运行结果:
'''
查准率: 0.86
查全率: 0.86
F1 分数: 0.86
'''

2.3. ROC 和 AUC

ROC曲线Receiver Operating Characteristic Curve):以真正例率(TPR)为横轴,假正例率(FPR)为纵轴绘制的曲线。

它反映了模型在不同阈值下的真正例率和假正例率之间的权衡关系。

其中,

  • 真正例率(TPR):TPR = 真正例数 / (真正例数 + 假反例数)
  • 假正例率(FPR):FPR = 假正例数 / (假正例数 + 真反例数)

AUC曲线Area Under ROC Curve):ROC曲线下的面积,用于衡量模型区分正负样本的能力。

AUC值越大,表示模型的区分能力越强。

ROCAUC 适用于评估二分类模型的性能,尤其在需要比较不同模型对正负样本的区分能力时非常有效。

它们能够全面地反映模型在不同阈值下的综合表现,而不受阈值选择的影响。

绘制ROC曲线的代码如下,模型的训练过程和上面的示例类似,这里不再重复:

import matplotlib.pyplot as plt
from sklearn.metrics import (roc_auc_score,roc_curve,
)plt.rcParams["font.sans-serif"] = ["SimHei"]  # 设置字体
plt.rcParams["axes.unicode_minus"] = False# 计算ROC AUC
roc_auc = roc_auc_score(y_test, y_proba)# 绘制ROC曲线
fpr, tpr, thresholds = roc_curve(y_test, y_proba)
plt.figure()
plt.plot(fpr, tpr, color="darkorange", lw=2, label=f"ROC curve (area = {roc_auc:.2f})")
plt.plot([0, 1], [0, 1], color="navy", lw=2, linestyle="--")
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel("假正例率(FPR)")
plt.ylabel("真正例率(TPR)")
plt.title("ROC 曲线")
plt.legend(loc="lower right")
plt.show()

2.4. 代价曲线

代价曲线考虑了不同分类错误所造成的实际损失(代价),通过绘制不同阈值下的总代价变化情况,帮助选择最优的分类阈值,使模型在实际应用中的损失最小。

它是对ROC曲线的一种扩展,考虑了不同错误分类的代价。

通过计算ROC曲线上每个点对应的期望总体代价,并在代价平面上绘制线段,取所有线段的下界围成的面积即为代价曲线

代价曲线的绘制方法稍微复杂一些,下面的的代码展示了不同ccp_alpha值对训练集和测试集错误率的影响,以及节点数量的变化。

# 演示代价复杂度剪枝
path = clf.cost_complexity_pruning_path(X_train, y_train)
ccp_alphas, impurities = path.ccp_alphas, path.impuritiesclfs = []
for ccp_alpha in ccp_alphas:clf = DecisionTreeClassifier(random_state=0, ccp_alpha=ccp_alpha)clf.fit(X_train, y_train)clfs.append(clf)node_counts = [clf.tree_.node_count for clf in clfs]
depth = [clf.tree_.max_depth for clf in clfs]train_errors = [1 - clf.score(X_train, y_train) for clf in clfs]
test_errors = [1 - clf.score(X_test, y_test) for clf in clfs]plt.figure(figsize=(12, 6))
plt.subplot(121)
plt.plot(ccp_alphas, train_errors, marker="o", drawstyle="steps-post", label="train")
plt.plot(ccp_alphas, test_errors, marker="o", drawstyle="steps-post", label="test")
plt.xlabel("有效 alpha")
plt.ylabel("错误率")
plt.title("错误率 vs alpha")
plt.legend()plt.subplot(122)
plt.plot(ccp_alphas, node_counts, marker="o", drawstyle="steps-post", label="number of nodes"
)
plt.xlabel("有效 alpha")
plt.ylabel("节点数")
plt.title("节点数 vs alpha")
plt.legend()plt.tight_layout()
plt.show()

3. 度量方法之间的比较

以上各个度量方法有各自的优缺点和使用场景,整理如下表,使用时请根据具体情况来选择。

度量方法 优点 缺点 适用场景
错误率 直观易懂,计算简单 未能区分不同类型的错误,可能在不平衡数据集上具有误导性 分类问题的初步评估,样本分布均衡的情况
精度 直观反映模型正确率 同错误率类似,在不平衡数据集上可能不够准确 快速了解模型整体正确性,各类别样本分布相对均匀的任务
查准率、查全率和 F1 全面考虑正类样本的预测情况,适用于不平衡数据集 指标较多,需要综合考虑 对正类样本预测准确性有特殊要求的任务,如医疗诊断、欺诈检测等
ROC 和 AUC 全面反映模型对正负样本的区分能力,与分类阈值无关 主要适用于二分类问题,且当正负样本分布极度不平衡时,可能对少数类的评估不够敏感 比较不同模型的分类性能,尤其是当需要综合考虑不同阈值下的表现时
代价曲线 考虑实际业务损失,针对性强 需要明确不同错误类型的代价,且曲线绘制和分析相对复杂 实际应用场景中对分类错误代价敏感的任务,如金融风控、营销策略制定等

4. 总结

模型的泛化性能度量是机器学习流程中不可或缺的一环。

通过合理选择和运用不同的度量方法,我们能够全面、客观地评估模型在未知数据上的表现,为模型的优化和实际应用提供坚实的依据。

在实际项目中,应根据数据特点、业务需求以及模型类型等因素,灵活选择合适的度量指标,充分发挥各指标的优势,确保模型在复杂多变的现实场景中稳定、高效地运行。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/906010.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

掌握设计模式--访问者模式

访问者模式(Visitor Pattern) 访问者模式(Visitor Pattern)是一种行为设计模式,它允许你将操作(方法)封装到另一个类中,使得你可以在不修改现有类的情况下,向其添加新的操作。 核心思想是将数据结构和对数据的操作分离,通过访问者对象来对数据进行操作,而不是将操作…

双向广搜-BiDirectional BFS

双向广搜 文章目录 前言前言 复习acwing算法提高课的内容,本篇为讲解算法:双向广搜 一、双向广搜 双向广搜其实就是两个bfs,我们知道bfs是一种暴力的做题方法,搜索树长下图所示:我们会发现搜索树越来越宽,每一层的搜索量增加,如果数据范围很大的话,显然是会TLE的,那么…

读DAMA数据管理知识体系指南31参考数据和主数据概念(上)

读DAMA数据管理知识体系指南31参考数据和主数据概念(上)1. 业务驱动因素 1.1. 满足组织数据需求1.1.1. 组织中的多个业务领域需要访问相同的数据集,并且他们都相信这些数据集是完整的、最新的、一致的1.2. 管理数据质量1.2.1. 数据的不一致、质量问题和差异均会导致决策错误…

生成式 AI 和 LLM 简介 起源 历史记录

领域 年份 定义人工智能 (AI) 1956 计算机科学领域,旨在创造能够复制或超越人类智能的智能机器。机器学习 (Machine Learning) 1997 人工智能的子集,使机器能够从现有数据中学习并根据这些数据进行决策或预测。深度学习 (Deep Learning) 2012 一种机器学习技术,通过使用多层…

拿到代理对象,如何调用增强方法

步骤1 前面已经创建了MathCal的代理对象了,我们在调用方法时加一个断点这里返回的确实是代理对象,这个对象中保存了详细信息(增强器,原始对象等),我们进入bean.add(2, 10) 中,来到 org.springframework.aop.framework.CglibAopProxy.DynamicAdvisedInterceptor.intercept(…

如何保证消息队列的消息只能被消费一次

如何保证消息队列的消息只能被消费一次,首先先保证消息不会丢失 首先先生产者到消费者到消费者有哪些场景会消息丢失一、问题场景 场景一、生产者发送到消息队列失败 场景二、消息队列接受到消息磁盘化失败 场景三、消费者接受到消息消费失败 二、场景原因,如何解决 1、场景一…

Day22_java方法

Java方法 方法重载 package com.xiang.method;public class Demo02 {public static void main(String[] args) {int max = max(20, 100, 10);System.out.println(max);}// 比大小public static int max(int num1,int num2){int result = 0;if (num1 == num2){System.out.printl…

文献阅读《Spectral Networks and Deep Locally Connected Networks on Graphs》

参考博客 第一代图卷积网络:图的频域网络与深度局部连接网络 - 知乎 (zhihu.com) 论文解读一代GCN《Spectral Networks and Locally Connected Networks on Graphs》 - 别关注我了,私信我吧 - 博客园 (cnblogs.com) 论文核心 卷积神经网络得益于所处理的数据具有局部平移不变…

【CodeForces训练记录】Codeforces Round 1013 (Div. 3)

训练情况赛后反思 A题题目读半天,发现日期有前导零,div3还是比较基础一点,但是感觉自己还是不够熟练,D题看出来二分但是调了挺久的 A题 判断取多少个数之后才能构成 20250301,我们维护数字的出现次数,直到所有数字的出现次数全部大于等于 20250301 的出现次数时输出位置即…

字符串问题的江湖奇宝:进制哈希

江湖中,剑客以快制胜,而算法竞赛里,字符串哈希(String Hashing)便是那柄出招如电的快剑。 各种字符串问题纷乱复杂,各种字符串算法招式繁复,需苦练内功心法。但字符串哈希算法却只凭一招:将字符串化作数字,以数论为刃,至简之道斩尽来犯之敌。 但此招并非无懈可击。若…

HW-1

1.选项A是正确的,它表示的是极小项m6的正确形式。极小项m6对应的是变量a=0,b=1,c=1,d=0的情况,因此其表达式应为(\overline{a} \cdot b \cdot c \cdot \overline{d}),即选项A。 其他选项的分析:选项B是一个或项,不符合极小项的定义。 选项C缺少变量a和d,不是一个完整的…