基于MNIST的手写数字识别

上次我们基于CIFAR-10训练一个图像分类器,梳理了一下训练模型的全过程,并且对卷积神经网络有了一定的理解,我们再在GPU上搭建一个手写的数字识别cnn网络,加深巩固一下

步骤

  1. 加载数据集
  2. 定义神经网络
  3. 定义损失函数
  4. 训练网络
  5. 测试网络

MNIST数据集简介

MINIST是一个手写数字数据库(官网地址:http://yann.lecun.com/exdb/mnist/),它有6w张训练样本和1w张测试样本,每张图的像素尺寸为28*28,如下图一共4个图片,这些图片文件均被保存为二进制格式

训练全过程

1.加载数据集

import torch
import torchvision
from torchvision import transforms
trainset = torchvision.datasets.MNIST(root='./data',train=True,download=True,transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))]))
trainloader = torch.utils.data.DataLoader(trainset,batch_size=64,shuffle=True)testset = torchvision.datasets.MNIST('./data', train=False, transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))
]))
test_loader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=True)

展示一些训练图片

import numpy as np
import matplotlib.pyplot as plt
def imshow(img):img = img / 2 + 0.5     # unnormalizenpimg = img.numpy()plt.imshow(np.transpose(npimg, (1, 2, 0)))plt.show()
# 得到batch中的数据
dataiter = iter(train_loader)
images, labels = dataiter.next()imshow(torchvision.utils.make_grid(images))

2.定义卷积神经网络

import torch
import torch.nn as nn
import torch.nn.functional as F#可以调用一些常见的函数,例如非线性以及池化等
class Net(nn.Module):def __init__(self):super(Net, self).__init__()# input image channel, 6 output channels, 5x5 square convolutionself.conv1 = nn.Conv2d(1, 6, 5)self.conv2 = nn.Conv2d(6, 16, 5)# 全连接 从16 * 4 * 4的维度转成120self.fc1 = nn.Linear(16 * 4 * 4, 120)self.fc2 = nn.Linear(120, 84)self.fc3 = nn.Linear(84, 10)def forward(self, x):x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))x = F.max_pool2d(F.relu(self.conv2(x)), 2)#(2,2)也可以直接写成数字2x = x.view(-1, self.num_flat_features(x))#将维度转成以batch为第一维 剩余维数相乘为第二维x = F.relu(self.fc1(x))x = F.relu(self.fc2(x))x = self.fc3(x)return xdef num_flat_features(self, x):size = x.size()[1:]  # 第一个维度batch不考虑num_features = 1for s in size:num_features *= sreturn num_features
net = Net()
print(net)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)
net.to(device)

3.定义损失和优化器

criterion = nn.CrossEntropyLoss()
import torch.optim as optim
optimizer = optim.SGD(net.parameters(), lr=0.01, momentum=0.9)

这里设置了 momentum=0.9 ,训练一轮的准确率由90%提到了98%

4.训练网络

def train(epochs):net.train()for epoch in range(epochs):running_loss = 0.0for i, data in enumerate(trainloader):# 得到输入 和 标签inputs, labels = datainputs, labels = inputs.to(device), labels.to(device)# 消除梯度optimizer.zero_grad()# 前向传播 计算损失 后向传播 更新参数outputs = net(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()# 打印日志running_loss += loss.item()if i % 100 == 0:    # 每100个batch打印一次print('[%d, %5d] loss: %.3f' %(epoch + 1, i + 1, running_loss / 100))running_loss = 0.0
torch.save(net, 'mnist.pth')

net.train():调用方法时,模型将进入训练模式。在训练模式下,一些特定的模块,例如Dropout和Batch Normalization,将被启用。这是因为在训练过程中,我们需要使用Dropout来防止过拟合,并使用Batch Normalization来加速收敛

net.eval():调用方法时,模型将进入评估模式。在评估模式下,一些特定的模块,例如Dropout和Batch Normalization,将被禁用。这是因为在评估过程中,我们不需要使用Dropout来防止过拟合,并且Batch Normalization的统计信息应该是固定的。

5.测试网络

在其它地方导入模型测试时需要将类的定义添加到加载模型的这个py文件中

from mnist.py import Net  # 导入会运行mnist.py
net = torch.load('mnist.pth')testset = torchvision.datasets.MNIST('./data', train=False, transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))
]))
testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=True)correct = 0
total = 0
net.to('cpu') 
print(net)with torch.no_grad():  # 或者model.eval()for data in testloader:images, labels = dataoutputs = net(images)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()
print('Accuracy of the network on the 10000 test images: %d %%' % (100 * correct / total))

训练一轮速度

GPU:10s

CPU:10s

训练三轮速度

GPU:24.5s

CPU:28.6s

得出结论:训练数据计算量少的时候,无论在CPU上还是GPU,性能几乎都是接近的,而当训练数据计算量达到一定多的时候,GPU的优势就比较显著直观了

小小实验:

(1)加载并测试一张图片,正确则输出True

