对模型性能进行评估(Machine Learning 研习十五)

在上一篇我们已然训练了一个用于对数字图像识别的模型,但我们目前还不知道该模型在识别数字图像效率如何?所以,本文将对该模型进行评估。

使用交叉验证衡量准确性

评估模型的一个好方法是使用交叉验证,让我们使用cross_val_score() 函数来评估我们的 SGDClassifier模型,使用三折的 k 折交叉验证。k-fold 交叉验证意味着将训练集分成 k 个折叠(在本例中是三个),然后训练模型 k 次,每次取出一个不同的折叠进行评估:

在这里插入图片描述

当您看到这组数字,是不是感到很兴奋?毕竟所有交叉验证折叠的准确率(预测准确率)均超过了 95%。然而,在您兴奋于这组数字前,还是让我们来看看一个假分类器,它只是将每张图片归入最常见的类别,在本例中就是负类别(即非 5):

from sklearn.dummy import DummyClassifierdummy_clf = DummyClassifier() 
dummy_clf.fit(X_train, y_train_5) 
print(any(dummy_clf.predict(X_train)))  # prints False: no 5s detected

您能猜出这个模型的准确度吗?让我们一探究竟:

在这里插入图片描述

没错,它的准确率超过 90%!这只是因为只有大约 10% 的图片是 5,所以如果你总是猜测图片不是 5,你就会有大约 90% 的时间是正确的。比诺斯特拉达穆斯还准。

这说明了为什么准确率通常不是分类器的首选性能指标,尤其是在处理偏斜``````数据集时(即某些类别的出现频率远高于其他类别)。评估分类器性能的更好方法是查看混淆矩阵(CM)。

实施交叉验证

Scikit-Learn现成提供的功能相比,您有时需要对交叉验证过程进行更多控制。在这种情况下,你可以自己实现交叉验证。下面的代码与 Scikit-Learn cross_val_score() 函数做了大致相同的事情,并会打印出相同的结果:

from sklearn.model_selection import StratifiedKFold 
from sklearn.base import cloneskfolds = StratifiedKFold(n_splits=3)  # add shuffle=True if the dataset is                                                # not already shuffled 
for train_index, test_index in skfolds.split(X_train, y_train_5):    clone_clf = clone(sgd_clf)    X_train_folds = X_train[train_index]    y_train_folds = y_train_5[train_index]    X_test_fold = X_train[test_index]    y_test_fold = y_train_5[test_index]clone_clf.fit(X_train_folds, y_train_folds)    y_pred = clone_clf.predict(X_test_fold)    n_correct = sum(y_pred == y_test_fold)    print(n_correct / len(y_pred))  # prints 0.95035, 0.96035, and 0.9604 

StratifiedKFold 类执行分层抽样,生成的折叠数包含每个类别的代表性比例。每次迭代时,代码都会创建分类器的克隆,在训练折叠上训练该克隆,并在测试折叠上进行预测。然后计算正确预测的次数,并输出正确预测的比例。

混淆矩阵

混淆矩阵的一般概念是计算在所有 A/B 对中,A 类实例被分类为 B 类的次数。例如,要知道分类器将 8 和 0 的图像混淆的次数,可以查看混淆矩阵的第 8 行第 0 列。

要计算混淆矩阵,首先需要有一组预测结果,以便与实际目标进行比较。你可以在测试集上进行预测,但最好暂时不要使用测试集(记住,只有在项目的最后阶段,也就是分类器准备好启动时,才会使用测试集)。相反,你可以使用 cross_val_predict() 函数:

from sklearn.model_selection import cross_val_predicty_train_pred = cross_val_predict(sgd_clf, X_train, y_train_5, cv=3) 

cross_val_score() 函数一样,cross_val_predict()也会执行 k 折交叉验证,但它返回的不是评估分数,而是在每个测试折上做出的预测。这意味着你可以得到训练集中每个实例的准确预测(我说的 "准确 "是指 “样本外”:模型对训练期间从未见过的数据进行预测)。

现在可以使用 confusion_matrix()函数获取混淆矩阵了。只需将目标类 (y_train_5) 和预测类 (y_train_pred) 传递给它即可:

在这里插入图片描述

混淆矩阵的每一行代表一个实际类别,每一列代表一个预测类别。矩阵的第一行是非 5 图像(负类): 其中 53 892 幅图像被正确分类为非 5 图像(称为真阴性图像),其余 687 幅图像被错误分类为 5 图像(称为假阳性图像,也称为 I 类错误)。第二行是 5 的图像(正类): 有 1 891 张图片被错误地归类为非 5(假阴性,也称为 II 类错误),而其余 3 530 张图片被正确地归类为 5(真阳性)。一个完美的分类器只有真阳性和真阴性,因此其混淆矩阵只有在主对角线上(从左上角到右下角)才有非零值:

在这里插入图片描述

混淆矩阵提供了大量信息,但有时您可能更喜欢更简洁的指标。一个有趣的指标是正向预测的准确度;这被称为分类器的精度(公式 见下图)。

在这里插入图片描述

TP 是正面的数量,FP是反面的数量。

要想获得完美的精度,一个简单的方法就是创建一个分类器,除了对它最有信心的实例进行一次正向预测外,它总是进行负向预测。如果这一个预测是正确的,那么分类器的精度就是 100%(精度 = 1/1 = 100%)。显然,这样的分类器用处不大,因为它会忽略除了一个正向实例之外的所有实例。因此,精度通常与另一个名为召回率的指标一起使用,召回率也称为灵敏度或真阳性率(TPR):这是分类器正确检测到的阳性实例的比率(公式见下图)。

在这里插入图片描述

FN当然是假不良的数量。

在这里插入图片描述

精确度和召回率

Scikit-Learn提供多种函数来计算分类器指标,包括精度和召回率:

