Ai项目十四:基于 LeNet5 的手写数字识别及训练

若该文为原创文章,转载请注明原文出处。

一、介绍

pytorch复现lenet5模型,并检测自己手写的数字图片。

利用torch框架搭建模型相对比较简单,但是也会遇到很多问题,网上资料很多,搭建模型的方法大同小异,在我尝试了自己搭建搭建出来模型,无论是训练还是检测都会遇到很多的问题,像这种自己遇到的问题,请教别人也没有用。原本使用的是github上的一份代码来复现,环境搭建完成后,才发现要有GPU,而我搭建是使用CPU,失败告终,为了复现,租用了AutoDL平台,在次搭建,这里记录GPU下的操作,CPU版本需要修改源码,自行修改,我的目的是在要训练自己的模型并在RK3568上部署,所以先训练并测试好。为后续部署作基础。

二、环境

三、搭建

1、创建虚拟环境

 conda create -n LeNet5_env python==3.8

2、安装pytorch

Previous PyTorch Versions | PyTorch

根据官方PyTorch,安装pytorch,使用的是CPU版本,其他版本自行安装,安装命令:

​​​​​​​
pip install torch==1.7.1+cu110 torchvision==0.8.2+cu110 torchaudio==0.7.2 -f https://download.pytorch.org/whl/torch_stable.html-i https://pypi.tuna.tsinghua.edu.cn/simple

还需要安装一些其他的库

pip install matplotlib -i https://pypi.tuna.tsinghua.edu.cn/simple
pip install opencv-python -i https://pypi.tuna.tsinghua.edu.cn/simple

3、数据集下载

http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz

直接把上面地址复制到网页上,就只可以下载

下载后保存到data/MNIST/raw目录下

四、训练代码

训练模型有四个文件分别为:LeNet5.py;myDatast.py;readMnist.py;train.py

文件LeNet5.py是网络层模型

train.py

import torch
from torch.autograd import Variable
import torch.nn as nn
from torch.utils.data import DataLoader
from readMnist import *
from myDatast import Mnist
from LeNet5 import LeNet5train_images = load_train_images()
train_labels = load_train_labels()trainData = Mnist(train_images, train_labels)
train_data = DataLoader(dataset=trainData, batch_size=1, shuffle=True)lenet5 = LeNet5()
lenet5.cuda()lossFun = nn.CrossEntropyLoss()optimizer = torch.optim.Adam(params=lenet5.parameters(), lr=1e-4)Epochs = 100
L = len(train_data)for epoch in range(Epochs):for i, (img, id) in enumerate(train_data):img = img.float()id = id.float()img = img.cuda()id = id.cuda()img = Variable(img, requires_grad=True)id = Variable(id, requires_grad=True)Output = lenet5.forward(img)loss = lossFun(Output, id.long())optimizer.zero_grad()loss.backward()optimizer.step()iter = epoch * L + i + 1if iter % 100 == 0:print('epoch:{},iter:{},loss:{:.6f}'.format(epoch + 1, iter, loss))torch.save(lenet5.state_dict(), 'lenet5.pth')

 LeNet5.py

import torch.nn as nnclass LeNet5(nn.Module):def __init__(self):super(LeNet5, self).__init__()self.conv1 = nn.Sequential(nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5),nn.Sigmoid(),nn.MaxPool2d(kernel_size=2, stride=2))self.conv2 = nn.Sequential(nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5),nn.Sigmoid(),nn.MaxPool2d(kernel_size=2, stride=2))self.fc1 = nn.Sequential(nn.Linear(in_features=16 * 4 * 4, out_features=120),nn.Sigmoid())self.fc2 = nn.Sequential(nn.Linear(in_features=120, out_features=84),nn.Sigmoid())self.fc3 = nn.Linear(in_features=84, out_features=10)def forward(self, img):img = self.conv1.forward(img)img = self.conv2.forward(img)img = img.view(img.size()[0], -1)img = self.fc1.forward(img)img = self.fc2.forward(img)img = self.fc3.forward(img)return img

 readMnist.py

