02、Tensorflow实现手写数字识别(数字0-9)

02、Tensorflow实现手写数字识别(数字0-9)

开始学习机器学习啦,已经把吴恩达的课全部刷完了,现在开始熟悉一下复现代码。对这个手写数字实部比较感兴趣,作为入门的素材非常合适。

基于Tensorflow 2.10.0与pycharm

1、识别目标

识别手写仅仅是为了区分手写的0到9,所以实际上是一个多分类问题。
STEP1:导入相关包

import numpy as np
import tensorflow as tf
from keras.models import Sequential
from keras.layers import Dense
from sklearn.model_selection import train_test_split
from sklearn.metrics import  classification_report
import matplotlib.pyplot as plt
import logging
import warnings

import numpy as np:这是引入numpy库,并为其设置一个缩写np。Numpy是Python中用于大规模数值计算的库,它提供了多维数组对象及一系列操作这些数组的函数。

import tensorflow as tf:这是引入tensorflow库,并为其设置一个缩写tf。TensorFlow是一个开源的深度学习框架,它被广泛用于各种深度学习应用。

from keras.models import Sequential:这是从Keras库中引入Sequential模型。Keras是一个高级神经网络API,它可以运行在TensorFlow之上。Sequential模型是Keras中的线性堆栈模型,允许你简单地堆叠多个网络层。

from keras.layers import Dense:这是从Keras库中引入Dense层。Dense层是神经网络中的全连接层,每个输入节点与输出节点都是连接的。

from sklearn.model_selection import train_test_split:这是从scikit-learn库中引入train_test_split函数。这个函数用于将数据分割为训练集和测试集。

from sklearn.metrics import classification_report 这行代码的主要作用是导入classification_report 函数,以便在后续的代码中使用它来评估分类模型的性能。

import matplotlib.pyplot as plt:这是引入matplotlib的pyplot模块,并为其设置一个缩写plt。Matplotlib是Python中的绘图库,而pyplot是其中的一个模块,用于绘制各种图形和图像。

import warnings:这是引入Python的标准警告库,它可以用来发出警告,或者过滤掉不需要的警告。

import logging:这是引入Python的标准日志库,用于记录日志信息,方便追踪和调试代码。


STEP2:屏蔽无用警告并允许中文

# 使用warnings模块来忽略特定类型的警告  
warnings.simplefilter(action='ignore', category=FutureWarning)  
# 配置tensorflow的日志记录级别  
logging.getLogger("tensorflow").setLevel(logging.ERROR)  
# 设置TensorFlow的autograph模块的详细级别  
tf.autograph.set_verbosity(0)  
# 设置numpy的打印选项  
np.set_printoptions(precision=2)  

STEP3:加载数据集并分割测试集

# load dataset
def load_data():X = np.load("Handwritten_Digit_Recognition_Multiclass_data/X.npy")y = np.load("Handwritten_Digit_Recognition_Multiclass_data/y.npy")return X, y# load dataset
X, y = load_data()print ('The shape of X is: ' + str(X.shape))
print ('The shape of y is: ' + str(y.shape))
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1, random_state=42)

原始的输入的数据集是5000* 400数组,共包含5000个手写数字的数据,其中400为20*20像素的图片,
在这里插入图片描述


STEP4:模型构建与训练

# 构建模型  
tf.random.set_seed(1234)  # 设置随机种子以确保每次运行的结果是一致的  
model = Sequential([### START CODE HERE ###  tf.keras.Input(shape=(400,)),  # 输入层,输入数据的形状是400维   Dense(100, activation='relu', name="L1"),  # 全连接层,100个神经元,使用ReLU激活函数,命名为"L1"  Dense(75, activation='relu', name="L2"),  # 全连接层,75个神经元,使用ReLU激活函数,命名为"L2"  Dense(10, activation='linear', name="L3"),  # 输出层,10个神经元,使用线性激活函数,命名为"L3"  ### END CODE HERE ###  ], name="my_model"
)  # 定义模型名称为"my_model"  
model.summary()  # 打印模型的概述信息  # 配置模型的训练参数  
model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),# 使用稀疏分类交叉熵作为损失函数,且输出是logits(即未经过softmax的原始输出)  optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),  # 使用Adam优化器,并设置学习率为0.001  
)# 训练模型  
history = model.fit(X_train, y_train,  # 使用X_train作为输入数据,y_train作为目标数据  epochs=100  # 训练100轮  
)

