torch.utils.data

整体架构

平时使用 pytorch 加载数据时大概是这样的:

import numpy as np
from torch.utils.data import Dataset, DataLoaderclass ExampleDataset(Dataset):def __init__(self):self.data = [1, 2, 3, 4, 5]def __getitem__(self, idx):return self.data[idx]def __len__(self):return len(self.data)def collate_fn(batch):return np.array(batch)dataset = ExampleDataset()  # create the dataset
dataloader = DataLoader(dataset=dataset,batch_size=2,shuffle=True,num_workers=4,collate_fn=collate_fn
)
for datapoint in dataloader:print(datapoint)
  1. 继承 Dataset 类,定义一个迭代器,包含两个魔法方法:__getitem__(self, idx)__len__(self),分别实现如何获取一条数据和如何设定数据长度;
  2. 定义 collate_fn 函数,设定如何组织一个 batch
  3. 实例化 Dataset,并和 collate_fn 一起传入 DataLoader,参数 batch_size 设置批大小、shuffle 设置是否打乱、num_workers 设置并行加载数据的进程数。

然而,背后到底干了什么,我们不清楚,甚至遇到 DataLoader 的如 samplerbatch_samplerworker_init_fn 的其他参数,就会懵逼。那就看一看官方文档,了解一下 torch.utils.data 是如何工作的。


上图是数据加载的整体框架图,官网说 DataLoader 组合datasetsampler,多个 workers 根据 dataset 提供的数据副本sampler 提供的 keys 并行地加载数据,并通过 collate_fn 组成 batch 供用户迭代。需要注意的有:

  1. 每个 worker 持有数据的一个副本,故占用内存主线程内存 * num_workers”;
  2. 即使用户不提供 sampler 对象 (通常不提供),DataLoader 也会根据 shuffle 参数创建一个默认的 sampler 对象;一旦提供了,其前路的 shuffle 参数不能为 True (不提供就好);
  3. 即使用户不提供 batch_sampler 对象 (通常不提供),DataLoader 也会根据 batch_sampler, drop_last 参数创建一个默认的 batch_sampler 对象;一旦提供了,其前路的 shuffle, drop_last 不能为 Truebatch_size 必须为 1 1 1sampler 必须为 None,因为创建 BatchSampler 时已经有了这些参数;

    本质上是把创建 batch_sampler 的活拉出来由用户在 DataLoader 外自定义地做了。

Dataset

分为两种:map-styleiterable-style。前者的数据可通过 [idx or key] 访问,后者的数据只能通过迭代器 next 一个个访问。所以上面架构中的采样器是对于 map-style 数据集说的iterable-style 的数据集的访问顺序由迭代器决定。

Sampler

torch.utils.data.Sampler 的子类或 Iterable,两个例子:

class AccedingSequenceLengthSampler(tu_data.Sampler[int]):def __init__(self, data: List[str]) -> None:super().__init__()self.data = datadef __len__(self) -> int:return len(self.data)def __iter__(self) -> Iterator[int]:""":return: 实现了按数据长短顺序访问数据集"""sizes = torch.tensor([len(x) for x in self.data])yield from torch.argsort(sizes).tolist()class AccedingSequenceLengthBatchSampler(tu_data.Sampler[List[int]]):def __init__(self, data: List[str], batch_size: int) -> None:super().__init__()self.data = dataself.batch_size = batch_sizedef __len__(self) -> int:return (len(self.data) + self.batch_size - 1) // self.batch_sizedef __iter__(self) -> Iterator[List[int]]:sizes = torch.tensor([len(x) for x in self.data])for batch in torch.chunk(torch.argsort(sizes), len(self)):  # 按块遍历yield batch.tolist()

Batch

batch_sampler 提供一批下标,取得一批数据后由 collate_fn 将这批数据整合:

if collate_fn is None:if self._auto_collation:collate_fn = _utils.collate.default_collateelse:  # self.batch_sampler is None: (batch_size is None) and (batch_sampler is None)collate_fn = _utils.collate.default_convert

分两种情况:

  • automatic batching is disabled:调用 default_convert 函数简单地将 NumPy arrays 转化为 PyTorch Tensor;
  • automatic batching is enabled:调用 default_collate 函数,转化会变得复杂一点:
from torch.utils import data as tu_data
import collections# %% Example with a batch of `int`s:
tu_data.default_collate([0, 1, 2, 3])
# tensor([0, 1, 2, 3])# %% Example with a batch of `str`s:
tu_data.default_collate(['a', 'b', 'c'])
# ['a', 'b', 'c']# %% Example with `Map` inside the batch:
tu_data.default_collate([{'A': 0, 'B': 1},{'A': 100, 'B': 100}
])
# {'A': tensor([0, 100]), 'B': tensor([1, 100])}, 同 key 的合并了# %% Example with `NamedTuple` inside the batch:
Point = collections.namedtuple('Point', ['x', 'y'])
tu_data.default_collate([Point(0, 0), Point(1, 1)])
# Point(x=tensor([0, 1]), y=tensor([0, 1])), 同 name 的合并了, 大概和 dict 一样吧# %% Example with `Tuple` inside the batch:
tu_data.default_collate([(0, 1), (2, 3)])
# [tensor([0, 2]), tensor([1, 3])], 对 list 内部执行 collate# %% Example with `List` inside the batch:
tu_data.default_collate([[0, 1], [2, 3]])  # [tensor([0, 2]), tensor([1, 3])], 对 list 内部执行 collate, 并没有变成二维 tensor

