MNIST内置手写数字数据集的实现

torchvision库

torchivision库是PyTorch中用来处理图像和视频的一个辅助库,接下来我们就会使用torchvision库加载内置的数据集进行分类模型的演示

为了统一数据加载和处理代码,PyTorch提供了两个类用于处理数据加载,他们分别是torch.utils.data.Dataset类和torch.utils.data.DataLoader类,通过这两个类可使数据集加载和预处理代码与模型训练代码脱钩,从而获得更好的代码模块化和代码可读性。torchvision加载的内置图片数据集均继承自torch.utils.data.Dataset类,因此可直接使用加载的内置数据集创建DataLoader.

加载内置图片数据集

PyTorch的内置图片数据集均在torchvision.datasets模块下,包含Caltech、CelebA、CIFAR、Cityscapes、COCO、Fashion-MNIST、ImageNet、MNIST等很多著名的数据集,其中MNIAT数据集是手写数字数据集,这是一个很适合入门者学习使用的小型计算机视觉数据集,它包含0到9的手写数字图片和每一张图片对应的标签。接下来我们就以此数据集为例子进行学习。

import torchvision  # 导入torchvision库
from torchvision.transforms import ToTensor  #做好准备工作,导入所需要的包
import torch
import matplotlib.pyplot as plt
import numpy as np

首先就是对我们所需要的库进行导入。

我对上述的代码进行一下解读,首先导入了torchvision库,从torchvision.transforms模块下导入ToTensor类。torchvision.transforms模块包含了转换函数,使用它可以很方便的对加载的图形进行各种变换,这里用到的ToTensor类,该类的主要作用有以下3点。

  1. 将输入转换为张量
  2. 将读取图片的格式规范为(channel,height,width),这里和我们经常遇到的图片格式有可能会有一些去呗,PyTorch中的图片格式一般是通道数(channel)在前,然后是高度(height)和宽度(width)
  3. 将图片像素的取值范围归一化,规范为0到1的范围内
train_ds=torchvision.datasets.MNIST('data/',train=True,transform=ToTensor(),download=True)
test_ds=torchvision.datasets.MNIST('data/',train=False,transform=ToTensor(),download=True)

通过torchvision.datasets.MNIST方法加载MNIST数据集,方法中的第一个参数为data/表示下载数据集存放的位置,参数train表示是否是训练数据,若为True,则加载训练数据集,若为False,则加载测试数据集;
使用参数transform表示对加载数据的预处理,参数值为ToTensor();
最后一个参数download=True表示将下载此数据集,一旦下载完成后,下一次执行此代码是,将优先从本地文件夹直接加载,如果咱们的计算机不能连接互联网,也可以直接将文件复制到data文件夹中,这样就能从本地直接加载数据了。

现在我们得到了两个数据集,分别是训练数据集和测试数据集,PyTorch还提供了torch.utils.data.DataLoader类用以对数据集做进一步的处理,DataLoader接收数据集,并执行复杂的操作,如小批次处理、多线程、随机打乱等,以便从数据集中获取数据。它接收来自用户的Dataset实例,并使用采样器策略将数据采样为小批次。DataLoader的目的如下

1.使用shuffle参数对数据集做乱序的操作,一般情况下,需要对训练数据集进行乱序的操作,因为原始的数据在样本均衡的情况下肯呢个是按照某种顺序进行排列的,经过顺序打乱之后,数据的排列就会拥有一定的随机性,这样做可以避免出现模型反复依次序学习数据的特征或者学习到的只是数据的次数特征的情况。

2.将数据采样为小批次,可用batch_size参数指定批次大小。首先单个样本训练有一个很大的缺点,就是损失和梯度会受到单个样本的影响,如果样本分布不均匀,或者有错误标注样本,则会引起梯度的巨大震荡,从而导致模型训练效果很差。为了解决这个问题,我们可以考虑使用批量数据训练(也叫做批量梯度下降算法),通过遍历全部数据集算一次损失函数,然后计算损失对各个参数的梯度,并更新参数。这种训练方式没更新一次,参数都要把数据集里所有样本都看一遍,不仅计算开销大,而且计算速度慢。为了克服上述方法的缺点,一般采用的是一种折中手段进行损失函数计算:即把数据分为若干个小的批次,按批次来更新参数,这样,一个批次中的一组数据共同决定了本次梯度的方向,大大降低了参数更新时的梯度方差,下降起来更加稳定,减少了随机性,与单样本训练相比,小批次训练可利用矩阵操作进行有效的梯度计算,计算量也不是很大,对计算机内存的要求也不高。

3.可以充分利用多个子进程加速数据预处理。num_workers参数可以指定子进程的数量

4.可通过collate_fn参数传递批次数据的处理函数,实现在DataLoader中对批次数据做转换处理

train_dl=torch.utils.data.DataLoader(train_ds,batch_size=64,shuffle=True)
test_dl=torch.utils.data.DataLoader(test_ds,batch_size=46)

上面代码中分别创建了训练数据和测试数据的DataLoader,并设置他们的批次大小为64,对训练数据设置了shuffle为True;对测试数据,由于仅仅作为测试,没必要做乱序。

DataLoader是可迭代对象,我们观察它返回的数据集的类型,给大家对对DataLoader和MNIST数据集有一个直观的印象

imgs,labels=next(iter(train_dl))#创建生成器,并用next方法返回一个批次的数据
print(imgs.shape)
print(labels.shape)

