如何信任机器学习模型的预测结果?

 在本篇中,我将通过一个例子演示在 MATLAB 如何使用 LIME 进行复杂机器学习模型预测结果的解释

我使用数据集 carbig(MATLAB 自带的数据集)训练一个回归模型,用于预测汽车的燃油效率。数据集 carbig 是 70 年代到 80 年代生产的汽车的一些数据,包括:

图片

其中:MPG 为响应变量(预测结果),其它变量为预测变量(数据特征),训练一个回归模型 f,该回归模型可以通过汽车的气缸数量、排量、生产年份等信息预测汽车的燃油效率,数学表达如下:

MPG = f(Cylinders,Displacement,Horsepower,modelyear,Weight,Acceleration)

利用 LIME 技术对回归模型 f 的预测结果进行解释,查看是那些特征对预测结果产生影响。具体的实现过程如下。

训练机器学习模型

导入数据集并构建数据表:

rng(0);

load carbig

tbl =table(Acceleration,Cylinders,Displacement,…

Horsepower,Model_Year,Weight,MPG);

进行数据预处理,去除带有缺失值的行:

tbl =rmmissing(tbl);

生成变量数据表:

tblX =removevars(tbl,'MPG');

head(tblX)

将变量按数据类型进行划分。其中第二列 Cylinders 和第五列 Model_Year 是分类变量,其它列是数值变量

图片

tblX_num= removevars(tblX,{'Cylinders','Model_Year'});

tblX_cate= tblX(:,{'Cylinders','Model_Year'});

对于数值变量查看变量之间的相关性

cor =corr(tblX_num{:,:});

h =heatmap(cor);

h.XDisplayLabels= tblX_num.Properties.VariableNames;

h.YDisplayLabels= tblX_num.Properties.VariableNames;

图片

 图 1

从图 1 的计算结果得出,Displacement 与 Weight 具有强相关性,Displacement 与 Horsepower 的相关性也较大。因此,去除 Displacement 变量。

tblX_num= removevars(tblX_num,{'Displacement'});

再次计算变量之间的相关性:

cor =corr(tblX_num{:,:});

h =heatmap(cor);

h.XDisplayLabels= tblX_num.Properties.VariableNames;

h.YDisplayLabels= tblX_num.Properties.VariableNames;

图片

图 2

图 2 显示,变量之间的相关性都小于 0.9,因此保留相关性小于 0.9 的变量作为预测变量。

对数值型预测变量进行标准化,缩放到[0,1]之间,以消除量纲对预测结果的影响。

tblX_num= normalize(tblX_num,"range");

tbl.MPG =normalize(tbl.MPG,"range");

head(tblX_num)

将数值变量和分类变量合并成训练数据集:

图片

将数值变量和分类变量合并成训练数据集:

tblX =[tblX_num tblX_cate];

head(tblX)

图片

训练一个随机森林回归模型,预测变量是表 tblX 中的变量,响应变量是 MPG,并指明第 4 和第 5 列是分类变量。

mdl_bag= fitrensemble(tblX,tbl.MPG,'Method',"Bag",…

'CategoricalPredictors',[4 5]);

对机器学习模型进行解释

利用 LIME 对训练好的回归模型进行解释。

首先构建使用 lime 函数构建一个 LIME,简单的解释模型选择决策树,同时 lime 中也指明了原始回归模型的训练样本的第 4 和第 5 列是分类变量。

lime_bag = lime(mdl_bag,'CategoricalPredictors',[4 5],…        'SimpleModelType',"tree");

我们从训练集中选取一个样本作为预测数据(即 QueryPoint),测试模型的预测结果,以及模型的解释结果。选择训练集中的第 257 个样本作为预测数据。

num =257;

queryPoint= tblX(num,:)

图片

以预测数据为基础生成合成数据,并训练一个可解释模型(决策树模型)。对于可解释模型,指定变量个数为 5。也就是说,我们最多只分析 5 个对预测结果产生影响的变量。

lime_bag= fit(lime_bag,queryPoint,5);

根据预测变量对预测结果的影响程度进行排序并可视化。

f =plot(lime_bag);

title('随机森林回归模型的LIME');

f.CurrentAxes.TickLabelInterpreter= 'none';

图片

图 3

从图 3 可以看出,预测变量 Weight,对回归模型预测的结果影响最大,其次是 Cylinders 和 Horsepower。

可以解释为:基于输入数据预测出的汽车燃油效率,主要依次考虑了汽车的重量(Weight)、汽车的动力(Horsepower)、汽车汽缸数量(Cylinders)。

