PyTorch搭建LeNet训练集详细实现

一、下载训练集

导包

import torch
import torchvision
import torch.nn as nn
from model import LeNet
import torch.optim as optim
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np

 ToTensor()函数:

把图像[heigh x width x channels] 转换为 [channels x height x width]

Normalize() 数据标准化函数:

最后一行是标准化数值计算公式

transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])# 50000张训练图片
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,download=True, transform=transform)

参数解释: 

root='./data':数据集下载的路径,我下载到当前目录下的data文件夹,下载完成后会自动创建 

train=True:当前为训练集

download=True:下载数据集时设置为True,下载完成后改为False

transform=transform :设置对图像进行预处理的函数

运行下载数据集结果为: 

 下载完成后生成了data文件夹

二、导入训练集 

# 导入训练集
trainloader = torch.utils.data.DataLoader(trainset, batch_size=36,shuffle=True, num_workers=0)

参数解释: 

        trainset:把刚刚下载的数据导入进来

        batch_size=36:一批数据的大小

        shuffle=True:训练集中的数据是否打乱(一般默认打乱)

        num_workers=0:载入数据的现成数,在lunix操作系统下,可以设置为别的参数,在windows操作系统系统下,默认为0.

三、下载测试集

# 10000张测试图片
testset = torchvision.datasets.CIFAR10(root='./data', train=False,download=False, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=10000,shuffle=False, num_workers=0)
test_data_iter = iter(testloader)
test_image, test_lable = test_data_iter.next()classes = ('plane', 'car', 'bird', 'cat',   # 数据集中的分类,设置为元组,不可变类'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

参数解释:

test_data_iter = iter(testloader):通过iter()函数把testloader转化成可迭代的迭代器
test_image, test_lable = test_data_iter.next():通过next()方法可以获得测试的图像和图像对应的标签值。

 四、查看导入的图片

在中间过程打印图片进行查看,后续会注释掉

def imshow(img):img = img / 2 + 0.5nping = img.numpy()plt.imshow(np.transpose(nping, (1, 2, 0)))plt.show()# print labels
print(' '.join('%5s' % classes[test_lable[j]] for j in range(4)))
# show images
imshow(torchvision.utils.make_grid(test_image))

运行结果

 图片很模糊,因为像素很低。

上面识别出来的结果都对了。

 我遇到的问题:
一开始有结果但是没有图片,我以为时matplotlib的问题,我重新安装并且更新了版本,但是我再运行后报错更多了,报错提示我 AttributeError: module 'numpy' has no attribute 'bool',我就知道是numpy的问题了,我重新安装并且更新了版本结果还是不行,我百度了一下,发现不是越新的版本越好,我重新下载了1.23.2这个版本的numpy,下载完成后运行就出来结果了。

pip install numpy==1.23.2

这个也只是中间过程,后续会注释或者删了。


五、将创建的模型实例化

创建模型请看PyTorch搭建LeNet神经网络-CSDN博客

# 将创建的模型实例化
net = LeNet()  # 实例化
loss_fuction = nn.CrossEntropyLoss()  # 定义损失函数# 通过优化器将所有可训练的参数都进行训练,lr是learningrate学习率
optimizer = optim.Adam(net.parameters(), lr=0.001)#通过for循环实现训练过程,循环几次就是将训练集迭代多少次
for epoch in range(5):running_loss = 0.0  # 用来累加在学习过程中的损失for step, data in enumerate(trainloader, start=0):# get the inputs; data is a list of [inputs, labels]inputs, labels = data# zero the parameter gradientsoptimizer.zero_grad()   # 历时损失梯度清零。# forward + backward + optimizeoutputs = net(inputs)loss = loss_fuction(outputs, labels)  # 计算神经网络的预测值和真实标签之间的损失loss.backward()optimizer.step()  # step()函数实现参数更新# print statistics  打印数据的过程running_loss += loss.item()if step % 500 == 499:  # 每隔500步打印一次数据的信息with torch.no_grad():  # 上下文管理器outputs = net(test_image)predict_y = torch.max(outputs, dim=1)[1]accuracy = (predict_y == test_lable).sum().item() / test_lable.size(0)print('[%d, %5d] train_loss: %.3f test_accuracy: %.3f' %(epoch + 1, step + 1, running_loss / 500, accuracy))running_loss = 0.0print('Finished Training')# 将模型保存到文件夹中
save_path = './Lenet.pth'
torch.save(net.state_dict(), save_path)

 详细解释:

比较重点的单独解释了,其他的在注释中。

 optimizer.zero_grad()   # 历时损失梯度清零。

 ? 为什么每计算一个batch,就要调用一次 optimizer.zero_grad()函数

=> 通过清楚历史梯度,就会对计算的历史梯度进行累加。通过这个特性,能变相的实现一个很大的batch数值的训练(因为batch数值越大,训练效果越好)

 with torch.no_grad():  # 上下文管理器

 上下文管理器: 在接下来的计算过程中,不再去计算每个节点的误差损失梯度。

如果不调用这个函数,将会在测试过程中占用更多的算力,消耗更多的资源和占用更多的内存资源,导致内存容易崩。

print函数中打印参数解释:

print('[%d, %5d] train_loss: %.3f test_accuracy: %.3f' %(epoch + 1, step + 1, running_loss / 500, accuracy))

epoch + 1:迭代到第几轮了

step + 1:某一轮的第几步

running_loss / 500:训练过程中500步平均训练误差

accuracy:准确率

运行结果

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/526706.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【脚本玩漆黑的魅影】全自动刷努力值

文章目录 原理全部代码 原理 全自动练级,只不过把回城治疗改成吃红苹果。 吃一个可以打十下,背包留10个基本就练满了。 吃完会自动停止。 if img.getpixel(data_attack[0]) data_attack[1] or img.getpixel(data_attack_2[0]) data_attack_2[1]: # …

RESTful API关键部分组成和构建web应用程序步骤

RESTful API是一种基于HTTP协议的、符合REST原则的应用程序接口。REST(Representational State Transfer)是一种软件架构风格,用于设计网络应用程序的通信模式。 一个RESTful API由以下几个关键部分组成: 资源(Resour…

关于天线综合4(伍德沃德——罗森取样法)

伍德沃德——罗森取样法 就是在各个点指定方向图的值,对其方向图取样 主要就是将线源电流分布分解成一组等幅度、线性相位的源的和 求出对应电流分量方向图 中心位于wwn 最大值为an, 其中wn控制该分量方向图最大值的位置,an控制分量方向图的幅…

腾讯云8核16G服务器性能怎么样?能支持多少人访问?

腾讯云8核16G轻量服务器CPU性能如何?18M带宽支持多少人在线?轻量应用服务器具有100%CPU性能,18M带宽下载速度2304KB/秒,折合2.25M/s,系统盘为270GB SSD盘,月流量3500GB,折合每天116.6GB流量&…

从 iPhone 设备恢复误删微信消息的 4 种方法

您的微信消息可能会因无意删除、系统崩溃、卸载微信应用或升级过程失败而被删除。如果您遇到这种情况,您不必担心,因为您可以采取某些步骤来恢复丢失的微信历史记录。这里有 4 种方法可以帮助您从 iPhone恢复丢失的微信消息、群聊历史记录或微信联系人。…

直击现场 | 人大金仓携手中国大地保险上线核心超A系统

2023年底 中国大地保险 卡园三路59号办公室里 一群技术精英们正忙碌着 他们的眼中 闪烁着对即将到来的胜利的期待 这是大地保险超A系统 项目上线的关键时刻 也是通过科技创新 引领行业服务新趋势的一场征程 项目现场 #1 一次颠覆 改变传统保险服务模式 超A平台,是由…

kibana配置 dashbord,做可视化展示

一、环境介绍 这里我使用的kibana版本为7.17版本。 语言选择为中文。 需要已经有es,已经有kibana,并且都能正常访问。 二、背景介绍 kibana的可视化界面,可以配置很多监控统计界面。非常方便,做数据的可视化展示。 这篇文章&…

Java核心技术第十二章 并发

多进程和多线程的区别:每个进程拥有组件的一整套变量,线程则共享数据,一个程序可以同时运行多个线程,则为多线程程序。 什么是线程 线程状态 1. 新建线程 2.可运行线程 调用start方法,线程处于可运行状态&#xff0c…

汽车协议学习

ⅠOBD 1.OBD接口 OBD有16个引脚,每个引脚的电压不同(可以对应不同的协议) 车端: 16- 9 (短一点点的) 8-1 (长一点的) 2.基于OBDⅡ的通信协议 CAN (ISO-15765&am…

NPP VIIRS卫星数据介绍及获取

VIIRS(Visible infrared Imaging Radiometer)可见光红外成像辐射仪。扫描式成像辐射仪,可收集陆地、大气、冰层和海洋在可见光和红外波段的辐射图像。它是高分辨率辐射仪AVHRR和地球观测系列中分辨率成像光谱仪MODIS系列的拓展和改进。VIIRS数…

代码随想录刷题笔记-Day33

1. 跳跃游戏 55. 跳跃游戏https://leetcode.cn/problems/jump-game/ 给你一个非负整数数组 nums ,你最初位于数组的 第一个下标 。数组中的每个元素代表你在该位置可以跳跃的最大长度。 判断你是否能够到达最后一个下标,如果可以,返回 tru…

动态规划(算法竞赛、蓝桥杯)--数位DP度的数量

1、B站视频链接&#xff1a;E38 数位DP 度的数量_哔哩哔哩_bilibili #include <bits/stdc.h> using namespace std; const int N34; int a[N];//把B进制数的每一位抠出存入数组 int f[N][N];//f[i][j]表示在i个位置上&#xff0c;放置j个1的组合数 int K,B;void init(…