PyTorch的Dataset 和TorchData API的比较

深度神经网络需要很长时间来训练。训练速度受模型的复杂性、批大小、GPU、训练数据集的大小等因素的影响。

在PyTorch中,torch.utils.data.Dataset和torch.utils.data.DataLoader通常用于加载数据集和生成批处理。但是从版本1.11开始,PyTorch引入了TorchData库,它实现了一种不同的加载数据集的方法。

在本文中,我们将比较数据集比较大的情况下这两两种方法是如何工作的。我们以CelebA和DigiFace1M的面部图像为例。表1显示了它们的比较特征。我们训练使用ResNet-50模型。然后进行1轮的训练来进行使用方法和时间的比较。

数据集的信息如下:

CelebA (align) 图片数:202,599 总大小:1.4 图片大小:178x218

DigiFace1M 图片数:720,000 总大小:14.6 图片大小:112x112

我们使用的环境如下:

CPU: Intel® Core™ i9-9900K CPU @ 3.60GHz(16核)

GPU: GeForce RTX 2080 Ti 11Gb

驱动版本515.65.01 / CUDA 11.7 / CUDNN 8.4.0.27

Docker 20.10.21

Pytorch 1.12.1

TrochData 0.4.1

训练的代码如下:

 def train(data_loader: torch.utils.data.DataLoader, cfg: Config):# create modelmodel = resnet50(num_classes=cfg.n_celeba_classes + cfg.n_digiface1m_classes, pretrained=True)torch.cuda.set_device(cfg.gpu)model = model.cuda(cfg.gpu)model.train()# define loss function (criterion) and optimizercriterion = torch.nn.CrossEntropyLoss().cuda(cfg.gpu)optimizer = torch.optim.SGD(model.parameters(), lr=0.1,momentum=0.9,weight_decay=1e-4)start_time = time.time()for _ in range(cfg.epochs):scaler = torch.cuda.amp.GradScaler(enabled=cfg.use_amp)for batch_idx, (images, target) in enumerate(data_loader):images = images.cuda(cfg.gpu, non_blocking=True)target = target.cuda(cfg.gpu, non_blocking=True)# compute outputwith torch.cuda.amp.autocast(enabled=cfg.use_amp):output = model(images)loss = criterion(output, target)# compute gradientscaler.scale(loss).backward()# do SGD stepscaler.step(optimizer)scaler.update()optimizer.zero_grad()print(batch_idx, loss.item())print(f'{time.time() - start_time} sec')

Dataset

首先看看Dataset,这是自从Pytorch发布以来一直使用的方式,我们对这个应该非常熟悉。PyTorch 支持两种类型的数据集:map-style Datasets 和 iterable-style Datasets。Map-style Dataset 在预先知道元素个数的情况下使用起来很方便。

该类实现了__getitem__()和__len__()方法。如果通过索引读取太费时间或者无法获得,那么可以使用 iterable-style,需要实现__iter__() 方法。在我们的例子中,map-style已经可以了,因为对于 CelebA 和 DigiFace1M 数据集,我们知道其中的图像总数。

下面我们创建CelebADataset 类。对于 CelebA,类标签位于 identity_CelebA.txt 文件中。CelebA 和 DigiFace1M 中的面部图像在裁剪方面有所不同,因此为了在图像上传后减少getitem方法中的这些差异,必须从各个方面稍微裁剪它们。

 from PIL import Imagefrom torch.utils.data import Datasetclass CelebADataset(torch.utils.data.Dataset):def __init__(self, data_path: str, transform) -> None:self.data_path = data_pathself.transform = transformself.image_names, self.labels = self.load_labels(f'{data_path}/identity_CelebA.txt')def __len__(self) -> int:return len(self.image_names)def  __getitem__(self, idx: int) -> Tuple[torch.Tensor, int]:image_path = f'{self.data_path}/img_align_celeba/{self.image_names[idx]}'image = Image.open(image_path)left, right, top, bottom = 25, 153, 45, 173image = image.crop((left, top, right, bottom))if self.transform is not None:image = self.transform(image)label = self.labels[idx]return image, label@staticmethoddef load_labels(labels_path: str) -> Tuple[list, list]:image_names, labels = [], []with open(labels_path, 'r', encoding='utf-8') as labels_file:lines = labels_file.readlines()for line in lines:file_name, class_id = line.split(' ')image_names.append(file_name)labels.append(int(class_id[:-1]))return image_names, labels