import torch
import torch.nn as nn
import torchvision
from torchvision import transforms
import torch.nn.functional as F
import cv2
import numpy as npclass Net(nn.Module):def __init__(self):super(Net, self).__init__()self.conv1 = nn.Conv2d(1, 6, 5)self.conv2 = nn.Conv2d(6, 16, 5)self.fc1 = nn.Linear(16 * 4 * 4, 120)self.fc2 = nn.Linear(120, 84)self.fc3 = nn.Linear(84, 10)def forward(self, x):x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))x = F.max_pool2d(F.relu(self.conv2(x)), 2)  x = x.view(-1, self.num_flat_features(x))  x = F.relu(self.fc1(x))x = F.relu(self.fc2(x))x = self.fc3(x)return xdef num_flat_features(self, x):size = x.size()[1:]  num_features = 1for s in size:num_features *= sreturn num_featurescorrect = 0
total = 0
net = torch.load('mnist.pth')
net.to('cpu')
# print(net)with torch.no_grad(): imgdir = '3.jpeg'img = cv2.imread(imgdir, 0)img = cv2.resize(img, (28, 28))trans = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))])image = trans(img)image = image.unsqueeze(0)label = torch.tensor([int(imgdir.split('.')[0])])outputs = net(image)_, predicted = torch.max(outputs.data, 1)print(predicted)print((predicted == label).item())

拿刚刚训练的模型试了6张数字图片,只有一张2是预测对的....

unsuqeeze:通过unsuqeeze(int)中的int整数,增加一个维度,int整数表示维度增加到哪儿去,且维度为1,参数:【0, 1, 2】

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/638417.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Android 性能优化之黑科技开道(二)

3. 其它可以黑科技优化的方向 3.1 核心线程绑定大核 3.1.1 定义 核心线程绑定大核的思路也很容易理解,现在的 CPU 都是多核的,大核的频率比小核要高不少,如果我们的核心线程固定运行在大核上,那么应用性能自然会有所提升。 核…

new[]与delete[]

(要理解之前关于new,delete的一些概念,看​​​​​​ CSDN) 引子: 相比new,new[]不仅仅是个数的增加,还有int大小记录空间的创建, 下图中错误的用模拟多个new来替代new[],释放步…

Git 原理及使用 (带动图演示)

文章目录 🌈 Ⅰ Git 安装🌙 01. Linux - centos 🌈 Ⅱ Git 工作区、暂存区和版本库🌙 01. 认识工作区、暂存区和版本库🌙 02. 使用 Git 管理工作区的文件 🌈 Ⅲ Git 基本操作🌙 01. 创建本地仓库…

Java客户端如何直接调用es的API

Java客户端如何直接调用es的API 一. 问题二. withJson 前言 这是我在这个网站整理的笔记,有错误的地方请指出,关注我,接下来还会持续更新。 作者:神的孩子都在歌唱 一. 问题 今天做项目的时候,想要直接通过java客户端调用es的api…

docker的安装以及docker中nginx配置

机器 test3 192.168.23.103 1机器初始化配置 1.1关闭防火墙,清空防火墙规则 systemctl stop firewalld iptables -F setenforce 01.2部署时间同步 yum install ntp ntpdate -y1.3安装基础软件包 yum install -y wget net-tools nfs-utils lrzsz gcc gcc-c make…

2023年网络安全行业:机遇与挑战并存

2023年全球网络安全人才概况 根据ISC2的《2023年全球网络安全人才调查报告》,全球的网络安全专业人才数量达到了550万,同比增长了8.7%。然而,这一年也见证了网络安全人才短缺达到了历史新高,缺口数量接近400万。尤其是亚太地区&am…

【Linux学习】Linux调试器-gdb使用

这里写目录标题 🌂背景🌂gdb使用🌂指令总结: 🌂背景 程序的发布方式有两种,debug模式和 release模式 其中,debug模式是可以被调试的,到那时release模式是不能被调试的; …

用Nest实现对数据库的增删改查~

概述 为了与 SQL和 NoSQL 数据库集成,Nest 提供了 nestjs/typeorm 包。Nest 使用TypeORM是因为它是 TypeScript 中最成熟的对象关系映射器( ORM )。因为它是用 TypeScript 编写的,所以可以很好地与 Nest 框架集成。 TypeORM 提供了对许多关系数据库的支…

数据库主从复制

一、主从复制概述 1、介绍: 主从复制是指将主数据库的 DDL 和 DML 操作写入到二进制日志中,将二进制日志传送到从库服务器,然后在从库上对这些日志重新执行(重做),从而使得从库和主库的数据保持同步。 M…

护眼台灯哪个牌子好?排名靠前的护眼台灯十大排名推荐!

护眼台灯哪个牌子好?目前,书客、松下、飞利浦等品牌备受关注。急需护眼的朋友,先不必焦虑。护眼台灯的选择,同样需要细致考虑,不是简单地亮起来就足够护眼。因为不当的光线可能对眼睛造成微妙而长远的伤害,…

怎样快速打造二级分销小程序

乔拓云是一个专门开发小程序模板的平台,致力于帮助商家快速上线自己的小程序。通过套用乔拓云提供的精美模板,商家无需具备专业的技术背景,也能轻松打造出功能齐全、美观大方的小程序。 在乔拓云的官网,商家可以免费注册账号并登录…

火力发电资质升级,河南企业申报周期一览

河南企业申报火力发电资质从丙级升级到乙级的周期,通常是一个涉及多个环节和因素的复杂过程。因此,具体的申报周期会因企业的实际情况、申报材料的准备情况、审批部门的工作效率等多种因素而有所差异。 一般来说,整个升级周期可能包含以下步骤…