实战:基于卷积的MNIST手写体分类

前面实现了基于多层感知机的MNIST手写体识别,本章将实现以卷积神经网络完成的MNIST手写体识别。

1.  数据的准备

在本例中,依旧使用MNIST数据集,对这个数据集的数据和标签介绍,前面的章节已详细说明过了,相对于前面章节直接对数据进行“折叠”处理,这里需要显式地标注出数据的通道,代码如下:

import numpy as npimport einops.layers.torch as elt#载入数据x_train = np.load("../dataset/mnist/x_train.npy")y_train_label = np.load("../dataset/mnist/y_train_label.npy")x_train = np.expand_dims(x_train,axis=1)   #在指定维度上进行扩充print(x_train.shape)

这里是对数据的修正,np.expand_dims的作用是在指定维度上进行扩充,这里在第二维(也就是PyTorch的通道维度)进行扩充,结果如下:

(60000, 1, 28, 28)

2.  模型的设计

下面使用PyTorch 2.0框架对模型进行设计,在本例中将使用卷积层对数据进行处理,完整的模型如下:

import torch
import torch.nn as nn
import numpy as np
import einops.layers.torch as elt
class MnistNetword(nn.Module):def __init__(self):super(MnistNetword, self).__init__()#前置的特征提取模块self.convs_stack = nn.Sequential(nn.Conv2d(1,12,kernel_size=7),  	#第一个卷积层nn.ReLU(),nn.Conv2d(12,24,kernel_size=5), 	#第二个卷积层nn.ReLU(),nn.Conv2d(24,6,kernel_size=3)  	#第三个卷积层)#最终分类器层self.logits_layer = nn.Linear(in_features=1536,out_features=10)def forward(self,inputs):image = inputsx = self.convs_stack(image)        #elt.Rearrange的作用是对输入数据的维度进行调整,读者可以使用torch.nn.Flatten函数完成此工作x = elt.Rearrange("b c h w -> b (c h w)")(x)logits = self.logits_layer(x)return logits
model = MnistNetword()
torch.save(model,"model.pth")

这里首先设定了3个卷积层作为前置的特征提取层,最后一个全连接层作为分类器层,需要注意的是,对于分类器的全连接层,输入维度需要手动计算,当然读者可以一步一步尝试打印特征提取层的结果,依次将结果作为下一层的输入维度。最后对模型进行保存。

3.  基于卷积的MNIST分类模型

下面进入本章的最后示例部分,也就是MNIST手写体的分类。完整的训练代码如下:

import torch
import torch.nn as nn
import numpy as np
import einops.layers.torch as elt
#载入数据
x_train = np.load("../dataset/mnist/x_train.npy")
y_train_label = np.load("../dataset/mnist/y_train_label.npy")
x_train = np.expand_dims(x_train,axis=1)
print(x_train.shape)
class MnistNetword(nn.Module):def __init__(self):super(MnistNetword, self).__init__()self.convs_stack = nn.Sequential(nn.Conv2d(1,12,kernel_size=7),nn.ReLU(),nn.Conv2d(12,24,kernel_size=5),nn.ReLU(),nn.Conv2d(24,6,kernel_size=3))self.logits_layer = nn.Linear(in_features=1536,out_features=10)def forward(self,inputs):image = inputsx = self.convs_stack(image)x = elt.Rearrange("b c h w -> b (c h w)")(x)logits = self.logits_layer(x)return logits
device = "cuda" if torch.cuda.is_available() else "cpu"
#注意记得将model发送到GPU计算
model = MnistNetword().to(device)
model = torch.compile(model)
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-4)
batch_size = 128
for epoch in range(42):train_num = len(x_train)//128train_loss = 0.for i in range(train_num):start = i * batch_sizeend = (i + 1) * batch_sizex_batch = torch.tensor(x_train[start:end]).to(device)y_batch = torch.tensor(y_train_label[start:end]).to(device)pred = model(x_batch)loss = loss_fn(pred, y_batch)optimizer.zero_grad()loss.backward()optimizer.step()train_loss += loss.item()  # 记录每个批次的损失值# 计算并打印损失值train_loss /= train_numaccuracy = (pred.argmax(1) == y_batch).type(torch.float32).sum().item() / batch_sizeprint("epoch:",epoch,"train_loss:", round(train_loss,2),"accuracy:",round(accuracy,2))

在这里,我们使用了本章新定义的卷积神经网络模块作为局部特征抽取,而对于其他的损失函数以及优化函数,只使用了与前期一样的模式进行模型训练。最终结果如下所示,请读者自行验证。

(60000, 1, 28, 28)
epoch: 0 train_loss: 2.3 accuracy: 0.11
epoch: 1 train_loss: 2.3 accuracy: 0.13
epoch: 2 train_loss: 2.3 accuracy: 0.2
epoch: 3 train_loss: 2.3 accuracy: 0.18
…
epoch: 58 train_loss: 0.5 accuracy: 0.98
epoch: 59 train_loss: 0.49 accuracy: 0.98
epoch: 60 train_loss: 0.49 accuracy: 0.98
epoch: 61 train_loss: 0.48 accuracy: 0.98
epoch: 62 train_loss: 0.48 accuracy: 0.98Process finished with exit code 0

