K最近邻算法:简单高效的分类和回归方法(三)

文章目录

  • 🍀引言
  • 🍀训练集和测试集
  • 🍀sklearn中封装好的train_test_split
  • 🍀超参数

🍀引言

本节以KNN算法为主,简单介绍一下训练集和测试集超参数


🍀训练集和测试集

训练集和测试集是机器学习和深度学习中常用的概念。在模型训练过程中,通常将数据集划分为训练集和测试集,用于训练和评估模型的性能。

训练集是用于模型训练的数据集合。模型通过对训练集中的样本进行学习和参数调整来提高自身的预测能力。训练集应该尽可能包含各种不同的样本,以使模型能够学习到数据集中的模式和规律,并能够适应新的数据。

测试集是用于评估模型性能的数据集合。模型训练完成后,使用测试集中的样本进行预测,并与真实标签进行对比,以评估模型的精度、准确度和其他性能指标。测试集应该与训练集相互独立,以确保对模型的泛化能力进行准确评估。

一般来说,训练集和测试集的划分比例是80:20或者70:30。有时候还会引入验证集,用于在训练过程中调整模型的超参数。训练集、验证集和测试集是机器学习中常用的数据集拆分方式,以确保模型的准确性和泛化能力。

接下来我们回顾一下KNN算法的简单原理,选取离待预测最近的k个点,再使用投票进行预测结果

from sklearn.neighbors import KNeighborsClassifier
knn_clf = KNeighborsClassifier()
from sklearn.datasets import load_iris  # 因为我们并没有数据集,所以从库里面调出来一个
iris = load_iris()
X = iris.data
y = iris.target
knn_clf.fit(X,y)
knn_clf.predict()

那么我们如何评价KNN模型的好坏呢?

这里我们将数据集分为两部分,一部分为训练集,一部分为测试集,因为这里的训练集和测试集都是有y的,所以我们只需要将训练集进行训练,然后产生的模型应用到测试集,再将预测的y和原本的y进行对比,这样就可以了

接下来进行简易代码演示讲解

from sklearn.datasets import load_iris
iris = load_iris()
X = iris.data
y = iris.target

我们可以把y打印出来看看
在这里插入图片描述
这里我们不妨思考一下,如果训练集和测试集是8:2的话,测试集的y岂不是都是2了,那么还有啥子意义,所以我们需要将其打乱一下下,当然我们这里打乱的是index也就是下标,可不要自以为是的将y打乱了

import numpy as np
indexs = np.random.permutation(len(X))

导入必要的库后,我们将数据集下标进行打乱并保存于indexs中,接下来迎来重头戏分割数据集

test_ratio = 0.2
test_size = int(len(X) * test_ratio)
test_indexs = shuffle_indexs[:test_size] # 测试集
train_indexs = shuffle_indexs[test_size:] # 训练集

不信的小伙伴可以使用如下代码进行检验

test_indexs.shape
train_indexs.shape

在这里插入图片描述
接下来将打乱的下标进行分别赋值

X_train = X[train_indexs]
y_train = y[train_indexs]
X_test = X[test_indexs]
y_test = y[test_indexs]

分割好数据集后,我们就可以使用KNN算法进行预测了

from sklearn.neighbors import KNeighborsClassifier
knn_clf = KNeighborsClassifier()
knn_clf.fit(X_train,y_train)
y_predict = knn_clf.predict(X_test)

我们这里可以打印一下y_predict和y_test进行肉眼对比一下
在这里插入图片描述
在这里插入图片描述
最后一步就是将精度求出来

np.sum(np.array(y_predict == y_test,dtype='int'))/len(X_test)

在这里插入图片描述


🍀sklearn中封装好的train_test_split

上面我们只是简单演示了一下,接下来我们使用官方的train_test_split

from sklearn.model_selection import train_test_split
X_train,X_test,y_train,y_test = train_test_split(X,y) # 注意这里返回四个结果

这里你可以试着看一眼,分割的比例与之前手动分割的比例大不相同
最后按部就班来就行

knn_clf = KNeighborsClassifier()
knn_clf.fit(X_train,y_train)
knn_clf.predict(X_test) 
knn_clf.score(X_test,y_test)

在这里插入图片描述


🍀超参数

什么是超参数,可以点击链接查看

在pycharm中我们可以查看一些参数
在这里插入图片描述

接下来通过简单的演示来介绍一下

from sklearn.neighbors import KNeighborsClassifier
from sklearn.datasets import load_iris
knn_clf = KNeighborsClassifier(weights='distance') 
from sklearn.model_selection import train_test_split
iris = load_iris()
X = iris.data
y = iris.target
X_train,X_test,y_train,y_test = train_test_split(X,y)

上面是老熟人了就不一一赘述了,但是注意这里面有个超参数(weights),这个参数有两种,一个是distance一个是uniform,前者和距离有关联,后者无关


首先测试一下n_neighbors这个参数代表的就行之前的那个k,邻近点的个数

%%time
best_k = 0
best_score = 0.0
best_clf = None
for k in range(1,21):knn_clf = KNeighborsClassifier(n_neighbors=k)knn_clf.fit(X_train,y_train)score = knn_clf.score(X_test,y_test)if score>best_score:best_score = scorebest_k = kbest_clf = knn_clf
print(best_k)
print(best_score)
print(best_clf)

在这里插入图片描述
测试完参数n_neighbors,我们再来试试weights

