PyTorch从入门到放弃之数据模块

news/2024/11/13 10:33:12/文章来源:https://www.cnblogs.com/kohler21/p/18400571

目录
  • Dataset简介及用法
    • Map-style datasets类型
    • Iterable-style datasets类型
  • DataLoader简介及用法

Dataset 和 DataLoader 都 是 用 来 帮 助 我 们 加 载 数 据 集 的 两 个 重 要 工 具类。 Dataset 用来构造支持索引的数据集。
在训练时需要在全部样本中拿出小批量数据参与每次的训练,因此我们需要使用 DataLoader ,即 DataLoader 是用来在 Dataset 里取出一组数据 (mini-batch)供训练时快速使用的。

Dataset简介及用法

Dataset 本质上就是一个抽象类,可以把数据封装成 Python 可以识别的数据结构。Dataset 类不能实例化,所以在使用 Dataset 的时候,我们需要定义自己的数据集类,也是 Dataset 的子类,来继承 Dataset 类的属性和方法。Dataset 可作为 DataLoader 的参数传入 DataLoader ,实现基于张量的数据预处理。Dataset 主要有两种类型,分别为 Map-style datasets 和 Iterable-style datasets 。

Map-style datasets类型

该类型实现了 getitem() 和 len() 方法,它代表数据的索引到真正数据样本的映射。也就是说,使用这种方式读取的数据并非直接直接把所有数据读取出来,而是读取数据的索引或者键值。其中,列表或者数组类型的数据读取的就是索引,而字典类型的数据读取的就是键值。在访问时,用dataset[idx]访问idx对应的真实数据。这种类型的数据也是使用最多的类型。

Iterable-style datasets类型

该类型实现了 iter() 方法,与上述类型不同之处在于,他会将真实的数据全部载入,然后在整个数据集上进行迭代。如果随机读取的情况不能实现或者代价太大就用这种读取方式。这种读取数据的方式比较适合处理流数据

Dataset 作为一个抽象类,需要定义其子类来实例化。所以需要自己定义其子类或者使用已经定义好的子类。

(1)自定义子类

  • 必须要继承已经内置的抽象类 dataset
  • 必须要重写其中的 init() 方法、 getitem() 方法和 len() 方法
  • 其中 getitem() 方法实现通过给定的索引遍历数据样本, len() 方法实现返回数据的条数

定义一个MyDataset类继承Dataset抽象类,其中pass为占位符,并且改写其中的三个方法

import torch
from torch.utils.data import Datasetclass MyDataset(Dataset):def __init__(self):passdef __getitem__(self, index):passdef __len__(self):pass

这里定义了一个MyDataset类继承Dataset抽象类,并且改写其中的三个方法。在创建的dataset类中可根据用户本身的需求对数据进行处理。可独立编写的数据处理函数,在__getitem__()函数中进行调用;或者直接将数据处理方法写在__getitem__()函数中或者__init__()函数中,但__getitem__()函数必须根据index返回响应的值,该值会通过index传到DataLoader中进行厚涂的Batch批量处理。

在创建的dataset类中可根据自己的需求对数据进行处理,以时间序列使用为示例,输入3个时间步,输出1个时间步,batch_size=5

import torch 
from torch.utils.data import Datasetclass GetTrainTestData(Dataset):def __init__(self, input_len, output_len, train_rate, is_train=True):super().__init__()# 使用sin函数返回10000个时间序列,如果不自己构造数据,就使用numpy,pandas等读取自己的数据为x即可。# 以下数据组织这块既可以放在init方法里,也可以放在getitem方法里self.x = torch.sin(torch.arange(0, 1000, 0.1))self.sample_num = len(self.x)self.input_len = input_lenself.output_len = output_lenself.train_rate = train_rateself.src, self.trg = [], []if is_train:for i in range(int(self.sample_num*train_rate)-self.input_len-self.output_len):self.src.append(self.x[i:(i+input_len)])self.trg.append(self.x[(i+input_len):(i+input_len+output_len)])else:for i in range(int(self.sample_num*train_rate), self.sample_num-self.input_len-self.output_len):self.src.append(self.x[i:(i+input_len)])self.trg.append(self.x[(i+input_len):(i+input_len+output_len)])print(len(self.src), len(self.trg))def __getitem__(self, index):return self.src[index], self.trg[index]def __len__(self):return len(self.src)  # 或者return len(self.trg), src和trg长度一样