对于DigiFace1M数据集,同一类的所有图像都在一个单独的文件夹中。但是这两个数据集中,类的标签是相同的,所以对于在DigiFace1M我们不需要获取类别,而是在CelebA中按类增加。所以我们需要add_to_class变量。另外就是DigiFace1M中的图像以“RGBA”格式存储,因此仍需将其转换为“RGB”。

 class DigiFace1M(torch.utils.data.Dataset):def __init__(self, data_path: str, transform, add_to_class: int = 0) -> None:self.data_path = data_pathself.transform = transformself.image_paths, self.labels = self.load_labels(data_path, add_to_class)def __len__(self):return len(self.image_paths)def __getitem__(self, idx: int) -> Tuple[torch.Tensor, int]:image = Image.open(self.image_paths[idx]).convert('RGB')if self.transform is not None:image = self.transform(image)label = self.labels[idx]return image, label@staticmethoddef load_labels(data_path: str, add_to_class: int) -> Tuple[list, list]:image_paths, labels = [], []for root, _, files in os.walk(data_path):for file_name in files:if file_name.endswith('.png'):image_paths.append(f'{root}/{file_name}')labels.append(int(os.path.basename(root)) + add_to_class)return image_paths, labels

现在我们可以使用torch.utils.data将两个数据集合并为一个数据集ConcatDataset,创建DataLoader,开始训练。

 def main():cfg = Config()celeba_dataset = CelebADataset(f'{cfg.data_path}/CelebA', cfg.transform)digiface_dataset = DigiFace1M(f'{cfg.data_path}/DigiFace1M', cfg.transform, cfg.n_celeba_classes)dataset = torch.utils.data.ConcatDataset([celeba_dataset, digiface_dataset])loader = torch.utils.data.DataLoader(dataset=dataset,batch_size=cfg.batch_size,shuffle=True,drop_last=True,num_workers=cfg.n_workers)utils.train(loader, cfg)

TorchData API

与Dataset一样,TorchData支持map-style 和 iterable-style的数据处理管道。但是官方建议使用IterDataPipe,只在必要时将其转换为MapDataPipe。

因为TorchData提供了优化的数据加载实用程序,可以帮助我们方便的构建处理流程。以下是一些主要的功能:

  • IterableWrapper:包装可迭代对象以创建IterDataPipe。
  • FileListerr:给定目录的路径,将生成根目录内文件的文件路径名(path + filename)
  • Filterr:根据输入filter_fn(函数名:filter)从源数据口过滤元素
  • Mapperr:对源DataPipe中的每个项应用函数(函数名:map)
  • Concaterr:连接多个可迭代数据管道(函数名:concat)
  • Shufflerr:打乱输入DataPipe数据的顺序(函数名:shuffle)
  • ShardingFilterr:允许对DataPipe进行分片(函数名:sharding_filter)

使用TorchData 构建CelebA和DigiFace1M的数据处理管道,我们需要执行以下步骤:

