机器学习实战第3天:手写数字识别

 

☁️主页 Nowl

🔥专栏《机器学习实战》 《机器学习》

📑君子坐而论道,少年起而行之 

文章目录

一、任务描述

二、数据集描述

三、主要代码

(1)主要代码库的说明与导入方法

(2)数据预处理

(3)模型训练

(4)模型预测与性能评估

(5)除数据预处理外的完整代码

四、本章总结​ 


一、任务描述

手写数字识别是机器学习中的一个经典问题,通常涉及将手写数字的图像与其对应的数字进行关联。这种问题通常被认为是计算机视觉领域的一个入门任务,也是许多深度学习框架和算法的基础测试案例之一。

二、数据集描述

手写数字识别数据集包含了一列数字标签,每个数字标签有784个像素值,代表这个数字图片的像素值

三、主要代码

(1)主要代码库的说明与导入方法

import pandas as pd

pandas 是一个数据分析库,提供了灵活的数据结构,如 DataFrame,用于处理和分析结构化数据。它常被用于数据清洗、处理和分析。

import matplotlib.pyplot as plt

matplotlib 是一个用于绘制图表和可视化数据的库。pyplot模块是 matplotlib 的一个子模块,用于创建各种类型的图表,如折线图、散点图、直方图等。

import numpy as np

NumPy 是用于科学计算的库,提供了高性能的数组对象和各种数学函数。它在数据处理和数值计算中被广泛使用,尤其是在机器学习中。

import matplotlib as mpl

这里再次导入 matplotlib 库,但是这次将其别名设置为mpl。这样做是为了在代码中使用更短的别名,以提高代码的可读性。

from sklearn.model_selection import train_test_split

scikit-learn(sklearn)是一个用于机器学习的库。train_test_split函数用于将数据集划分为训练集和测试集,这是机器学习模型评估的一种常见方式。

from sklearn.neighbors import KNeighborsClassifier

这里导入了 scikit-learn 中的KNeighborsClassifier类,该类实现了 k-近邻分类器,用于进行基于邻近样本的分类。

from sklearn.metrics import accuracy_score

从 scikit-learn 中导入accuracy_score函数,用于计算分类模型的准确度分数。准确度是分类模型预测的正确样本数占总样本数的比例。

为确保代码能正常运行,请先复制以下代码,导入本文用到的所有库

import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import matplotlib as mpl
from sklearn.model_selection import train_test_split
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import accuracy_score

 当代码无法运行时,还有可能是文件路径问题,注意改成自己的文件路径

(2)数据预处理

1.导入数据

使用pandas库导入数据集文件,文件路径要换成自己的

digit = pd.read_csv("datasets/digit-recognizer/train.csv")

2.划分训练集与测试集

使用train_test_split函数将数据集分为训练集和测试集,测试集比例为0.2

再将特征和标签分离出来

train, test = train_test_split(digit, test_size=0.2)train_x = train.drop(columns="label")
train_y = train["label"]
test_x = test.drop(columns="label")
test_y = test["label"]

 3.图片显示

我们可以使用matplotlib库将图片显示出来

  • train_x.iloc[2]选取训练集的第3行数据
  • np.array()将数组转化为numpy数组,以便使用reshape函数
  • .reshape(28,28)将原来的784个特征转化为(28,28)格式的数据,这代表一个正方形图片
  • cmap=mpl.cm.binary使图片颜色为黑白
  • plt.imshow()函数可以将一个像素数组转化为图片
plt.imshow(np.array(train_x.iloc[2]).reshape(28, 28), cmap=mpl.cm.binary)
plt.show()
print(train_y.iloc[2])

显示图片并打印数据标签

我们可以看到图像是一个数字9,打印标签也确实是9,接下来我们就来训练一个数字识别机器学习模型

(3)模型训练

由于这是一个分类任务,我们可以选择使用KNN近邻算法,第一步设置模型,第二步训练模型

model = KNeighborsClassifier(n_neighbors=3)
model.fit(train_x, train_y)

(4)模型预测与性能评估

寻找最优参数

对于大部分机器学习模型来说,设置不同的参数得到的模型性能都不同,我们可以绘制不同参数的准确率曲线图来寻找最优参数

accuracy = []for i in range(1, 10):model = KNeighborsClassifier(n_neighbors=i)model.fit(train_x, train_y)prediction = model.predict(test_x)accuracy.append(accuracy_score(prediction, test_y))print()plt.plot(range(1, 10), accuracy)
plt.xlabel("neighbors")
plt.ylabel("accuracy")
plt.show()

可以看到当neighbors为3时模型效果最好,我们在应用时就将模型参数设置为3

(5)除数据预处理外的完整代码

这里是舍弃了一些寻找特征等工作的完整模型训练代码

