KNN算法回归问题介绍和实现

上篇博客中,介绍了使用KNN算法实现分类问题,本篇文章介绍使用KNN算法实现回归问题。介绍思路是先使用sklearn包提供的方法实现一个KNN算法的回归问题。再自定义实现一个KNN算法的回归问题工具类。

一、sklearn包使用KNN算法

1. 准备数据

使用sklearn包提供的make_regression模块制作回归类型数据。

from sklearn.datasets import make_regression

除了make_regession外,sklearn包还提供了制作分类问题的数据等方法,如下图:
在这里插入图片描述
我们在需要测试数据时,可以根据需求引入不同的模块来创建数据。
生成数据:

X, y = make_regression(n_samples=10000, n_features=20, n_informative=15, random_state=0)

其中,n_samples是样本数,n_features是每个样本的特征数,n_informative是有效特征数,random_state是随机生成数的种子,种子相同,生成的X和y的值都相同。
回归问题生成的X,均值接近0,标准差接近1,是去中心化后的数据。而分类问题生成的数据,就不具有此特点,如下所示为make_regression生成的数据:

X.mean() #0.0033349709157105382
X.std() #0.998015035291231

2. 切分数据

使用sklearn提供的train_test_split方法对生成的数据进行切分:

from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=0)

3. 使用sklearn包进行KNN回归问题的验证

from sklearn.neighbors import KNeighborsRegressor
# 第一步,构建模型
knn = KNeighborsRegressor(n_neighbors=7)
# 第二步,训练模型
knn.fit(X=X_train, y=y_train)
#第三步,模拟测试
y_pred = knn.predict(X =X_test)
#第四步,使用MSE评测预测结果
((y_test - y_pred)**2).mean()

首先需要知道的是,单独看最后MSE的计算结果的大小,不代表预测准确与否。而是要通过调整进邻的参数,来比较不同近邻下MSE的结果,来看选择哪个近邻参数最合适。而对于现实问题,是否要用KNN的回归问题算法来解决,是另当别论的,不能通过MSE的结果去判断是否用KNN算法正确与否。

MSE: 求预测结果与实际结果的差值的绝对值(平方),然后再求差值绝对值的平均数。具体MSE的含义和定义,参考通俗易懂讲解均方误差 (MSE)

在回归问题中,一般使用MSE表示预测结果。在分类问题中,使用预测值=实际值的平均数[(y_predict == y_test).mean()]表示预测结果。

二、自定义回归问题实现

class MyKNeighborsRegressor(object):"""自定义KNN 回归器"""def __init__(self, n_neighbors=5):"""挂载超参数"""self.n_neighbors=n_neighborsdef fit(self, X, y):"""训练过程"""self.X = np.array(X)self.y = np.array(y)def predict(self, X):X = np.array(X)results = []for x in X:# 计算两个向量之间的距离,sqrt((x1-x2)**2+(y1-y2)**2+......)。x是行向量,self.X是测试用例的矩阵,self.X-x用到了向量的广播机制,进行对齐然后相减。由此计算出来的距离是测试集中单个向量与训练集中所有行向量的距离distances = ((self.X - x) ** 2).sum(axis=1)#选出距离最近的向量的脚标indices = distances.argsort(axis=0)[:self.n_neighbors]#根据脚标获取训练集中的对应脚本的元素labels = self.y[indices]# 取距离最近的训练集的标签,然后求均值,就是回归问题的预测结果y_pred = labels.mean()results.append(y_pred)return np.array(results)

三、总结

对比KNN算法的分类问题和回归问题的自定义实现,需要捋清楚几点:

  1. 对于上面的例子而言,样本矩阵就是一个二维数组,二维数组中的每个一维数组,就是一行,每列都代表一个特征。而一行样本数据,都会对应一个标签值,即y。
  2. 使用KNN算法,就是将训练数据中的样本数据和标签数据提供给模型后,在预测测试数据时,模型根据测试数据的每行样本,去查找之前提供的训练数据中所有的样本中,距离这个测试数据样本最近的n个训练样本数据。
  3. 找到邻近的n个训练样本数据后,找出这n个训练样本数据对应的标签值。然后在分类问题中,找出n个训练样本对应标签值中最多的那个标签,就认为是测试这条测试样本的标签值。在回归问题中,找出n个训练样本对应的标签后,求这些标签的均值,就认为是当前测试样本的标签。
  4. 上述是针对测试样本中每个向量的处理逻辑,当循环找出所有测试样本的标签值后,就可以返回总体的预测数据了。
  5. 最后通过预测数据与真实数据比对,查看是否适合用KNN算法以及近邻参数如何设置最准确。

