Pytorch实现基于MNIST的手写数字识别

news/2025/1/23 10:36:07/文章来源:https://www.cnblogs.com/w1ck/p/18292044

本文目的在于训练一个模型,使其能对手写的数字图片进行分类识别,并不断优化使其准确度尽可能地提高

一、数据预处理

(1)运行时所需库

import numpy as np  
import torch  
import torchvision  
from torch import nn  
from torch.utils.data import DataLoader  
from torchvision import datasets  
import matplotlib.pyplot as plt  
import os.path

(2)选择合适的设备进行训练

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

(3) 构建数据集

# 将图片转化为张量以及归一化处理  
Trans = torchvision.transforms.Compose(  [torchvision.transforms.ToTensor(), torchvision.transforms.Normalize(mean=[0.5], std=[0.5])])  # 下载MNIST对应的训练和测试数据集  
training_data = datasets.MNIST(  root="data",  train=True,  download=True,  transform=Trans,  
)  test_data = datasets.MNIST(  root="data",  train=False,  download=True,  transform=Trans,  
)  # 设定batch大小  
batch_size = 64  # 构建用于训练和测试的数据集的dataloader  
train_dataloader = DataLoader(training_data, batch_size=batch_size)  
test_dataloader = DataLoader(test_data, batch_size=batch_size)  for X, y in test_dataloader:  print("Shape of X [N,C,H,W]:", X.shape)  print("Shape of y: ", y.shape, y.dtype)  break

二、训练和测试

(1)模型网络构建

三层的全连接层网络

class NeuralNetwork(nn.Module):  def __init__(self):  super(NeuralNetwork, self).__init__()  self.flatten = nn.Flatten()  self.linear_relu_stack = nn.Sequential(  nn.Linear(28 * 28, 512),  nn.ReLU(),  nn.Linear(512, 512),  nn.ReLU(),  nn.Linear(512, 10),  nn.ReLU()  )  def forward(self, x):  x = self.flatten(x)  logits = self.linear_relu_stack(x)  return logits  model = NeuralNetwork().to(device)  
if os.path.exists(filename):  model.load_state_dict(torch.load(filename))  
print(model)

(2)定义损失函数和优化器

loss_fn = nn.CrossEntropyLoss()  
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)

(3)定义训练函数

def train(dataloader, model, loss_fn, optimizer):  size = len(dataloader.dataset)  for batch, (X, y) in enumerate(dataloader):  X, y = X.to(device), y.to(device)  pred = model(X)  loss = loss_fn(pred, y)  optimizer.zero_grad()  loss.backward()  optimizer.step()  if batch % 100 == 0:  loss, current = loss.item(), batch * len(X)  print(f"loss:{loss:>7f} [{current:>5d}/{size:>5d}]")

(4)定义测试函数

def test(dataloader, model):  size = len(dataloader.dataset)  model.eval()  test_loss, correct = 0, 0  global ok  with torch.no_grad():  for X, y in dataloader:  X, y = X.to(device), y.to(device)  pred = model(X)  test_loss += loss_fn(pred, y).item()  correct += (pred.argmax(1) == y).type(torch.float).sum().item()  if ok:  ok = False  L = X.cpu()  R = y.cpu()  M = pred.argmax(1).cpu()  plot_images_labels_prediction(np.array(L), np.array(R), np.array(M), 10, 25)  test_loss /= size  correct /= size  history['Test Loss'].append(test_loss)  history['Test Accuracy'].append(correct * 100)  print(f"Test Error: \nAccuracy:{(100 * correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")

三、主函数和权值保存

(1)主函数

if __name__ == '__main__':  epochs = 10  for t in range(epochs):  print(f"Epoch {t + 1}\n------------------")  train(train_dataloader, model, loss_fn, optimizer)  test(test_dataloader, model)  print("Done!")

(2)保存和恢复网络权值

