[机器学习]KNN——K邻近算法实现

一.K邻近算法概念

二.代码实现 

# 0. 引入依赖
import numpy as np
import pandas as pd# 这里直接引入sklearn里的数据集,iris鸢尾花
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split  # 切分数据集为训练集和测试集
from sklearn.metrics import accuracy_score # 计算分类预测的准确率# 1. 数据加载和预处理
iris = load_iris()
# print(iris)df = pd.DataFrame(data = iris.data, columns = iris.feature_names)
df['class'] = iris.target
df['class'] = df['class'].map({0: iris.target_names[0], 1: iris.target_names[1], 2: iris.target_names[2]})
df.head(10)
# df.describe()
# print(df)x = iris.data
y = iris.target.reshape(-1,1)
# print(x.shape, y.shape)# 划分训练集和测试集
# test_size:测试比例,random_state:随机划分,stratify:按照y的分布等比例分割
x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.3, random_state=35, stratify=y)
# print(x_train.shape, y_train.shape)
# print(x_test.shape, y_test.shape)# 2. 核心算法实现
# 距离函数定义
def l1_distance(a, b):return np.sum(np.abs(a - b), axis=1) # 曼哈顿距离def l2_distance(a, b):return np.sqrt(np.sum((a - b) ** 2, axis=1)) # 欧氏距离# 分类器实现
class kNN(object):# 定义一个初始化方法,__init__ 是类的构造方法def __init__(self, n_neighbors=1, dist_func=l1_distance):self.n_neighbors = n_neighborsself.dist_func = dist_func# 训练模型方法def fit(self, x, y):self.x_train = xself.y_train = y# 模型预测方法def predict(self, x):# 初始化预测分类数组:初始化一个0数组,x.shape[0]:行数,1:列数,dtype:定义此数据类型y_pred = np.zeros((x.shape[0], 1), dtype=self.y_train.dtype)# 遍历输入的x数据点,取出每一个数据点的序号i和数据x_test。enumerate:可同时拿出两个(序号和值)for i, x_test in enumerate(x):# x_test跟所有训练数据计算距离distances = self.dist_func(self.x_train, x_test)# 得到的距离按照由近到远排序,取出索引值nn_index = np.argsort(distances)# 选取最近的k个点,保存它们对应的分类类别,n_neighbors:表示取k个邻近的值nn_y = self.y_train[nn_index[0:self.n_neighbors]].ravel()# 统计类别中出现频率最高的那个,赋给y_pred[i]y_pred[i] = np.argmax(np.bincount(nn_y))return y_pred"""a = np.array([[3,2,4,2],[2,1,4,23],[12,3,2,3],[2,3,15,23],[1,3,2,3],[13,3,2,2],[213,16,3,63],[23,62,23,23],[23,16,23,43]])b = np.array([[1,1,1,1]])print("a-b:",a-b) # 下面的a-b:a表示数组,b表示向量np.sum(np.abs(a - b), axis=1)dist = np.sqrt( np.sum((a-b) ** 2, axis=1) )nn_index = np.argsort(dist)print("dist: ", dist)print("nn_index: ", nn_index)nn_y = y_train[nn_index[:9]].ravel()print("未转换前的y:",y_train[:8])print("nn_y:", nn_y)print("y计数:",np.bincount(nn_y))print("取出现次数最多的y:",np.argmax(np.bincount(nn_y)))
"""# 3. 测试
# 定义一个knn实例
knn = kNN(n_neighbors = 3)
# 训练模型
knn.fit(x_train, y_train)
# 传入测试数据,做预测
y_pred = knn.predict(x_test)
print("y测试值: ", y_test.ravel())
print("y预测值: ", y_pred.ravel())
# 求出预测准确率
accuracy = accuracy_score(y_test, y_pred)
print("预测准确率: ", accuracy)# 定义一个knn实例
knn = kNN()
# 训练模型
knn.fit(x_train, y_train)
# 保存结果list
result_list = []
# 针对不同的参数选取,做预测
for p in [1, 2]:knn.dist_func = l1_distance if p == 1 else l2_distance# 考虑不同的k取值,步长为2(取奇数1,3,5,7,9)for k in range(1, 10, 2):knn.n_neighbors = k# 传入测试数据,做预测y_pred = knn.predict(x_test)# 求出预测准确率accuracy = accuracy_score(y_test, y_pred)result_list.append([k, '曼哈顿距离' if p == 1 else '欧氏距离', accuracy])
df = pd.DataFrame(result_list, columns=['k', '距离函数', '预测准确率'])
print(df)

y测试值:  [2 1 2 2 0 0 2 0 1 1 2 0 1 1 1 2 2 0 1 2 1 0 0 0 1 2 0 2 0 0 2 1 0 2 1 0 2 1 2 2 1 1 1 0 0]
y预测值:  [2 1 2 2 0 0 2 0 1 1 1 0 1 1 1 2 2 0 1 2 1 0 0 0 1 2 0 2 0 0 2 1 0 2 1 0 2 1 2 1 1 2 1 0 0]
预测准确率:  0.9333333333333333
   k   距离函数     预测准确率
0  1  曼哈顿距离  0.933333
1  3  曼哈顿距离  0.933333
2  5  曼哈顿距离  0.977778
3  7  曼哈顿距离  0.955556
4  9  曼哈顿距离  0.955556
5  1   欧氏距离  0.933333
6  3   欧氏距离  0.933333
7  5   欧氏距离  0.977778
8  7   欧氏距离  0.977778
9  9   欧氏距离  0.977778

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/439466.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

来聊聊大厂面试题:求Java对象的大小

写在文章开头 日常使用Java进行业务开发时,我们基本不关心一个Java对象的大小,所以经常因为错误的估算导致大量的内存空间在无形之间被浪费了,所以今天笔者就基于这篇文章来聊聊一个Java对象的大小。 你好,我叫sharkchili&#x…

【数据结构与算法】之哈希表系列-20240129

这里写目录标题 一、217. 存在重复元素二、219. 存在重复元素 II三、242. 有效的字母异位词四、268. 丢失的数字五、290. 单词规律六、349. 两个数组的交集七、350. 两个数组的交集 II 一、217. 存在重复元素 简单 给你一个整数数组 nums 。如果任一值在数组中出现至少两次 &a…

力扣hot100 子集 回溯 超简洁

Problem: 78. 子集 文章目录 思路复杂度Code 思路 &#x1f468;‍&#x1f3eb; 参考题解 复杂度 时间复杂度: 添加时间复杂度, 示例&#xff1a; O ( n ) O(n) O(n) 空间复杂度: 添加空间复杂度, 示例&#xff1a; O ( n ) O(n) O(n) Code class Solution {List<Li…

闪测影像|智能影像测量仪高精度快速批量检测

在现代工业制造领域&#xff0c;快速批量测量零部件尺寸能确保产品质量、提升生产效率、优化生产过程、降低成本以及增强市场竞争力等。 通过快速批量测量&#xff0c;迅速检测出不合格的零部件&#xff0c;避免生产过程中的浪费和延误&#xff0c;优化生产过程并提高生产效率。…

MG7050HAN 基于声表的差分多输出 晶体振荡器 (HCSL)

基于MG7050 HAN的声表差分多输出晶体振荡器(HCSL)&#xff0c;采用两路或四路差分HCSL&#xff08;高速电流驱动逻辑&#xff09;输出&#xff0c;可以减少外部扇出缓冲区&#xff0c;特别适用于需要超低抖动、高频率范围内稳定工作的应用场合。其输出特性曲线超低抖动&#xf…

【C/C++ 05】快速排序

快速排序是Hoare于1962年提出的一种二叉树结构的交换排序算法&#xff0c;其基本思想是&#xff1a;任取待排序序列中的某元素作为基准值&#xff0c;按照该基准值将待排序集合分割成两个子序列&#xff0c;左子序列中所有元素均小于基准值&#xff0c;右子序列中所有元素均大于…

集简云数据表新增动态下拉,一键拉取相关数据,快速实现业务场景自动化

为了提升数据表相关场景的数据交互的效率和准确性&#xff0c;本周集简云数据表新增了动态下拉字段&#xff0c;可直接在该字段中关联应用动作获取&#xff0c;无需搭建复杂流程&#xff0c;可搭配按钮使用&#xff0c;直接调用和配置应用动作获取相关字段数据&#xff0c;手动…

MySQL前百分之N问题--percent_rank()函数

PERCENT_RANK()函数 PERCENT_RANK()函数用于将每行按照(rank - 1) / (rows - 1)进行计算,用以求MySQL中前百分之N问题。其中&#xff0c;rank为RANK()函数产生的序号&#xff0c;rows为当前窗口的记录总行数 PERCENT_RANK()函数返回介于 0 和 1 之间的小数值 selectstudent_…

elementui中的tree自定义图标

需求&#xff1a;实现如下样式的树形列表 自定义树的图标以及点击时&#xff0c;可以根据子级的关闭&#xff0c;切换图标 <el-tree :data"treeList" :props"defaultProps"><template #default"{ node, data }"><span class&quo…

Mac安装及配置MySql及图形化工具MySQLworkbench安装

Mac下载配置MySql mysql下载及安装 下载地址&#xff1a;https://dev.mysql.com/downloads/mysql/ 根据自己电脑确定下载x86还是ARM版本的 如果不确定&#xff0c;可以查看自己电脑版本&#xff0c;终端输入命令 uname -a 点击Download下载&#xff0c;可跳过登录注册&…

【论文阅读】Long-Tailed Recognition via Weight Balancing(CVPR2022)

目录 论文使用方法weight decayMaxNorm 如果使用原来的代码报错的可以看下面这个 论文 问题&#xff1a;真实世界中普遍存在长尾识别问题&#xff0c;朴素训练产生的模型在更高准确率方面偏向于普通类&#xff0c;导致稀有的类别准确率偏低。 key:解决LTR的关键是平衡各方面&a…

VRRP协议原理

目录 VRRP的产生单网关的缺陷多网关存在的问题VRRP基本概述VRRP基本结构状态机 VRRP主备备份工作过程VRRP的工作过程如果Master发生故障&#xff0c;则主备切换的过程如果原Master故障恢复&#xff0c;则主备回切的过程 VRRP联动功能 VRRP负载分担工作过程 VRRP的产生 单网关的…