本文节选自《PyTorch 2.0深度学习从零开始学》,本书实战案例丰富,可带领读者快速掌握深度学习算法及其常见案例。

   

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/93846.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

性能测试必备监控技能windows篇

前言 在手头没有专门的第三方监控时,该怎么监控服务指标呢?本篇就windows下监控进行分享,也是我们在进行性能测试时,必须掌握的。下面我们就windows下常用的三种监视工具进行说明: 任务管理器 资源监视器 性能监视器…

【python爬虫】4.爬虫实操(菜品爬取)

文章目录 前言项目:解密吴氏私厨分析过程代码实现(一)获取与解析提取最小父级标签一组菜名、URL、食材写循环,存列表 代码实现(二)复习总结 前言 上一关,我们学习了用BeautifulSoup库解析数据和…

数据结构(Java实现)-二叉树(下)

获取二叉树的高度 检测值为value的元素是否存在(前序遍历) 层序遍历 判断一棵树是不是完全二叉树 获取节点的路径 二叉树的最近公共祖先

【uniapp】 实现公共弹窗的封装以及调用

图例&#xff1a;红框区域为 “ 内容区域 ” 一、组件 <!-- 弹窗组件 --> <template> <view class"add_popup" v-if"person.isShowPopup"><view class"popup_cont" :style"{width:props.width&&props.width&…

基于微信小程序的汽车租赁系统的设计与实现ljx7y

汽车租赁系统&#xff0c;主要包括管理员、用户二个权限角色&#xff0c;对于用户角色不同&#xff0c;所使用的功能模块相应不同。本文从管理员、用户的功能要求出发&#xff0c;汽车租赁系统系统中的功能模块主要是实现管理员后端&#xff1b;首页、个人中心、汽车品牌管理、…

解决windows下安装python并终端运行python弹出windows商店的最终解决方案

最近在windows下安装了python&#xff0c;在终端中运行python命令的时候总是会弹出windows商店 为了解决这个问题&#xff0c;我修改系统的环境变量也一直不起作用&#xff0c;只有指定绝对路径时才能正常启动&#xff0c;这个问题困扰了我很长一段时间&#xff0c;后来了解到…

【K8S系列】深入解析k8s网络插件—Cilium

序言 做一件事并不难&#xff0c;难的是在于坚持。坚持一下也不难&#xff0c;难的是坚持到底。 文章标记颜色说明&#xff1a; 黄色&#xff1a;重要标题红色&#xff1a;用来标记结论绿色&#xff1a;用来标记论点蓝色&#xff1a;用来标记论点 在现代容器化应用程序的世界中…

WIFI与BT的PCB布局布线注意事项

1、模块整体布局时&#xff0c;WIFI模组要尽量远离DDR、HDMI、USB、LCD电路以及喇叭等易干扰模块或连接座&#xff1b; 2、晶体电路布局需要优先考虑&#xff0c;布局时应与芯片在同一层并尽量靠近放置以避免打过孔&#xff0c;晶体走线尽可能的短&#xff0c;远离干扰源&…

java八股文面试[多线程]——一个线程两次调用start()方法会出现什么情况

典型回答&#xff1a; Java 的线程是不允许启动两次的&#xff0c;第二次调用必然会抛出 IllegalThreadStateException&#xff0c;这是一种运行时异常&#xff0c;多次调用 start 被认为是编程错误。 通过线程的状态图&#xff0c;在第二次调用 start() 方法的时候&#xff…

若依富文本 html样式 被过滤问题

一.场景 进入页面&#xff0c;富文本编辑框里回显这条新闻内容&#xff0c;如下图&#xff0c; 然后可以在富文本编辑框里对它实现再编辑&#xff0c;编辑之后将html代码提交保存到后台数据库。可以点击详情页进行查看。 出现问题&#xff1a;在提交到后台controller时&#x…

Java项目-苍穹外卖-Day08

文章目录 前言导入地址簿代码导入需求分析代码导入功能测试 用户下单需求分析接口设计数据库设计 代码开发功能测试 前言 本篇博客主要是用户端的功能完善 主要是三个功能 1.导入地址簿 2.点击去结算弹出结算页面 3.微信支付功能 导入地址簿代码导入 这个地址簿就是一个表的…

WireShark流量抓包详解

目录 Wireshark软件安装Wireshark 开始抓包示例Wireshakr抓包界面介绍WireShark 主要界面 wireshark过滤器表达式的规则 Wireshark软件安装 软件下载路径&#xff1a;wireshark官网。按照系统版本选择下载&#xff0c;下载完成后&#xff0c;按照软件提示一路Next安装。 Wire…