model = NeuralNetwork().to(device)  
if os.path.exists(filename):  model.load_state_dict(torch.load(filename))  
print(model)torch.save(model.state_dict(), filename)
print("Save PyTorch Model State to " + filename)

四、可视化

(1)显示图片以及预测结果

def plot_images_labels_prediction(images, labels, prediction, index, num=10):  fig = plt.gcf()  # 获取当前图表  fig.set_size_inches(10, 12)  # 显示成英寸(1英寸等于2.54cm)  if num > 25:  num = 25  # 最多显示25幅图片  for i in range(0, num):  ax = plt.subplot(5, 5, i + 1)  # 画多个子图(5*5)  ax.imshow(np.reshape(images[index], (28, 28)), cmap='binary')  # 显示第index张图像  title = "label=" + str(labels[index])  # 构建图片上要显示的title  if len(prediction) > 0:  title += ", predict=" + str(prediction[index])  ax.set_title(title, fontsize=10)  ax.set_xticks([])  # 不显示坐标轴  ax.set_yticks([])  index += 1  plt.show()if ok:  ok = False  L = X.cpu()  R = y.cpu()  M = pred.argmax(1).cpu()  plot_images_labels_prediction(np.array(L), np.array(R), np.array(M), 10, 25)

(2)Acc和loss的变化曲线

history = {'Test Loss': [], 'Test Accuracy': []}plt.plot(history['Test Loss'], label='Test Loss')  
plt.legend(loc='best')  
plt.grid(True)  
plt.xlabel('Epoch')  
plt.ylabel('Loss')  
plt.show()  plt.plot(history['Test Accuracy'], color='red', label='Test Accuracy')  
plt.legend(loc='best')  
plt.grid(True)  
plt.xlabel('Epoch')  
plt.ylabel('Accuracy%')  
plt.show()

五、实验结果展示

当学习率选择0.1时,10次训练后手写图片预测情况:

image.png

image.png

image.png

image.png

六、优化与参数调整

(1)调整优化器的学习率

在 learningrate = 1e-3 时训练结果:

image.png

在多次尝试的经验选择下

调整为 learningrate = 0.1 后的效果如上图实验结果所示

(2)保存模型,同时将学习率逐步减小以趋近极值点

考虑到局部最优解的原理,我们将 lr = 0.1 的模型保存后(目的为了加快求得解),之后加载模型,采用 lr = 1e-3 去多轮训练,最终得到结果:

image.png

但是可以见得由于 lr = 0.1 下解已然十分逼近最优,在此优化下提升已经不多。

Code:

import numpy as np  
import torch  
import torchvision  
from torch import nn  
from torch.utils.data import DataLoader  
from torchvision import datasets  
import matplotlib.pyplot as plt  
import os.path  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")  # 将图片转化为张量以及归一化处理  
Trans = torchvision.transforms.Compose(  [torchvision.transforms.ToTensor(), torchvision.transforms.Normalize(mean=[0.5], std=[0.5])])  # 下载MNIST对应的训练和测试数据集  
training_data = datasets.MNIST(  root="data",  train=True,  download=True,  transform=Trans,  
)  test_data = datasets.MNIST(  root="data",  train=False,  download=True,  transform=Trans,  
)  # 设定batch大小  
batch_size = 64  # 构建用于训练和测试的数据集的dataloader  
train_dataloader = DataLoader(training_data, batch_size=batch_size)  
test_dataloader = DataLoader(test_data, batch_size=batch_size)  for X, y in test_dataloader:  print("Shape of X [N,C,H,W]:", X.shape)  print("Shape of y: ", y.shape, y.dtype)  break  class NeuralNetwork(nn.Module):  def __init__(self):  super(NeuralNetwork, self).__init__()  self.flatten = nn.Flatten()  self.linear_relu_stack = nn.Sequential(  nn.Linear(28 * 28, 512),  nn.ReLU(),  nn.Linear(512, 512),  nn.ReLU(),  nn.Linear(512, 10),  nn.ReLU()  )  def forward(self, x):  x = self.flatten(x)  logits = self.linear_relu_stack(x)  return logits  filename = "model.pth"  model = NeuralNetwork().to(device)  
if os.path.exists(filename):  model.load_state_dict(torch.load(filename))  
print(model)  loss_fn = nn.CrossEntropyLoss()  
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)  history = {'Test Loss': [], 'Test Accuracy': []}  def plot_images_labels_prediction(images, labels, prediction, index, num=10):  fig = plt.gcf()  # 获取当前图表  fig.set_size_inches(10, 12)  # 显示成英寸(1英寸等于2.54cm)  if num > 25:  num = 25  # 最多显示25幅图片  for i in range(0, num):  ax = plt.subplot(5, 5, i + 1)  # 画多个子图(5*5)  ax.imshow(np.reshape(images[index], (28, 28)), cmap='binary')  # 显示第index张图像  title = "label=" + str(labels[index])  # 构建图片上要显示的title  if len(prediction) > 0:  title += ", predict=" + str(prediction[index])  ax.set_title(title, fontsize=10)  ax.set_xticks([])  # 不显示坐标轴  ax.set_yticks([])  index += 1  plt.show()  def train(dataloader, model, loss_fn, optimizer):  size = len(dataloader.dataset)  for batch, (X, y) in enumerate(dataloader):  X, y = X.to(device), y.to(device)  pred = model(X)  loss = loss_fn(pred, y)  optimizer.zero_grad()  loss.backward()  optimizer.step()  if batch % 100 == 0:  loss, current = loss.item(), batch * len(X)  print(f"loss:{loss:>7f} [{current:>5d}/{size:>5d}]")  ok = False  def test(dataloader, model):  size = len(dataloader.dataset)  model.eval()  test_loss, correct = 0, 0  global ok  with torch.no_grad():  for X, y in dataloader:  X, y = X.to(device), y.to(device)  pred = model(X)  test_loss += loss_fn(pred, y).item()  correct += (pred.argmax(1) == y).type(torch.float).sum().item()  if ok:  ok = False  L = X.cpu()  R = y.cpu()  M = pred.argmax(1).cpu()  plot_images_labels_prediction(np.array(L), np.array(R), np.array(M), 10, 25)  test_loss /= size  correct /= size  history['Test Loss'].append(test_loss)  history['Test Accuracy'].append(correct * 100)  print(f"Test Error: \nAccuracy:{(100 * correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")  if __name__ == '__main__':  epochs = 10  for t in range(epochs):  if t == epochs - 1:  ok = True  print(f"Epoch {t + 1}\n------------------")  train(train_dataloader, model, loss_fn, optimizer)  test(test_dataloader, model)  print("Done!")  plt.plot(history['Test Loss'], label='Test Loss')  plt.legend(loc='best')  plt.grid(True)  plt.xlabel('Epoch')  plt.ylabel('Loss')  plt.show()  plt.plot(history['Test Accuracy'], color='red', label='Test Accuracy')  plt.legend(loc='best')  plt.grid(True)  plt.xlabel('Epoch')  plt.ylabel('Accuracy%')  plt.show()  torch.save(model.state_dict(), filename)  
print("Save PyTorch Model State to " + filename)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/741175.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【proto】python根据proto文件构造message,并换为二进制

一、场景测试需要构造数据,而且存储的格式为grpc消息的二进制格式,所以必须要根据proto构造二进制二、构造方法 1、根据proto文件生成python格式的pb文件python3 -m grpc_tools.protoc -I. proto/upload_state.proto --python_out=. --grpc_python_out=. 2、检查文件生成 3…

electron 跨域/CSP问题

请求报错:Refused to connect to http://127.0.0.1:8000/get?name=kv-grpc because it violates the following Content Security Policy directive: "default-src self". Note that connect-src was not explicitly set, so default-src is used as a fallback 这…

