基于Pytorch深度学习神经网络MNIST手写数字识别系统源码(带界面和手写画板)

 第一步:准备数据

mnist开源数据集

第二步:搭建模型

我们这里搭建了一个LeNet5网络

参考代码如下:

import torch
from torch import nnclass Reshape(nn.Module):def forward(self, x):return x.view(-1, 1, 28, 28)class LeNet5(nn.Module):def __init__(self):super(LeNet5, self).__init__()self.net = nn.Sequential(Reshape(),# CONV1, ReLU1, POOL1nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5, padding=2),# nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5),nn.ReLU(),nn.MaxPool2d(kernel_size=2, stride=2),# CONV2, ReLU2, POOL2nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5),nn.ReLU(),nn.MaxPool2d(kernel_size=2, stride=2),nn.Flatten(),# FC1nn.Linear(in_features=16 * 5 * 5, out_features=120),nn.ReLU(),# FC2nn.Linear(in_features=120, out_features=84),nn.ReLU(),# FC3nn.Linear(in_features=84, out_features=10))# 添加softmax层self.softmax = nn.Softmax()def forward(self, x):logits = self.net(x)# 将logits转为概率prob = self.softmax(logits)return probif __name__ == '__main__':model = LeNet5()X = torch.rand(size=(256, 1, 28, 28), dtype=torch.float32)for layer in model.net:X = layer(X)print(layer.__class__.__name__, '\toutput shape: \t', X.shape)X = torch.rand(size=(1, 1, 28, 28), dtype=torch.float32)print(model(X))

第三步:训练代码

import torch
from torch import nn
from torchvision import datasets
from torchvision.transforms import ToTensor
from torch.utils.data import DataLoaderfrom model import LeNet5# DATASET
train_data = datasets.MNIST(root='./data',train=False,download=True,transform=ToTensor()
)test_data = datasets.MNIST(root='./data',train=False,download=True,transform=ToTensor()
)# PREPROCESS
batch_size = 256
train_dataloader = DataLoader(dataset=train_data, batch_size=batch_size)
test_dataloader = DataLoader(dataset=test_data, batch_size=batch_size)
for X, y in train_dataloader:print(X.shape)		# torch.Size([256, 1, 28, 28])print(y.shape)		# torch.Size([256])break# MODEL
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = LeNet5().to(device)# TRAIN MODEL
loss_func = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(params=model.parameters())def train(dataloader, model, loss_func, optimizer, epoch):model.train()data_size = len(dataloader.dataset)for batch, (X, y) in enumerate(dataloader):X, y = X.to(device), y.to(device)y_hat = model(X)loss = loss_func(y_hat, y)optimizer.zero_grad()loss.backward()optimizer.step()loss, current = loss.item(), batch * len(X)print(f'EPOCH{epoch+1}\tloss: {loss:>7f}', end='\t')# Test model
def test(dataloader, model, loss_fn):size = len(dataloader.dataset)num_batches = len(dataloader)model.eval()test_loss, correct = 0, 0with torch.no_grad():for X, y in dataloader:X, y = X.to(device), y.to(device)pred = model(X)test_loss += loss_fn(pred, y).item()correct += (pred.argmax(1) == y).type(torch.float).sum().item()test_loss /= num_batchescorrect /= sizeprint(f'Test Error: Accuracy: {(100 * correct):>0.1f}%, Average loss: {test_loss:>8f}\n')if __name__ == '__main__':epoches = 80for epoch in range(epoches):train(train_dataloader, model, loss_func, optimizer, epoch)test(test_dataloader, model, loss_func)# Save modelstorch.save(model.state_dict(), 'model.pth')print('Saved PyTorch LeNet5 State to model.pth')

第四步:统计训练过程

EPOCH1	loss: 1.908403	Test Error: Accuracy: 58.3%, Average loss: 1.943602EPOCH2	loss: 1.776060	Test Error: Accuracy: 72.2%, Average loss: 1.750917EPOCH3	loss: 1.717706	Test Error: Accuracy: 73.6%, Average loss: 1.730332EPOCH4	loss: 1.719344	Test Error: Accuracy: 76.0%, Average loss: 1.703456EPOCH5	loss: 1.659312	Test Error: Accuracy: 76.6%, Average loss: 1.694500EPOCH6	loss: 1.647946	Test Error: Accuracy: 76.9%, Average loss: 1.691286EPOCH7	loss: 1.653712	Test Error: Accuracy: 77.0%, Average loss: 1.690819EPOCH8	loss: 1.653270	Test Error: Accuracy: 76.8%, Average loss: 1.692459EPOCH9	loss: 1.649021	Test Error: Accuracy: 77.5%, Average loss: 1.686158EPOCH10	loss: 1.648204	Test Error: Accuracy: 78.3%, Average loss: 1.678802EPOCH11	loss: 1.647159	Test Error: Accuracy: 78.4%, Average loss: 1.676133EPOCH12	loss: 1.647390	Test Error: Accuracy: 78.6%, Average loss: 1.674455EPOCH13	loss: 1.646807	Test Error: Accuracy: 78.4%, Average loss: 1.675752EPOCH14	loss: 1.630824	Test Error: Accuracy: 79.1%, Average loss: 1.668470EPOCH15	loss: 1.524222	Test Error: Accuracy: 86.3%, Average loss: 1.599240EPOCH16	loss: 1.524022	Test Error: Accuracy: 86.7%, Average loss: 1.594947EPOCH17	loss: 1.524296	Test Error: Accuracy: 87.1%, Average loss: 1.588946EPOCH18	loss: 1.523599	Test Error: Accuracy: 87.3%, Average loss: 1.588275EPOCH19	loss: 1.523655	Test Error: Accuracy: 87.5%, Average loss: 1.586576EPOCH20	loss: 1.523659	Test Error: Accuracy: 88.2%, Average loss: 1.579286EPOCH21	loss: 1.523733	Test Error: Accuracy: 87.9%, Average loss: 1.582472EPOCH22	loss: 1.523748	Test Error: Accuracy: 88.2%, Average loss: 1.578699EPOCH23	loss: 1.523788	Test Error: Accuracy: 88.0%, Average loss: 1.579700EPOCH24	loss: 1.523708	Test Error: Accuracy: 88.1%, Average loss: 1.579758EPOCH25	loss: 1.523683	Test Error: Accuracy: 88.4%, Average loss: 1.575913EPOCH26	loss: 1.523646	Test Error: Accuracy: 88.7%, Average loss: 1.572831EPOCH27	loss: 1.523654	Test Error: Accuracy: 88.9%, Average loss: 1.570528EPOCH28	loss: 1.523642	Test Error: Accuracy: 89.0%, Average loss: 1.570223EPOCH29	loss: 1.523663	Test Error: Accuracy: 89.0%, Average loss: 1.570385EPOCH30	loss: 1.523658	Test Error: Accuracy: 88.9%, Average loss: 1.571195EPOCH31	loss: 1.523653	Test Error: Accuracy: 88.4%, Average loss: 1.575981EPOCH32	loss: 1.523653	Test Error: Accuracy: 89.0%, Average loss: 1.570087EPOCH33	loss: 1.523642	Test Error: Accuracy: 88.9%, Average loss: 1.571018EPOCH34	loss: 1.523649	Test Error: Accuracy: 89.0%, Average loss: 1.570439EPOCH35	loss: 1.523629	Test Error: Accuracy: 90.4%, Average loss: 1.555473EPOCH36	loss: 1.461187	Test Error: Accuracy: 97.1%, Average loss: 1.491042EPOCH37	loss: 1.461230	Test Error: Accuracy: 97.7%, Average loss: 1.485049EPOCH38	loss: 1.461184	Test Error: Accuracy: 97.7%, Average loss: 1.485653EPOCH39	loss: 1.461156	Test Error: Accuracy: 98.2%, Average loss: 1.479966EPOCH40	loss: 1.461335	Test Error: Accuracy: 98.2%, Average loss: 1.479197EPOCH41	loss: 1.461152	Test Error: Accuracy: 98.7%, Average loss: 1.475477EPOCH42	loss: 1.461153	Test Error: Accuracy: 98.7%, Average loss: 1.475124EPOCH43	loss: 1.461153	Test Error: Accuracy: 98.9%, Average loss: 1.472885EPOCH44	loss: 1.461151	Test Error: Accuracy: 99.1%, Average loss: 1.470957EPOCH45	loss: 1.461156	Test Error: Accuracy: 99.1%, Average loss: 1.471141EPOCH46	loss: 1.461152	Test Error: Accuracy: 99.1%, Average loss: 1.470793EPOCH47	loss: 1.461151	Test Error: Accuracy: 98.8%, Average loss: 1.474548EPOCH48	loss: 1.461151	Test Error: Accuracy: 99.1%, Average loss: 1.470666EPOCH49	loss: 1.461151	Test Error: Accuracy: 99.1%, Average loss: 1.471546EPOCH50	loss: 1.461151	Test Error: Accuracy: 99.0%, Average loss: 1.471407EPOCH51	loss: 1.461151	Test Error: Accuracy: 98.8%, Average loss: 1.473795EPOCH52	loss: 1.461164	Test Error: Accuracy: 98.2%, Average loss: 1.480009EPOCH53	loss: 1.461151	Test Error: Accuracy: 99.2%, Average loss: 1.469931EPOCH54	loss: 1.461152	Test Error: Accuracy: 99.2%, Average loss: 1.469916EPOCH55	loss: 1.461151	Test Error: Accuracy: 98.9%, Average loss: 1.472574EPOCH56	loss: 1.461151	Test Error: Accuracy: 98.6%, Average loss: 1.476035EPOCH57	loss: 1.461151	Test Error: Accuracy: 98.2%, Average loss: 1.478933EPOCH58	loss: 1.461150	Test Error: Accuracy: 99.4%, Average loss: 1.468186EPOCH59	loss: 1.461151	Test Error: Accuracy: 99.4%, Average loss: 1.467602EPOCH60	loss: 1.461151	Test Error: Accuracy: 99.1%, Average loss: 1.471206EPOCH61	loss: 1.461151	Test Error: Accuracy: 98.8%, Average loss: 1.473356EPOCH62	loss: 1.461151	Test Error: Accuracy: 99.2%, Average loss: 1.470242EPOCH63	loss: 1.461150	Test Error: Accuracy: 99.1%, Average loss: 1.470826EPOCH64	loss: 1.461151	Test Error: Accuracy: 98.7%, Average loss: 1.474476EPOCH65	loss: 1.461150	Test Error: Accuracy: 99.3%, Average loss: 1.469116EPOCH66	loss: 1.461150	Test Error: Accuracy: 99.4%, Average loss: 1.467823EPOCH67	loss: 1.461150	Test Error: Accuracy: 99.5%, Average loss: 1.466486EPOCH68	loss: 1.461152	Test Error: Accuracy: 99.3%, Average loss: 1.468688EPOCH69	loss: 1.461150	Test Error: Accuracy: 99.5%, Average loss: 1.466256EPOCH70	loss: 1.461150	Test Error: Accuracy: 99.5%, Average loss: 1.466588EPOCH71	loss: 1.461150	Test Error: Accuracy: 99.6%, Average loss: 1.465280EPOCH72	loss: 1.461150	Test Error: Accuracy: 99.4%, Average loss: 1.467110EPOCH73	loss: 1.461151	Test Error: Accuracy: 99.6%, Average loss: 1.465245EPOCH74	loss: 1.461150	Test Error: Accuracy: 99.5%, Average loss: 1.466551EPOCH75	loss: 1.461150	Test Error: Accuracy: 99.5%, Average loss: 1.466001EPOCH76	loss: 1.461150	Test Error: Accuracy: 99.3%, Average loss: 1.468074EPOCH77	loss: 1.461151	Test Error: Accuracy: 99.6%, Average loss: 1.465709EPOCH78	loss: 1.461150	Test Error: Accuracy: 99.5%, Average loss: 1.466567EPOCH79	loss: 1.461150	Test Error: Accuracy: 99.6%, Average loss: 1.464922EPOCH80	loss: 1.461150	Test Error: Accuracy: 99.6%, Average loss: 1.465109

第五步:搭建GUI界面

第六步:整个工程的内容

有训练代码和训练好的模型以及训练过程,提供数据,提供GUI界面代码,主要使用方法可以参考里面的“文档说明_必看.docx”

 代码的下载路径(新窗口打开链接)基于Pytorch深度学习神经网络MNIST手写数字识别系统源码(带界面和手写画板)

有问题可以私信或者留言,有问必答

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/704690.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Lora训练Windows[笔记]

一. 使用kohya_ss的GUI版本(https://github.com/bmaltais/kohya_ss.git) 这个版本跟stable-diffusion-webui的界面很像,只不过是训练模型专用而已,打开的端口同样是7860。 1.双击setup.bat,选择1安装好xformers,pytorch等和cuda…

Qt多文档程序的一种实现

注&#xff1a;文中所列代码质量不高&#xff0c;但不影响演示我的思路 实现思路说明 实现DemoApplication 相当于MFC中CWinAppEx的派生类&#xff0c;暂时没加什么功能。 DemoApplication.h #pragma once#include <QtWidgets/QApplication>//相当于MFC中CWinAppEx的派生…

详细分析Vue3中的reactive(附Demo)

目录 1. 基本知识2. 用法3. Demo 1. 基本知识 reactive 是一个函数&#xff0c;用于将一个普通的 JavaScript 对象转换为响应式对象 当对象的属性发生变化时&#xff0c;Vue 会自动追踪这些变化&#xff0c;并触发相应的更新 Vue2没有&#xff0c;而Vue3中有&#xff0c;为啥…

Springboot+MybatisPlus如何实现带验证码的登录功能

实现带验证码的登录功能由两部分组成&#xff1a;&#xff1a;1、验证码的获取 2、登录&#xff08;进行用户名、密码和验证码的判断&#xff09; 获取验证码 获取验证码需要使用HuTool中的CaptchaUtil.createLineCaptcha()来定义验证码的长度、宽度、验证码位数以及干扰线…

【2024华为HCIP831 | 高级网络工程师之路】刷题日记(18)

个人名片&#xff1a;&#x1faaa; &#x1f43c;作者简介&#xff1a;一名大三在校生&#xff0c;喜欢AI编程&#x1f38b; &#x1f43b;‍❄️个人主页&#x1f947;&#xff1a;落798. &#x1f43c;个人WeChat&#xff1a;hmmwx53 &#x1f54a;️系列专栏&#xff1a;&a…

uniapp获取当前位置及检测授权状态——支持App、微信小程序

uniapp获取当前位置检测及定位权限——支持App、微信小程序 首先&#xff0c;祝天下母亲&#xff0c;节日快乐~ 文章目录 uniapp获取当前位置检测及定位权限——支持App、微信小程序效果图新增 兼容小程序方法manifest Tips&#xff1a; 上一篇介绍 App端 uniapp获取当前位置及…

C++ requires关键字简介

requires 是 C20 中引入的一个新关键字&#xff0c;用于在函数模板或类模板中声明所需的一组语义要求&#xff0c;它可以用来限制模板参数&#xff0c;类似于 typename 和 class 关键字。 requires关键字常与type_traits头文件下类型检查函数匹配使用&#xff0c;当requires后…

React渲染流程

在 React 渲染分为两个阶段&#xff0c;Render 和 Commit&#xff0c;Render 是修改 React 组件的状态&#xff0c;把需要更新的组件标记为待更新&#xff0c;在 Commit 阶段将待更新的组件进行渲染并最终更新到浏览器的 Dom 树中。 Render 阶段是可以并执行操作的&#xff0c…

【Image captioning】基于检测模型网格特征提取——以Sydeny为例

【Image captioning】基于检测模型网格特征提取——以Sydeny为例 今天,我们将重点探讨如何利用Faster R-CNN检测模型来提取Sydeny数据集的网格特征。具体而言,这一过程涉及通过Faster R-CNN模型对图像进行分析,进而抽取出关键区域的特征信息,这些特征在网格结构中被系统地…

【考研数学】准备开强化,更「张宇」还是「武忠祥」?

数一125学长前来回答&#xff0c;选择哪位老师的课程&#xff0c;这通常取决于你的个人偏好和学习风格&#xff01; 张宇老师和武忠祥老师都是非常有经验的数学老师&#xff0c;他们的教学方法各有特点。 张宇老师的教学风格通常被认为是通俗易懂&#xff0c;善于将复杂的概念…

国际化日期(inti)

我们可以使用国际化API自动的格式化数字或者日期&#xff0c;并且格式化日期或数字的时候是按照各个国家的习惯来进行格式化的&#xff0c;非常的简单&#xff1b; const now new Date(); labelDate.textContent new Intl.DateTimeFormat(zh-CN).format(now);比如说这是按照…

使用RN的kitten框架的日历组件的修改

官方网页地址 下面就是我参考官方封装的时间日期组件&#xff08;主要是功能和使用方法&#xff0c;页面粗略做了下&#xff0c;不好看勿怪&#xff09; import React, {useState} from react; import {StyleSheet, View, TouchableOpacity, SafeAreaView} from react-native; …