这种解释也是符合我们的先验知识:汽车自重越大燃油效率越低,也就是每加仑行驶的里程数越少。汽车功率越大、汽缸数量越多,耗油越大。

因此说,回归模型预测的结果是可信的。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/329536.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

ros2在启动前准备工作:

ros2的准备工作就是:setup.bash内容等价于setup.sh 文件存放路径:/opt/ros/humble/ # generated from ament_package/template/prefix_level/local_setup.sh.in# since a plain shell script cant determine its own path when being sourced # either…

Dockerfile的ENV

文章目录 环境总结测试测试1测试2测试3测试4测试5测试6 参考 环境 RHEL 9.3Docker Community 24.0.7 总结 如果懒得看测试的详细信息,可以直接看结果: 一条 ENV 指令可以定义多个环境变量。Dockerfile里可以包含多条 ENV 指令。环境变量的值不需要用…

游戏开发中,你的游戏图片压缩格式使用ASTC了吗

文章目录 ASTC原理:使用要求 ASTC(Adaptive Scalable Texture Compression,自适应可伸缩纹理压缩)是一种高级的纹理压缩技术,由ARM公司开发并推广。它在图形处理领域中因其出色的压缩效率和灵活性而受到广泛关注。 AST…

[Linux] 一文理解HTTPS协议:什么是HTTPS协议、HTTPS协议如何加密数据、什么是CA证书(数字证书)...

之前的文章中, 已经分析介绍过了HTTP协议. HTTP协议在网络中是以明文的形式传输的. 无论是GET还是POST方法都是不安全的. 为什么不安全呢? 因为: HTTP协议以明文的形式传输数据, 缺乏对信息的保护. 如果在网络中传输数据以明文的形式传输, 网络中的任何人都可以轻松的获取数据…

软件装一送三了!还附带弹窗资讯,你确定不试一下?

前言 前几天一个朋友向我吐槽,说电脑太卡了。自己好像都没安装什么软件,怎么就那么多弹窗广告。 我看了一下他的电脑,笑了一下说:你的电脑真好,都会只能给你推荐美女看,这资讯来之不易啊,好好享…

20240105-工作安排的最大收益

题目要求 我们有 n 份工作,每份工作都安排在 startTime[i] 至 endTime[i] 期间完成,从而获得 profit[i] 的利润。 给你 startTime、endTime 和 profit 数组,返回你能获得的最大利润,使得子集中没有两个时间范围重叠的工作。 如…

【C++】几种常用的类型转换

类型转换 c语言中的类型转换C的类型转换static_castreinterpret_castconst_castdynamic_cast c语言中的类型转换 在C语言中我们经常会遇到类型转化的问题,主要分为两种:显式类型转换和隐式类型转换。 显式类型转换:就是程序员使用强制类型转…

Kali Linux——设置中文

【问题现象】 从下图可以看到,菜单全是英文的。对于英文不好的同学,使用起来很难受。 【解决方法】 1、获取root权限 su root 2、进入语言设置 dpkg-reconfigure locales 3、选择zh_CN.UTF-8 UTF-8 4、设置默认 5、安装完成 6、重启虚拟机 reboot…

关于java的多维数组

关于java的多维数组 在前面的文章中,我们了解了数组的使用,我们之前所了解的数组是一维数组,本篇文章我们来了解一下二维数组,多维数组😀 一、二维数组 首先我们知道一维数组的声明和创建的方式是。 int array ne…

Mysql SQL审核平台Yearning本地部署

文章目录 前言1. Linux 部署Yearning2. 本地访问Yearning3. Linux 安装cpolar4. 配置Yearning公网访问地址5. 公网远程访问Yearning管理界面6. 固定Yearning公网地址 前言 Yearning 简单, 高效的MYSQL 审计平台 一款MYSQL SQL语句/查询审计工具,为DBA与开发人员使用…

Python和Java环境搭建

小白搭建全流程 首先不建议装在C盘,一旦重置电脑,之前安装第三方包需要重新安装 relolver :解释器 1、Python解释器安装 资源包: 1、 python -version java -version–用于查看是否安装 where python whrer java–用于查看安装的位置【非常…

【强力推荐】GitCode AI开源搜索,面向开发者的专业AI搜索

一、GitCode AI开源搜索是什么? GitCode AI开源搜索 是面开发者的 AI 开源搜索工具,目的是为了帮助开发者快速寻找开源项目代码、解决开发问题和快速寻找答案,帮助开发者提升效率的同时利用代码仓托管能力建立自己个人知识库。 二、GitCode…