实例化定义好的Dataset子类GetTrainTestData

data_train = GetTrainTestData(input_len=3, output_len=1, train_rate=0.8, is_train=True)
data_test = GetTrainTestData(input_len=3, output_len=1, train_rate=0.8, is_train=False)

(2)已经定义好的内置子类

除了自己定义子类继承Dataset外,还可以使用PyTorch提供的已经被定义好的子类,如TensorDataset和IterableDataset。

对 于 给 定 的 tensor 数 据 , TensorDataset 是 一 个 包 装 了 Tensor 的Dataset 子类,传入的参数就是张量,每个样本都可以通过 Tensor 第一个维度的索引获取,所以传入张量的第一个维度必须一致。

PyTorch官方给出的TensorDataset类的定义:

class TensorDataset(Dataset[Tuple[Tensor, ...]]):r"""Dataset wrapping tensors.Each sample will be retrieved by indexing tensors along the first dimension.Args:*tensors (Tensor): tensors that have the same size of the first dimension."""tensors: Tuple[Tensor, ...]def __init__(self, *tensors: Tensor) -> None:assert all(tensors[0].size(0) == tensor.size(0) for tensor in tensors), "Size mismatch between tensors"self.tensors = tensorsdef __getitem__(self, index):return tuple(tensor[index] for tensor in self.tensors)def __len__(self):return self.tensors[0].size(0)

所以这个类的实例化有两个参数,分别为data_tensor(Tensor)样本数据和target_tensor(Tensor)样本标签。

使用TensorDataset:

import torch
from torch.utils.data import TensorDatasetsrc = torch.sin(torch.arange(1, 1000, 0.1))
trg = torch.cos(torch.arange(1, 1000, 0.1))

于是可以直接实例化已定义好的Dataset子类TensorDataset

data = TensorDataset(src, trg)

DataLoader简介及用法

Dataset 和 DataLoader 是一起使用的,在模型训练的过程中不断为模型提供数据,同时,使用 Dataset 加载出来的数据集也是
DataLoader 的第一个参数。所以, DataLoader 本质上就是用来将已经加载好的数据以模型能够接收的方式输入到即将训练的模型中去。

几个深度学习模型训练时涉及的参数:

(1)Data_size:所有数据的样本数量。

(2)Batch_size:每个Batch加载多少个样本。

(3)Batch:每一批放进module训练的样本叫一个Batch。

(4)Epoch:模型把所有样本训练完毕一次叫做一个Epoch。

(5)Iteration:所有数据共分成了几个Batch,即训练几次才能够便利所有样本/数据。

(6)Shuffle:在抽取Batch之前是否将样本全部打乱顺序。

数据的输入过程如下图所示。

Data_size=10 , Batch_size=3 ,一次 Epoch 需要四次 Iteration ,第一列为所有样本,第二列为打乱之后的所有样本,由于 Batch_size=3 ,所以通过 DataLoader输入了 4 个 batch ,包括最后一个数量已经不够 3 个的 Batch4 ,里边只包含sample3

官方给出的DataLoader定义:

DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,batch_sampler=None, num_workers=0, collate_fn=None,pin_memory=False, drop_last=False, timeout=0,worker_init_fn=None,*, prefetch_factor=2,persistent_workers=False)

参数说明:

dataset: 通过Dataset加载进来的数据集。

batch_size:每个Batch加载多少个样本。

shuffle: 是否打乱输入数据的顺序,设置为True时,调用RandomSample进行随机索引。

sampler: 定义从数据集中提取样本的策略,若指定,就不能用shuffle函数随机索引,其取值必须为False。

