Dataset的简单使用

Pytorch 给我们提供了一个方法,方便我们加载数据,我们可以使用这个框架,去加载我们的数据。看下伪代码:

# ================================================================== #
#                Input pipeline for custom dataset                 #
# ================================================================== ## You should build your custom dataset as below.
class CustomDataset(torch.utils.data.Dataset):def __init__(self):# TODO# 1. Initialize file paths or a list of file names. passdef __getitem__(self, index):# TODO# 1. Read one data from file (e.g. using numpy.fromfile, PIL.Image.open).# 2. Preprocess the data (e.g. torchvision.Transform).# 3. Return a data pair (e.g. image and label).passdef __len__(self):# You should change 0 to the total size of your dataset.return 0 # You can then use the prebuilt data loader. 
custom_dataset = CustomDataset()
train_loader = torch.utils.data.DataLoader(dataset=custom_dataset,batch_size=64, shuffle=True)
  • __getitem__:返回一个样本
  • __len__:返回样本的数量

首先先创建一个文件夹,将图片放在同一个文件夹下。

image-20230829102904815

导入库文件

import torch 
import torchvision.datasets 
from torch.utils.data import Dataset 
import os 
from PIL import Image 
import numpy as np 
import torchvision.transforms as transforms 

图片数据预处理

预处理在机器学习和深度学习中起着重要的作用,它包括对输入数据进行一系列的变换和标准化操作。以下是为什么需要预处理的一些常见原因:

  1. 数据归一化/标准化:预处理过程中的归一化/标准化步骤有助于将数据的范围缩放到一个可接受的范围,以便更好地适应模型的训练。这有助于提高模型的收敛速度,并可以避免梯度消失或爆炸的问题。
  2. 数据增强:通过应用一系列的图像变换,如旋转、裁剪、平移、翻转等,可以扩增训练数据集,从而增加模型的泛化能力。数据增强可以减轻过拟合问题,并提高模型对多样性数据的鲁棒性。
  3. 数据格式转换:预处理可以将数据从原始格式(如图像文件、文本文件等)转换为模型所需的张量格式。例如,在计算机视觉任务中,图像通常被转换为张量,并进行通道重新排列、大小调整等操作。
  4. 噪声去除和数据清洗:预处理过程也可以用于去除数据中的噪声、异常值或无效样本。这有助于提高数据质量,并减少对模型的负面影响。

来自chatGPT

# 预处理
transform = transforms.Compose([ # 使用 Compose 可以将这些操作串联在一起transforms.Resize((224,224)), # 调整图片大小transforms.ToTensor(), # 将图片转换为Tensor对象,方便作为神经网络的输入transforms.Normalize( (0.1307, ), (0.3081, )) # 对图片进行归一化
])

定义Dataset类

__init__

__init__里面是初始化方法,例如传入图片的路径,或者要不要选择预处理等。

    # 初始化:指定路径,是否进行预处理等def __init__(self, path, transform = None) -> None: super().__init__()# os.listdir : 会将data下面的image中所有的文件读取,放在imgs里面img_path = os.path.join(path, "image/") # 进行拼接 得到 data/train/image/imgs = os.listdir(img_path) # 取出path下所有的文件self.imgs = [os.path.join(img_path, img) for img in imgs]self.transforms = transform # 图像预处理

__getitem__

__getitem__用于返回一个样本,返回之前做的处理数据的操作,也在__getitem__里面。

    def __getitem__(self, index): # 读取图片img_path = self.imgs[index] # 图片路径label_path = img_path.replace("image", "label") # 得到label文件夹下数据label = Image.open(label_path)data = Image.open(img_path)if self.transforms: # 图片预处理data = self.transforms(data)return data, label # tuple类型

__len__

__len__返回样本个数(图片路径的个数)

    def __len__(self):return len(self.imgs)

测试

image-20230829103513330

全部代码

import torch 
import torchvision.datasets 
from torch.utils.data import Dataset 
import os 
from PIL import Image 
import numpy as np 
import torchvision.transforms as transforms # 预处理
data_transform = transforms.Compose([ # 使用 Compose 可以将这些操作串联在一起transforms.Resize((224,224)), # 调整图片大小transforms.ToTensor(), # 将图片转换为Tensor对象,方便作为神经网络的输入transforms.Normalize( (0.1307, ), (0.3081, )) # 对图片进行归一化
])class Data(Dataset):# 初始化:指定路径,是否进行预处理等def __init__(self, path, transform = None) -> None: super().__init__()# os.listdir : 会将data下面的image中所有的文件读取,放在imgs里面img_path = os.path.join(path, "image/") # 进行拼接 得到 data/train/image/imgs = os.listdir(img_path) # 取出path下所有的文件self.imgs = [os.path.join(img_path, img) for img in imgs]self.transforms = transform # 图像预处理def __getitem__(self, index): # 读取图片img_path = self.imgs[index] # 图片路径label_path = img_path.replace("image", "label") # 得到label文件夹下数据label = Image.open(label_path)data = Image.open(img_path)if self.transforms: # 图片预处理data = self.transforms(data)label = self.transforms(label)return data, label # tuple类型def __len__(self):return len(self.imgs)# ts1 = Data('data/train/', transform=data_transform)
# print(type(ts1[0]))
# print(ts1[0])
# print(len(ts1))if __name__ == '__main__':ts1 = Data('data/train/', transform=data_transform)for i,(img, label) in enumerate(ts1):print(i, 'img', img.size(), 'label', label.size())