STEP5:结果可视化与打印准确度信息

fig, axes = plt.subplots(20, 25, figsize=(20, 25))
fig.tight_layout(pad=0.13, rect=[0, 0.03, 1, 0.91])  # [left, bottom, right, top]
for i, ax in enumerate(axes.flat):# Select random indicesrandom_index = np.random.randint(X_test.shape[0])# Select rows corresponding to the random indices and# reshape the imageX_random_reshaped = X_test[random_index].reshape((20, 20)).T# Display the imageax.imshow(X_random_reshaped, cmap='gray')# Predict using the Neural Networkprediction = model.predict(X_test[random_index].reshape(1, 400))prediction_p = tf.nn.softmax(prediction)yhat = np.argmax(prediction_p)# 错误结果标红if y_test[random_index, 0] == yhat:ax.set_title(f"{y_test[random_index, 0]},{yhat}", fontsize=10)ax.set_axis_off()else:ax.set_title(f"{y_test[random_index, 0]},{yhat}", fontsize=10, color='red')ax.set_axis_off()fig.suptitle("Label, yhat", fontsize=14)
plt.show()# 给出预测的测试集误差
def evaluation(y_test, y_predict):accuracy=classification_report(y_test, y_predict,output_dict=True)['accuracy']s=classification_report(y_test, y_predict,output_dict=True)['weighted avg']precision=s['precision']recall=s['recall']f1_score=s['f1-score']#kappa=cohen_kappa_score(y_test, y_predict)return accuracy,precision,recall,f1_score #, kappay_pred=model.predict(X_test)
prediction_p = tf.nn.softmax(y_pred)
yhat = np.argmax(prediction_p, axis=1)
accuracy,precision,recall,f1_score=evaluation(y_test,yhat)print("测试数据集准确率为:", accuracy)
print("测试数据集精确率为:", precision)
print("测试数据集召回率为:", recall)
print("测试数据集F1_score为:", f1_score)

3、运行结果

在这里插入图片描述

4、工程下载与全部代码

工程链接:Tensorflow实现手写数字识别(数字0-9)

