基于全连接神经网络模型的手写数字识别

基于全连接神经网络模型的手写数字识别

  • 一. 前言
  • 二. 设计目的及任务描述
    • 2.1 设计目的
    • 2.2 设计任务
  • 三. 神经网络模型
    • 3.1 全连接神经网络模型方案
    • 3.2 全连接神经网络模型训练过程
    • 3.3 全连接神经网络模型测试
  • 四. 程序设计

一. 前言

手写数字识别要求利用MNIST数据集里的70000张手写体数字的图像,建立神经网络模型,进行0到9的分类,并能够对其他来源的图片进行识别,识别准确率大于97%。图片示例如下。
alt

图1.1 mnist数据集图片示例

该设计要求学生基于TensorFlow深度学习平台,利用自动下载的MNIST数据集,建立全连接或者CNN神经网络模型,对MNIST或者其他图片中的数字进行正确识别。同时,在数据获取、处理和分析过程中考虑数据安全、技术经济、工程伦理、行业规范等要素。

通过该题目的训练,使学生对深度学习技术有一定的了解,掌握深度学习模型建立、训练、测试和调优的过程,理解监督学习、数据处理、神经网络、卷积计算等概念并通过实例进行实践,学习TensorFlow并搭建深度学习平台,加深学生对深度学习技术的理解和实际引用,并能够利用深度学习方法解决实际问题。

二. 设计目的及任务描述

2.1 设计目的

深入学习TensorFlow深度学习平台,通过构建全连接神经网络和卷积神经网络的手写数字识别模型,实现对MNIST数据集中的数字0到9的分类,并具备对其他来源的图片进行准确识别的能力,要求识别准确率大于97%。这一设计旨在深入理解深度学习技术,并掌握模型的建立、训练、测试和调优的全过程。

首先,进行文献资料查阅,至少阅读5篇相关文献,以确保对深度学习领域的最新进展有所了解。通过文献的学习,将为设计过程提供前沿的理论支持,在实践中融入最新的研究成果。

学习TensorFlow深度学习平台的搭建是课程设计的第二步,这将提供一个强大而灵活的工具,用以实现神经网络的建模和训练。通过掌握TensorFlow,学生将具备在深度学习领域进行实际工作的基本能力。

在全连接神经网络的学习中,理解神经网络的基本原理,包括监督学习、数据处理、损失率函数的构建方法等。通过构建手写数字识别模型,亲身经历模型训练、测试和调优的过程,深入理解各参数的作用及其对模型准确率的影响。

通过这个课程设计,不仅获得深度学习技术的实际应用经验,还将培养文献查阅、团队协作、数据伦理等方面的能力,为将来深入科研或产业实践打下坚实基础。

2.2 设计任务

  1. 查阅文献资料,一般在5篇以上;
  2. 学习TensorFlow深度学习平台的搭建。
  3. 学习全连接神经网络,建立全连接网络的手写数字识别模型,并进行模型训练、测试和调优。
  4. 理解学习率、衰减率等参数的作用。
  5. 理解监督学习的过程。
  6. 学习损失率函数构建方法。
  7. 经过模型调优,理解模型中各参数的作用以及影响模型准确率的因素。
  8. 模型识别准确率大于97%。
  9. 撰写课程设计说明书,须达到以下要求:
    (1) 陈述设计题目、设计任务;
    (2) 描述TensorFlow深度学习平台的搭建过程;
    (3) 写出全连接神经网络模型方案;
    (4) 记录全连接神经网络模型训练过程;
    (5) 记录全连接神经网络模型测试准确率;
    (6) 陈述模型调优过程,包括调优过程中遇到的主要问题,是如何解决的;对模型设计和编码的回顾、反思和体会等,与同学对问题的讨论、分析、改进设想以及收获等。同时,分析数据处理及分析过程中面临的数据安全、工程伦理等问题。

三. 神经网络模型

3.1 全连接神经网络模型方案

设计中使用的全连接神经网络模型采用了典型的多层感知器(Multi-Layer Perceptron,MLP)架构,旨在解决手写数字识别任务。模型的输入层与输出层之间,有两个隐藏层负责提取和学习输入图像的特征。

模型的输入层包含了784个节点,对应于MNIST数据集中的每个图像像素。这个输入层将图像展平为一维向量,使得神经网络能够处理每个像素的信息。第一个隐藏层包含512个节点,通过ReLU激活函数引入非线性特性,帮助网络学习复杂的特征和模式。第二个隐藏层也有512个节点,并同样使用ReLU激活函数。这两个隐藏层的存在增强了网络对抽象特征的学习能力。

