深度学习手写字符识别:训练模型

说明

本篇博客主要是跟着B站中国计量大学杨老师的视频实战深度学习手写字符识别。
第一个深度学习实例手写字符识别

深度学习环境配置

可以参考下篇博客,网上也有很多教程,很容易搭建好深度学习的环境。
Windows11搭建GPU版本PyTorch环境详细过程

数据集

手写字符识别用到的数据集是MNIST数据集(Mixed National Institute of Standards and Technology database);MNIST是一个用来训练各种图像处理系统二进制图像数据集,广泛应用到机器学习中的训练和测试。
作为一个入门级的计算机视觉数据集,发布20多年来,它已经被无数机器学习入门者应用无数遍,是最受欢迎的深度学习数据集之一。

序号说明
发布方National Institute of Standards and Technology(美国国家标准技术研究所,简称NIST)
发布时间1998
背景该数据集的论文想要证明在模式识别问题上,基于CNN的方法可以取代之前的基于手工特征的方法,所以作者创建了一个手写数字的数据集,以手写数字识别作为例子证明CNN在模式识别问题上的优越性。
简介MNIST数据集是从NIST的两个手写数字数据集:Special Database 3 和Special Database 1中分别取出部分图像,并经过一些图像处理后得到的。MNIST数据集共有70000张图像,其中训练集60000张,测试集10000张。所有图像都是28×28的灰度图像,每张图像包含一个手写数字。

跟着视频跑源码

  1. 下载源码:mivlab/AI_course (github.com)
  2. 下载数据集:https://opendatalab.com/MNIST;网上下载的地址比较多,也可以直接下载B站中国计量大学杨老师的百度网盘位置里的MNIST。

运行源码

  1. 在Pycharm中打开AI_course项目,运行classify_pytorch文件目录里train_mnist.py的Python文件。
    在这里插入图片描述
    train_mnist.py具体的源码如下:
import torch
import math
import torch.nn as nn
from torch.autograd import Variable
from torchvision import transforms, models
import argparse
import os
from torch.utils.data import DataLoaderfrom dataloader import mnist_loader as ml
from models.cnn import Net
from toonnx import to_onnxparser = argparse.ArgumentParser(description='PyTorch MNIST Example')
parser.add_argument('--datapath', required=True, help='data path')
parser.add_argument('--batch_size', type=int, default=256, help='training batch size')
parser.add_argument('--epochs', type=int, default=300, help='number of epochs to train')
parser.add_argument('--use_cuda', default=False, help='using CUDA for training')args = parser.parse_args()
args.cuda = args.use_cuda and torch.cuda.is_available()
if args.cuda:torch.backends.cudnn.benchmark = Truedef train():os.makedirs('./output', exist_ok=True)if True: #not os.path.exists('output/total.txt'):ml.image_list(args.datapath, 'output/total.txt')ml.shuffle_split('output/total.txt', 'output/train.txt', 'output/val.txt')train_data = ml.MyDataset(txt='output/train.txt', transform=transforms.ToTensor())val_data = ml.MyDataset(txt='output/val.txt', transform=transforms.ToTensor())train_loader = DataLoader(dataset=train_data, batch_size=args.batch_size, shuffle=True)val_loader = DataLoader(dataset=val_data, batch_size=args.batch_size)model = Net(10)#model = models.vgg16(num_classes=10)#model = models.resnet18(num_classes=10)  # 调用内置模型#model.load_state_dict(torch.load('./output/params_10.pth'))#from torchsummary import summary#summary(model, (3, 28, 28))if args.cuda:print('training with cuda')model.cuda()optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=1e-3)scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, [20, 30], 0.1)loss_func = nn.CrossEntropyLoss()for epoch in range(args.epochs):# training-----------------------------------model.train()train_loss = 0train_acc = 0for batch, (batch_x, batch_y) in enumerate(train_loader):if args.cuda:batch_x, batch_y = Variable(batch_x.cuda()), Variable(batch_y.cuda())else:batch_x, batch_y = Variable(batch_x), Variable(batch_y)out = model(batch_x)  # 256x3x28x28  out 256x10loss = loss_func(out, batch_y)train_loss += loss.item()pred = torch.max(out, 1)[1]train_correct = (pred == batch_y).sum()train_acc += train_correct.item()print('epoch: %2d/%d batch %3d/%d  Train Loss: %.3f, Acc: %.3f'% (epoch + 1, args.epochs, batch, math.ceil(len(train_data) / args.batch_size),loss.item(), train_correct.item() / len(batch_x)))optimizer.zero_grad()loss.backward()optimizer.step()scheduler.step()  # 更新learning rateprint('Train Loss: %.6f, Acc: %.3f' % (train_loss / (math.ceil(len(train_data)/args.batch_size)),train_acc / (len(train_data))))# evaluation--------------------------------model.eval()eval_loss = 0eval_acc = 0for batch_x, batch_y in val_loader:if args.cuda:batch_x, batch_y = Variable(batch_x.cuda()), Variable(batch_y.cuda())else:batch_x, batch_y = Variable(batch_x), Variable(batch_y)out = model(batch_x)loss = loss_func(out, batch_y)eval_loss += loss.item()pred = torch.max(out, 1)[1]num_correct = (pred == batch_y).sum()eval_acc += num_correct.item()print('Val Loss: %.6f, Acc: %.3f' % (eval_loss / (math.ceil(len(val_data)/args.batch_size)),eval_acc / (len(val_data))))# 保存模型。每隔多少帧存模型,此处可修改------------if (epoch + 1) % 1 == 0:# torch.save(model, 'output/model_' + str(epoch+1) + '.pth')torch.save(model.state_dict(), 'output/params_' + str(epoch + 1) + '.pth')#to_onnx(model, 3, 28, 28, 'params.onnx')if __name__ == '__main__':train()
  1. 报错:没有cv2,即没有安装OpenCV库。
    在这里插入图片描述
  2. 安装OpenCV库,可以命令行安装,也可以Pycharm中安装。
  • 命令行激活虚拟环境:conda activate deeplearning
  • 命令行安装: pip install opencv-python(也可以Pycharm中下载,可能上梯子安装更快)
    在这里插入图片描述
  1. 再次运行,出现如下图提示,表明需要将下载好的数据集配置到configure中。
    在这里插入图片描述
  2. 加载下载好的数据集,即--datapath=数据集的路径
    在这里插入图片描述
  3. 点击“Run”,开始训练,损失和准确率在一直更新,持续训练,直到模型完成,未改动源码的情况下,训练时间可能需要较长。
    在这里插入图片描述
  4. 在小编的拯救者笔记本电脑上持续训练了10小时才完成最终的模型训练,可以看到训练损失已经很低了,准确度很高水平。
    在这里插入图片描述
  5. 在项目中output文件夹中可以看到已经训练好了很多模型;后面可以利用模型进行推理了。
    在这里插入图片描述