在这里插入图片描述

现在,我们的 "5-检测器 "看起来不像我们观察它的准确性时那么闪亮了。当它声称一幅图像代表 5 时,正确率只有 83.7%。而且,它只能检测到 65.1% 的 5。

通常情况下,将精确度和召回率合并为一个称为 F1 分数的指标会比较方便,尤其是在需要用一个指标来比较两个分类器时。F1 分数是精确度和召回率的调和平均数(公式 见下图)。普通均值对所有值一视同仁,而调和均值对低值的权重要大得多。因此,分类器只有在召回率和精确率都很高的情况下才能获得较高的 F1 分数。

在这里插入图片描述

要计算 F1 分数,只需调用f1_score() 函数即可:

在这里插入图片描述

F1 分数有利于精确度和召回率相似的分类器。这并不总是你想要的:在某些情况下,你主要关心精度,而在另一些情况下,你真正关心的是召回率。例如,如果您训练了一个分类器来检测对儿童安全的视频,那么您可能更倾向于选择一个剔除了许多好视频(召回率低)但只保留安全视频(高精度)的分类器,而不是一个召回率高得多但却让一些非常糟糕的视频出现在您的产品中的分类器(在这种情况下,您甚至可能想要添加一个人工管道来检查分类器的视频选择)。另一方面,假设您训练了一个分类器来检测监控图像中的偷窃者:只要您的分类器的召回率达到 99%,即使它只有 30% 的精度也没有问题(当然,保安会收到一些错误警报,但几乎所有的偷窃者都会被抓住)。

不幸的是,鱼和熊掌不可兼得:提高精度会降低召回率,反之亦然。这就是所谓的精度/召回权衡。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/541798.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

vue/uniapp路由history模式下宝塔空间链接打开新窗口显示404解决方法

vue/uniapp路由history模式下宝塔空间链接打开新窗口显示404,或者域名后带路径参数刷新就报404 解决方法: 宝塔中站点配置修改:【配置文件】中添加下面代码,具体如图: location / {try_files $uri $uri/ /index.html…

汇总全网免费API,持续更新(新闻api、每日一言api、音乐。。。)

Public&FreeAPI 网址:apis.whyta.cn (推荐) UomgAPI 网址:https://api.uomg.com 教书先生 网址:https://api.oioweb.cn/ 山海API https://api.shserve.cn/ 云析API铺 https://api.a20safe.com/ 韩小韩…

浅析 Python 的一些底层原理与 CPython

🍉 CSDN 叶庭云:https://yetingyun.blog.csdn.net/ Python 是一门强大且易用的脚本语言,以其简洁的语法和全面的功能而闻名,能够有效地支持各种业务的快速实现。但 Python 的设计者有意地隐藏了背后的复杂细节。在解决项目问题时&…

hadoop单机ssh免密登录

1. 在hadoop目录下生成密钥对 [rootmaster centos]# cd /usr/apps/hadoop-2.7.1/ [rootmaster hadoop-2.7.1]# ssh-keygen -t rsa //在hadoop目录下生成密钥对 2.找到密钥对的位置 [rootmaster hadoop-2.7.1]# find / -name .ssh //找到密钥对的位置 cd [rootmaster hadoo…

WRF模型运行教程(ububtu系统)--III.运行WRF模型(官网案例)

零、创建DATA目录 # 1.创建一个DATA目录用于存放数据(一般为fnl数据,放在Build_WRF目录下)。 mkdir DATA # 2.进入 DATA cd DATA 一、WPS预处理 在模拟之前先确定模拟域(即模拟范围),并进行数据预处理&#xff08…

宠物医院管理系统{源码+报告}

目 录 1 绪论 1.1 课题背景 1.2 课题研究的现状 1.3 课题研究的意义 2 需求分析 2.1 需求描述 2.2 需求功能描述 2.3 用例模型 2.3.1 业务用例模型 2.3.2 系统用例模型 2.4 动态模型 2.4.1 项目泳道图 2.4.2 业务泳道图 2.5 静态类模型 2.5.1 分析类图 2.…

upload-labs第一关

上一篇文章中搭建好了upload-labs环境,接下来进行第一关的尝试,我也是第一次玩这个挺有意思。 1、第一关的界面是这样的先不看其他的源码,手动尝试下试试。 2、写一个简单的php一句话木马 3、直接上传,提示必须要照片格式的文…

NFT Insider #123:Solana NFT 市场 Tensor 将发行 TNSR 治理代币

引言:NFT Insider由NFT收藏组织WHALE Members (https://twitter.com/WHALEMembers)、BeepCrypto (https://twitter.com/beep_crypto)联合出品,浓缩每周NFT新闻,为大家带来关于NFT最全面、最新鲜…

Windows上Git LFS的安装和使用

到Git LFS官网下载 传送门 初始化GitHub LFS和Git仓库 在仓库目录中运行: git lfs install再运行: git init跟踪大文件 git lfs track "*.zip"添加并提交文件 git add . git commit -m "Add large files"上传到我的github 配…

cron表达式

发现了一个好用测cron表达式组件 一个基于vue3Ant-Design-vue的cron表达式组件 项目地址 效果图 使用方式 前置条件:项目基于vueAnt-Design-vue开发(用到Ant-Design-vue相关组件) 第一步安装组件npm i shiyzhangcron 或 pnpm i shiyzhangcro…

Linux下Arthas(阿尔萨斯)的简单使用-接口调用慢排查

使用环境 k8s容器内运行了一个springboot服务,服务的启动方法是main()方法 下载并启动 arthas curl -O https://arthas.aliyun.com/arthas-boot.jar java -jar arthas-boot.jar选择应用 java 进程 就一个进程org.apache.catalina.startup.Bootstrap,输…