最后,输出层包含10个节点,对应于手写数字的10个可能类别。使用softmax激活函数,输出层将模型的原始输出转换为概率分布,表示每个类别的概率。

在模型的编译阶段,采用了交叉熵作为损失函数,这是多类别分类问题中常用的损失函数。模型的优化器选择了Adam,这是一种自适应学习率的优化算法。为了评估模型性能,选择了准确率作为指标,它度量了模型在训练和测试数据上的分类准确性。

3.2 全连接神经网络模型训练过程

训练过程是深度学习中至关重要的一部分,通过多次迭代优化模型参数,使其能够更好地适应训练数据。在这个训练过程中,采用了全连接神经网络模型,旨在实现手写数字的准确识别。

加载并预处理了MNIST数据集,将图像数据归一化到 [0, 1] 的范围,并进行了独热编码以适应模型的训练需求。构建了一个具有两个隐藏层的全连接神经网络模型,其中包含了512个节点,并使用ReLU激活函数,最终输出层具有10个节点,使用softmax激活函数进行多类别分类。

然后,对模型进行了编译,选择了交叉熵作为损失函数和Adam作为优化器。为了更充分地训练模型,将训练轮数设置为5。每次训练迭代,模型根据梯度下降的原理,不断更新权重和偏差,以最小化损失函数。

训练过程的 fit 函数的参数中,verbose=1表示在训练过程中输出详细信息,包括每个epoch的损失和准确率。模型的性能将在整个训练过程中逐渐提升,反映出它对训练数据的更好拟合能力。在迭代的过程中,我期望看到损失降低,而训练和验证准确率逐步提高。

通过增加训练轮数,提高模型学习的迭代次数,有望取得更好的性能和更强的泛化能力,使模型在未见过的数据上表现出色。
在这里插入图片描述

图 3-1 全连接神经网络_训练结果
如图3-1所示,通过5次训练模型的准确度达到97%。

3.3 全连接神经网络模型测试

使用 Keras 模型的 evaluate 方法在测试集上进行评估。

在这里插入图片描述

图 3-2 全连接神经网络_测试结果
经测试,如图3-2所示,模型准确度为97.66%。

四. 程序设计

import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Flatten, Conv2D, MaxPooling2D
from tensorflow.keras.datasets import mnist
from tensorflow.keras.optimizers import Adamdef load_and_preprocess_data():# 加载并预处理MNIST数据集(x_train, y_train), (x_test, y_test) = mnist.load_data()# 重塑和手动归一化数据x_train = x_train.reshape((x_train.shape[0], 28, 28, 1)).astype('float32') / 255.0x_test = x_test.reshape((x_test.shape[0], 28, 28, 1)).astype('float32') / 255.0# 对标签进行多分类编码num_categories = 10y_train = tf.keras.utils.to_categorical(y_train, num_categories)y_test = tf.keras.utils.to_categorical(y_test, num_categories)return x_train, y_train, x_test, y_testdef build_model_Fully_connected():# 构建全连接神经网络模型model = Sequential()model.add(Conv2D(32, kernel_size=(3, 3), activation='relu', input_shape=(28, 28, 1)))model.add(MaxPooling2D(pool_size=(2, 2)))model.add(Flatten())model.add(Dense(units=512, activation='relu'))model.add(Dense(units=512, activation='relu'))model.add(Dense(units=10, activation='softmax'))model.summary()return modeldef compile_and_train_model(model, x_train, y_train, x_test, y_test):# 编译并训练模型optimizer = Adam(learning_rate=0.0001)model.compile(optimizer='Adam', loss='categorical_crossentropy', metrics=['accuracy'])history = model.fit(x_train, y_train, epochs=10, verbose=1, validation_data=(x_test, y_test))return historyif __name__ == "__main__":# 加载并预处理数据x_train, y_train, x_test, y_test = load_and_preprocess_data()# 构建全连接神经网络模型model = build_model_Fully_connected()# 编译并训练模型history = compile_and_train_model(model, x_train, y_train, x_test, y_test)# 保存训练模型model.save("mnist_dnn_model.h5", include_optimizer=True)print("Model saved successfully.")

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/460258.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

分享86个行业PPT,总有一款适合您

分享86个行业PPT,总有一款适合您 86个行业PPT下载链接:https://pan.baidu.com/s/1avbzwqK8ILLWYIOylK1aRQ?pwd8888 提取码:8888 Python采集代码下载链接:采集代码.zip - 蓝奏云 学习知识费力气,收集整理更不易…