因此,KNN算法的核心理念是通过找邻近训练样本的标签,来推算测试样本的标签进行返回。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/106291.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

单片机之硬件记录

一、概念 VBAT 当使用电池或其他电源连接到VBAT脚上时,当VDD断电时,可以保存备份寄存器的内容和维持RTC的功能。如果应用中没有使用外部电池,VBAT引脚应接到VDD引脚上。 VCC:Ccircuit 表示电路的意思,即接入电路的电压&#x…

java设计模式,简单工厂和抽象工厂有什么区别?

java设计模式,简单工厂和抽象工厂有什么区别? 简单工厂模式: 这个模式本身很简单而且使用在业务较简单的情况下。一般用于小项目或者具体产品很少扩展的情况(这样工厂类才不用经常更改)。 它由三种角色组成&#xf…

pip和conda的环境管理,二者到底应该如何使用

关于pip与conda是否能混用的问题,Anaconda官方早就给出了回答 先说结论,如果conda和pip在相同环境下掺杂使用,尤其是频繁使用这两个工具进行包的安装,可能会导致环境状态混乱 就像其他包管理器一样,大部分这些问题均…

Eviews用向量自回归模型VAR实证分析公路交通通车里程与经济发展GDP协整关系时间序列数据和脉冲响应可视化...

全文下载链接:http://tecdat.cn/?p27784 河源市是国务院1988年1月7日批准设立的地级市,为了深入研究河源市公路交通与经济发展的关系,本文选取了1988-2014年河源市建市以来24年的地区生产总值(GDP)和公路通…

Linux dup dup2函数

/*#include <unistd.h>int dup2(int oldfd, int newfd);作用&#xff1a;重定向文件描述符oldfd 指向 a.txt, newfd 指向b.txt,调用函数之后&#xff0c;newfd和b.txt close&#xff0c;newfd指向a.txtoldfd必须是一个有效的文件描述符 */ #include <unistd.h> #i…

selenium的Chrome116版驱动下载

这里写自定义目录标题 下载地址https://googlechromelabs.github.io/chrome-for-testing/#stable 选择chromedriver 对应的平台和版本 国内下载地址 https://download.csdn.net/download/dongtest/88314387

北斗高精度定位,破解共享单车停车乱象

如今&#xff0c;共享单车已经成为了许多人出行的首选方式&#xff0c;方便了市民们的“最后一公里”&#xff0c;给大家的生活带来了很多便利。然而&#xff0c;乱停乱放的单车也给城市治理带来了难题。在这种情况下&#xff0c;相关企业尝试将北斗导航定位芯片装载到共享单车…

Mysql->Hudi->Hive

一 准备 1.启动集群 /hive/mysql start-all.sh2.启动spark-shell spark-shell \--master yarn \ //--packages org.apache.hudi:hudi-spark3.1-bundle_2.12:0.12.2 \--jars /opt/software/hudi-spark3.1-bundle_2.12-0.12.0.jar \--conf spark.serializerorg.apache.spark.…

【数据结构】双向链表详解

当我们学习完单链表后&#xff0c;双向链表就简单的多了&#xff0c;双向链表中的头插&#xff0c;尾插&#xff0c;头删&#xff0c;尾删&#xff0c;以及任意位置插&#xff0c;任意位置删除比单链表简单&#xff0c;今天就跟着小张一起学习吧&#xff01;&#xff01; 双向链…

12个微服务架构模式最佳实践

微服务架构是一种软件开发技术&#xff0c;它将大型应用程序分解为更小的、可管理的、独立的服务。每个服务负责特定的功能&#xff0c;并通过明确定义的 API 与其他服务进行通信。微服务架构有助于实现软件系统更好的可扩展性、可维护性和灵活性。 接下来&#xff0c;我们将介…

vue中预览xml并高亮显示

项目中有需要将接口返回的数据流显示出来&#xff0c;并高亮显示&#xff1b; 1.后端接口返回blob,类型为xml,如图 2.页面中使用pre code标签&#xff1a; <pre v-if"showXML"><code class"language-xml">{{xml}}</code></pre> …

RJ45水晶头网线顺序出错排查

线序 网线水晶头RJ45常用的线序标准ANSI / TIA-568定义了T568A与T568B两种线序&#xff0c;一般使用T568B&#xff0c;水晶头8个孔对应的8条线颜色如下图&#xff1a; 那1至8的编号&#xff0c;是从水晶头哪一面为参考呢&#xff0c;如下图&#xff0c;是水晶头金手指一面&am…