Multi-process Data Loading

dataset, collate_fn, and worker_init_fn are passed to each worker,大概能说明 batch 是在子进程内部合成的。

有一个需要注意的地方是内存增长问题,当 __get_item__(self, key) 访问数据时,由于 Python 对象的 refcount 机制,数据会不断地复制,从而内存爆炸。但这里说解决 number of workers * size of parent process 问题,就不追究了,反正尽量用 numpy 或 pytorch tensor 吧。
iterable-style datasets 的随机性

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/473490.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

祝所有的CSDN社区成员们新年快乐

文章目录 尊敬的CSDN社区成员们, 在新年的钟声即将敲响之际,我携带着满心祝福与期许,以字为舟,穿越虚拟与现实的界限,来到您的身边,向每一位热爱编程、投身技术研究、在CSDN平台上挥洒智慧和汗水的朋友们&a…

Netty中的适配器、Handler共享和资源管理

ChannelHandler的适配器 有一些适配器类可以将编写自定义的ChannelHandler所需要的工作降到最低限度, 因为它们提供了定义在对应接口中的所有方法的默认实现。因为有时会忽略那些不感兴趣的 事件,所以Netty提供了抽象积累ChannelInboundHandlerAdapter(…

深夜突发! OpenAI震撼发布了SORA文生视频模型,对职场人的影响可能跟你想的不一样

深夜突发! OpenAI震撼发布了SORA文生视频模型,对职场人的影响可能跟你想的不一样。 马上就要节后返工了,顾问老师也早已回到了温暖的广州。与一位同城的学员相聚在老广州的一个茶楼中,喝起了下午茶。面对各式的广式茶点,在淡淡的茶…

防火墙 iptables(二)--------------------SNAT与DNAT

一、SNAT ①SNAT 应用环境: 局域网主机共享单个公网IP地址接入Internet (私有IP不能在Internet中正常路由) ②SNAT原理: 源地址转换,根据指定条件修改数据包的源IP地址,通常被叫做源映射 数据包从内网发送到公网时,SNAT会把数据包的源IP由…

Shokz韶音是运动耳机的领导品牌

在一年一度的Keep官方营销沙龙Keep自由营上,运动耳机领导品牌Shokz韶音和全球运动科技App Keep共同宣布达成深度合作。Shokz韶音运动耳机将成为Keep官方合作运动耳机。同时,双方将在线上赛事、电商购物、新品发布乃至圈层耕耘等诸多方面,展开全方位合作。 Shokz韶音是运动耳机的…

前端秘法进阶篇----这还是我们熟悉的浏览器吗?(浏览器的渲染原理)

目录 一.浏览器渲染原理 二.渲染时间点 三.渲染流水线 1.解析html(Parse HTML) 1.1解析成DOM树(document object model) 1.2解析成CSSOM树(css object model) 2.样式计算(Recalculate Style) 3.布局(Layout) 4.分层(Layer) 5. 绘制(Paint) 6.分块(Tiling) 7. 光栅化…

java 课程签到管理系统Myeclipse开发mysql数据库web结构jsp编程servlet计算机网页项目

一、源码特点 java 课程签到管理系统是一套完善的java web信息管理系统 采用serlvetdaobean,对理解JSP java编程开发语言有帮助,系统具有完整的源代码和数据库,系统主要采用B/S模式开发。开发环境为TOMCAT7.0,Myeclipse8.5开发&#xff0…

【Linux】进程信号的保存 | 自定义捕捉

文章目录 三、信号的阻塞(信号的保存)1. 信号相关其他常见概念2. 在内核中的表示3. sigset_t类型4. 信号集操作函数函数列表注意事项 5. 读取/修改block位图 - sigprocmask6. 读取pending位图 - sigpending 四、信号捕捉1. 信号捕捉的初步认识自定义捕捉…

【Python--Web应用框架大比较】

🚀 作者 :“码上有前” 🚀 文章简介 :Python 🚀 欢迎小伙伴们 点赞👍、收藏⭐、留言💬 Django Django太重了,除了web框架,自带ORM和模板引擎,灵活和自由度不…

android获取sha1

1.cmd在控制台获取 切换到Android Studio\jre\bin目录下执行keytool -list -v -keystore 签名文件路径例如: 2.也可以在android studio中获取 在Terminal中输入命令:keytool -list -v -keystore 签名文件路径获取 获取到的sha1如下:

【分享】JLINK的SW调试模式连线方式

大家知道,JLINK有2种调试模式:JTAG和SWD(串行模式)。 JTAG是常用模式,大家都熟悉、不废话了;如果使用SW模式,需要(只需要)4根连线,连接方式如下: …

四种mfc140u.dll丢失的解决方法,有效恢复mfc140u.dll丢失

mfc140u.dll文件的重要性,当系统中出现mfc140u.dll丢失的情况时,可能会导致一系列问题和影响。因此,保持mfc140u.dll文件的完整性对于系统和应用程序的稳定运行至关重要。一旦出现mfc140u.dll文件丢失的情况,我们需要采取有效的方…