【PyTorch攻略(2/7)】 加载数据集

一、说明

        PyTorch提供了两个数据原语:torch.utils.data.DataLoadertorch.utils.data.Dataset,允许您使用预加载的数据集以及您自己的数据。数据集存储样本及其相应的标签,DataLoader 围绕数据集包装一个可迭代对象,以便轻松访问样本。

        PyTorch域库提供了许多示例预加载数据集,例如FashionMNIST,它子类torch.utils.data.Dataset并实现特定于特定数据的函数。可以在此处找到它们并用作原型设计和基准测试模型的示例:

  • 图像数据集
  • 文本数据集
  • 音频数据集

二、加载数据集

        我们将从TorchVision加载FashionMNIST数据集。FashionMNIST是Zalando文章图像的数据集,由60,000个训练示例和10,000个测试示例组成。每个示例包含一个 28x28 灰度图像和一个来自 10 个类之一的相关标签。

  • 每张图片高 28 像素,宽 28 像素,共 784 像素。
  • 这 10 个类告诉它是什么类型的图像。例如,T型短裤/上衣,裤子,套头衫,连衣裙,包,踝靴等。
  • 灰度是介于 0 到 255 之间的值,用于测量黑白图像的强度。强度值从白色增加到黑色。例如,白色为 0,黑色为 255。

        我们使用以下参数加载 FashionMNIST 数据集:

  • 是存储训练/测试数据的路径。
  • 训练指定训练或测试数据集。
  • download = 如果数据在根目录中不可用,则 True 从互联网下载数据。
  • 转换指定特征和标注转换
import torch
from torch.utils.data import Dataset
from torchvision import datasets
from torchvision.transforms import ToTensor, Lambda
import matplotlib.pyplot as plttraining_data = datasets.FashionMNIST(root="data",train=True,download=True,transform=ToTensor()
)test_data = datasets.FashionMNIST(root="data",train=False,download=True,transform=ToTensor()
)

三、迭代和可视化数据集

我们可以像列表一样手动索引数据集:training_data[index]。我们使用 matplotlib 来可视化训练数据中的一些样本。

labels_map = {0: "T-Shirt",1: "Trouser",2: "Pullover",3: "Dress",4: "Coat",5: "Sandal",6: "Shirt",7: "Sneaker",8: "Bag",9: "Ankle Boot",
}
figure = plt.figure(figsize=(8, 8))
cols, rows = 3, 3
for i in range(1, cols * rows + 1):sample_idx = torch.randint(len(training_data), size=(1,)).item()img, label = training_data[sample_idx] # Iterate training datafigure.add_subplot(rows, cols, i)plt.title(labels_map[label])plt.axis("off")plt.imshow(img.squeeze(), cmap="gray")
plt.show()

四、准备数据以使用数据加载程序进行训练

        数据集检索数据集的特征并一次标记一个样本。在训练模型时,我们通常希望以“小批量”方式传递样本,在每个时期重新洗牌数据以减少模型过度拟合,并使用 Python 的多处理来加速数据检索。

在机器学习中,需要指定数据集中的特征和标签。要素是输入,标注是输出。我们训练特征并训练模型来预测标签。

        DataLoader 是一个迭代对象,它在一个简单的 API 中为我们抽象了这种复杂性。要使用数据加载器,我们需要设置以下参数:

  • 数据是将用于训练模型的训练数据;以及用于评估模型的测试数据。
  • 批大小是每个批中要处理的记录数。
  • 随机播放是按索引随机抽取的数据。
from torch.utils.data import DataLoadertrain_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)

五、遍历数据加载器

        我们已将数据集加载到数据加载器中,并可以根据需要循环访问数据集。下面的每次迭代都会返回一批train_featurestrain_labels(分别包含 batch_size = 64 个要素和标注)。由于我们指定了 shuffle = True,因此在我们遍历所有批次后,数据将被洗牌。

# Display image and label
train_features, train_labels = next(iter(train_dataloader))
print(f"Feature batch shape: {train_features.size()}")
print(f"Labels batch shape: {train_labels.size()}")img = train_features[0].squeeze()
label = train_labels[0]
plt.imshow(img, cmap="gray")
plt.show()
print(f"Label: {label}")

        NOrmalization是一种常见的数据预处理技术,用于缩放或转换数据,以确保每个特征的学习贡献相等。例如,灰度图像中的每个像素都有一个介于 0 到 255 之间的值,这些值是特征。如果一个像素值为 17,另一个像素值为 197。像素重要性的分布将不均匀,因为较高的像素体积会偏离学习。归一化会更改数据的范围,而不会扭曲其在我们的功能之间的区别。进行此预处理是为了避免:

  • 预测精度降低
  • 模型学习的难度
  • 特征数据范围的不利分布

六、变换

数据并不总是以训练机器学习算法所需的最终处理形式出现。我们使用 transform 对数据执行一些操作,并使其适合训练。

所有 TorchVision 数据集都有两个参数:用于修改特征的转换和用于修改接受包含转换逻辑的可调用对象的标签target_transformtorchvision.transform模块提供了几种开箱即用的常用变换。

FashionMNIST 功能采用 PIL 图像格式,标签为整数。对于训练,我们需要特征作为规范化张量,标签作为独热编码张量。为了进行这些转换,我们使用ToTensorLamda

