机器学习KNN最邻近分类算法

文章目录

  • 1、KNN算法简介
  • 2、KNN算法实现
  • 3、调用scikit-learn库中KNN算法
  • 4、使用scikit-learn库生成数据集
  • 5、自定义函数划分数据集
  • 6、使用scikit-learn库划分数据集
  • 7、使用scikit-learn库对鸢尾花数据集进行分类

1、KNN算法简介

KNN (K-Nearest Neighbor) 最邻近分类算法,其核心思想“近朱者赤,近墨者黑”,由你的邻居来推断你的类别。

图中绿色圆归为哪一类?
1、如果k=3,绿色圆归为红色三角形
2、如果k=5,绿色圆归为蓝色正方形
在这里插入图片描述

参考文章

knn算法实现原理:为判断未知样本数据的类别,以所有已知样本数据作为参照物,计算未知样本数据与所有已知样本数据的距离,从中选取k个与已知样本距离最近的k个已知样本数据,根据少数服从多数投票法则,将未知样本与K个最邻近样本中所属类别占比较多的归为一类。(我们还可以给邻近样本加权,距离越近的权重越大,越远越小)

2、KNN算法实现

1、k值选择:太小容易产生过拟合问题,过度相信样本数据,太大容易产生欠拟合问题,与数据贴合不够解密,决策效率低。

2、样本数据归一化:最简单的方式就是所有特征的数值都采取归一化处置。

3、一个距离函数计算两个样本之间的距离:通常使用的距离函数有:欧氏距离、曼哈顿距离、汉明距离等,一般选欧氏距离作为距离度量,但是这是只适用于连续变量。在文本分类这种非连续变量情况下,汉明距离可以用来作为度量。通常情况下,如果运用一些特殊的算法来计算度量的话,K近邻分类精度可显著提高。在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

4、KNN优点:
1.简单,易于理解,易于实现,无需估计参数,无需训练
2. 适合对稀有事件进行分类
3.特别适合于多分类问题(multi-modal,对象具有多个类别标签), KNN比SVM的表现要好

5、KNN缺点:
KNN算法在分类时有个主要的不足是,当样本不平衡时,如一个类的样本容量很大,而其他类样本容量很小时,有可能导致当输入一个新样本时,该样本的K个邻居中大容量类的样本占多数,如下图所示。该算法只计算最近的邻居样本,某一类的样本数量很大,那么或者这类样本并不接近目标样本,或者这类样本很靠近目标样本。无论怎样,数量并不能影响运行结果。可以采用权值的方法(和该样本距离小的邻居权值大)来改进。该方法的另一个不足之处是计算量较大,因为对每一个待分类的文本都要计算它到全体已知样本的距离,才能求得它的K个最近邻点。
可理解性差,无法给出像决策树那样的规则。

实现KNN算法简单实例

1、样本数据散点图展示在这里插入图片描述

# KNN算法实现
import numpy as np
import matplotlib.pyplot as plt# 样本数据 
data_X = [[1.3,6],[3.5,5],[4.2,2],[5,3.3],[2,9],[5,7.5],[7.2,4],[8.1,8],[9,2.5],]
# 样本标记数组
data_y = [0,0,0,0,1,1,1,1,1]# 将数组转换成np数组
X_train = np.array(data_X)
y_train = np.array(data_y)# 散点图绘制
# 取等于0的行中的第0列数据X_train[y_train==0,0]
plt.scatter(X_train[y_train==0,0],X_train[y_train==0,1],color='red',marker='x')
# 取等于1的行中的第1列数据X_train[y_train==1,0]
plt.scatter(X_train[y_train==1,0],X_train[y_train==1,1],color='black',marker='o')
plt.show()

2、新的样本数据,判断它属于哪一类
在这里插入图片描述

data_new = np.array([4,5])
plt.scatter(X_train[y_train==0,0],X_train[y_train==0,1],color='red',marker='x')
plt.scatter(X_train[y_train==1,0],X_train[y_train==1,1],color='black',marker='o')
plt.scatter(data_new[0],data_new[1],color='blue',marker='s')
plt.show()

3、计算新样本点与所有已知样本点的距离
在这里插入图片描述

Numpy使用

# 样本数据-新样本数据 的平方,然后开平,存储距离值到distances中
distances = [np.sqrt(np.sum((data-data_new)**2)) for data in X_train]
# 按照距离进行排序,返回原数组中索引 升序
sort_index = np.argsort(distances)# 随机选一个k值
k = 5# 距离最近的5个点进行投票表决
first_k = [y_train[i] for i in sort_index[:k]]# 使用计数库统计
from collections import Counter
# 取出结果为类别0
predict_y = Counter(first_k).most_common(1)[0][0]
predict_y