松灵机器人scout mini小车 自主导航(2)——仿真指南

松灵机器人Scout mini小车仿真指南 之前介绍了如何通过CAN TO USB串口实现用键盘控制小车移动。但是一直用小车测试缺乏安全性。而松灵官方贴心的为我们准备了gazebo仿真环境,提供了完整的仿真支持库,本文将介绍如何上手使用仿真。 官方仓库地址:https://github.com/agilexr…

zabbix“专家坐诊”第245期问答

问题一 Q:vfs.dev.discovery拿的是哪里的文件,我看源码里面获取的是/proc/parttions里面的信息,但是我没有这个device,是怎么获取出来的?A:检查下系统内核版本或者agent程序版本,如果未定义KERNEL_2_4的情况下,读的是后面这个文件。 Q:这两个文件我都看过,也没有cdro…

【Python】Word文档操作

一、全文替换 不是创建word文档写入内容,而是基于现有的Word文档进行替换处理 使用run.text直接赋值修改发现样式会丢失,而网上大部分办法都是这么写的... 直到我看到这篇文章的评论:https://blog.csdn.net/qq_40222956/article/details/106098464 除了段落替换后,Word文档…

【ubuntu】安装go

一、官网 https://golang.google.cn/dl/ 选择稳定版本,点击下载二、安装步骤 1、解压2、移动目录sudo mv go /usr/local3、配置环境变量 vim ~/.bashrcexport PATH=$PATH:/usr/local/go/bin export GOPATH=$HOME/gocode创建gocode目录 vim ~/.profile 添加同样配置三、验证$ g…

2024迎新马拉松——字典

思路 这道题可以把每个单词正过来放在一个字典树里。 把每个单词反过来,给每个单词单独建立一个字典树。 而询问要求的就是前缀在正串的字典树上的那个节点为根的子树中,所有串的反串字典树合并之后的那个字典树上,后缀的那个节点所对应的子树当中有多少个串就是答案。 举个…

win11系统 连接共享打印机提示 0x0000709

windows11 用户在添加共享打印机的时候,遇到了系统错误提示:操作无法完成(错误0x0000709) 其他查考文章: https://baijiahao.baidu.com/s?id=1788757659395932042&wfr=spider&for=pc

24迎新马拉松——字典

思路 这道题我们可以把每个单词正过来放在一个字典树里。 而我们把每个单词反过来,给每个单词单独建立一个字典树。 而询问要求的就是前缀在正串的字典树上的那个节点为根的子树中,所有串的反串字典树合并之后的那个字典树上,后缀的那个节点所对应的子树当中有多少个串就是答…

香橙派5plus上跑云手机方案二 waydroid

前言 上篇文章香橙派5plus上跑云手机方案一 redroid(带硬件加速)说了怎么跑带GPU加速的redroid方案,这篇说下怎么在香橙派下使用Waydroid。 温馨提示 虽然能运行,但是体验下来只能用软件加速,无法使用GPU加速,所有会很卡。而且Waydroid还依赖于桌面环境wayland,要么插上显…

弹性伸缩落地实践

1. 什么是 HPA ? HPA(Horizontal Pod Autoscaler)是 Kubernetes 中的一种资源自动伸缩机制,用于根据某些指标动态调整 Pod 的副本数量。 2. 什么时候需要 HPA ?负载波动:当您的应用程序的负载经常发生波动时,HPA 可以自动调整 Pod 的副本数量,以适应负载的变化。例如,…

三星 NAND FLASH命名规范 Samsung NAND Flash Code Information

一共有三页,介绍了前面主要的编号和横杠后面的编号,当前文档只关注前面的编号。 从前面的命名规范中可以得知当前芯片的容量、技术等概要信息,对芯片有一个整体了解。 详细解释 Small Classification 表示存储单元的类型和应用,比如 SLC 1 Chip XD Card 表示是SLC的,包含1…