关于pytorch的数据处理-数据加载Dataset_datasets pytorch_Henry_zhangs的博客-CSDN博客

Pytorch深度学习实战教程(三):UNet模型训练,深度解析! - 知乎 (zhihu.com)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/87685.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

CUDA小白 - NPP(2) -图像处理-算数和逻辑操作

cuda小白 原文链接 NPP GPU架构近些年也有不少的变化,具体的可以参考别的博主的介绍,都比较详细。还有一些cuda中的专有名词的含义,可以参考《详解CUDA的Context、Stream、Warp、SM、SP、Kernel、Block、Grid》 常见的NppStatus&#xff0c…

传承精神 缅怀伟人——湖南多链优品科技有限公司赴韶山开展红色主题活动

8月27日上午, 湖南多链优品科技有限公司全体员工怀着崇敬之情,以红色文化为引领,参加了毛泽东同志诞辰130周年的纪念活动。以董事长程小明为核心的公司班子成员以及全国优秀代表近70人一行专赴韶山,缅怀伟人毛泽东同志的丰功伟绩。…

Hbase文档--架构体系

阿丹: 基础概念了解之后了解目标知识的架构体系,就能事半功倍。 架构体系 关键组件介绍: HBase – Hadoop Database,是一个高可靠性、高性能、面向列、可伸缩的分布式存储系统,利用HBase技术可在廉价PC Server上搭建起…

TensorBoard的使用

TensorBoard:对图像进行变换 1. SummaryWriter的使用 ctrl类出现注释解析: 将条目直接log_dir写入要成为由TensorBoard使用。 “摘要编写器”类提供了一个高级 API 来创建事件文件,并在给定目录中添加摘要和事件。该类更新文件内容异步。…

ceph peering机制-状态机

本章介绍ceph中比较复杂的模块: Peering机制。该过程保障PG内各个副本之间数据的一致性,并实现PG的各种状态的维护和转换。本章首先介绍boost库的statechart状态机基本知识,Ceph使用它来管理PG的状态转换。其次介绍PG的创建过程以及相应的状…

Java学数据结构(4)——散列表Hash table 散列函数 哈希冲突

目录 引出散列表Hash table关键字Key和散列函数(hash function)散列函数解决collision哈希冲突(碰撞)分离链接法(separate chaining)探测散列表(probing hash table)双散列(double hashing) Java标准库中的散列表总结 引出 1.散列表,key&…

day2 牛客TOP100:BM 11-20 链表 二分法 流输入 小美加法

文章目录 链表BM11 链表相加(二)BM12 单链表的排序归并排序分割 超时辅助数组快排 BM13 判断一个链表是否为回文结构BM14 链表的奇偶重排BM15 删除有序链表中重复的元素-IBM16 删除有序链表中重复的元素-IIJZ35 复杂链表的复制 二分法BM17 二分查找-IBM18 二维数组中的查找BM19…

Git中smart Checkout与force checkout

Git中smart Checkout与force checkout 使用git进行代码版本管理,当我们切换分支有时会遇到这样的问题: 这是因为在当前分支修改了代码,但是没有commit,所以在切换到其他分支的时候会弹出这个窗口, 提示你选force checkout或者smart checko…

Windows11 安装 nvm node版本管理工具

在 Windows 11 上安装并配置 NVM 与 Node.js 版本管理工具 引言: Node.js 是一款强大的开发工具,而版本管理工具 NVM 则可以帮助我们在不同的项目中灵活地切换和管理 Node.js 版本。本篇博客将为大家介绍如何在 Windows 11 操作系统上安装 NVM&#xff…

手机无人直播软件有哪些,又有哪些优势?

如今,随着智能手机的普及和移动互联网的发展,手机无人直播成为了一个炙手可热的领域。手机无人直播软件为用户提供了便捷、灵活的直播方式,让更多商家人能够实现自己的直播带货的梦想。接下来,我们将探讨手机无人直播软件有哪些&a…

【随笔】如何使用阿里云的OSS保存基础的服务器环境

使用阿里云OSS创建一个存储仓库:bucket 在Linux上下载并安装阿里云的ossutil工具 // 命令行,是linux环境 3. 安装ossutil。sudo -v ; curl https://gosspublic.alicdn.com/ossutil/install.sh | sudo bash 说明:安装过程中,需要使用解压工具…

AP9234 9W升压恒流型 DCDC多串LED恒流驱动 2串3串 LED灯串

描述 AP9234是一款由基准电压源、振荡电路、误差放大电路、相位补偿电路、电流限制电路等构成的CMOS升压型DC/DC LED驱动。由于内置了低导通电阻的增强型N沟道功率MOSFET,因此适用于需要高效率、高输出电流的应用电路。另外,可通过在VSENSE端子连接电流…