3、调用scikit-learn库中KNN算法

2007年,Scikit-learn首次被Google Summer of Code项目开发使用,现在已经被认为是最受欢迎的机器学习Python库。

安装:pip install scikit-learn
在这里插入图片描述

# 使用scikit-learn中的KNN算法
from sklearn.neighbors import KNeighborsClassifier
# 初始化设置k大小
knn_classifier = KNeighborsClassifier(n_neighbors=5)
# 喂入数据集,以及数据类型
knn_classifier.fit(X_train,y_train)
# 放入新样本数据进行预测,需要先转换成二维数组
knn_classifier.predict(data_new.reshape(1,-1))

4、使用scikit-learn库生成数据集

生成的数据,画出的散点图
在这里插入图片描述

# 数据集生产
import numpy as np
from matplotlib import pyplot as plt
from sklearn.datasets import make_blobs
x,y = make_blobs(n_samples=300, # 样本总数n_features=2, # 生产二维数据centers=3, # 种类数据cluster_std=1, # 类内的标注差center_box=(-10,10), # 取值范围random_state=233, # 随机数种子return_centers=False, # 类别中心坐标反回值
)
# c指定每个点颜色,s指定点大小
plt.scatter(x[:,0],x[:,1],c=y,s=15)
plt.show()
x.shape,y.shape

5、自定义函数划分数据集

将生成好的数据集,划分成训练数据集和测试数据集
在这里插入图片描述

# 数据集划分
np.random.seed(233)
# 随机生成数组排列下标
shuffle = np.random.permutation(len(x))
train_size = 0.7train_index = shuffle[:int(len(x)*train_size)]
test_index = shuffle[int(len(x)*train_size):]
train_index.shape,test_index.shape# 通过下标数组到数据集中取出数据
x_train = x[train_index]
y_train = y[train_index]
x_test = x[test_index]
y_test = y[test_index]# 训练数据集
plt.scatter(x_train[:,0],x_train[:,1],c=y_train,s=15)
plt.show()# 测试数据集
plt.scatter(x_test[:,0],x_test[:,1],c=y_test,s=15)
plt.show()

6、使用scikit-learn库划分数据集

# sklearn划分数据集
from sklearn.model_selection import train_test_split
# 保证3个样本数保持原来分布,添加参数stratify=y
x_train,x_test,y_train,y_test = train_test_split(x,y,train_size=0.7,random_state=233,stratify=y)
from collections import Counter
Counter(y_test)

7、使用scikit-learn库对鸢尾花数据集进行分类

# 使用鸢尾花数据集
import numpy as np
from sklearn import datasets# 加载数据集
iris = datasets.load_iris()
# 获取样本数组,样本类型数组
X = iris.data
y = iris.target# 拆分数据集
# 不能直接拆分因为现在的y已经是排序好的,需要先乱序数组
# shuffle_index = np.random.permutation(len(X))
# train_ratio = 0.8
# train_size = int(len(y)*train_ratio)
# train_index = shuffle_index[:train_size]
# test_index = shuffle_index[train_size:]
# X_train = X[train_index]
# Y_train = y[train_index]
# X_test = X[test_index]
# Y_test = y[test_index]from sklearn.model_selection import train_test_split
# 保证3个样本数保持原来分布,添加参数stratify=y
x_train,x_test,y_train,y_test = train_test_split(X,y,train_size=0.8,random_state=666)# 预测
from sklearn.neighbors import KNeighborsClassifier
# 初始化设置k大小
knn_classifier = KNeighborsClassifier(n_neighbors=5)
# 喂入数据集,以及数据类型
knn_classifier.fit(x_train,y_train)# 如果关心预测结果可以跳过下面所有score返回得分
knn_classifier.score(x_test,y_test)y_predict = knn_classifier.predict(x_test)# 评价预测结果 将y_predict和真是的predict进行比较就可以了
accuracy = np.sum(y_predict == y_test)/len(y_test)
# accuracy# sklearn中计算准确度的方法
from sklearn.metrics import accuracy_score
accuracy_score(y_test,y_predict)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/586270.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

QT - 日志:qDebug/qInfo/qWarning/qCritical

篇一、日志打印函数 头文件&#xff1a; #include <QDebug> 代码&#xff1a;qDebug()<<"hello world!"; 其他打印级别&#xff1a; qInfo(): 普通信息 qDebug(): 调试信息 qWarning(): 警告信息 qCritical(): 严重错误 qFatal(): 致命错误 1. qDebug…

