Pytorch入门实战 P2-CIFAR10彩色图片识别

目录

一、前期准备

1、数据集CIFAR10

2、判断自己的设备,是否可以使用GPU运行。

3、下载数据集,划分好训练集和测试集

4、加载训练集、测试集

5、取一个批次查看下

6、数据可视化

二、搭建简单的CNN网络模型

三、训练模型

1、设置超参数

2、编写训练函数

3、编写测试函数

4、正式训练

四、模型训练结果可视化

五、模型训练结果:


  • 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍖 原作者:K同学啊 | 接辅导、项目定制

这周的实战内容,主要使用的数据集是CIFAR10数据集。用来验证彩色图片的识别。

一、前期准备

1、数据集CIFAR10

我们使用的数据集的文档地址:Datasets — Torchvision 0.17 documentation

简单介绍下CIFAR10数据集:

CIFAR-10数据集由60000张32 × 32彩色图像组成,分为10个类,每个类有6000张图像。

50000张训练图像10000张测试图像

2、判断自己的设备,是否可以使用GPU运行。

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

3、下载数据集,划分好训练集和测试集

import torchvision.datasets# 下载训练集
train_ds = torchvision.datasets.CIFAR10('data',train=True,transform=torchvision.transforms.ToTensor(),download=True)
# 下载测试集
test_ds = torchvision.datasets.CIFAR10('data',train=False,transform=torchvision.transforms.ToTensor(),download=True)

4、加载训练集、测试集

# 使用dataloader加载数据集,并设置好batch_size
batch_size = 32
train_dl = torch.utils.data.DataLoader(train_ds,shuffle=True,batch_size=batch_size)
test_dl = torch.utils.data.DataLoader(test_ds,batch_size=batch_size)

5、取一个批次查看下

# 取一个批次,查看下数据
imgs,labels = next(iter(train_dl))
print(imgs.shape)   #  数据的shape为:[batch_size,channel,height,weight]  
'''对于CIFAR10,这里的shape是 [32,3,32,32],即 因为取得是train_dl的数据,batch_size为32;channel为3是因为,是彩色图片RGB的3通道,如果是黑白图片,则channel为1;剩下的32x32是高度和宽度;
'''

6、数据可视化

即:展示下取到的数据。

# 数据可视化
plt.figure(figsize=(20,5))
for i, imgs in enumerate(imgs[:20]):npimg = imgs.numpy().transpose((1,2,0))   #.numpy()用于将Tensor转换为一个Numpy数组。transpose是Numpy数组的一个方法,用于重新排列数组的维度。plt.subplot(2, 10, i+1)plt.imshow(npimg, cmap=plt.cm.binary)plt.axis('off')
plt.show()

运行结果展示: 

二、搭建简单的CNN网络模型

 CNN(卷积神经网络),需要注意其结构、层与层之间的连接关系以及各层的功能。

①卷积层:负责提取特征。(通常使用局部连接权值共享方式,这有助于减少网络的参数数量和计算复杂度。)

②池化层:负责降低数据的空间尺寸和计算复杂度。

③全连接层:负责将提取的特征映射到输出类别。

# 构建简单的CNN网络
num_classes = 10
class Model(nn.Module):def __init__(self):super().__init__()# 特征提取self.conv1 = nn.Conv2d(3, 64, kernel_size=3)self.pool1 = nn.MaxPool2d(2)self.conv2 = nn.Conv2d(64, 64, kernel_size=3)self.pool2 = nn.MaxPool2d(2)self.conv3 = nn.Conv2d(64, 128, kernel_size=3)self.pool3 = nn.MaxPool2d(2)# 分类网络self.fc1 = nn.Linear(512, 256)self.fc2 = nn.Linear(256, num_classes)# 前向传播def forward(self,x):x = self.pool1(F.relu(self.conv1(x)))x = self.pool2(F.relu(self.conv2(x)))x = self.pool3(F.relu(self.conv3(x)))x = torch.flatten(x, start_dim=1)  # 线性层+激活函数  是构建复杂模型的基础x = F.relu(self.fc1(x))x = self.fc2(x)return x# 打印并加载模型
model = Model().to(device)
print(model)

三、训练模型

1、设置超参数

# 1、设置超参数
loss_fn = nn.CrossEntropyLoss() # 创建损失函数
learn_rate = 1e-2   #学习率
opt = torch.optim.SGD(model.parameters(), lr=learn_rate)   # 定义一个随机梯度下降优化器,即SGD优化器。# model.parameters() 返回模型中所有可训练的参数(通常是权重和偏置)

