通过MNIST手写数字识别任务快速入门深度学习(事无巨细版)

可点此跳转看全篇

本文内容

  • 什么是深度学习
  • 入门深度学习时的困惑
  • 典型的入门案例——CNN实现的MNIST手写数字识别
    • 虚拟环境的创建
      • 创建虚拟环境
      • 配置需求的依赖包
    • 代码
      • 1. 引入依赖包
      • 2. 准备数据集datasets
      • 3. 准备数据加载器dataloader
      • 4. 配置网络
      • 5. 设置训练器
      • 6. 网络训练
      • 7. 模型保存
      • 8. 加载和测试模型

公众号

什么是深度学习

深度学习(DL, Deep Learning)是机器学习(ML, Machine Learning)领域中一个研究方向。
深度学习通过对样本数据的内在规律和特征的提取与抽象,在不同维度和层次上进行处理,让机器能够像人一样具有分析学习能力,能够识别文字、图像和声音等数据。
相比于初期的机器学习,深度学习是更加复杂的算法,但是同时因为深度学习算法的普适性,以及在语音和图像识别方面取得的惊人效果,他的发展速度远远超过先前相关技术。

入门深度学习时的困惑

很多同学刚入门的时候,会对代码中的网络主体在哪里,是怎么训练的,以及训练完如何保存,保存后如何使用等问题产生疑惑。这些问题会随着阅览的代码增多而自然化解。我们先从简单的开始:如何跑起来一个神经网络。

典型的入门案例——CNN实现的MNIST手写数字识别

废话不多说,直接通过MNIST手写识别快速入门深度学习。对于手写识别任务,目前已经能够被很轻松的解决。

虚拟环境的创建

创建虚拟环境

为了不使python环境变的混乱,我们使用conda工具创建虚拟python环境,每个虚拟环境之间是隔离的。具体的conda环境安装网上已经有很多教程,很多博主都写的很详细,这里就不展开了。我们直接使用conda工具创建一个新的python环境
conda create -n mnist_pytorch python=3.8
这句创建一个名字为mnist_pytorch,python版本为3.8的虚拟环境,使用如下命令激活环境
conda activate mnist_pytorch

配置需求的依赖包

conda install pytorch==2.0
conda install torchvision==0.15.1
除了会安装pytorch意外,还会自动配置相关的依赖包,比如numpy等。其他的库也会进行安装,如果没有找到对应的库,可以参考后面的安装命令重新来一次。
还需继续安装PIL的库:pillow,没错,import的名字和他的库名并不相同。
pip install pillow
安装画图的库
pip install matplotlib,python库os在安装python环境的时候,就会根据操作系统进行自动配置。

代码

1. 引入依赖包

首先我们需要如下的python包import

import numpy as np
import torch 
from torch import nn
from PIL import Image
import matplotlib.pyplot as plt
import os
from torchvision import datasets, transforms,utils

其中:

  • numpy是python著名以及普遍使用的第三方库,用于进行科学计算
  • PIL全名为Python Image Library,用于图像的处理
  • matplotlib是python中普遍使用的绘图库,而pyplot是其一种快捷的绘图接口
  • os为python提供了丰富的方法来处理文件和目录
  • torchvision通过这个库,我们能够实现很多经典数据集的下载,包括COCO,ImageNet,CIFCAR等,当然也包括我们的这个MNIST。

2. 准备数据集datasets

MNIST数据集是一张张黑底白字的手写体图片,大小均为28 × \times × 28。如下在这里插入图片描述

按照下面的代码取出的每一个数据都是:{ 图片, 数字 } 的组合。
运行以下的代码

# 1
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize(mean=[0.5],std=[0.5])])
# 2
train_data = datasets.MNIST(root = "./data/",transform=transform,train = True,download = True)
# 3
test_data = datasets.MNIST(root="./data/",transform = transform,train = False)

我们使用torchvision库中提供的方法,把MNIST数据下载到本地并分为测试集和训练集。每一句的作用解释如下:

  • 第一句指定了数据集预处理的方案,在上述代码中,制定了将数据转变为tensor()格式,可以理解这是pytorch中的矩阵;并且将数值归一化,指定均值和标准差均为0.5。数据的标准化或者说是归一化是神经网络数据预处理中经常采用的方式,能够剔除数据中的极端情况,并且有利于模型训练过程的收敛。
  • 第二、三句通过引入torchvision中的datasets方法,指定MNIST数据集的位置,这里指定为和当前程序文件同一文件夹下的data文件夹,指定预处理方法为第一句设定的transform,然后训练集将train设定为true,测试集设定为false。第一个的download选项指定了如果没有在该路径下找到数据,那么会自动下载到该路径。
    由于没有实现下载数据集,代码运行后有输出:
    在这里插入图片描述