软件工程知识体系 Chapter3 软件构造

介绍 软件构造一词指的是通过编码、验证、单元测试、集成测试和调试等组合详细创建工作软件的过程。 软件构建知识领域&#xff08;KA&#xff09;与所有其他KA都有关联&#xff0c;但它与软件设计和软件测试的关联最为紧密&#xff0c;因为软件构建过程涉及重要的软件设计和…

frp内网穿透,让外网可以访问内网

需求 我们的svn部署在内网&#xff0c;用的一直没问题&#xff0c;但是有时候有需求在外网访问svn&#xff0c;进行提交更新等操作&#xff0c;这时候就有了内网穿透这个需求。 当然&#xff0c;我们也可以借助花生壳等软件进行内网穿透&#xff0c;傻瓜化操作&#xff0c;也…

windows linux 安装 nvm

windows 一、下载nvm-windows 前往github https://github.com/coreybutler/nvm-windows 进入latest 往下滑下载nvm-setup.exe 二、下载好后直接一直点击下一步就好。 检查一下 nvm -v &#xff0c;会输出版本号 附带常用命令 nvm install 10.15.3 安装v10.15.3版本 nvm u…

了解游戏相关知识

个人笔记&#xff08;整理不易&#xff0c;有帮助点个赞&#xff09; 笔记目录&#xff1a;学习笔记目录_pytest和unittest、airtest_weixin_42717928的博客-CSDN博客 个人随笔&#xff1a;工作总结随笔_8、以前工作中都接触过哪些类型的测试文档-CSDN博客 目录 一&#xff1a…

C语言 输入输出语句讲解 标识符概念讲解

上文 C语言 预处理器 注释 基本案例讲解 我们讲了一些 预处理器等逻辑 那么 本文继续 C语言由一个或多个函数组成&#xff0c;每个程序都必须有一个main() 函数 因为每个程序总是从这个函数开始执行 main() 函数可以返回一个值&#xff0c;返回值为0表示程序正常结束 如果有多…

MS SQL Server STUFF 函数实战 统计记录行转为列显示

目录 范例运行环境 视图样本设计 数据统计要求 STUFF函数实现 小结 范例运行环境 操作系统&#xff1a; Windows Server 2019 DataCenter 数据库&#xff1a;Microsoft SQL Server 2016 视图样本设计 假设某一视图 [v_pj_rep1_lname_score] 可查询对某一被评价人的绩效…

校园通勤车可视化系统的设计与实现

1.需求分析&#xff1a; 校园通勤车可视化系统的设计与实现&#xff0c;不用管什么可视化&#xff0c;就是一个小程序就是可以知道校园车的路线&#xff0c;然后往简单了弄就可以。 校园通勤车可视化系统的设计与实现&#xff0c;不用管什么可视化&#xff0c;就是一个小程序…

uniapp开发App(一)登陆流程 判断是否登陆,是,进入首页,否,跳转到登录页

一、登陆流程 文字描述&#xff1a;用户进入App&#xff0c;之后就是判断该App是否有用户登陆过&#xff0c;如果有&#xff0c;直接进入首页&#xff0c;否则跳转到登陆页&#xff0c;登陆成功后&#xff0c;进入首页。 流程图如下&#xff1a; 二、在uniapp项目中代码实现 实…

zabbix图表时间与服务器时间不一致问题

部署完zabbix后&#xff0c;有时候会发现zabbix服务器的时间明明是对的&#xff0c;但是图标的时间不对&#xff0c;通过以下的配置可以快速解决。 登录zabbix-nginx容器 docker exec -u root -it docker-compose-zabbix-zabbix-web-nginx-mysql-1 bash修改php配置文件 vi /e…

2012年认证杯SPSSPRO杯数学建模B题(第一阶段)减缓热岛效应全过程文档及程序

2012年认证杯SPSSPRO杯数学建模 减缓热岛效应 B题 白屋顶计划 原题再现&#xff1a; 第一阶段问题   夏天的城市气温往往格外炎热&#xff0c;这被称为热岛效应。有专家提出&#xff0c;将城市建筑的屋顶漆成白色&#xff0c;减小对阳光的吸收率&#xff0c;可以使城市的气…

GROBID库文献解析

1. 起因 由于某些原因需要在大量的文献中查找相关内容&#xff0c;手动实在是太慢了&#xff0c;所以选择了GROBID库进行文献批量解析 2. GROBID介绍 GROBID是一个机器学习库&#xff0c;用于将PDF等原始文档提取、解析和re-structuring为结构化的XML/TEI编码文档&#xff0…