简介
Clothing1M 包含 14 个类别的 100 万张服装图像。这是一个带有噪声标签的数据集,因为数据是从多个在线购物网站收集的,并且包含许多错误标记的样本。该数据集还分别包含 50k、14k 和 10k 张带有干净标签的图像,用于训练、验证和测试。
下载地址:https://github.com/Newbeeer/L_DMI/issues/8
Dataset & DataLoader
数据集目录结构:
└─images├─0│ ├─00│ ├─...│ └─99├─...└─9├─00├─...└─99
import osfrom PIL import Image
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import transforms# mode=0: noisy train set, mode=1: clean val set, mode=2: clean test set
class Clothing1m(Dataset):nb_classes = 14def __init__(self, mode=0, root='~/data/clothing1m', transform=None):root = os.path.expanduser(root)self.mode = modeself.root = rootself.transform = transformif mode == 0:txt_file = 'noisy_label_kv.txt'else:txt_file = 'clean_label_kv.txt'with open(os.path.join(root, txt_file), 'r') as f:lines = f.read().splitlines()self.labels = {line.split()[0]: int(line.split()[1]) for line in lines}data_path = []txt_file = ['noisy_train_key_list.txt', 'clean_val_key_list.txt', 'clean_test_key_list.txt']if mode in [0, 1, 2]:with open(os.path.join(root, txt_file[mode]), 'r') as f:lines = f.read().splitlines()for line in lines:data_path.append(line)else:raise ValueError('mode should be 0, 1 or 2')self.data = data_pathself.targets = [self.labels[img_path] for img_path in data_path]def __len__(self):return len(self.targets)def __getitem__(self, index):img_path = self.data[index]targets = self.labels[img_path]image = Image.open(os.path.join(self.root, img_path)).convert('RGB')image = self.transform(image)if self.mode == 0:return image, targets, indexreturn image, targetsclass Clothing1mDataloader:def __init__(self, batch_size=64, num_workers=8, root='~/data/clothing1m'):self.batch_size = batch_sizeself.num_workers = num_workersself.root = rootself.transform_train = transforms.Compose([transforms.Resize(256),transforms.RandomCrop(224),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize((0.6959, 0.6537, 0.6371), (0.3113, 0.3192, 0.3214)),])self.transform_test = transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize((0.6959, 0.6537, 0.6371), (0.3113, 0.3192, 0.3214)),])def train(self):transform = self.transform_traindataset = Clothing1m(mode=0, root=self.root, transform=transform)dataloader = DataLoader(dataset, batch_size=self.batch_size, shuffle=True,num_workers=self.num_workers, pin_memory=True)return dataloaderdef val(self):dataset = Clothing1m(mode=1, root=self.root, transform=self.transform_test)dataloader = DataLoader(dataset, batch_size=self.batch_size, shuffle=False,num_workers=self.num_workers, pin_memory=True)return dataloaderdef test(self):dataset = Clothing1m(mode=2, root=self.root, transform=self.transform_test)dataloader = DataLoader(dataset, batch_size=self.batch_size, shuffle=False,num_workers=self.num_workers, pin_memory=True)return dataloader
依赖
torch 2.3.1