机器学习11-前馈神经网络识别手写数字1.0

在这个示例中,使用的神经网络是一个简单的全连接前馈神经网络,也称为多层感知器(Multilayer Perceptron,MLP)。这个神经网络由几个关键组件构成:

1. 输入层
输入层接收输入数据,这里是一个 28x28 的灰度图像,每个像素值表示图像中的亮度值。

2. Flatten 层
Flatten 层用于将输入数据展平为一维向量,以便传递给后续的全连接层。在这里,我们将 28x28 的图像展平为一个长度为 784 的向量。

3. 全连接层(Dense 层)
全连接层是神经网络中最常见的层之一,每个神经元与上一层的每个神经元都连接。在这里,我们有一个包含 128 个神经元的隐藏层,以及一个包含 10 个神经元的输出层。隐藏层使用 ReLU(Rectified Linear Unit)激活函数,输出层使用 softmax 激活函数。

4. 输出层
输出层产生神经网络的输出,这里是一个包含 10 个元素的向量,每个元素表示对应类别的概率。softmax 函数用于将网络的原始输出转换为概率分布。

5. 编译模型
在编译模型时,我们指定了优化器(optimizer)和损失函数(loss function)。在这里,我们使用 Adam 优化器和稀疏分类交叉熵损失函数。

6. 训练模型
使用训练数据集对模型进行训练,以学习如何将输入映射到正确的输出。在训练过程中,模型通过优化损失函数来调整权重和偏置,使其尽可能准确地预测输出。

总的来说,这个神经网络是一个经典的多层感知器(MLP),它在输入层和输出层之间包含一个或多个隐藏层,通过学习逐步提取和组合特征来进行分类或回归任务。

代码:

import tensorflow as tf
from tensorflow.keras.datasets import mnist
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Flatten# 加载 MNIST 数据集
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()# 数据预处理
train_images = train_images / 255.0
test_images = test_images / 255.0# 构建神经网络模型
model = Sequential([Flatten(input_shape=(28, 28)),Dense(128, activation='relu'),Dense(10, activation='softmax')
])# 编译模型
model.compile(optimizer='adam',loss='sparse_categorical_crossentropy',metrics=['accuracy'])# 训练模型
model.fit(train_images, train_labels, epochs=5)# 评估模型
test_loss, test_acc = model.evaluate(test_images, test_labels)
print('Test accuracy:', test_acc)# 保存模型
model.save('mnist_model.h5')# 加载模型
loaded_model = tf.keras.models.load_model('mnist_model.h5')# 使用加载的模型进行预测
predictions = loaded_model.predict(test_images)

结果:

Epoch 1/5
1875/1875 [==============================] - 4s 2ms/step - loss: 0.2586 - accuracy: 0.9265
Epoch 2/5
1875/1875 [==============================] - 3s 2ms/step - loss: 0.1136 - accuracy: 0.9656
Epoch 3/5
1875/1875 [==============================] - 3s 2ms/step - loss: 0.0773 - accuracy: 0.9768
Epoch 4/5
1875/1875 [==============================] - 3s 2ms/step - loss: 0.0587 - accuracy: 0.9823
Epoch 5/5
1875/1875 [==============================] - 4s 2ms/step - loss: 0.0462 - accuracy: 0.9855
313/313 [==============================] - 0s 1ms/step - loss: 0.0750 - accuracy: 0.9775

Test accuracy: 0.9775000214576721

识别准确率挺高,然后我们也得到了训练好的模型

应用测试:

import tensorflow as tf
import numpy as np
from PIL import Image# 加载保存的模型
loaded_model = tf.keras.models.load_model('mnist_model.h5')# 打开手写图片文件
image_path = 'pic/handwritten_digit_thick_5.png'  # 修改为你的手写图片文件路径
image = Image.open(image_path).convert('L')  # 转换为灰度图像# 调整图片大小为 28x28 像素
image = image.resize((28, 28))# 将图片转换为 NumPy 数组并进行归一化处理
image_array = np.array(image) / 255.0# 将图片转换为模型输入的格式(添加批次维度)
input_image = np.expand_dims(image_array, axis=0)# 使用模型进行预测
predictions = loaded_model.predict(input_image)# 获取预测结果(最大概率的类别)
predicted_class = np.argmax(predictions)print('Predicted digit:', predicted_class)

准备了4张图片,3张自己手写,1张摘自minst:

前两张画笔比较细,第三张是minst的5,第四张是用了粗笔自己写的5,最终结果是就minst预测对了。

Predicted digit: 2

Predicted digit: 8

Predicted digit: 5

Predicted digit: 3


结论:

可见这个模型的扩展适应性能还是不够,只能预测正确训练过的minst数字。

改进:

想办法提升训练的质量,让预测能力达标

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/461752.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

基于华为云欧拉操作系统(HCE OS)容器化部署传统应用(Redis+Postgresql+Git+SpringBoot+Nginx)

写在前面 博文内容为 华为云欧拉操作系统入门级开发者认证(HCCDA – Huawei Cloud EulerOS)实验笔记整理认证地址:https://edu.huaweicloud.com/certificationindex/developer/9bf91efb086a448ab4331a2f53a4d3a1博文内容涉及一个传统 Springboot 应用HCE部署&#x…

stm32软件安装以及创建工程

文章目录 前言一、软件安装软件破解 二、创建工程三、创建项目创建组配置启动文件添加到组 为项目添加头文件路径创建源文件(main函数文件)使用寄存器配置引脚拼接好STLINK与stm32最小电路板的接线编写程序配置STLink下载程序配置寄存器配置13号端口&…

奋斗与诗意的三纲八目

人生得有一个基调、总的宗旨、指导思想、根据、根本。当人做出一个重大决定时,绝非偶然,一定是背后的宗旨在起作用。你每天起床的动力,是否能热情洋溢地做事,也是这个宗旨在起作用。念天地之悠悠独怆然而涕下,忧思难忘…

云计算运营模式介绍

目录 一、云计算运营模式概述 1.1 概述 二、云计算服务角色 2.1 角色划分 2.1.1 云服务提供商 2.1.2 云服务消费者 2.1.3 云服务代理商 2.1.4 云计算审计员 2.1.5 云服务承运商 三、云计算责任模型 3.1 云计算服务模式与责任关系图 3.2 云计算服务模式与责任关系解析…

vue3 之 Pinia数据持久化

持久化用户数据说明 1️⃣用户数据中有一个关键的数据叫做token(用来标识当前用户是否登陆),而token持续一段时间才会过期 2️⃣Pinia的存储是基于内存,刷新就丢失,为了保持登陆状态就要做到刷新不丢失,需要…

肯尼斯·里科《C和指针》第13章 高级指针话题(1)进一步探讨指向指针的指针变量的高级声明

13.1 进一步探讨指向指针的指针 上一章使用了指向指针的指针,用于简化向单链表插入新值的函数。另外还存在许多领域,指向指针的指针可以在其中发挥重要的作用。这里有一个通用的例子: 这些声明在内存中创建了下列变量。如果它们是自动变量&am…

问题:必须坚持以中国式现代化推进中华民族伟大复兴,既不走封闭僵化的老路,也不走 #媒体#知识分享

问题:必须坚持以中国式现代化推进中华民族伟大复兴,既不走封闭僵化的老路,也不走 A、中国特色社会主义道路 B、改革开放之路 C、改旗易帜的邪路 D、中国式现代化之路 参考答案如图所示

Linux 36.2@Jetson Orin Nano基础环境构建

Linux 36.2Jetson Orin Nano基础环境构建 1. 源由2. 步骤2.1 安装NVIDIA Jetson Linux 36.2系统2.2 必备软件安装2.3 基本远程环境2.3.1 远程ssh登录2.3.2 samba局域网2.3.3 VNC远程登录 2.4 开发环境安装 3. 总结 1. 源由 现在流行什么,也跟风来么一个一篇。当然&…

RabbitMQ的延迟队列实现[死信队列](笔记一)

关于死信队列的使用场景不再强调,只针对服务端配置 注意: 本文只针对实现死信队列的rabbitMQ基本配置步骤进行阐述和实现 目录 1、docker-compose 安装rabbitMq2、查看对应的版本及插件下载3、安装插件和检测 1、docker-compose 安装rabbitMq a、使用d…

IT行业有哪些证书含金量高呢?

目录 引言: 一、 计算机网络类证书 二、 数据库管理类证书 三、 安全与信息技术管理类证书 四、 编程与开发类证书 五、 数据科学与人工智能类证书 六、结论: 悟已往之不谏,知来者犹可追 …

备战蓝桥杯---动态规划(基础3)

本专题主要介绍在求序列的经典问题上dp的应用。 我们上次用前缀和来解决,这次让我们用dp解决把 我们参考不下降子序列的思路,可以令f[i]为以i结尾的最大字段和,易得: f[i]max(a[i],a[i]f[i-1]); 下面是AC代码: #in…

猫头虎分享已解决Bug || JavaScript语法错误(Syntax Error):SyntaxError: Unexpected token

博主猫头虎的技术世界 🌟 欢迎来到猫头虎的博客 — 探索技术的无限可能! 专栏链接: 🔗 精选专栏: 《面试题大全》 — 面试准备的宝典!《IDEA开发秘籍》 — 提升你的IDEA技能!《100天精通鸿蒙》 …