import numpy as np
import tensorflow as tf
from keras.models import Sequential
from keras.layers import Dense
from sklearn.model_selection import train_test_split
from sklearn.metrics import  classification_report
import matplotlib.pyplot as plt
import logging
import warnings# 使用warnings模块来忽略特定类型的警告
warnings.simplefilter(action='ignore', category=FutureWarning)
# 配置tensorflow的日志记录级别
logging.getLogger("tensorflow").setLevel(logging.ERROR)
# 设置TensorFlow的autograph模块的详细级别
tf.autograph.set_verbosity(0)
# 设置numpy的打印选项
np.set_printoptions(precision=2)# load dataset
def load_data():X = np.load("Handwritten_Digit_Recognition_Multiclass_data/X.npy")y = np.load("Handwritten_Digit_Recognition_Multiclass_data/y.npy")return X, y# load dataset
X, y = load_data()print ('The shape of X is: ' + str(X.shape))
print ('The shape of y is: ' + str(y.shape))
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1, random_state=42)# # 绘图可选
# m, n = X.shape
# fig, axes = plt.subplots(8, 8, figsize=(5, 5))
# fig.tight_layout(pad=0.13, rect=[0, 0.03, 1, 0.91])  # [left, bottom, right, top]
# # fig.tight_layout(pad=0.5)
# for i, ax in enumerate(axes.flat):
#     # Select random indices
#     random_index = np.random.randint(m)
#     # Select rows corresponding to the random indices and
#     # reshape the image
#     X_random_reshaped = X[random_index].reshape((20, 20)).T
#     # Display the image
#     ax.imshow(X_random_reshaped, cmap='gray')
#     # Display the label above the image
#     ax.set_title(y[random_index, 0])
#     ax.set_axis_off()
#     fig.suptitle("Label, image", fontsize=14)
# plt.show()# 构建模型
tf.random.set_seed(1234)  # 设置随机种子以确保每次运行的结果是一致的
model = Sequential([### START CODE HERE ###tf.keras.Input(shape=(400,)),  # 输入层,输入数据的形状是400维Dense(100, activation='relu', name="L1"),  # 全连接层,100个神经元,使用ReLU激活函数,命名为"L1"Dense(75, activation='relu', name="L2"),  # 全连接层,75个神经元,使用ReLU激活函数,命名为"L2"Dense(10, activation='linear', name="L3"),  # 输出层,10个神经元,使用线性激活函数,命名为"L3"### END CODE HERE ###], name="my_model"
)  # 定义模型名称为"my_model"
model.summary()  # 打印模型的概述信息# 配置模型的训练参数
model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),# 使用稀疏分类交叉熵作为损失函数,且输出是logits(即未经过softmax的原始输出)optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),  # 使用Adam优化器,并设置学习率为0.001
)# 训练模型
history = model.fit(X_train, y_train,  # 使用X_train作为输入数据,y_train作为目标数据epochs=100  # 训练100轮
)fig, axes = plt.subplots(20, 25, figsize=(20, 25))
fig.tight_layout(pad=0.13, rect=[0, 0.03, 1, 0.91])  # [left, bottom, right, top]
for i, ax in enumerate(axes.flat):# Select random indicesrandom_index = np.random.randint(X_test.shape[0])# Select rows corresponding to the random indices and# reshape the imageX_random_reshaped = X_test[random_index].reshape((20, 20)).T# Display the imageax.imshow(X_random_reshaped, cmap='gray')# Predict using the Neural Networkprediction = model.predict(X_test[random_index].reshape(1, 400))prediction_p = tf.nn.softmax(prediction)yhat = np.argmax(prediction_p)# Display the label above the imageif y_test[random_index, 0] == yhat:ax.set_title(f"{y_test[random_index, 0]},{yhat}", fontsize=10)ax.set_axis_off()else:ax.set_title(f"{y_test[random_index, 0]},{yhat}", fontsize=10, color='red')ax.set_axis_off()fig.suptitle("Label, yhat", fontsize=14)
plt.show()# 给出预测的测试集误差
def evaluation(y_test, y_predict):accuracy=classification_report(y_test, y_predict,output_dict=True)['accuracy']s=classification_report(y_test, y_predict,output_dict=True)['weighted avg']precision=s['precision']recall=s['recall']f1_score=s['f1-score']#kappa=cohen_kappa_score(y_test, y_predict)return accuracy,precision,recall,f1_score #, kappay_pred=model.predict(X_test)
prediction_p = tf.nn.softmax(y_pred)
yhat = np.argmax(prediction_p, axis=1)
accuracy,precision,recall,f1_score=evaluation(y_test,yhat)print("测试数据集准确率为:", accuracy)
print("测试数据集精确率为:", precision)
print("测试数据集召回率为:", recall)
print("测试数据集F1_score为:", f1_score)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/221314.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【洛谷算法题】P5715-三位数排序【入门2分支结构】

👨‍💻博客主页:花无缺 欢迎 点赞👍 收藏⭐ 留言📝 加关注✅! 本文由 花无缺 原创 收录于专栏 【洛谷算法题】 文章目录 【洛谷算法题】P5715-三位数排序【入门2分支结构】🌏题目描述🌏输入格式…

打包SpringBoot 项目为本地应用

