笔记2:cifar10数据集获取及pytorch批量处理

(1)cifar10数据集预处理

CIFAR-10是一个广泛使用的图像数据集,它由10个类别的共60000张32x32彩色图像组成,每个类别有6000张图像。
CIFAR-10官网
以下为CIFAR-10数据集data_batch_*表示训练集数据,test_batch表示测试集数据
在这里插入图片描述
预处理结果(将CIFAR-10保存为图片格式)
在这里插入图片描述

#!/usr/bin/env python
# -*- coding:utf-8 -*-
"""
@author: LIFEI
@time: 2024/5/8 15:00 
@file: 加载cifar10数据.py
@project: 深度学习(4):深度神经网络(DNN)
@describe: TEXT
@# ------------------------------------------(one)--------------------------------------
@# ------------------------------------------(two)--------------------------------------
"""
import glob
import pickle
import numpy as np
import cv2 as
import os
#%% md
cifar10官网处理函数:
#%%
def unpickle(file):with open(file, 'rb') as fo:dict = pickle.load(fo, encoding='bytes')return dict
#%% md
利用上面的函数进行读取数据:
#%%
label = ["airplane","automobile", "bird","cat", 'deer',"dog","frog","horse","ship","truck"]  #标签矩阵
filepath = glob.glob("../../test_doucments/cifar-10-batches-py/data_batch_*") # 获取当前文件的路径,返回路径矩阵,获取test数据集时将data_batch——*改为test_batch*
write_path =["./train","./test"] #
print(filepath)
for file in filepath:if not file:print("空集出错")else:# print(file)data_dic = unpickle(file) # 将二进制表示形式转换回 Python 对象的反序列化过程,结果为字节型数据# print(data_dic.keys()) #此处的keys主要有b"data",b"labels",b"filenames"index = 0for im_data in data_dic[b"data"]:  # 遍历影像矩阵数据im_label = data_dic[b"labels"][index] # 赋值标签数据im_filename = data_dic[b"filenames"][index] # 赋值影像名字index +=1# print(f"图像的文件名为:{im_filename}\n",f"图像的所属标签为:{im_label}\n",f"图像的矩阵数据为:{im_data}\n")#开始存放数据im_label_name = label[im_label]im_data_data = np.reshape(im_data,(3,32,32)) # 将影像矩阵数据转换为图像形式# 由于需要opencv进行写出图像,因此需要转化通道im_data_data = np.transpose(im_data_data,(1,2,0))imgname = f"当前图像名称{im_label},所属标签{im_label_name}"cv.imshow(str( im_label_name),cv.resize(im_data_data,(500,500))) # 将显示时的图像变大,图像数据本身大小不变cv.waitKey(0)cv.destroyAllWindows()#创建文件夹for path in write_path:if not os.path.exists("{}/{}".format(path,im_label_name)): #查看存储路径中的文件夹是否存在os.mkdir("{}/{}".format(path,im_label_name)) # 没有就创建文件else:breakcv.imwrite("{}/{}/{}".format(write_path[0],im_label_name,str(im_filename,'utf-8')),im_data_data)# #write_path[1]写出测试数据的时候将write_path[0]改为write_path[1]
#%% md
将cifar10数据转为图片格式并保存

(2)利用pytorch将图像转为张量数据

或是批量读取训练集和测试集数据
在这里插入图片描述

#!/usr/bin/env python
# -*- coding:utf-8 -*-
"""
@author: LIFEI
@time: 2024/5/8 15:00 
@file: 加载cifar10数据.py
@project: 深度学习(4):深度神经网络(DNN)
@describe: TEXT
@# ------------------------------------------(one)--------------------------------------
@# ------------------------------------------(two)--------------------------------------
"""
# 导入库
import glob
from torchvision import transforms
from torch.utils.data import Dataset,DataLoader
import cv2 as cv
# DataLoader参考网址https://blog.csdn.net/sazass/article/details/116641511from PIL import Imagelabel_name = ["airplane","automobile", "bird","cat", 'deer',"dog","frog","horse","ship","truck"]
label_list = {} # 创建一个字典用于存储标签和下标
index = 0
for name in label_name:  # 也可以采用for index,name in enumerate(label_name)label_list[name] = index # 字典的常规赋值操作index += 1def default_loder(path):# return Image.open(path).convert("RGB") # 也可采用opencv读取img = cv.imread(path)return cv.cvtColor(img,cv.COLOR_BGR2RGB)# 定义训练集数据的增强   下面的Compose表示拼接需要增强的操作
train_transform = transforms.Compose([transforms.RandomCrop(28,28), #进行随机裁剪为28*28大小transforms.RandomHorizontalFlip(), #垂直方向翻转transforms.RandomVerticalFlip(), #水平方向的翻转transforms.RandomRotation(90), #随机旋转90度transforms.RandomGrayscale(0.1), #灰度转化transforms.ColorJitter(0.3,0.3,0.3,0.3), #随机颜色增强transforms.ToTensor() #将数据转化为张量数据
])# 定义pytorh的dataset类
class MyData(Dataset):def __init__(self,im_list,transform = None,loder = default_loder):     #初始化函数super(MyData,self).__init__() #初始化这个类# 获取图片的路径以及标签号images = []for item_data in im_list:# 注意下面这一步,split("\\")根据不同的操作系统会不相同,有的是"/"img_label_name = item_data.split("\\")[-2] #通过遍历每一个路径进行获取当前图片的文字标签images.append([item_data,label_list[img_label_name]])self.images = imagesself.tranform =transformself.loder = loderdef __getitem__(self, index_num): # 此处的index_num是在训练的时候反复传进来的值img_path , img_label = self.images[index_num] #这里的img_data = self.loder(img_path)  # 这里用到了self.loder(path)==>default_loder(path)外置函数if self.tranform is not None: # 判断数据是否增强img_data = self.tranform(img_data)return img_data,img_labeldef __len__(self):return len(self.images)train_list = glob.glob("./train/*/*.png") # glob.glob 获取改路径下的所有文件路径并返回为列表
test_list = glob.glob("./test/*/*.png")train_dataset = MyData(train_list,transform = train_transform)
test_dataset = MyData(test_list,transform = transforms.ToTensor()) #测试集无需进行图像增强操作,直接转为张量train_data_loder = DataLoader(dataset =train_dataset,batch_size=6,shuffle=True,num_workers=4)
test_data_loder = DataLoader(dataset =test_dataset,batch_size=6,shuffle=False,num_workers=4)
print(f"训练集的大小:{len(train_dataset)}")
print(f"测试集的大小:{len(test_dataset)}")