from torch.utils.data import Dataset
from torchvision import transforms
import numpy as npclass Mnist(Dataset):def __init__(self, dataset, label):self.dataset = datasetself.label = labelself.len = len(self.label)self.transforms = transforms.Compose([transforms.ToTensor() , transforms.Normalize(mean=[0.5], std=[0.5])])def __len__(self):return self.lendef __getitem__(self, item):img = self.dataset[item]img_id = self.label[item]img = np.transpose(img,(1,2,0))img = self.transforms(img)return img, img_id

readMnist.py

import numpy as np
import struct
import matplotlib.pyplot as plt
import cv2fpath = 'G:/enpei_Project_Code/21_LeNet5/LeNet5-master/myLeNet5/data/MNIST/raw/'# 训练集文件
train_images_idx3_ubyte_file = fpath + 'train-images-idx3-ubyte'
# 训练集标签文件
train_labels_idx1_ubyte_file = fpath + 'train-labels-idx1-ubyte'# 测试集文件
test_images_idx3_ubyte_file = fpath + 't10k-images-idx3-ubyte'
# 测试集标签文件
test_labels_idx1_ubyte_file = fpath + 't10k-labels-idx1-ubyte'def decode_idx3_ubyte(idx3_ubyte_file):"""解析idx3文件的通用函数:param idx3_ubyte_file: idx3文件路径:return: 数据集"""# 读取二进制数据bin_data = open(idx3_ubyte_file, 'rb').read()# 解析文件头信息,依次为魔数、图片数量、每张图片高、每张图片宽offset = 0fmt_header = '>iiii'  # 因为数据结构中前4行的数据类型都是32位整型,所以采用i格式,但我们需要读取前4行数据,所以需要4个i。我们后面会看到标签集中,只使用2个ii。magic_number, num_images, num_rows, num_cols = struct.unpack_from(fmt_header, bin_data, offset)print('魔数:%d, 图片数量: %d张, 图片大小: %d*%d' % (magic_number, num_images, num_rows, num_cols))# 解析数据集image_size = num_rows * num_colsoffset += struct.calcsize(fmt_header)  # 获得数据在缓存中的指针位置,从前面介绍的数据结构可以看出,读取了前4行之后,指针位置(即偏移位置offset)指向0016。print(offset)fmt_image = '>' + str(image_size) + 'B'  # 图像数据像素值的类型为unsigned char型,对应的format格式为B。这里还有加上图像大小784,是为了读取784个B格式数据,如果没有则只会读取一个值(即一副图像中的一个像素值)print(fmt_image, offset, struct.calcsize(fmt_image))images = np.empty((num_images, 1, num_rows, num_cols))# plt.figure()for i in range(num_images):if (i + 1) % 10000 == 0:print('已解析 %d' % (i + 1) + '张')print(offset)images[i] = np.array(struct.unpack_from(fmt_image, bin_data, offset)).reshape((1, num_rows, num_cols))# print(images[i])offset += struct.calcsize(fmt_image)#        plt.imshow(images[i],'gray')#        plt.pause(0.00001)#        plt.show()# plt.show()return imagesdef decode_idx1_ubyte(idx1_ubyte_file):"""解析idx1文件的通用函数:param idx1_ubyte_file: idx1文件路径:return: 数据集"""# 读取二进制数据bin_data = open(idx1_ubyte_file, 'rb').read()# 解析文件头信息,依次为魔数和标签数offset = 0fmt_header = '>ii'magic_number, num_images = struct.unpack_from(fmt_header, bin_data, offset)print('魔数:%d, 图片数量: %d张' % (magic_number, num_images))# 解析数据集offset += struct.calcsize(fmt_header)fmt_image = '>B'labels = np.empty(num_images)for i in range(num_images):if (i + 1) % 10000 == 0:print('已解析 %d' % (i + 1) + '张')labels[i] = struct.unpack_from(fmt_image, bin_data, offset)[0]offset += struct.calcsize(fmt_image)return labelsdef load_train_images(idx_ubyte_file=train_images_idx3_ubyte_file):"""TRAINING SET IMAGE FILE (train-images-idx3-ubyte):[offset] [type]          [value]          [description]0000     32 bit integer  0x00000803(2051) magic number0004     32 bit integer  60000            number of images0008     32 bit integer  28               number of rows0012     32 bit integer  28               number of columns0016     unsigned byte   ??               pixel0017     unsigned byte   ??               pixel........xxxx     unsigned byte   ??               pixelPixels are organized row-wise. Pixel values are 0 to 255. 0 means background (white), 255 means foreground (black).:param idx_ubyte_file: idx文件路径:return: n*row*col维np.array对象,n为图片数量"""return decode_idx3_ubyte(idx_ubyte_file)def load_train_labels(idx_ubyte_file=train_labels_idx1_ubyte_file):"""TRAINING SET LABEL FILE (train-labels-idx1-ubyte):[offset] [type]          [value]          [description]0000     32 bit integer  0x00000801(2049) magic number (MSB first)0004     32 bit integer  60000            number of items0008     unsigned byte   ??               label0009     unsigned byte   ??               label........xxxx     unsigned byte   ??               labelThe labels values are 0 to 9.:param idx_ubyte_file: idx文件路径:return: n*1维np.array对象,n为图片数量"""return decode_idx1_ubyte(idx_ubyte_file)def load_test_images(idx_ubyte_file=test_images_idx3_ubyte_file):"""TEST SET IMAGE FILE (t10k-images-idx3-ubyte):[offset] [type]          [value]          [description]0000     32 bit integer  0x00000803(2051) magic number0004     32 bit integer  10000            number of images0008     32 bit integer  28               number of rows0012     32 bit integer  28               number of columns0016     unsigned byte   ??               pixel0017     unsigned byte   ??               pixel........xxxx     unsigned byte   ??               pixelPixels are organized row-wise. Pixel values are 0 to 255. 0 means background (white), 255 means foreground (black).:param idx_ubyte_file: idx文件路径:return: n*row*col维np.array对象,n为图片数量"""return decode_idx3_ubyte(idx_ubyte_file)def load_test_labels(idx_ubyte_file=test_labels_idx1_ubyte_file):"""TEST SET LABEL FILE (t10k-labels-idx1-ubyte):[offset] [type]          [value]          [description]0000     32 bit integer  0x00000801(2049) magic number (MSB first)0004     32 bit integer  10000            number of items0008     unsigned byte   ??               label0009     unsigned byte   ??               label........xxxx     unsigned byte   ??               labelThe labels values are 0 to 9.:param idx_ubyte_file: idx文件路径:return: n*1维np.array对象,n为图片数量"""return decode_idx1_ubyte(idx_ubyte_file)if __name__ == '__main__':train_images = load_train_images()train_labels = load_train_labels()test_images = load_test_images()test_labels = load_test_labels()pass# 查看前十个数据及其标签以读取是否正确for i in range(10):print(train_labels[i])img = train_images[i]img = np.transpose(img, (1, 2, 0))cv2.namedWindow('img')cv2.imshow('img', img)cv2.waitKey(100)print('done')

