# 判断某个文件是否是图像 # enswith判断是否以指定的.png,.jpg,.jpeg结尾的字符串 # 可以根据情况扩充图像类型,加入.bmp、.tif等 def is_image_file(filename):return any(filename.endswith(extension) for extension in [".png", ".jpg", ".jpeg"])# 读取图像转为YCbCr模式,得到Y通道 def load_img(filepath):img = Image.open(filepath).convert('YCbCr')y, _, _ = img.split()return y# 裁剪大小,宽高一致为300 # 如果想训练自己的数据集,请根据情况修改裁剪大小 CROP_SIZE = 300# 封装数据集,适配后面的torch.utils.data.DataLoader中的dataset,定义成类似形式 # 类参数为图像文件夹路径和放大倍数 # __len__(self) 定义当被len()函数调用时的行为(返回容器中元素的个数) #__getitem__(self) 定义获取容器中指定元素的行为,相当于self[key],即允许类对象可以有索引操作。 #__iter__(self) 定义当迭代容器中的元素的行为 # 返回输入图像和标签,传入DataLoader的dataset参数 class DatasetFromFolder(Dataset):def __init__(self, image_dir, zoom_factor):super(DatasetFromFolder, self).__init__()self.image_filenames = [join(image_dir, x) for x in listdir(image_dir) if is_image_file(x)] # 图像路径列表crop_size = CROP_SIZE - (CROP_SIZE % zoom_factor) # 处理放大倍数,防止用户瞎设置,本例只能设置为2,3,4,大小不变# 数据集变换# 还有一些其他的变换操作,如归一化等,遇到一个积累一个self.input_transform = transforms.Compose([transforms.CenterCrop(crop_size), # 从图片中心裁剪成300*300 transforms.Resize(crop_size // zoom_factor), # Resize, 输入应该是缩放倍数后的图像,因为先缩小后放大 transforms.Resize(crop_size, interpolation=Image.BICUBIC), # 双三次插值transforms.ToTensor()]) # 图像转成tensor# label标签,超分不是分类问题,定义成一样的就行self.target_transform = transforms.Compose([transforms.CenterCrop(crop_size), transforms.ToTensor()])def __getitem__(self, index):input = load_img(self.image_filenames[index]) # 输入是图像的Y通道,即亮度通道target = input.copy()input = self.input_transform(input)target = self.target_transform(target)return input, targetdef __len__(self):return len(self.image_filenames) # 图像个数