%%time
best_k = 0
best_score = 0.0
best_clf = None
best_method = None
for weight in ['uniform','distance']:for k in range(1,21):knn_clf = KNeighborsClassifier(n_neighbors=k,weights=weight)knn_clf.fit(X_train,y_train)score = knn_clf.score(X_test,y_test)if score>best_score:best_score = scorebest_k = kbest_clf = knn_clfbest_method = weight
print(best_k)
print(best_score)
print(best_clf)
print(best_method)

在这里插入图片描述
最后我们测试一下参数p

%%time
best_k = 0
best_score = 0.0
best_clf = None
best_p = None
for p in range(1,6):for k in range(1,21):knn_clf = KNeighborsClassifier(n_neighbors=k,weights='distance',p=p)knn_clf.fit(X_train,y_train)score = knn_clf.score(X_test,y_test)if score>best_score:best_score = scorebest_k = kbest_clf = knn_clfbest_p = pprint(best_k)
print(best_score)
print(best_clf)
print(best_p)

或许大家不知道这个参数p的含义,下面我根据几个公式带大家简单了解一下
请添加图片描述

请添加图片描述
请添加图片描述

三张图分别代表欧拉距离曼哈顿距离明科夫斯基距离,细心的小伙伴就可以发现了,p=1位曼哈顿距离,p=2位欧拉距离,这里不做详细的说明,感兴趣的小伙伴可以翻阅相关数学书籍

请添加图片描述

挑战与创造都是很痛苦的,但是很充实。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/58299.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

一文了解 Android Auto 车载开发~

作者:牛蛙点点申请出战 背景 我的的产品作为一个海外音乐播放器,在车载场景听歌是一个很普遍的需求。在用户反馈中,也有很多用户提到希望能在车上播放音乐。同时车载音乐也可以作为提升用户消费时长一个抓手。 出海产品,主要服务…

初中信息技术考试编程题,初中信息技术python教案

大家好,小编来为大家解答以下问题,初中信息技术python编程题库 网盘,初中信息技术python编程教学,今天让我们一起来看看吧! ID:12450455 资源大小:934KB 资料简介: 2019-2020学年初中信息技术【轻松备课】P…

svg使用技巧

什么是svg SVG 是一种基于 XML 语法的图像格式,全称是可缩放矢量图(Scalable Vector Graphics)。其他图像格式都是基于像素处理的,SVG 则是属于对图像的形状描述,所以它本质上是文本文件,体积较小&#xf…

瞅一眼nginx

目录 🦬什么是nginx? 🦬nginx配置官方yum源: 🦬nginx优点 🦬nginx 缺点 🦬查看nginx默认模块 🐌nginx新版本的配置文件: 🐌nginx目录索引 🐌nginx状态…

机器学习实战1-kNN最近邻算法

文章目录 机器学习基础机器学习的关键术语 k-近邻算法(KNN)准备:使用python导入数据实施kNN分类算法示例:使用kNN改进约会网站的配对效果准备数据:从文本文件中解析数据分析数据准备数据:归一化数值测试算法…

JavaWeb(9)——前端综合案例3(悬停显示下拉列表)

一、实例需求 ⌛ 实现类似百度首页的“一个简单的鼠标悬停显示的下拉列表效果”。 二、代码实现 ☕ <!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8"><title>Title</title><style>.dropdown-cont…

Flink-串讲面试题

1. 概念 有状态的流式计算框架 可以处理源源不断的实时数据&#xff0c;数据以event为单位&#xff0c;就是一条数据。 2. 开发流程 先获取执行环境env&#xff0c;然后添加source数据源&#xff0c;转换成datastream&#xff0c;然后使用各种算子进行计算&#xff0c;使用…

【数据结构OJ题】轮转数组

原题链接&#xff1a;https://leetcode.cn/problems/rotate-array/ 目录 1. 题目描述 2. 思路分析 3. 代码实现 1. 题目描述 2. 思路分析 1. 方法一&#xff1a;暴力求解&#xff0c;将数组的第一个元素用临时变量tmp存起来&#xff0c;再将数组其他元素往右挪动一步&…

SpringBoot+MyBatis多数据源配置

1.先在配置文件application.yml中配置好数据源 spring:datasource:type: com.alibaba.druid.pool.DruidDataSourcedb1:driver-class-name: com.mysql.cj.jdbc.Driverusername: rootpassword: rootjdbc-url: jdbc:mysql://192.168.110.128:3306/CampusHelp?useUnicodeyes&…

基于dockerfile构建sshd、httpd、nginx、tomcat、mysql、lnmp、redis镜像

一、镜像概述 Docker 镜像是Docker容器技术中的核心&#xff0c;也是应用打包构建发布的标准格式。一个完整的镜像可以支撑多个容器的运行&#xff0c;在Docker的整个使用过程中&#xff0c;进入一个已经定型的容器之后&#xff0c;就可以在容器中进行操作&#xff0c;最常见的…

【深度学习】Collage Diffusion,拼接扩散,论文,实战

论文&#xff1a;https://arxiv.org/abs/2303.00262 代码&#xff1a;https://github.com/VSAnimator/collage-diffusion 文章目录 AbstractIntroductionProblem Definition and Goals论文其他内容实战 Abstract 基于文本条件的扩散模型能够生成高质量、多样化的图像。然而&a…

CentOS 7中,配置了Oracle jdk,但是使用java -version验证时,出现的版本是OpenJDK,如何解决?

1.首先&#xff0c;检查已安装的jdk版本 sudo yum list installed | grep java2.移除、卸载圈红的系统自带的openjdk sudo yum remove java-1.7.0-openjdk.x86_64 sudo yum remove java-1.7.0-openjdk-headless.x86_64 sudo yum remove java-1.8.0-openjdk.x86_64 sudo yum r…