2、编写训练函数

# 2、编写训练函数
def train(dataloader, model, loss_fn, optimizer):size = len(dataloader.dataset) # 数据集的大小,一共60000张图片num_batches = len(dataloader)  # 批次数目 1875 (60000/32 = 1875)train_loss, train_acc = 0, 0   # 初始化训练的损失和正确率for X,y in dataloader:  # 获取图片及其标签X, y = X.to(device), y.to(device)# 计算预测误差pred = model(X)  # 网络输出loss = loss_fn(pred, y)  # 计算网络输出和真实值之间的差距,y为真实值,计算二者差值,即为损失。# 反向传播optimizer.zero_grad()  # grad属性归零loss.backward()  # 反向传播optimizer.step()  # 每一步自动更新# 记录acc与losstrain_acc += (pred.argmax(1) == y).type(torch.float).sum().item()train_loss += loss.item()train_acc /= sizetrain_loss /= num_batchesreturn train_acc, train_loss

3、编写测试函数

# 3、编写测试函数
# 测试函数和训练函数大致相同,但是由于不进行梯度下降对网络权重进行更新,所以不需要传入优化器。
def test(dataloader, model, loss_fn):size = len(dataloader.dataset)  # 数据集的大小,共10000张num_batches = len(dataloader)  # 批次数目 ,313( 10000/32 = 321.5 ,向上取整)test_loss, test_acc = 0, 0  # 初始化测试的损失和精确# 不进行训练时,停止梯度下降,节省计算内存消耗with torch.no_grad():for imgs, target in dataloader:imgs, target = imgs.to(device), target.to(device)# 计算losstarget_pred = model(imgs)loss = loss_fn(target_pred, target)test_loss += loss.item()test_acc += (target_pred.argmax(1) == target).type(torch.float).sum().item()test_acc /= sizetest_loss /= num_batchesreturn test_acc, test_loss

4、正式训练

# 4、正式训练
epochs = 10
train_loss = []
train_acc = []
test_loss = []
test_acc = []'''model.train()和model.eval() 是深度学习中常见的两个方法,它们用于设置模型的训练模式和评估模式。①当你调用model.train()时,你正在告诉模型你即将进入训练阶段。通常意味着模型中的某些层(如Dropout层和BatchNormalization层)会改变它们的行为以适应训练过程。Dropout层:在训练模式下,Dropout层会随机将一部分神经元的输出设置为0,有助于防止过拟合。BatchNormalization层:在训练模式下,BatchNoralization层会使用当前批次的数据来更新其运行均值和方差,并应用这些统计量来标准化输入。②当你调用model.eval()时,你正在告诉模型你即将进入评估或推断阶段。在这种模式下,模型的某些层会改变它们的行为,以确保在评估时模型给出一致的结果。
'''
for epoch in range(epochs):model.train()  # 进入训练阶段epoch_train_acc, epoch_train_loss = train(train_dl, model, loss_fn, opt)model.eval()epoch_test_acc, epoch_test_loss = test(test_dl, model, loss_fn)train_acc.append(epoch_train_acc)train_loss.append(epoch_train_loss)test_acc.append(epoch_test_acc)test_loss.append(epoch_test_loss)template = 'Epoch:{:2d}, Train_acc:{:.1f}%, Train_loss:{:.3f}, Test_acc:{:.1f}%, Test_loss:{:.3f}'print(template.format(epoch+1, epoch_train_acc*100,epoch_train_loss, epoch_test_acc*100, epoch_test_loss))
print('Finish')

四、模型训练结果可视化

# 四、结果可视化
warnings.filterwarnings('ignore')   # 忽略警告信息
plt.rcParams['axes.unicode_minus'] = False  # 用来正常显示负号
plt.rcParams['figure.dpi'] = 100    # 分辨率epochs_range = range(epochs)  # 生成从0到epoches-1的整数序列plt.figure(figsize=(12,3))  # figsize=(12,3)  包含两个元素的元组,分别代表图形的宽度和高度,单位是英寸。plt.subplot(1,2,1)
plt.plot(epochs_range, train_acc, label='Training Accuracy')
plt.plot(epochs_range, test_acc, label='Test Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')plt.subplot(1,2,2)
plt.plot(epochs_range, train_loss, label='Training Loss')
plt.plot(epochs_range, test_loss, label='Test Loss')
plt.legend(loc='upper right')
plt.title('Training and Validataion Loss')# 在远程服务器上面跑代码,想要保存下,plt.show()的结果,打下下面的注释
# plt.savfig('想要保存的服务器的地址+图片的名称.png/jpg自行定义即可')  
# eg:plt.savefig('/data/jupyter/deepinglearning/resultImg.jpg')plt.show()
print("画图结束。。。")