注:以上代码非原创,仅供个人记录学习笔记,若有侵权,请我联系删除

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/678786.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Ubuntu22.04下安装kafka_2.11-0.10.1.0并运行简单实例

目录 一、版本信息 二、安装Kafka 1.将Kafka安装包移到下载目录中 2.下载Spark并确保hadoop用户对Spark目录有操作权限 三、启动Kafka并测试Kafka是否正常工作 1.启动Kafka 2.测试Kafka是否正常工作 一、版本信息 虚拟机产品:VMware Workstation 17 Pro 虚…

C++从入门到精通---模版

文章目录 泛型编程函数模版模版参数的匹配原则类模版类模版的定义格式类模版的实例化 总结 泛型编程 泛型编程是一种编程范式,旨在实现通用性和灵活性。它允许在编写代码时使用参数化类型,而不是具体的类型,从而使代码更加灵活和可重用。 在…

Autoxjs 实践-Spring Boot 集成 WebSocket

概述 最近弄了福袋工具,由于工具运行中,不好查看福袋结果,所以我想将福袋工具运行数据返回到后台,做数据统计、之后工具会越来越多,就弄了个后台,方便管理。 实现效果 WebSocket? websocket是…

图像处理之SVD检测显示屏缺陷(C++)

图像处理之SVD检测显示屏缺陷(C) 文章目录 图像处理之SVD检测显示屏缺陷(C)前言一、SVD算法简介二、代码实现总结 前言 显示屏缺陷检测是机器视觉领域的一处较广泛的应用场景,显示屏主要有LCD和OLED,缺陷类…

本地连接服务器Jupyter【简略版】

首先需要在你的服务器激活conda虚拟环境: 进入虚拟环境后使用conda install jupyter命令安装jupyter: 安装成功后先不要着急打开,因为需要设置密码,使用jupyter notebook password命令输入自己进入jupyter的密码: …

USB系列一:USB技术概念

在这里USB的历史就不赘述了,有兴趣可以自己去搜索。也省略掉USB接口的概述,这些都是一些飞技术性的常识性的知识,没必要浪费篇幅和文字来描述。 一、USB总线版本:(从USB1.1说起) 1、USB1.1 1998年9月23日…

怎么获取二维码链接?二维码图片链接提取技巧

现在很多内容会使用二维码的方式来做展示,这种方式可以实现内容的快速传递,而且当遇到无法扫码的情况时,也可以通过二维码链接来访问内容,与其他内容分享方式相比更加灵活。那么如何快速提取二维码链接呢?可能有很多小…

0508_IO3

练习1&#xff1a; 1&#xff1a;使用 dup2 实现错误日志功能 使用 write 和 read 实现文件的拷贝功能&#xff0c;注意&#xff0c;代码中所有函数后面&#xff0c;紧跟perror输出错误信息&#xff0c;要求这些错误信息重定向到错误日志 err.txt 中去 1 #include <stdio.h…

鸿蒙开发接口Ability框架:【 (ServiceExtensionAbility)】

ServiceExtensionAbility ServiceExtensionAbility模块提供ServiceExtension服务扩展相关接口的能力。 说明&#xff1a; 本模块首批接口从API version 9开始支持。后续版本的新增接口&#xff0c;采用上角标单独标记接口的起始版本。 本模块接口仅可在Stage模型下使用。 导入…

Kafka和Spark Streaming的组合使用(Spark 3.5.1)

一、安装Kafka 1.执行以下命令完成Kafka的安装&#xff1a; cd ~ //默认压缩包放在根目录 sudo tar -zxf kafka_2.11-2.3.1.tgz -C /usr/local cd /usr/local sudo mv kafka_2.11-2.3.1 kafka-2.3.1 sudo chown -R qiangzi ./kafka-2.3.1 二、启动Kafaka 1.首先需要启动K…

Qt应用开发(拓展篇)——图表 QChart

一、前言 QChart是一个图形库模块&#xff0c;它可以实现不同类型的序列和其他图表相关对象(如图例和轴)的图形表示。要在布局中简单地显示图表&#xff0c;可以使用QChartView来代替QChart。此外&#xff0c;线条、样条、面积和散点序列可以通过使用QPolarChart类表示为极坐标…

给window电脑安装Linux系统

身为小白的我们在安装Linux系统时会遇到很多麻烦&#xff0c;没有接触过的命令&#xff0c;没有实际操作的经验&#xff0c;觉得Linux遥不可及&#xff0c;这篇博客讲述了我安装Linux的经历与安装过程遇到的问题与解决方案。我是为了学习Linux开发而安装&#xff0c;目的不同安…