Pytorch:torch.utils.data.DataLoader()

如果读者正在从事深度学习的项目,通常大部分时间都花在了处理数据上,而不是神经网络上。因为数据就像是网络的燃料:它越合适,结果就越快、越准确!神经网络表现不佳的主要原因之一可能是由于数据不佳或理解不足。因此,以更直观的方式理解、预处理数据并将其加载到网络中非常重要。
参考:https://zhuanlan.zhihu.com/p/596730297

DataLoader加载和迭代数据集

Dataloader本质是一个迭代器对象,也就是可以通过for batch_idx,batch_dict in dataloader 来提取数据集,提取的数量由batch_size 参数决定,得到这一batch的数据后,就可以喂入网络开始训练或者推理了。
在迭代的过程中,dataloader会自动调用dataset中的__getitem__ 函数,以获取一帧数据(item)

from torch.utils.data import DataLoaderDataLoader(dataset,batch_size=1,shuffle=False,num_workers=0,collate_fn=None,pin_memory=False,)

以U-Net中的代码为例:
具体详见:U-Net代码复现

loader_args = dict(batch_size=batch_size, num_workers=os.cpu_count(), pin_memory=True)
train_loader = DataLoader(train_set, shuffle=True, **loader_args)
val_loader = DataLoader(val_set, shuffle=False, drop_last=True, **loader_args)

1. 数据集

**dataset (Dataset) ** – dataset from which to load the data.
即自定义的数据集,非常重要,因为dataloader会调用dataset的一些重载函数(e.g. getitem && len )

2. 对数据进行批处理

batch_size (int, optional)how many samples per batch to load(default: 1).

3. 在 CUDA 张量上加载数据

pin_memory(bool, optional)If True, the data loader will copy Tensors into device/CUDA pinned memory before returning them. If your data elementsare a custom type, or your collate_fn returns a batch that is a custom type,see the example below.

pin_memory参数直接将数据集加载为 CUDA 张量。它是一个可选参数,接受一个布尔值;如果设置为True,会在返回张量之前张量复制到 CUDA 固定内存中。这样在GPU训练过程中,数据从内存到GPU的复制可以使用异步的方式进行,从而提高数据读取的效率。

通常情况下,当使用GPU训练模型时,数据读取会成为整个训练过程的瓶颈之一。使用pin_memory可以将数据在CPU和GPU之间进行传输时的复制时间减少,从而提高数据加载的速度,加速训练过程。

需要注意的是,使用pin_memory会占用更多的内存空间,因此在内存资源紧张的情况下,需要谨慎使用。同时,在某些情况下(例如数据集比较小的情况下),使用pin_memory并不会带来明显的加速效果。

4.允许多进程

num_workers (int, optional)how many subprocesses to use for dataloading. 0 means that the data will be loaded in the main process.(default: 0)
这也是一个很有意思的参数,按照官方的说法, num_workers 用于设置数据加载过程中使用的子进程数。其默认值为0,即在主进程中进行数据加载,而不使用额外的子进程。

以下是我看到的一个解释,原文链接:https://blog.csdn.net/vonct/article/details/130263743
下面说一下个人的理解,在初始化 dataloader对象时,会根据num_workers创建子线程用于加载数据(主线程数+子线程=num_workers)。每个worker或者说线程都有自己负责的dataset范围(下面统称worker)

每当迭代 dataloader 对象时,工人们(workers)就开始干活了:将数据从数据源(如硬盘)加载到内存(数据加载),当一个worker读取(调用__getitem__)到足够的数据(看你在dataset中怎么定义一个item了)后,会将这些数据封装成一个(即一帧),并将其放到该worker独有的内存队列中。 要注意的是,每次迭代时,worker会尽可能地读数据,直到自己的队列被填满。

当所有workers的队列都被填满时,一个名为sampler的线程将会被创建,它的作用就是收集各workers队列中队首的 ,把他们放到一个各线程共享内存的缓冲队列中,并调用 collate_fn 函数来将 batch_size 个 整合,最后返回给迭代的输出。

这时候大家肯定会有点疑惑,那当迭代到后期时,需要读取的样本都已经在队列中了,是不是意味着这时候工人们已经在休息了?根据chatgpt的回答:是的!下面以一张图来帮助大家理解

在这里插入图片描述

5.合并数据集

collate_fn (Callable, optional)merges a list of samples to form a mini-batch of Tensor(s). Used when using batched loading from a map-style dataset.

整合多个样本到一个batch时需要调用的函数,当 getitem 返回的不是tensor而是字典之类时,需要进行 collate_fn的重载,同时可以进行数据的进一步处理以满足pytorch的输入要求。
以U-Net为例:

def __getitem__(self, idx):name = self.ids[idx]mask_file = list(self.mask_dir.glob(name + self.mask_suffix + '.*'))img_file = list(self.images_dir.glob(name + '.*'))assert len(img_file) == 1, f'Either no image or multiple images found for the ID {name}: {img_file}'assert len(mask_file) == 1, f'Either no mask or multiple masks found for the ID {name}: {mask_file}'mask = load_image(mask_file[0])img = load_image(img_file[0])assert img.size == mask.size, \f'Image and mask {name} should be the same size, but are {img.size} and {mask.size}'img = self.preprocess(self.mask_values, img, self.scale, is_mask=False)mask = self.preprocess(self.mask_values, mask, self.scale, is_mask=True)return {'image': torch.as_tensor(img.copy()).float().contiguous(),'mask': torch.as_tensor(mask.copy()).long().contiguous()}

getitem 返回的是一个包含image和mask的 data_dict 字典,这时候就需要调用自定义的collate_fn来进行打包(待补充。。。)

6.数据采样