我们使用iter方法将DataLoader对象创建为生成器,并使用next方法反悔了一个批次的图像(imgs)和对应的一个批次的标签(labels),image.shape为torch.Size([64,1,28,28]),这里的64是批次,我们可以认为这代表64张形状为(1,28,28)的图片,其中1为通道数,28和28分别为高和宽;既然这里有64张图片,那么就对应着有64个标签,也就是labels.shape所显示的torch.Size([64])

结果绘制

# 我们使用Matplotlib来绘制一下前10张的图片
plt.figure(figsize=(20,2))  # 创建一个(10,1)大小的画布
for i,img in enumerate(imgs[:20]):npimg=img.numpy()  # 将张量转换为ndarraynpimg=np.squeeze(npimg)  # 图片形状由(1,28,28)转换为(28,28)plt.subplot(1,20,i+1)  # 初始化子图,3个参数表示1行10列的第i+1个子图plt.imshow(npimg)  #在子图中绘制单张图片plt.axis('off')  # 关闭显示子图坐标

plt.imshow() 是一个用于显示图像的函数,通常用于在 Python 中使用 Matplotlib 库绘制图像。它可以接受一个数组或图像数据,并将其显示为图像。这个函数通常用于可视化图像数据,比如热图、灰度图、彩色图等。plt.imshow() 可以接受一些参数,比如 cmap(颜色映射)、interpolation(插值方法)等,用来控制图像的显示效果。

接下来,我们打印对应的标签

print(labels[:20])

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/282488.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【基础算法】试除法判定质数(优化)

文章目录 算法优化模板题目代码实现 算法优化模板 bool is_prime(int n){if(n < 2) return false;for(int i 2;i < n / i;i ){ //优化内容if(n % i 0){return false;}}return true; }注意这里的一个总要优化是for循环的终止条件是i<n/i。为什么不是i<n或者i<…

【ArkTS】如何修改应用的首页

之前看到一种说法&#xff0c;说是应用首页是 entry > src > main > resources > base > profile > main_pages.json 中src配置中数组第一个路径元素。这种说法是不对的&#xff01;&#xff01;&#xff01; 如果需要修改应用加载时的首页&#xff0c;需要…

二叉树前,中序推后续_中,后续推前序

文章目录 介绍思路例子 介绍 二叉树是由根、左子树、右子树三部分组成。 二叉树的遍历方式又可以分为前序遍历&#xff0c;中序遍历&#xff0c;后序遍历。 前序遍历&#xff1a;根&#xff0c;左子树&#xff0c;右子树 中序遍历&#xff1a;左子树&#xff0c;根&#xff0…

python学习1补充

大家好&#xff0c;这里是七七&#xff0c;这个专栏是用代码实例来学习的&#xff0c;不是去介绍很多知识的。 话不多说&#xff0c;开始今天的内容 目录 代码1 代码2 代码3 代码4 代码5 学习1的总代码 代码1 groupeddf.groupby(单品编码) result{} groupeddf.groupb…

配置 vim 默认显示行号 行数 :set number

vi ~/.vimrc 最后添加一行 :set number保存退出&#xff0c;再次 vim 打开文件&#xff0c;默认就会显示行号了

Python-折线图可视化

折线图可视化 1.JSON数据格式2.pyecharts模块介绍3.pyecharts快速入门4.创建折线图 1.JSON数据格式 1.1什么是JSON JSON是一种轻量级的数据交互格式。可以按照JSON指定的格式去组织和封装数据JSON本质上是一个带有特定格式的字符串 1.2主要功能json就是一种在各个编程语言中流…

网络监控软件提高企业网络效率

企业网络监控是主动监控和管理业务网络以确保无缝性能并提高可靠性的做法&#xff0c;持续监控和分析网络各层的可用性、运行状况和性能&#xff0c;但是&#xff0c;选择的网络监控软件应该能够满足业务需求。不是所有的网络监控工具都能用于监控企业网络&#xff0c;它们无法…

DENet:用于可见水印去除的Disentangled Embedding网络笔记

1 Title DENet: Disentangled Embedding Network for Visible Watermark Removal&#xff08;Ruizhou Sun、Yukun Su、Qingyao Wu&#xff09;[AAAI2023 Oral] 2 Conclusion This paper propose a novel contrastive learning mechanism to disentangle the high-level embedd…

【基础算法】前缀和

文章目录 算法介绍什么是前缀和&#xff1f;&#xff1f;前缀和的作用一维数组求解前缀和(Si)二维数组求解前缀项和 示例题目1&#xff1a;acwing795示例题目2&#xff1a;acwing796总结收获 算法介绍 什么是前缀和&#xff1f;&#xff1f; 数组: a[1], a[2], a[3], a[4], a[…

WPF——命令commond的实现方法

命令commond的实现方法 属性通知的方式 鼠标监听绑定事件 行为&#xff1a;可以传递界面控件的参数 第一种&#xff1a; 第二种&#xff1a; 附加属性 propa&#xff1a;附加属性快捷方式

加密的艺术:对称加密的奇妙之处(下)

&#x1f90d; 前端开发工程师&#xff08;主业&#xff09;、技术博主&#xff08;副业&#xff09;、已过CET6 &#x1f368; 阿珊和她的猫_CSDN个人主页 &#x1f560; 牛客高级专题作者、在牛客打造高质量专栏《前端面试必备》 &#x1f35a; 蓝桥云课签约作者、已在蓝桥云…

SpringData JPA 整合Springboot

1.导入依赖 <?xml version"1.0" encoding"UTF-8"?> <project xmlns"http://maven.apache.org/POM/4.0.0"xmlns:xsi"http://www.w3.org/2001/XMLSchema-instance"xsi:schemaLocation"http://maven.apache.org/POM/4.0…