运行完代码后,我们会得到以下的文件树:
在这里插入图片描述

3. 准备数据加载器dataloader

# 1
train_loader = torch.utils.data.DataLoader(train_data,batch_size=64,shuffle=True

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/517924.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

CVPR 2022 Oral | Bailando: 基于编舞记忆和Actor-Critic GPT的3D舞蹈生成

目录 测试结果: 02 提出的方法 测试结果: 预测有3个步骤,速度比较慢 02 提出的方法 1. 针对舞蹈序列的VQ-VAE和编舞记忆 与之前的方法不同,我们不学习从音频特征到 3D 关键点序列的连续域的直接映射。相反,我们先让…

2024环境工程、能源系统与化学材料国际会议(ICEEESCM 2024)

2024环境工程、能源系统与化学材料国际会议(ICEEESCM 2024) 一、【会议简介】 2024环境工程、能源系统与化学材料国际会议(ICEEESCM 2024)将于2024年在西安举行。会议将围绕环境工程、能源系统与化学材料等议题展开讨论,旨在为从事环境工程…

arm架构服务器使用Virtual Machine Manager安装的kylin v10虚拟机

本文中使用Virtual Machine Manager安装kylin v10的虚拟机 新建虚拟机 新建虚拟机 选择镜像,下一步 设置内存和CPU,下一步 选择或创建自定义存储(默认存储位置的磁盘空间可能不够用) 点击管理,打开选择存储卷页…

曲线曲面 - 连续性, 坐标变换矩阵

连续性 有两种:参数连续性(Parametric Continuity)、几何连续性(Geometric Continuity)参数连续性: 零阶参数连续性,记为,指相邻两段曲线在结合点处具有相同的坐标 一阶参数连续性&…

1.JavaWebJava基础加强[万字长文]-Junit、反射、注解核心知识点梳理

导语: 一、Junit单元测试 1.Junit测试概述 2.Junit使用步骤 3.Junit_Before&After 二、反射 1.反射概述 2.反射获取字节码Class对象的三种方式 3.Class对象功能概述 4.Class对象功能_获取Field 5.Class对象功能_获取Constructor 6.Class对象功能_获取…

响应人大代表王旭的提议:996程序员也要每天一节体育课

哈喽,我是熊子峰,38岁程序员,正在结合AI写作进行自我成长,穿越程序员的中年危机,这是第 69 篇日更文章。 每天一节体育课 今天,看到一条新闻,人大代表王旭提议中小学生每天应该有一节体育课&am…

你不得不知道的Python AI库

Python是人工智能(AI)和机器学习(ML)领域中使用最广泛的编程语言之一,拥有丰富的库支持各种AI和ML任务。本文介绍一些经典的Python AI库。 1. NumPy 简介:NumPy(Numerical Python)…

Docker部署ruoyi前后端分离项目

目录 一. 介绍前后端项目 二. 搭建局域网 2.1 创建网络 2.2 注意点 三. Redis 3.1 安装 3.2 配置redis.conf文件 3.3 测试 四. 安装MySQL 4.1 安装 4.2 配置my2.cnf文件 4.3 充许远程连接 五. 若依部署后端服务 5.1 数据导入 5.2 使用Dockerfile自定义镜像 5.3 运行…

基于单片机的医院输液系统设计

目 录 摘 要 Ⅰ Abstract Ⅱ 引 言 1 1系统方案设计与论证 3 1.1系统硬件结构总体设计方案 3 1.2点滴速度测量电路方案的选择与论证 3 1.3液面检测电路方案的选择与论证 4 1.4通过电机控制滴速电路的方案与论证 4 1.5显示器接口电路方案选择与论证 5 1.6键盘接口电路方案选择与…

土地利用数据分类过程教学/土地利用分类/遥感解译/土地利用获取来源介绍/地理数据获取

本篇主要介绍如何对影像数据进行分类解译,及过程教学,示例数据下载链接:数据下载链接 一、背景介绍 土地是人类赖以生存与发展的重要资源和物质保障,在“人口-资源-环境-发展&#x…

稀碎从零算法笔记Day6-LeetCode:长度最小的子数组

前言:做JD的网安笔试题,结果查找子串(单词)这个操作不会。痛定思痛,决定学习滑动数组 题型:数组、双指针、滑动窗口 链接:209. 长度最小的子数组 - 力扣(LeetCode) 来…

关于yolov8的DFL模块(pytorch以及tensorrt)

先看代码 class DFL(nn.Module):"""Integral module of Distribution Focal Loss (DFL).Proposed in Generalized Focal Loss https://ieeexplore.ieee.org/document/9792391"""def __init__(self, c116):"""Initialize a convo…