QML中常见热区及层级结构

目录 引言层级结构默认层级结构z值作用范围遮罩实现-1的作用 热区嵌套与普通元素与其他热区与Flickable 事件透传总结 引言 热区有很多种,诸如MouseArea、DropArea、PinchArea等等,基本都是拦截对应的事件,允许开发者在事件函数对事件进行响…

有道ai写作,突破免费限制,无限制使用

预览效果 文末提供源码包及apk下载地址有道ai写作python版 import hashlib import time import json import ssl import base64 import uuidfrom urllib.parse import quote import requests from requests_toolbelt.multipart.encoder import MultipartEncoder from Crypto.C…

C语言笔试题之实现C库函数 strstr()(设置标志位)

实例要求: 1、请你实现C库函数strstr()(stdio.h & string.h),请在 haystack 字符串中找出 needle 字符串的第一个匹配项的下标(下标从 0 开始);2、函数声明:int strStr(char* h…

jvm垃圾收集器之七种武器

目录 1.回收算法 1.1 标记-清除算法(Mark-Sweep) 1.2 复制算法(Copying) 1.3 标记-整理算法(Mark-Compact) 2.HotSpot虚拟机的垃圾收集器 2.1 新生代的收集器 Serial 收集器(复制算法) ParNew 收集器 (复制算法) Parallel Scavenge 收集器 (复制…

【递归】【前序中序后序遍历】【递归调用栈空间与二叉树深度有关】【斐波那契数】Leetcode 94 144 145

【递归】【前序中序后序遍历】【递归调用栈空间与二叉树深度有关】Leetcode 94 144 145 1.前序遍历(递归) preorder2.中序遍历(递归)inorder3.后序遍历(递归)postorder4. 斐波那契数 ---------------&…

Zoho Mail企业邮箱商业扩展第3部分:计算财务状况

在Zoho Mail商业扩展系列的压轴篇章中,王雪琳利用Zoho Mail的集成功能成功地完成了各项工作,并顺利地建立了自己的营销代理机构。让我们快速回顾一下她的成功之路。 一、使用Zoho Mail成功方法概述 首先她通过Zoho Mail为其电子邮件地址设置了自定义域…

【多模态大模型】GLIP:零样本学习 + 目标检测 + 视觉语言大模型

GLIP 核心思想GLIP 对比 BLIP、BLIP-2、CLIP 主要问题: 如何构建一个能够在不同任务和领域中以零样本或少样本方式无缝迁移的预训练模型?统一的短语定位损失语言意识的深度融合预训练数据类型的结合语义丰富数据的扩展零样本和少样本迁移学习 效果 论文:…

软件应用实例分享,电玩计时计费怎么算,佳易王PS5游戏计时器系统程序教程

软件应用实例分享,电玩计时计费怎么算,佳易王PS5游戏计时器系统程序教程 一、前言 以下软件教程以 佳易王电玩计时计费管理系统软件V17.9为例说明 软件文件下载可以点击最下方官网卡片——软件下载——试用版软件下载 点击开始计时后,图片…

【STC8A8K64D4开发板】第2-11讲:模数转换ADC

第2-11讲:模数转换ADC 学习目的 1. 了解ADC的基本概念:分辨率、精度等。 2. 掌握STC8A8K64D4单片机ADC的配置、采样数据计算为实际电压值的方法。 3. 掌握ADC多通道采样。 ADC基本概念 实际应用中,我们经常需要将模拟量转换为数字量供CPU…

HCIA-HarmonyOS设备开发认证V2.0-3.轻量系统内核基础

目录 一、前言二、LiteOS-M系统概述三、内核框架3.1、CMSIS 和 POSIX 整体架构3.2、LiteOS-M内核启动流程 四、内核基础4.1、任务管理4.2、时间管理(待续)4.3、中断管理(待续)4.4、软件定时器(待续) 五、内存管理5.1、静态内存(待续)5.2、动态内存(待续) 六、内核通信机制6.1、…

【制作100个unity游戏之24】unity制作一个3D动物AI生态系统游戏2(附项目源码)

最终效果 文章目录 最终效果系列目录前言添加捕食者动画控制源码完结 系列目录 前言 欢迎来到【制作100个Unity游戏】系列!本系列将引导您一步步学习如何使用Unity开发各种类型的游戏。在这第24篇中,我们将探索如何用unity制作一个3D动物AI生态系统游戏…