上面代码需要注意的是数据集的路径,需要修改成对应的路径。

运行python train.py

训练大概5小时

五、测试

from LeNet5 import LeNet5
import torch
from readMnist import *
from myDatast import Mnist
from torch.utils.data import DataLoader
import numpy as np
import cv2test_images = load_test_images()
test_labels = load_test_labels()testData = Mnist(test_images, test_labels)
test_data = DataLoader(dataset=testData, batch_size=1, shuffle=True)lenet5 = LeNet5()
lenet5.load_state_dict(torch.load('lenet5.pth'))
lenet5.eval()showimg = True
js = 0
for i, (img, id) in enumerate(test_data):img = img.float()outid = lenet5(img)oid = torch.argmax(outid)if oid == id:js = js + 1if showimg == True:img = img.numpy()img = np.squeeze(img)id = id.numpy()id = np.squeeze(id)id = np.int32(id)oid = oid.numpy()oid = np.squeeze(oid)maxv = np.max(img)minv = np.min(img)img = (img - minv) / (maxv - minv)cv2.namedWindow("img", 0)cv2.imshow("img", img)title = "img, predicted value:{},truth value:{}".format(oid, id)cv2.setWindowTitle("img",title)cv2.waitKey(1)print('准确率:{:.6f}'.format(js / (i + 1)))

测试结果准确率达到0.986基本达到要求 

如有侵权,或需要完整代码,请及时联系博主。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/123749.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