batch_sampler: 批量采样,每次返回一个Batch大小的索引,默认设置为None,和batch_size、shuffle等参数是互斥的。

num_workers: 用多少子进程加载数据。0表示数据将在主进程中加载,根据自己的计算资源配置选定。

collate_fn: 将一小段数据合并成数据列表以形成一个Batch。

pin_memory:是否在将张量返回之前将其复制到Cuda固定的内存中。

drop_last: 设置了batch_size的数目后,最后一批数据未必是设置的数目,有可能会小一些,这时需要丢弃这些数据。

timeout:设置数据表读取的超时时间,但超过这个时间还没读取到数据就会报错,不能为负。

worker_init_fn:是否在数据导入前和步长结束后根据工作子进程的ID逐个按照顺序导入数据,默认为None。

prefetch_factor:每个worker提前加载的Sample数量。

persistent_workers: 如果为True,DataLoader将不会终值worker进程,直到dataset迭代完成。

将Dataset读取的数据输入到DataLoader中。

import torch
from torch import nn
from torch.utils.data import Dataset, DataLoaderclass GetTrainTestData(Dataset):def __init__(self, input_len, output_len, train_rate, is_train=True):super().__init__()# 使用sin函数返回10000个时间序列,如果不自己构造数据,就使用numpy,pandas等读取自己的数据为x即可。# 以下数据组织这块既可以放在init方法里,也可以放在getitem方法里self.x = torch.sin(torch.arange(1, 1000, 0.1))self.sample_num = len(self.x)self.input_len = input_lenself.output_len = output_lenself.train_rate = train_rateself.src,  self.trg = [], []if is_train:for i in range(int(self.sample_num*train_rate)-self.input_len-self.output_len):self.src.append(self.x[i:(i+input_len)])self.trg.append(self.x[(i+input_len):(i+input_len+output_len)])else:for i in range(int(self.sample_num*train_rate), self.sample_num-self.input_len-self.output_len):self.src.append(self.x[i:(i+input_len)])self.trg.append(self.x[(i+input_len):(i+input_len+output_len)])print(len(self.src), len(self.trg))def __getitem__(self, index):return self.src[index], self.trg[index]def __len__(self):return len(self.src)  # 或者return len(self.trg), src和trg长度一样data_train = GetTrainTestData(input_len=3, output_len=1, train_rate=0.8, is_train=True)
data_test = GetTrainTestData(input_len=3, output_len=1, train_rate=0.8, is_train=False)
data_loader_train = DataLoader(data_train, batch_size=5, shuffle=False)
data_loader_test = DataLoader(data_test, batch_size=5, shuffle=False)

for idx, train in enumerate(data_loader_train):print(idx, train)break


文章推荐

NumPy从入门到放弃 https://mp.weixin.qq.com/s/EocThNWhQlI2zeLcUApsQQ
Pandas从入门到放弃 https://mp.weixin.qq.com/s/mSkA5KvL1390Js8_1ZBiyw
SciPy从入门到放弃 https://mp.weixin.qq.com/s/MulhzVRvWbaDUjfNPHN8qA
Scikit-learn从入门到放弃 https://mp.weixin.qq.com/s/L0tKz9JFnsgrzSCXDswbRA
PyTorch从入门到放弃之张量模块 https://www.cnblogs.com/kohler21/p/18392248

欢迎关注公众号:愚生浅末。
image

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/793131.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

pbootcms网站后台突然登录不了怎么解决

如果你使用的是PbootCMS V3.2.5之前的版本,并且遇到了无法登录后台的情况,可以按照以下步骤进行排查和修复: 步骤 1: 删除 runtime 文件夹找到 runtime 文件夹:通常 runtime 文件夹位于网站根目录下。 如果找不到,可以尝试搜索整个项目目录中的 runtime 文件夹。删除 runt…

如何在pbootcms网站中调用公司简介等频道内容

在PbootCMS中,使用{pboot:content}标签可以方便地调用特定频道的内容。下面是一个完整的示例,展示了如何使用{pboot:content}标签来调用公司简介等频道内容,并进行适当的展示。 示例代码html<!-- 调用ID为1的频道内容 --> {pboot:content id=1}<!-- 显示频道标题 -…