对于CelebA数据集:创建一个列表(file_name, label, ’ CelebA '),并使用IterableWrapper从它创建一个IterDataPipe

对于DigiFace1M:使用FileLister创建一个IterDataPipe,返回所有图像文件的路径,使用Mapper来使用collate_ann。这个函数以图像路径作为输入,并返回元组(file_name, label, ’ DigiFace1M ')。

上面两个步骤之后,我们得到两个数据类型(file_name, label, data_name)的结果。然后使用Concater将它们连接到一个数据管道中。

使用Shufflerr,打乱顺序,这与在DataLoader中设置了shuffle=True是一样的。

使用ShardingFilter将数据管道分割成片。每个worker将拥有原始DataPipe元素的n个部分,其中n等于worker的数量。(多线程处理,DataLoader中的num_worker)

最后就是从磁盘读取图像

完整代码如下:

 @torchdata.datapipes.functional_datapipe("load_image")class ImageLoader(torchdata.datapipes.iter.IterDataPipe):def __init__(self, source_datapipe, **kwargs) -> None:self.source_datapipe = source_datapipeself.transform = kwargs['transform']def __iter__(self) -> Tuple[torch.Tensor, int]:for file_name, label, data_name in self.source_datapipe:image = Image.open(file_name)if data_name == 'DigiFace1M':image = image.convert('RGB')elif data_name == 'CelebA':left, right, top, bottom = 25, 153, 45, 173image = image.crop((left, top, right, bottom))if self.transform is not None:image = self.transform(image)yield image, labeldef collate_ann(file_path):label = int(os.path.basename(os.path.dirname(file_path))) + N_CELEBA_CLASSESdata_name = os.path.basename(os.path.dirname(os.path.dirname(file_path)))return file_path, label, data_namedef load_celeba_labels(labels_path: str) -> Dict[str, int]:labels = []data_path = os.path.split(labels_path)[0]with open(labels_path, 'r', encoding='utf-8') as labels_file:lines = labels_file.readlines()for line in lines:file_name, class_id = line.split(' ')class_id = int(class_id[:-1])labels.append((f'{data_path}/img_align_celeba/{file_name}', class_id, 'CelebA'))return labelsdef build_datapipes(cfg: Config) -> torchdata.datapipes.iter.IterDataPipe:celeba_dp = torchdata.datapipes.iter.IterableWrapper(load_celeba_labels(labels_path=f'{cfg.data_path}/CelebA/identity_CelebA.txt'))digiface_dp = torchdata.datapipes.iter.FileLister(f'{cfg.data_path}/DigiFace1M', masks='*.png', recursive=True)digiface_dp = digiface_dp.map(collate_ann)datapipe = celeba_dp.concat(digiface_dp)datapipe = datapipe.shuffle(buffer_size=100000)datapipe = datapipe.sharding_filter()datapipe = datapipe.load_image(transform=cfg.transform)return datapipe

Torch的DataLoader是同时支持Datasets和DataPipe的,所以我们可以直接使用

 def main():cfg = Config()datapipe = build_datapipes(cfg)loader = torch.utils.data.DataLoader(dataset=datapipe,batch_size=cfg.batch_size,shuffle=True,drop_last=True,num_workers=cfg.n_workers)utils.train(loader, cfg)

加速数据读取的一个小技巧

批处理中耗时最长的操作之一是从磁盘读取图片。为了减少这个操作所花费的时间,可以加载所有图像并将它们分割成小的数据集,例如10,000张图像保存为.pickle文件。在读取时每一个worker只要读取一个相应的pickle文件即可

 def prepare_data():cfg = Config()cfg.transform = Noneos.makedirs(cfg.prepared_data_path, exist_ok=True)celeba_dataset = dataset_example.CelebADataset(f'{cfg.data_path}/CelebA', cfg.transform)digiface_dataset = dataset_example.DigiFace1M(f'{cfg.data_path}/DigiFace1M', cfg.transform, cfg.n_celeba_classes)dataset = torch.utils.data.ConcatDataset([celeba_dataset, digiface_dataset])shard_size = 10000next_shard = 0data = []shuffled_idxs = np.arange(len(dataset))np.random.shuffle(shuffled_idxs)for idx in tqdm(shuffled_idxs):data.append(dataset[idx])if len(data) == shard_size:with open(f'{cfg.prepared_data_path}/{next_shard}_shard.pickle', 'wb') as _file:pickle.dump(data, _file)next_shard += 1data = []with open(f'{cfg.prepared_data_path}/{next_shard}_shard.pickle', 'wb') as _file:pickle.dump(data, _file)

下面就是使用FileLister收集.pickle数据集的所有路径,按worker划分并在每个worker上加载.pickle数据。

 @torchdata.datapipes.functional_datapipe("load_pickle_data")class PickleDataLoader(torchdata.datapipes.iter.IterDataPipe):def __init__(self, source_datapipe, **kwargs) -> None:self.source_datapipe = source_datapipeself.transform = kwargs['transform']def __iter__(self) -> Tuple[torch.Tensor, int]:for file_name in self.source_datapipe:with open(file_name, 'rb') as _file:pickle_data = pickle.load(_file)for image, label in pickle_data:image = self.transform(image)yield image, labeldef build_datapipes(cfg: Config) -> torchdata.datapipes.iter.IterDataPipe:datapipe = torchdata.datapipes.iter.FileLister(cfg.prepared_data_path, masks='*.pickle')datapipe = datapipe.shuffle()datapipe = datapipe.sharding_filter()datapipe = datapipe.load_pickle_data(transform=cfg.transform)return datapipe

数据加载对比

我们比较三种不同数据加载方法。对于所有测试,batch_size = 600。

n workersDatasets, secDataPipes, secDataPipe + pickle, sec
1035817986758
5100342993760

当在未准备好的数据上使用DataPipe进行训练时(不使用pickle),前几百个批次生成非常快,GPU使用率几乎是100%,但随后速度逐渐下降,这种方法甚至比使用n_workers=10的数据集还要慢。虽然我理解这两种方法的速度是一样的因为执行的操作是一样的,但实际上却不一样

DataLoader的最佳n_workers没有一个固定值,因为这取决于任务(图像大小,图像预处理的复杂性等等)和计算机配置(HDD vs SSD)。

当在有大量小图像的数据集上训练时,做数据的准备是必要的的,比如将小文件组合成几个大文件,这样可以减少从磁盘读取数据的时间。但是使用这种方法需要在将数据写入shard之前彻底打乱数据,来避免学习收敛性恶化。还需要选择合理的shard大小(它应该足够大以防止磁盘问题并且足够小以有效地使用datappipes中的Shuffler打乱数据)。

最后本文的代码在这里,有兴趣的可以自行测试比较:

https://github.com/karinaodm/pytorch-compare-datasets-vs-datapipes

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/335986.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

数 据 分 析 1

1.使用Wireshark查看并分析靶机桌面下的capture.pcapng数据包文件,找到黑客的IP地址,并将黑客的IP地址作为Flag值(如:172.16.1.1)提交;172.16.1.41 查找:tcp.connection.syn 2.继续分析captu…

clickhouse常规的优化方法

一、建表优化 1.1日期字段避免使用String存储 建表时能用数值型或日期时间型表示的字段就不要用字符串,全String 类型在以Hive 为中心的数仓建设中常见,但ClickHouse 环境不应受此影响。 虽然ClickHouse 底层将DateTime 存储为时间戳Long 类型&#xf…

使用pyinstaller打包生成exe(解决gradio程序的打包问题)

解决 [Errno 2] No such file or directory: gradio_client\types.json 问题,不需要手动创建hook文件 解决 FileNotFoundError: [Errno 2] No such file or directory: gradio\blocks_events.pyc 问题,不需要将pyi文件重命名为pyc文件 最终实现gradio程…

mysql忘记root密码后怎么重置

mysql忘记root密码后重置方法【windows版本】 重置密码步骤停掉mysql服务跳过密码进入数据库在user表中重置密码使用新密码登录mysql到此,密码就成功修改了,完结,撒花~ 重置密码步骤 当我们忘记mysql的密码时,连接mysql会报这样的…

20240109适配selinux让移远的4G模块EC20在Firefly的AIO-3399J开发板的Android11下跑通

20240109适配selinux让移远的4G模块EC20在Firefly的AIO-3399J开发板的Android11下跑通 2024/1/9 10:46 缘起:使用友善之臂的Android11可以让EC20上网,但是同样的修改步骤,Toybrick的Android11不能让EC20上网。 最后确认是selinux的问题&#…

控制论和科学方法论

《控制论与科学方法论》,真心不错。 书籍原文电子版PDF:https://pan.quark.cn/s/aa40d59295df(分类在学习目录下) 备用链接:https://pan.xunlei.com/s/VNgj2vjW-Hf_543R2K8kbaifA1?pwd2sap# 控制论是一种让系统按照我…

git克隆失败提示RPC failed的解决方法

现象 $ git clone https://github.com/guillemj/dpkg.git Cloning into dpkg... remote: Enumerating objects: 113312, done. remote: Counting objects: 100% (18045/18045), done. remote: Compressing objects: 100% (3915/3915), done. error: RPC failed; curl 18 trans…

Python——通过统计图像像素值初步分析图像噪声类型

图像噪声是指图像中不随真实场景变化而变化的随机干扰。噪声会影响图像的质量,因此需要对其进行去噪处理。 目录 一、图像噪声1.1 噪声类型1.2 结合峰度和偏度判断噪声1.2.1 峰度和偏度1.2.2 常见噪声的峰度和偏度 二、代码三、测试结果四、总结 一、图像噪声 图像…

CloudCompare——拟合空间球

目录 1.拟合球2.软件操作3.算法源码4.相关代码 本文由CSDN点云侠原创,CloudCompare——拟合空间球,爬虫自重。如果你不是在点云侠的博客中看到该文章,那么此处便是不要脸的爬虫与GPT生成的文章。 1.拟合球 源码里用到了四点定球,…

数学建模-Matlab R2022a安装步骤

软件介绍 MATLAB是一款商业数学软件,用于算法开发、数据可视化、数据分析以及数值计算的高级技术计算语言和交互式环境,主要包括MATLAB和Simulink两大部分,可以进行矩阵运算、绘制函数和数据、实现算法、创建用户界面、连接其他编程语言的程…

给自己创建的GPTs添加Action(查天气)

前言 在这篇文章中,我将分享如何利用ChatGPT 4.0辅助论文写作的技巧,并根据网上的资料和最新的研究补充更多好用的咒语技巧。 GPT4的官方售价是每月20美元,很多人并不是天天用GPT,只是偶尔用一下。 如果调用官方的GPT4接口&…

设备树在开发板的系统中的体现

一. 简介 设备树文件中的设备节点,可以在开发板系统中看到。 也就说,开发板加载设备树文件,Linux内核启动系统以后,可以在根文件系统里看到设备树的节点信息。在/proc/device-tree/目录下存放着设备树信息。 二. 设备树在开发板…