sampler (Sampler or Iterable, optional) – defines the strategy to draw samples from the dataset. Can be any Iterable with len implemented. If specified, shufflemust not be specified.

sampler的主要作用是控制样本的采样顺序,并提供样本的索引。在默认情况下,dataloader使用的是SequentialSampler,它按照数据集的顺序依次提取样本,但在某些情况下,我们可能需要自定义采样顺序。比如说想从队尾提取数据。

比如,当我们处理非常大的数据集时,为了提高训练效率,可能需要对数据进行分布式采样,这时候就需要使用DistributedSampler。DistributedSampler会将数据集划分成多个子集,每个子集分配给不同的进程进行采样。在这种情况下,如果使用默认的SequentialSampler,可能会导致各个进程采样到相同的数据,从而降低训练效率。

此外,还有一些自定义的sampler,比如随机采样器(RandomSampler)和加权采样器(WeightedRandomSampler),它们可以按照不同的采样策略对数据集进行采样,从而满足不同的训练需求。

因此,根据不同的训练需求,我们可能需要自定义sampler来控制数据的采样顺序。

原文链接:https://blog.csdn.net/vonct/article/details/130263743

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/229004.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Egg.js的方法扩展

Extend-application 方法扩展 eggjs的方法的扩展和编写 Egg.js可以对内部的五种对象进行扩展,以下是可扩展的对象、说明、this指向和使用方式。 application对象方法拓展 按照Egg的约定,扩展的文件夹和文件的名字必须是固定的。比如要对application扩…

C++学习之路(十一)C++ 用Qt5实现一个工具箱(增加一个进制转换器功能)- 示例代码拆分讲解

上篇文章,我们用 Qt5 实现了在小工具箱中添加了《时间戳转换功能》功能。为了继续丰富我们的工具箱,今天我们就再增加一个平时经常用到的功能吧,就是「 进制转换 」功能。下面我们就来看看如何来规划开发一个这样的小功能并且添加到我们的工具…

家政预约服务管理系统,轻松搭建专属家政小程序

家政预约服务管理系统,轻松搭建专属家政小程序app; 家政服务app开发架构包括: 1. 后台管理端:全面管理家政服务、门店、员工、阿姨信息、订单及优惠促销等数据,并进行统计分析。 2. 门店端:助力各门店及员工…

选择跨网数据摆渡系统时,你最关注的功能是哪些?

为什么要选择跨网数据摆渡系统呢?因为做了网络隔离后,要有数据交互。那为什么要做网络隔离呢?主要还是安全方面的考虑,一般有以下几个原因: 1、数据安全保护:对于一些重要数据,比如代码数据、隐…

根据端口查找进程

关闭kibana kibana自带命令 kibana没有提供关闭命令,通过命令 ps -ef|grep kibana查找不到kibana相关的信息。 可以通过进程暴露的端口来查找 netstat -anltp|grep 5601获取到进程号,然后kill掉进程 kill -9 进程号Docker管理Kibana 但是如果使用D…

Hive安装与配置

你需要掌握: 1.Hive的基本安装; 2.Mysql的安装与设置; 3.Hive 的配置。 注意:Hive的安装与配置建立在Hadoop已安装配置好的情况下。 hadopp安装与配置 Hive 的基本安装 从 官网 下载Hive二进制包,下载好放在/op…

Linux脚本sed命令

目录 一. sed命令定义 二. sed命令选项 三. sed语法选项 四. 案例解释 1. 打印奇数或偶数行 2. 打印固定行数 3. 打印包含字符的行 4. 打印特定字符首尾行 5. 删除固定行数 6. 删除特定字符行 7. 插入在固定行中 8. 替换规定行数 9. 使用变量 10. 多点编辑 11. 分…

【从零开始学习Linux】一文带你了解yum周边生态及vim常见模式

🚩纸上得来终觉浅, 绝知此事要躬行。 🌟主页:June-Frost 🚀专栏:Linux入门 🔭【从零开始学习Linux】系列均属于Linux入门,主要包含Linux操作系统下的指令、操作、权限以及开发工具&a…

TIME_WAIT状态TCP连接导致套接字无法重用实验

理论相关知识可以看一下《TIME_WAIT相关知识》。 #include<stdio.h> #include<stdlib.h> #include<string.h> #include<unistd.h> #include<arpa/inet.h> #include<sys/socket.h> #include<errno.h> #include<syslog.h> #inc…

【开题报告】基于深度学习的驾驶员危险行为检测系统

研究的目的、意义及国内外发展概况 研究的目的、意义&#xff1a;我国每年的交通事故绝对数量是一个十分巨大的数字&#xff0c;造成了巨大的死亡人数和经济损失。而造成交通事故的一个很重要原因就是驾驶员的各种危险驾驶操作行为。如果道路驾驶员的驾驶行为能够得到有效识别…

Django大回顾-2 之 Django的基本操作、路由层,MTV和MVC模型

【1】MTV和MVC模型 MVC与MTV模型 --->所有web框架其实都遵循mvc架构 MVC模型 MVC 本来坨在一起的代码&#xff0c;拆到不同的位置 模型(M&#xff1a;数据层)&#xff0c;控制器(C&#xff1a;逻辑判断)和视图(V&#xff1a;用户看到的)三层 他们之间以一种插件式…

linux(2)之buildroot使用手册

Linux(2)之buildroot配置toolchain Author&#xff1a;Onceday Date&#xff1a;2023年11月27日 漫漫长路&#xff0c;才刚刚开始… 参考文档&#xff1a; Buildroot - Making Embedded Linux Easy 文章目录 Linux(2)之buildroot配置toolchain1. 构建配置1.1 配置config生成…