import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import accuracy_scoredigit = pd.read_csv("datasets/digit-recognizer/train.csv")train, test = train_test_split(digit, test_size=0.2)train_x = train.drop(columns="label")
train_y = train["label"]
test_x = test.drop(columns="label")
test_y = test["label"]model = KNeighborsClassifier(n_neighbors=3)
model.fit(train_x, train_y)
prediction = model.predict(test_x)
print(accuracy_score(prediction, test_y))

四、本章总结​ 

  • 学习了使用numpy处理图像数据的方法
  • 学习了打印准确率曲线来寻找最优参数的方法
  • 使用KNN模型来完成分类任务

当然,也可以自己处理特征,自己选择模型,调整参数,看看会不会获得更好的结果

 感谢阅读,觉得有用的话就订阅下本专栏吧

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/225743.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

实现了父类 纯虚函数为什么还有 无法解析外部符号错误

使用背景: 将C 的函数或接口使用 pybind11 封装成可以供python 使用调用的接口或函数,使用了CMake 编译(若之前可以编译通过,现在编译不通过,重新选择 source code 路径)成 VS 2019 可使用的目标解决方案&a…

四丶openlayer之瓦片地图

瓦片地图源于一种大地图解决方案,针对一整块非常大的地图进行切片,分成很多相同大小的小块地图,在用户访问的时候,再一块一块小地图加载,拼接在一起,从而还原成一整块大的地图。这样做的优点在于&#xff0…

zblog插件-zblog采集插件下载

在当今数字化的时代,博客已经成为人们分享思想、经验和知识的重要平台。而对于使用zblog博客系统的用户来说,充实博客内容是提高用户体验和吸引读者的不二法门。然而,手动收集内容对于博主来说既费时又繁琐。在这个背景下,zblog插…

QQ导出聊天记录,学会这3个方法,轻松解决!

QQ是我们与他人沟通、交流的重要方式之一。特别是对于那些将QQ作为主要通讯工具的人来说,聊天记录中可能包含了大量重要的信息。那么,如果想要导出QQ聊天记录该怎么办呢?qq导出聊天记录的方法有哪些?本文将为您介绍三种方法&#…

逆向 tg 发送图片

开发工具 工具名称工具类型说明AndroidStuduo编辑工具开发工具jadxjava工具将apk解成java项目xposed插件工具插件tg版本9.7.5 分析源码的点: 发送图片的点 获取sendMessageParams 获取TLRPC$TL_photo 回调 实现 public void sendImg(String path, String…

MyBatis使用教程详解<上>

一. 什么是MyBatis? Mybatis是一个持久层框架,用于简化JDBC的操作MyBatis原本是Apache的一个开源项目ibatis,后来更名为MyBatis 上面我们提到了一个概念----持久层 不知道小伙伴们有没有想到五大注解的关系,类似于下图 其中MyBatis就是Mapper层的框架,是基于JDBC的封装,可以帮…

IDEA删除的文件怎么找回更新

一、 查找本地历史记录IDEA在进行代码版本管理时,会自动创建本地历史记录,如果我们误删了文件,可以通过查找本地历史记录来找回文件。 1.在项目中,选中被删文件的父级目录,“File”->“Local History”->“Show…

UE4 UE5 使用SVN控制

关键概念:虚幻引擎中使用SVN,帮助团队成员共享资源。 1. UE4/UE5项目文件 如果不需要编译的中间缓存,则删除: DerivedDataCache、Intermediate、Saved 三个文件夹 2.更新、上传

鸿蒙应用开发-初见:入门知识、应用模型

基础知识 Stage模型应用程序包结构 开发并打包完成后的App的程序包结构如图 开发者通过DevEco Studio把应用程序编译为一个或者多个.hap后缀的文件,即HAP一个应用中的.hap文件合在一起称为一个Bundle,bundleName是应用的唯一标识 需要特别说明的是&…

SpringCloud 微服务全栈体系(十八)

第十一章 分布式搜索引擎 elasticsearch 八、RestClient 查询文档 文档的查询同样适用 RestHighLevelClient 对象,基本步骤包括: 准备 Request 对象准备请求参数发起请求解析响应 1. 快速入门 以 match_all 查询为例 1.1 发起查询请求 代码解读&…

出海企业客服服务大作战:打造卓越客户体验的六大法宝!

全球化的加速推进,使得越来越多的企业开始涉足国际市场,进行跨境业务拓展。在这个过程中,优质的客服服务成为了出海企业获取竞争优势和留住客户的关键。本文将探讨出海企业如何做好客服服务,并重点阐述客服系统在提升服务质量和效…

FFA 2023|字节跳动 7 项议题入选

Flink Forward 是由 Apache 官方授权的 Apache Flink 社区官方技术大会,作为最受 Apache Flink 社区开发者期盼的年度峰会之一,FFA 2023 将持续集结行业最佳实践以及 Flink 最新技术动态,是中国 Flink 开发者和使用者不可错过的的技术盛宴。 …