from torchvision import datasets
from torchvision.transforms import ToTensor, Lambdads = datasets.FashionMNIST(root="data",train=True,download=True,transform=ToTensor(),target_transform=Lambda(lambda y: torch.zeros(10, dtype=torch.float).scatter_(0, torch.tensor(y), value=1))
)

七、ToTensor()

ToTensor 将 PIL 图像或 NumPy ndarray 转换为 FloatTensor,并将图像的像素强度值缩放在 [0., 1.] 范围内。

八、Lambda()

        Lambda 应用任何用户定义的 lambda 函数。在这里,我们定义了一个函数来将整数转换为独热编码张量。它首先创建一个大小为 10(我们数据集中的标签数量)的张量并调用 scatter,它在索引上分配一个 value=1,如标签 y 给出的那样。您也可以将torch.nn.functional.one_hot用作其他选项。

target_transform = Lambda(lambda y: torch.zeros(10, dtype=torch.float).scatter_(dim=0, index=torch.tensor(y), value=1))

下一>> PyTorch 简介 (3/7)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/112310.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Centos7.9 一键脚本部署 LibreNMS 网络监控系统

前言: LibreNMS 是个以 PHP/MySQL 为基底的自动探索网络监控系统 LibreNMS 官网 版本23.8.2-52-g7bbe0a2 - Thu Sep 14 2023 22:33:23 GMT0700数据库纲要2023_09_01_084057_application_new_defaults (259)Web 服务器nginx/1.20.1PHP8.1.23Python3.6.8DatabaseMa…

etcd之读性能主要影响因素

1、Raft模块-线性读ReadIndex-节点之间的RTT延时、磁盘IO 线性读时Follower节点首先会向Raft 模块发送ReadIndex请求,此时Raft模块会先向各节点发送心跳确认,一半以上节点确认 Leader 身份后由leader节点将已提交日志索引 (committed index) 封装成 Rea…

redis 集群(cluster)

1. 前言 我们知道,在Web服务器中,高可用是指服务器可以正常访问的时间,衡量的标准是在多长时间内可以提供正常服务(99.9%、99.99%、99.999% 等等)。但是在Redis语境中,高可用的含义似乎要宽泛一些&#xf…

Windows PHP 将 WORD转PDF,执行完成后 释放进程

Windows PHP 将 WORD转PDF,执行完成后 释放进程 word转PDF清理任务进程 【附赠彩蛋】每次PHP执行完word转pdf之后,在任务进程中都会生成并残留WINWORD.EXE进程,时间久了,服务器就会越来原卡,本文完整的讲述怎么转PDF和转换之后的操作。 word转PDF /**$doc 传入完整的doc路…

Vue3函数式编程

文章目录 前言一、三种编程风格1.template2.jsx/tsx3.函数式编写风格 二、函数式编程1.使用场景2.参数3.例子3.render渲染函数 总结 前言 本文主要记录vue3中的函数式编程以及其他编程风格的简介 一、三种编程风格 1.template Vue 使用一种基于 HTML 的模板语法,…

OPENCV实现人类识别(包括眼睛、鼻子、嘴巴)

人脸识别步骤 # -*- coding:utf-8 -*- """ 作者:794919561 日期:2023/9/14 """ import cv2 import numpy as np # load xml face_xml = cv2.CascadeClassifier(F:\\learnOpenCV\\opencv\\data\\haarcascades\\haarcascade_frontalface_defaul…

企业架构LNMP学习笔记43

memcached的使用: 命令行连接和操作: telnet连接使用: memcached默认使用启动服务占用tcp 11211端口,可以通过telnet进行连接使用。 安装telnet进行连接: 连接成功,敲击多次,如果看到error&…

使用Linkerd实现流量管理:学习如何使用Linkerd的路由规则来实现流量的动态控制

🌷🍁 博主猫头虎 带您 Go to New World.✨🍁 🦄 博客首页——猫头虎的博客🎐 🐳《面试题大全专栏》 文章图文并茂🦕生动形象🦖简单易学!欢迎大家来踩踩~🌺 &a…

I Pa?sWorD

2023icpc网络赛第一场 I 题意:题目给出只包含大小写字母,数字以及?的字符串,对于每一个小写字母,这一位字符既有可能是该小写字母,也有可能是该小写字母的对应大写字母,也就是该位的字符有两种可能&#x…

电商项目高级篇-01 elasticsearch

电商项目高级篇-01 elasticsearch 1、linux下安装elasticsearch和可视化工具 1、linux下安装elasticsearch和可视化工具 将安装好jdk1.8和tomcat的centos7下安装elasticsearch docker pull elasticsearch:7.4.2docker pull kibana:7.4.2##docker下安装软件需要配置挂载。方便…

基于matlab实现的多普勒脉冲雷达回波仿真

完整程序: clear all;clc;close all; fc3e9; %载波频率 PRF2000; Br5e6; %带宽 fs10*Br; %采样频率 Tp5e-6; %脉宽 KrBr/Tp; %频率变化率 c3e8; %光速 lamda…

MySQL查询表结构方法

MySQL查询数据库单个表结构代码 – 查询数据库表信息 SELECT​ COLUMN_NAME 列名,​ DATA_TYPE 字段类型,​ CHARACTER_MAXIMUM_LENGTH 长度,​ IS_NULLABLE 是否为空,​ IF(column_key PRI,Y,) 是否为主键,​ COLUMN_DEFAULT 默认值,​ COLUMN_COMMENT 备注FROM​ INFORMAT…