基于卷积神经网络实现手写数字识别

基于卷积神经网络实现手写数字识别

基于卷积神经网络实现手写数字识别。具体过程如下:

(1) 定义ConvNet结构类及其前向传播方式

(2) 设置超参数以及导入相关的包。

(3) 定义训练网络函数和绘图函数,并在main函数中完成调用过程

程序
import os 
import numpy as np 
#from sklearn.datasets import fetch_openml # 引入openml数据源
from matplotlib import pyplot as plt # 引入绘图工具
import torch
from torchvision.datasets import mnist
#from mnist_models import AlexNet, ConvNet
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torch.autograd import VariableBASE_PATH = os.path.dirname(__file__)# 设置模型超参数
EPOCHS = 50
SAVE_PATH = './models''''
# 载入MNIST数据集并显示部分样本
def load_mnist():# 从openml源载入MNIST数据集mnist = fetch_openml('mnist_784', version=1, data_home=os.path.join(BASE_PATH, './dataset'))X, y = mnist['data'], mnist['target']#X = mnist['data']#.astype(np.float32)#y = mnist['target']#.astype(np.int32)print('MNIST数据集大小:{}'.format(X.shape))# 显示其中25张样本图片for i in range(25):#print(i)digit = X.iloc[i * 2500]# 将图片恢复到28*28大小digit_image = digit.values.reshape(28, 28)# 绘制图片plt.subplot(5, 5, i + 1)# 隐藏坐标轴plt.axis('off')# 按灰度图绘制图片plt.imshow(digit_image, cmap='gray')# 显示图片plt.show()return X, y
'''# 定义卷积网络结构
class ConvNet(torch.nn.Module):def __init__(self):super(ConvNet, self).__init__()self.conv1 = torch.nn.Sequential(torch.nn.Conv2d(1, 10, 5, 1, 1),torch.nn.MaxPool2d(2),torch.nn.ReLU(),torch.nn.BatchNorm2d(10))self.conv2 = torch.nn.Sequential(torch.nn.Conv2d(10, 20, 5, 1, 1),torch.nn.MaxPool2d(2),torch.nn.ReLU(),torch.nn.BatchNorm2d(20))self.fc1 = torch.nn.Sequential(torch.nn.Linear(500, 60),torch.nn.Dropout(0.5),torch.nn.ReLU())self.fc2 = torch.nn.Sequential(torch.nn.Linear(60, 20),torch.nn.Dropout(0.5),torch.nn.ReLU())self.fc3 = torch.nn.Linear(20, 10)# 定义网络前向传播方式def forward(self, x):x = self.conv1(x)x = self.conv2(x)x = x.view(-1, 500)x = self.fc1(x)x = self.fc2(x)x = self.fc3(x)return x# 定义AlexNet结构
class AlexNet(torch.nn.Module):def __init__(self, num_classes=10):super(AlexNet, self).__init__()self.features = torch.nn.Sequential(torch.nn.Conv2d(1, 64, kernel_size=5, stride=1, padding=2),torch.nn.ReLU(inplace=True),torch.nn.MaxPool2d(kernel_size=3, stride=1),torch.nn.Conv2d(64, 192, kernel_size=3, padding=2),torch.nn.ReLU(inplace=True),torch.nn.MaxPool2d(kernel_size=3, stride=2),torch.nn.Conv2d(192, 384, kernel_size=3, padding=1),torch.nn.ReLU(inplace=True),torch.nn.Conv2d(384, 256, kernel_size=3, padding=1),torch.nn.ReLU(inplace=True),torch.nn.Conv2d(256, 256, kernel_size=3, padding=1),torch.nn.ReLU(inplace=True),torch.nn.MaxPool2d(kernel_size=3, stride=2))self.classifier = torch.nn.Sequential(torch.nn.Dropout(),torch.nn.Linear(256 * 6 * 6, 4096),torch.nn.ReLU(inplace=True),torch.nn.Dropout(),torch.nn.Linear(4096, 4096),torch.nn.ReLU(inplace=True),torch.nn.Linear(4096, num_classes))# 定义AlexNet前向传播过程def forward(self, x):x = self.features(x)x = x.view(x.size(0), 256 * 6 * 6)x = self.classifier(x)return x    # 训练网络函数
def train_net(net, train_data, test_data):losses = []acces = []# 测试集上Loss变化情况eval_losses = []eval_acces = []# 损失函数设置为交叉熵函数criterion = torch.nn.CrossEntropyLoss()# 优化方法选用SGD,初始学习率为1e-2optimizer = torch.optim.SGD(net.parameters(), 1e-2)for e in range(EPOCHS):train_loss = 0train_acc = 0# 将网络设置为训练模型net.train()for image, label in train_data:image = Variable(image)label = Variable(label)# 前向传播out = net(image)loss = criterion(out, label)# 反向传播optimizer.zero_grad()loss.backward()optimizer.step()# 记录误差train_loss += loss.data# 计算分类的准确率_, pred = out.max(1)num_correct = (np.array(pred, dtype=np.int32) == np.array(label, dtype=np.int32)).sum()acc = num_correct / image.shape[0]train_acc += acctrain_loss_rate = train_loss / len(train_data)train_acc_rate = train_acc / len(train_data)losses.append(train_loss_rate)acces.append(train_acc_rate)# 在测试集上检验效果eval_loss = 0eval_acc = 0net.eval() # 将模型改为预测模式for image, label in test_data:image = Variable(image)label = Variable(label)out = net(image)loss = criterion(out, label)# 记录误差eval_loss += loss.data# 记录准确率_, pred = out.max(1)num_correct = (np.array(pred, dtype=np.int32) == np.array(label, dtype=np.int32)).sum()acc = num_correct / image.shape[0]eval_acc += acceval_loss_rate = eval_loss / len(test_data)eval_acc_rate = eval_acc / len(test_data)eval_losses.append(eval_loss_rate)eval_acces.append(eval_acc_rate)print('epoch:{}, Train Loss: {:.6f}, Train Acc:{:.6f}, Eval Loss:{:.6f}, Eval Acc:{:.6f}'.format(e, train_loss_rate, train_acc_rate, eval_loss_rate, eval_acc_rate))torch.save(net.state_dict(), os.path.join(BASE_PATH, SAVE_PATH, 'Alex_model_epoch' + str(e) + '.pkl'))return eval_losses, eval_accesdef draw_result(eval_losses, eval_acces):x = range(1, EPOCHS + 1)fig, left_axis = plt.subplots()p1, = left_axis.plot(x, eval_losses, 'ro-')right_axis = left_axis.twinx()p2, = right_axis.plot(x, eval_acces, 'bo-')plt.xticks(x, rotation=0)# 设置左坐标轴以及右坐标轴的范围、精度left_axis.set_ylim(0, 0.5)left_axis.set_yticks(np.arange(0, 0.5, 0.1))right_axis.set_ylim(0.9, 1.01)right_axis.set_yticks(np.arange(0.9, 1.01, 0.02))# 设置坐标及标题的大小、颜色left_axis.set_xlabel('Labels')left_axis.set_ylabel('Loss', color='r')left_axis.tick_params(axis='y', colors='r')right_axis.set_ylabel('Accuracy', color='b')right_axis.tick_params(axis='y', colors='b')plt.show()if __name__ == '__main__':#x, y = load_mnist()print("基于卷积神经网络实现手写数字识别")train_set = mnist.MNIST('./data', train=True, download=True, transform=transforms.ToTensor())//需要转化成tensor数据格式test_set = mnist.MNIST('./data', train=False, download=True, transform=transforms.ToTensor())train_data = DataLoader(train_set, batch_size=64, shuffle=True)test_data = DataLoader(test_set, batch_size=64, shuffle=False)a, a_label = next(iter(train_data))#net = AlexNet()net = ConvNet()eval_losses, eval_acces = train_net(net, train_data, test_data)draw_result(eval_losses, eval_acces)
结果:

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/544340.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Linux系统——Session ID(负载均衡如何保持会话)