HTML的学习 Day02(列表、表格、表单)

文章目录 一、列表列表主要分为以下三种类型:1. 无序列表(Unordered List):2. 有序列表(Ordered List):将有序列表的数字改为字母或自定义内容li.../li 列表项标签中value属性,制定列…

【kylin】【ubuntu】搭建本地源

文章目录 一、制作一个本地源仓库制作ubuntu本地仓库制作kylin本地源 二、制作内网源服务器ubuntu系统kylin系统 三、使用内网源ubuntukylin 一、制作一个本地源仓库 制作ubuntu本地仓库 首先需要构建一个本地仓库,用来存放软件包 mkdir -p /path/to/localname/pac…

Linux 下如何调试代码

debug 和 release 在Linux下的默认模式是什么? 是release模式 那你怎么证明他就是release版本? 我们知道如果一个程序可以被调试,那么它一定是debug版本,如果它是release版本,它是没法被调试的,所以说我们可以来调试一…

Jmeter+jenkins接口性能测试平台实践整理

最近两周在研究jmeter+Jenkin的性能测试平台测试dubbo接口,分别尝试使用maven,ant和Shell进行构建,jmeter相关设置略。 一、Jmeterjenkins+Shell+tomcat 安装Jenkins,JDK,tomcat,并设置环境变量&#xff0…

【VR】【unity】如何在VR中实现远程投屏功能?

【背景】 目前主流的VD应用,用于娱乐很棒,但是用于工作还是无法效率地操作键鼠。用虚拟键盘工作则显然是不现实的。为了让自己的头显能够起到小面积代替多显示屏的作用,自己动手开发投屏VR应用。 【思路】 先实现C#的投屏应用。研究如何将C#投屏应用用Unity 3D项目转写。…

WebSocket的那些事(6- RabbitMQ STOMP目的地详解)

目录 一、目的地类型二、Exchange类型目的地三、Queue类型目的地四、AMQ Queue类型目的地五、Topic类型目的地 一、目的地类型 在上节 WebSocket的那些事(5-Spring STOMP支持之连接外部消息代理)中我们已经简单介绍了各种目的地类型,如下图&…

Redis持久化(RDB/AOF)

"在哪里走散,你都会 找 到 我。" 认识持久化 我们在接触Mysql事务的时候,一定了解过Mysql事务的四个特性: "原子性(A)一致性(C)隔离性(I)持久性(D)" 而其中持久性其实与持久化是一回事,所谓持久与不持久&#x…

python实现http/https拦截

python实现http拦截 前言:为什么要使用http拦截一、技术调研二、技术选择三、使用方法前言:为什么要使用http拦截 大多数爬虫玩家会直接选择API请求数据,但是有的网站需要解决扫码登录、Cookie校验、数字签名等,这种方法实现时间长,难度高。需求里面不需要高并发,有没有…

Docker 容器监控 - Weave Scope

Author:rab 目录 前言一、环境二、部署三、监控3.1 容器监控 - 单 Host3.2 容器监控 - 多 Host 总结 前言 Docker 容器的监控方式有很多,如 cAdvisor、Prometheus 等。今天我们来看看其另一种监控方式 —— Weave Scope,此监控方法似乎用的人…

【C语言】循环结构程序设计 (详细讲解)

前言:前面介绍了程序中常常用到的顺序结构和选择结构,但是只有这两种结构是不够的,还有用到循环结构(或者称为重复结构)。因为在日常生活中或是在程序所处理的问题中常常遇到需要重复处理的问题。 【卫卫卫的代码仓库】 【选择结构】 【专栏链…

蜂蜜配送销售商城小程序的作用是什么

蜂蜜是农产品中重要的一个类目,其受众之广市场需求量大,但由于非人人必需品,因此传统线下门店经营也面临着痛点,线上入驻平台也有很多限制难以打造自有品牌,无法管理销售商品及会员、营销等,缺少自营渠道&a…

列表的增删改查和遍历

任务概念 什么是任务 任务是一个参数为指针,无法返回的函数,函数体为死循环不能返回任务的实现过程 每个任务是独立的,需要为任务分别分配栈称为任务栈,通常是预定义的全局数组,也可以是动态分配的一段内存空间&#…