五、模型训练结果:

这周和上周的代码类似,但是,比起刚开始的时候,好多代码都清晰了很多。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/537979.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

基于java+springboot+vue实现的小区物业管理系统(文末源码+Lw+ppt)23-34

摘 要 随着互联网时代的发展,传统的线下管理技术已无法高效、便捷的管理信息。为了迎合时代需求,优化管理效率,各种各样的管理系统应运而生,在人们生活环境要求不断提高的前提下,小区物业管理系统建设也逐渐进入了…

0基础学习VR全景平台篇第145篇:图层控件功能

大家好,欢迎观看蛙色VR官方——后台使用系列课程!这期,我们将为大家介绍如何使用图层控件功能。 一.如何使用图层控件功能? 进入作品编辑页面,点击左边的控件后就可以在右边进行相应设置。 二.图层控件有哪些功能&am…

Mysql 死锁案例1-记录锁读写冲突

死锁复现 CREATE TABLE t (id int(11) NOT NULL,c int(11) DEFAULT NULL,d int(11) DEFAULT NULL,PRIMARY KEY (id),KEY c (c) ) ENGINEInnoDB DEFAULT CHARSETutf8;/*Data for the table t */insert into t(id,c,d) values (0,0,0),(5,5,5),(10,10,10) 事务1事务2T1 START…

Material UI 5 学习03-Text Field文本输入框

提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档 Text Field文本输入框 一、最基本的本文输入框1、基础示例2、一些表单属性3、验证 二、多行文本 一、最基本的本文输入框 1、基础示例 import {Box, TextField} from "…

0301taildir-source报错-flume-大数据

1 基础环境简介 linux系统:centos,前置安装:jdk、hadoop、zookeeper、kafka,版本如下 软件版本描述centos7linux系统发行版jdk1.8java开发工具集hadoop2.10.0大数据生态基础组件zookeeper3.5.7分布式应用程序协调服务kafka3.0分…

【深入理解设计模式】命令设计模式

命令设计模式: 命令模式(Command Pattern)是一种行为型设计模式,它将请求封装为一个对象,从而使你可以用不同的请求对客户端进行参数化,对请求排队或记录请求日志,以及支持可撤销的操作。 概述…

YOLOv9实例分割教程|(一)训练教程

专栏介绍:YOLOv9改进系列 | 包含深度学习最新创新,主力高效涨点!!! 一、创建数据集及数据配置文件 创新一个文件夹存放分割数据集,包含一个images和labels文件夹。标签格式如下所示: 创新数据集…

可视化Relay IR

目标 为Relay IR生成图片形式的计算图。 实现方式 使用RelayVisualizer可视化Relay,RelayVisualizer定义了一组接口(包括渲染器、解析器)将IRModule可视化为节点和边,并且提供了默认解析器和渲染器。 首先需要安装依赖&#x…

基于PHP的数字化档案管理系统

有需要请加文章底部Q哦 可远程调试 基于PHP的数字化档案管理系统 一 介绍 此数字化档案管理系统基于原生PHP,MVC架构开发,数据库mysql,前端bootstrap。系统角色分为用户和管理员。 技术栈 php(mvc)mysqlbootstrapphpstudyvscode 二 功能 …

Js输入输出语句

输入语法 prompt("您想输入的是&#xff1f;")输出语法: 语法1: document.write(‘要出的内容’&#xff09; <body><script>document.write("你好")document.write("<h1>我是<h1>")</script> </body>作…

武汉星起航:秉承客户至上服务理念,为创业者打造坚实后盾

在跨境电商的激荡浪潮中&#xff0c;武汉星起航电子商务有限公司一直秉持着以客户为中心的发展理念&#xff0c;为跨境创业者提供了独特的支持和经验积累&#xff0c;公司通过多年的探索和实践&#xff0c;成功塑造了一个以卖家需求为导向的服务平台&#xff0c;为每一位创业者…

MongoDB从0到1:高效数据使用方法

MongoDB&#xff0c;作为一种流行的NoSQL数据库。从基础的文档存储到复杂的聚合查询&#xff0c;从索引优化到数据安全都有其独特之处。文末附MongoDB常用命令大全。 目录 1. 引言 MongoDB简介 MongoDB的优势和应用场景 2. 基础篇 安装和配置MongoDB MongoDB基本概念 使…