目录 一、实验环境搭建 二、部署Nginx代理服务器配置 三、部署后端真是服务器Tomcat配置 四、配置Tomcat的Session ID会话保持 五、测试 此次实验是Tomcat后端服务器如何做Session ID会话保持 一、实验环境搭建 [rootlocalhost ~]#systemctl stop firewalld [rootlocalho…

Ubuntu Desktop - gnome-calculator (计算器)

Ubuntu Desktop - gnome-calculator [计算器] 1. Ubuntu Software -> gnome-calculator -> Install -> Continue2. Search your computer -> Calculator -> Lock to LauncherReferences 1. Ubuntu Software -> gnome-calculator -> Install -> Continu…

2024年3月退伍大学生士兵报名参加2024年天津专升本的通知

3月20日开始,2024年3月退伍的大学生士兵可报名参加2024年天津市高职升本科招生考试 为落实有关退役大学生士兵免试专升本工作的文件精神,按照《市高招办关于印发2024年天津市高职升本科招生实施办法的通知》(津招办高发〔2023〕14号&#xf…

D 咖智能饮品机入驻万达,引领时尚饮品新潮流!

近日,D 咖智能饮品机正式入驻万达广场,为广大消费者带来全新的时尚饮品体验。作为国内领先的智能饮品设备品牌,D 咖智能饮品机以其多样化的口味选择、便捷的操作方式和个性化的定制服务,受到了众多消费者的喜爱。 D 咖智能饮品机提…

鸿蒙Harmony应用开发—ArkTS声明式开发(容器组件:RowSplit)

将子组件横向布局,并在每个子组件之间插入一根纵向的分割线。 说明: 该组件从API Version 7开始支持。后续版本如有新增内容,则采用上角标单独标记该内容的起始版本。 子组件 可以包含子组件。 RowSplit通过分割线限制子组件的宽度。初始化…

C语言 1000内完数、素数判断

一、一个数如果恰好等于它的因子之和,这个数就称为“完数”。例如,6旳因子为1,2,3,而6123,因此6是“完数”。编程序找出1000以内的所有“完数”,并按照下面格式输出其因子:6 its fac…

Centos7安装Clickhouse单节点部署

🎈 作者:互联网-小啊宇 🎈 简介: CSDN 运维领域创作者、阿里云专家博主。目前从事 Kubernetes运维相关工作,擅长Linux系统运维、开源监控软件维护、Kubernetes容器技术、CI/CD持续集成、自动化运维、开源软件部署维护…

Stargo 管理部署 Starrocks 集群

配置主机间 ssh 互信 ssh-copy-id hadoop02 ssh-copy-id hadoop03配置系统参数 ############################ Swap检查 ############################ echo 0 | sudo tee /proc/sys/vm/swappiness########################### 内核参数检查 ########################## echo…

Swift 面试题及答案整理,最新面试题

Swift 中如何实现单例模式? 在Swift中,单例模式的实现通常采用静态属性和私有初始化方法来确保一个类仅有一个实例。具体做法是:定义一个静态属性来存储这个单例实例,然后将类的初始化方法设为私有,以阻止外部通过构造…

鸿蒙Harmony应用开发—ArkTS声明式开发(容器组件:TabContent)

仅在Tabs中使用,对应一个切换页签的内容视图。 说明: 该组件从API Version 7开始支持。后续版本如有新增内容,则采用上角标单独标记该内容的起始版本。 子组件 支持单个子组件。 说明: 可内置系统组件和自定义组件,支…

Java基础 - 9 - 集合进阶(二)

一. Collection的其他相关知识 1.1 可变参数 可变参数就是一种特殊形参,定义在方法、构造器的形参列表里,格式是:数据类型…参数名称; 可变参数的特点和好处 特点:可以不传数据给它;可以传一个或者同时传多个数据给…

Spring Web MVC入门(2)

学习Spring MVC Postman介绍 在软件工程中, 我们需要具有前后端分离的思想, 以降低耦合性. 但是在测试后端代码时,我们还得写前端代码测试,这是个令人头疼的问题. 那么我们如何测试自己的后端程序呢, 这就用到了一个工具: Postman. 界面介绍: 传参的介绍 1.普通传参, 也就…