Pytorch读取自己的数据集

数据集

在这里插入图片描述

流程图

  1. 导包
  2. 设置tfs
  3. 创建datasets.ImageFolder
  4. 创建torch.utils.data.DataLoader()
import time
import os
from tqdm import tqdm
import pandas as pd
import numpy as np
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader# 忽略烦人的红色提示
import warnings
warnings.filterwarnings("ignore")# 获取计算硬件
# 有 GPU 就用 GPU,没有就用 CPU
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print('device', device)from torchvision import transforms# 训练集图像预处理:缩放裁剪、图像增强、转 Tensor、归一化
train_transform = transforms.Compose([transforms.RandomResizedCrop(224),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])# 测试集图像预处理-RCTN:缩放、裁剪、转 Tensor、归一化
test_transform = transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])# 数据集文件夹路径
dataset_dir = "C:/Users/Administrator/Desktop/ResNet/data/xiaobo"train_path = os.path.join(dataset_dir, 'train_img')
test_path = os.path.join(dataset_dir, 'test_img')
print('训练集路径', train_path)
print('测试集路径', test_path)from torchvision import datasets
# 载入训练集
train_dataset = datasets.ImageFolder(train_path, train_transform)
# 载入测试集
test_dataset = datasets.ImageFolder(test_path, test_transform)BATCH_SIZE = 32# 训练集的数据加载器
train_loader = DataLoader(train_dataset,batch_size=BATCH_SIZE,shuffle=True,num_workers=4)# 测试集的数据加载器
test_loader = DataLoader(test_dataset,batch_size=BATCH_SIZE,shuffle=False,num_workers=4)print('训练集图像数量', len(train_dataset))
print('类别个数', len(train_dataset.classes))
print('各类别名称', train_dataset.classes)
print('测试集图像数量', len(test_dataset))
print('类别个数', len(test_dataset.classes))
print('各类别名称', test_dataset.classes)# 各类别名称
class_names = train_dataset.classes
n_class = len(class_names)
# 映射关系:类别 到 索引号
train_dataset.class_to_idx
# 映射关系:索引号 到 类别
idx_to_labels = {y:x for x,y in train_dataset.class_to_idx.items()}
print(idx_to_labels)
# 保存为本地的 npy 文件
np.save('idx_to_labels.npy', idx_to_labels)
np.save('labels_to_idx.npy', train_dataset.class_to_idx)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/701596.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

随机梯度算法应用动画场景

https://github.com/lilipads/gradient_descent_viz #include <math.h> #include <iostream> namespace Function { enum FunctionName { local_minimum, // 许多小坑 global_minimum, // 大坑 saddle_point, // 起伏山路 …

基于GWO灰狼优化的CNN-GRU-Attention的时间序列回归预测matlab仿真

目录 1.算法运行效果图预览 2.算法运行软件版本 3.部分核心程序 4.算法理论概述 4.1卷积神经网络&#xff08;CNN&#xff09;在时间序列中的应用 4.2 GRU网络 4.3 注意力机制&#xff08;Attention&#xff09; 4.4 GWO优化 5.算法完整程序工程 1.算法运行效果图预览…

【Android踩坑】 Constant expression required

gradle 8&#xff0c;报错 Constant expression required&#xff1a;意思是case语句后面要跟常量 解决1 单击switch语句&#xff0c;键盘按下altenter&#xff0c;将switch-case语句替换为if-else语句(或者手动修改) 解决2 在gradle.properties中添加 android.nonFinalRes…

IP代理网络协议介绍

在IP代理页面上&#xff0c;存在HTTP/HTTPS/Socks5三种协议。它们都是客户端与服务器之间交互的协议。 HTTP HTTP又称之为超文本传输协议&#xff0c;在因特网使用范围广泛。它是一种请求/响应模型&#xff0c;客户端向服务器发送请求&#xff0c;服务器解析请求后对客户端作出…

meshlab: pymeshlab合并多个物体模型并保存(flatten visible layers)

一、关于环境 请参考&#xff1a;pymeshlab遍历文件夹中模型、缩放并导出指定格式-CSDN博客 二、关于代码 本文所给出代码仅为参考&#xff0c;禁止转载和引用&#xff0c;仅供个人学习。 本文所给出的例子是https://download.csdn.net/download/weixin_42605076/89233917中的…

Centos 安装jenkins 多分支流水线部署前后端项目

1、安装jenkins 1.1 安装jdk 要求&#xff1a;11及以上版本 yum install yum install java-11-openjdk 1.2 安装jenkins 导入镜像 sudo wget -O /etc/yum.repos.d/jenkins.repo https://pkg.jenkins.io/redhat-stable/jenkins.repo出现以下错误 执行以下命令 sudo yum …

iview(viewUI) span-method 表格实现将指定列的值相同的行合并单元格

效果图是上面这样的&#xff0c;将第一列的名字一样的合并在一起&#xff1b; <template><div class"table-wrap"><Table stripe :columns"columns" :data"data" :span-method"handleSpan"></Table></div&…

喜大普奔!VMware Workstation Pro 17.5 官宣免费!

Broadcom 已经正式收购 VMware&#xff0c;【VMware中国】官方公众号已于3月11日更名为【VMware by Broadcom中国】。 13日傍晚&#xff0c;该公众号发表推文 V风拂面&#xff0c;好久不见 - 来自VMware 中国的问候 &#xff0c;意味着 VMware 带着惊喜和美好的愿景再次归来。 …

新书速览|MATLAB科技绘图与数据分析

提升你的数据洞察力&#xff0c;用于精确绘图和分析的高级MATLAB技术。 本书内容 《MATLAB科技绘图与数据分析》结合作者多年的数据分析与科研绘图经验&#xff0c;详细讲解MATLAB在科技图表制作与数据分析中的使用方法与技巧。全书分为3部分&#xff0c;共12章&#xff0c;第1…

IPSSL证书:为特定IP地址通信数据保驾护航

IPSSL证书&#xff0c;顾名思义&#xff0c;是专为特定IP地址设计的SSL证书。它不仅继承了传统SSL证书验证网站身份、加密数据传输的基本功能&#xff0c;还特别针对通过固定IP地址进行通信的场景提供了强化的安全保障。在IP地址直接绑定SSL证书的模式下&#xff0c;它能够确保…

分析 vs2019 cpp20 规范的 STL 库模板 function ,源码注释并探讨几个问题

&#xff08;1 探讨一&#xff09;第一个尝试弄清的问题是父类模板与子类模板的模板参数的对应关系&#xff0c;如下图&#xff1a; 我们要弄清的问题是创建 function 对象时&#xff0c;传递的模板参数 _Fty , 传递到其父类 _Func_class 中时 &#xff0c;父类的模板参数 _Ret…

k8s的整体架构及其内部工作原理,以及创建一个pod的原理

一、k8s整体架构 二、k8s的作用&#xff0c;为什么要用k8s&#xff0c;以及服务器的发展历程 1、服务器&#xff1a;缺点容易浪费资源&#xff0c;且每个服务器都要装系统&#xff0c;且扩展迁移成本高 2、虚拟机很好地解决了服务器浪费资源的缺点&#xff0c;且部署快&#x…