使用工具:exe4j、Inno Setup Compiler 步骤: 1,将dll包放入项目根路径下; 2,idea 使用Maven打jar包; 3,使用exe4j 工具进行打包; 打开工具首页不动(直接 next&#xff…

光线追踪-Peter Shirley的RayTracing In One Weekend系列教程(book1-book3)代码分章节整理

自己码完了一遍了,把代码分章节整理了一下,可以按章节独立编译,运行, 也可以直接下载编译好的release版本直接运行。 项目地址: Github: https://github.com/disini/RayTracingInOneWeekendChaptByChapt ​ ​ ​ ​

MySQL的undo log 与MVCC

文章目录 概要一、undo日志1.undo日志的作用2.undo日志的格式3. 事务id(trx_id) 二、MVCC1.版本链2.ReadView3.REPEATABLE READ —— 在第一次读取数据时生成一个ReadView4.快照读与当前读 小结 概要 Undo Log:数据库事务开始之前&#xff0…

Redis序列化操作

目录 1.protostuff 的 Maven 依赖 2.定义实体类 3.序列化工具类 ProtostuffSerializer 提供了序列化和反序列化方法 4.测试 利用 Jedis 提供的字节数组参数方法,如: public String set(String key, String value) public String set(byte[] key…

基于python协同过滤推荐算法的音乐推荐与管理系统

欢迎大家点赞、收藏、关注、评论啦 ,由于篇幅有限,只展示了部分核心代码。 文章目录 一项目简介 二、功能三、系统四. 总结 一项目简介 基于Python的协同过滤推荐算法的音乐推荐与管理系统是一个集成了音乐推荐和管理的系统,它使用协同过滤算…

Linux(8):BASH

硬件、核心与 Shell 操作系统其实是一组软件,由于这组软件在控制整个硬件与管理系统的活动监测,如果这组软件能被用户随意的操作,若使用者应用不当,将会使得整个系统崩溃。因为操作系统管理的就是整个硬件功能。 应用程序在最外层…

虚拟化逻辑架构: LBR 网桥基础管理

目录 一、理论 1.Linux Bridge 二、实验 1.LBR 网桥管理 三、问题 1.Linux虚拟交换机如何增删 一、理论 1.Linux Bridge Linux Bridge(网桥)是用纯软件实现的虚拟交换机,有着和物理交换机相同的功能,例如二层交换&#…

1.9 字符数组

1.9 字符数组 一、字符数组概述二、练习 一、字符数组概述 所谓字符数组&#xff0c;就是char类型的数组&#xff0c;比如 char a[]&#xff0c;是C语言中最常用的数组类型&#xff0c;先看一个程序 #include <stdio.h> #define MAXLINE 1000 //最大行长度限制 int get…

【赠书第8期】工程效能十日谈

文章目录 前言 1 工程效能十日谈 1.1 制定清晰的目标和计划 1.2 引入先进的技术和工具 1.3 建立有效的沟通机制 1.4 灵活应对变化 1.5 确保资源充足 1.6 进行有效的风险管理 1.7 进行持续的监控和评估 1.8 优化团队合作 1.9 注重质量管理 1.10 进行项目总结和反思 …

NX二次开发UF_CURVE_ask_offset_parms 函数介绍

文章作者&#xff1a;里海 来源网站&#xff1a;https://blog.csdn.net/WangPaiFeiXingYuan UF_CURVE_ask_offset_parms Defined in: uf_curve.h int UF_CURVE_ask_offset_parms(tag_t offset_curve_object, UF_CURVE_offset_data_p_t offset_data_pointer ) overview 概述 …

[pyqt5]pyqt5设置窗口背景图片后上面所有图片都会变成和背景图片一样

pyqt5的控件所有都是集成widget&#xff0c;窗体设置背景图片后控件背景也会跟着改变&#xff0c;此时有2个办法。第一个办法显然我们可以换成其他方式设置窗口背景图片&#xff0c;而不是使用styleSheet样式表&#xff0c;网上有很多其他方法。还有个办法就是仍然用styleSheet…