最近在跟着小土堆pytorch的视频跟着学习python,根据自己的理解和课程上面的知识,写了这一篇学习笔记。
1、加载数据
数据的加载是学习pytorch的第一步,我们需要加载数据,完成特征工程,对加载数据存在的一些特征来进行分析和处理,进而利用相关算法训练得到模型。
数据该如何加载呢?
首先,如果是文本之类数据的话,可以使用open()函数进行文件读取操作,对于图片的话,可以使用PIL下面的一个api,调用里面的open()方法来打开图片,如果要使用,则需要进行导包的操作。
from PIL import Image
如果导需要从某种数据源加载数据,并对这些数据进行预处理和格式化的话,利用pytorch中的Dataset类是最为方便的,也需要导包。
from torch.utils.data import Dataset
Dataset类里面定义了两种方法:
__len__()
: 返回数据集中的样本数量。__getitem__(idx)
: 根据给定的索引idx
返回一个样本
我们需要自己定义一个类,这个类继承Dataset类,并重写相关方法
需要调用系统路径,导入os模块,不要忘记了
import os
此时先定义一个自定义类,这个自定义类继承于Dataset类
class MyData(Dataset):
然后重写Dataset里面的__init__和__getitem__方法
class MyData(Dataset):# 初始化方法,当创建MyData对象时会被调用# root_dir: 数据集的根目录# label_dir: 自定义的类别目录,通常是某个类别的子目录def __init__(self, root_dir, label_dir):self.root_dir = root_dir # 存储根目录self.label_dir = label_dir # 存储类别目录# 拼接根目录和类别目录,得到完整路径self.path = os.path.join(self.root_dir, self.label_dir)# 获取该类别目录下的所有文件/图片名,并存储到self.img_path列表中self.img_path = os.listdir(self.path)# 根据索引获取数据集中的单个样本# idx: 样本的索引def __getitem__(self, idx):# 获取索引对应的图片名img_name = self.img_path[idx]# 拼接完整的图片路径img_item_path = os.path.join(self.root_dir, self.label_dir, img_name)# 打开图片并获取图片对象image = Image.open(img_item_path)# 使用类别目录作为标签label = self.label_dir# 返回图片对象和标签return image, label
记得在py文件统计目录下有相关文件,我是跟着小土堆的课程,所以就下载了dataset的数据集
最后定义好相关的变量和函数即可
root_dir = "dataset/train"
ants_label_dir = "ants_image"
bees_label_dir = "bees_image"
ants_dataset = MyData(root_dir,ants_label_dir)
bees_dataset = MyData(root_dir,bees_label_dir)
这样就能得到ants_dataset和bees_dataset两个类别的数据集
2、查看数据
查看数据集的话,需要用到DataLoader数据加载类来加载数据,调用Transform来对数据进行增强,通过使用 transforms.Compose
来组合多个转换操作(这是后面要学习的)
下面的代码可以看下,但不用深究,看个大致就行。
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import os
import torch
from torchvision import transformsclass MyData(Dataset):def __init__(self, root_dir, label_dir, transform=None):self.root_dir = root_dirself.label_dir = label_dirself.path = os.path.join(self.root_dir, self.label_dir)self.img_names = os.listdir(self.path)self.transform = transformself.label = self.label_dir.replace("_image", "") # 假设标签是目录名的前缀def __getitem__(self, idx):img_name = self.img_names[idx]img_path = os.path.join(self.path, img_name)image = Image.open(img_path)if self.transform:image = self.transform(image)return image, self.labeldef __len__(self):return len(self.img_names)root_dir = "dataset/train"
ants_label_dir = "ants_image"
bees_label_dir = "bees_image"# 定义数据增强
transform = transforms.Compose([transforms.Resize((224, 224)), # 调整为适合模型输入的尺寸transforms.ToTensor(), # 将 PIL Image 或 numpy.ndarray 转换为 torch.FloatTensortransforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # ImageNet 的均值和标准差
])ants_dataset = MyData(root_dir, ants_label_dir, transform=transform)
bees_dataset = MyData(root_dir, bees_label_dir, transform=transform)# 创建 DataLoader
batch_size = 4
train_loader = DataLoader(ants_dataset, batch_size=batch_size, shuffle=True)# 查看数据集
for images, labels in train_loader:print("Images batch shape:", images.shape)print("Labels batch:", labels)