消费降级,我的订阅服务瘦身

前言 前几天看到一篇文章,《消费降级,我的订阅服务瘦身》。 自己平时花钱有点大手大脚的,也没有统计个每个月固定的开销,现在正好趁这个机会记录一下。现在挣钱不容易,看下哪些开销可以进行降级。 腾讯云 - 服务器分类 周期及成本 需要程度网络服务 510元/年 需要/续订这个…

3.元素定位、规避监控、APP自动化测试(Appium)等

元素定位 我们通过webdriver打开一个网络页面,目的是为了操作当前页面已完成浏览器中的一些UI测试步骤,所以必然需要操作网页。而网页的内容组成是由HTML标签(element,也叫元素),所以基于selenium操作网页实际上本质就是操作元素。那么要操作元素就必须先获取元素对象。s…

虚拟化技术:新能源汽车空调控制系统的智能新突破

汽车生产中,空调系统已经成为标配,空调系统的性能是衡量一辆汽车是否舒适的重要指标之一。 01.汽车空调系统组成 (1) 制冷系统:制冷系统的功能是给汽车内部提供冷空气,主要由压缩机、冷凝器、膨胀阀以及蒸发器组成。首先由压缩机对空气进行压缩,使空气通过蒸发器,并由制冷…

计算机计算小数除法的陷阱

小学生都知道上面的代码中,8.1/3=2.7 但是计算机计算的结果却出人意料:2.6999999999999997 原因:计算机是用二进制格式存储小数的,这个二进制格式不能精确表示8.1,它只能表示一个非常接近8.1但又不等于8.1的一个数。

pbootcms提交留言、提交自定义表单时取消验证码

进入菜单 全局配置 -> 配置参数 -> 安全配置扫码添加技术【解决问题】专注中小企业网站建设、网站安全12年。熟悉各种CMS,精通PHP+MYSQL、HTML5、CSS3、Javascript等。承接:企业仿站、网站修改、网站改版、BUG修复、问题处理、二次开发、PSD转HTML、网站被黑、网站漏洞…

pbootcms站点信息调用

{pboot:siteindex} 站点入口地址,一般用于站内链接跳转设置地址前置,实现自适应URL模式{pboot:sitepath} 站点路径,根目录时值为空,为适应部署到二级目录时建议链接前面带上{pboot:sitelanguage} 站点语言{pboot:sitetitle} 站点标题{pboot:sitesubtitle} 站点副标题{pboot:…

house of stom

完成事项 house of stom学习 未完成事项 wmctf的blineless没打通 如何解决未完成事项 下周待做事项 house of orange house of lore 本周学习的知识分享 house of stom 条件:1.能控制unsorted的bk指针,还有largebin的fd_nextsize和bk_nextsize 码源分析 largebin attack:申…

Pbootcms留言“提交成功”的提示语修改

按照这个路径地址来修改下文件/apps/home/controller/MessageController.php 大概在103行,可以搜索提交成功快捷查询下。扫码添加技术【解决问题】专注中小企业网站建设、网站安全12年。熟悉各种CMS,精通PHP+MYSQL、HTML5、CSS3、Javascript等。承接:企业仿站、网站修改、网…

house of orange

house of orange 1.针对没有free的堆题目 orange部分 申请比topchunk的size大的chunk,会将原本的chunk放入unsortedbin中,可以借此泄露地址 FSOP io文件结构有chain连接成一个链表形式,这部分,头节点记录在_IO_list_all上,通过unsorted attack或者largebin attack劫持_io_…

docker 安装 redis 集群

集群搭建(三主三从) 集群搭建 集群中的节点都需要打开两个 TCP 连接。一个连接用于正常的给 Client 提供服务,比如 6379,还有一个额外的端口(通过在这个端口号上加10000)作为数据端口,例如:redis的端口为 6379,那么另外一个需要开通的端口是:6379 + 10000, 即需要开…