参考

https://zhuanlan.zhihu.com/p/681236488

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/450968.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

代码随想录算法训练营DAY11 | 栈与队列 (2)

一、LeetCode 20 有效的括号 题目链接:20.有效的括号https://leetcode.cn/problems/valid-parentheses/ 思路:遇到左括号直接进栈;遇到右括号判断站顶是否有匹配的括号,没有就返回flase,有就将栈顶元素出栈&#xff1…

ctfshow——文件包含

文章目录 web 78——php伪协议第一种方法——php://input第二种方法——data://text/plain第三种方法——远程包含(http://协议) web 78——str_replace过滤字符php第一种方法——远程包含(http://协议)第二种方法——data://&…

Nicn的刷题日常之杨氏矩阵(三种方法求解,逐级递增详解,手把手教学,建议三连收藏)

目录 1.杨氏矩阵知识普及:什么是样式矩阵 2.题目描述 3.解题 3.1暴力求解,遍历法 3.2巧妙解题:对角元素法 3.3将巧解法封装为函数 4.结语 1.杨氏矩阵知识普及:什么是样式矩阵 杨氏矩阵,是对组合表示理论和…

蓝牙 - BLE Basics

BLE Basics [ BLE基础 ] 需要了解的两个主要概念是 BLE 设备的两种模式: Two major concepts to know about are the two modes of BLE devices: * 广播模式(也称为通用访问配置文件 GAP) * 连接设备模式(也称为通用属性配置文件…

halcon中的坐标系相关

一、定义 世界坐标系:真实世界中物体实际位置(三维) 相机坐标系:以镜头光心为原点,光轴为Z轴(三维) 图像物理坐标系:以成像图像中心维原点(二维) 像素坐标系…

如何计算两个指定日期相差几年几月几日

一、题目要求 假定给出两个日期,让你计算两个日期之间相差多少年,多少月,多少天,应该如何操作呢? 本文提供网页、ChatGPT法、VBA法和Python法等四种不同的解法。 二、解决办法 1. 网页计算法 这种方法是利用网站给…

数据中心机房建设的关键痛点及解决方案

随着信息技术的飞速发展,数据中心机房已成为企业信息系统的核心。然而,在机房系统的建设过程中,投资及运行维护成为项目管理的关键痛点。合理的投资决策和高效的运维管理是确保机房系统经济性和可靠性的重要因素。本文将探讨机房系统建设的投…

学习Spring的第十三天

Repository : 注解Dao层 Service : 注解Service层 Controller : 注解Web层 值得注意的是 : 当业务中出现一个bean三层都不属于时 , 我们用Component进行注解 Bean依赖注入注解开发 : Value : 可把zhangsan注解进username属性 Value("zhangsan")private String …

debian12 解决 github 访问难的问题

可以在 /etc/hosts 文件中添加几个域名与IP对应关系,从而提高 github.com 的访问速度。 据搜索了解(不太确定),可以添加这几个域名:github.com,github.global.ssl.fastly.net,github.global.fa…

Linux 多线程 | 线程的操作、线程库、线程ID

Linux进程和线程 进程是资源分配的基本单位线程是调度的基本单位线程共享进程数据,但是也有自己的一部分数据:线程ID(LWP)、一组寄存器、栈、errno、信号屏蔽字、调度优先级 进程的多个线程共享 同一地址空间,因此Text Segment、Data Segment都是共享的…

Linux(一)

介绍 常见的操作系统(windows、IOS、Android、MacOS, Linux, Unix); 一个开源、免费的操作系统,其稳定性、安全性、处理多并发已经得到业界的认可;目前很多企业级的项目(c/c/php/python/java/go)都会部署到 Linux/unix 系统上。 吉祥物 …

遗失的源代码之回归之路的探索与实践

背景 最近比较突然被安排接手一个项目,该项目的情况如下 原生和RN结合的混合开发模式组件化开发,有很多基础组件以及业务组件但是在梳理项目依赖时发现了个别组件源码不全的情况,于是写了个cli用于对比两个版本产物文件,生成差异结果以便于快速进行源码找回恢复。 结果如下…