Pytorch实战01——CIAR10数据集

目录

1、model.py文件 (预训练的模型)

2、train.py文件(会产生训练好的.th文件)

3、predict.py文件(预测文件)

4、结果展示:


1、model.py文件 (预训练的模型)

import torch.nn as nn
import torch.nn.functional as Fclass LeNet(nn.Module):def __init__(self):super(LeNet, self).__init__()# RGB图像;  这里用了16个卷积核;卷积核的尺寸为5x5的self.conv1 = nn.Conv2d(3, 16, 5)  # 输入的是RBG图片,所以in_channel为3; out_channels=卷积核个数;kernel_size:5x5的self.pool1 = nn.MaxPool2d(2, 2)  # kernal_size:2x2   stride:2self.conv2 = nn.Conv2d(16, 32, 5)  # 这里使用32个卷积核;kernal_size:5x5self.pool2 = nn.MaxPool2d(2, 2)self.fc1 = nn.Linear(32*5*5, 120)  # 全连接层的输入,是一个一维向量,所以我们要把输入的特征向量展平。# 将得到的self.poolx(x) 的output(32,5,5)展开;  图片上给的全连接层是120self.fc2 = nn.Linear(120, 84)self.fc3 = nn.Linear(84, 10)  # 这里的10,是需要根据训练集修改的def forward(self, x):   # 正向传播# Pytorch Tensor的通道排序:[channel,height,width]'''卷积后的尺寸大小计算:N = (W-F+2P)/S + 1其中,默认的padding:0   stride:1①输入图片大小:WxW②Filter大小 FxF  (卷积核大小)③步长S④padding的像素数P'''x = F.relu(self.conv1(x))   # 输入特征图为32x32大小的RGB图片;  input(3,32,32)  output(16,28,28)x = self.pool1(x)           # 经过最大下采样会将图片的高度和宽度:缩小为原来的一半  output(16,14,14)   池化层,只改变特征矩阵的高和宽;x = F.relu(self.conv2(x))   # output(32, 10, 10)  因为第二个卷积层的卷积核大小是32个,这里就是32x = self.pool2(x)           # 经过最大下采样会将图片的高度和宽度:缩小为原来的一半output(32, 5, 5)x = x.view(-1, 32*5*5)   # x.view()  将其展开成一维向量,-1代表第一个维度x = F.relu(self.fc1(x))x = F.relu(self.fc2(x))x = self.fc3(x)return x
# 测试下
# import torch
# input1 = torch.rand([32,3,32,32])
# model = LeNet()
# print(model)
# output = model(input1)

2、train.py文件(会产生训练好的.th文件)

import matplotlib.pyplot as plt
import numpy as np
import torch.utils.data
import torchvision
from torch import nn, optim
from torchvision import transformsfrom pilipala_pytorch.pytorch_learning.Test1_pytorch_demo.model import LeNet# 1、下载数据集
# 图形预处理 ;其中transforms.Compose()是用来组合多个图像转换操作的,使得这些操作可以顺序地应用于图像。
transform = transforms.Compose([transforms.ToTensor(),   # 将PIL图像或ndarray转换为torch.Tensor,并将像素值的范围从[0,255]缩放到[0.0, 1.0]transforms.Normalize((0.5, 0.5, 0.5),(0.5, 0.5, 0.5))]   # 对图像进行标准化;标准化通常用于使模型的训练更加稳定。
)
# 50000张训练图片
train_ds = torchvision.datasets.CIFAR10('data',train=True,transform=transform,download=False)
# 10000张测试图片
test_ds = torchvision.datasets.CIFAR10('data',train=False,transform=transform,download=False)
# 2、加载数据集
train_dl = torch.utils.data.DataLoader(train_ds, batch_size=36, shuffle=True, num_workers=0)    # shuffle数据是否是随机提取的,一般设置为True
test_dl = torch.utils.data.DataLoader(test_ds, batch_size=10000, shuffle=True, num_workers=0)test_image,test_label = next(iter(test_dl))  # 将test_dl 转换为一个可迭代的迭代器,通过next()方法获取数据classes = ('plane','car','bird','cat','deer','dog','frog','horse','ship','truck')'''标准化处理:output = (input - 0.5) / 0.5反标准化处理: input = output * 0.5 + 0.5 = output / 2 + 0.5
'''
# 测试下展示图片
# def imshow(img):
#     img = img / 2 + 0.5   # unnormalize  反标准化处理
#     npimg = img.numpy()
#     plt.imshow(np.transpose(npimg, (1,2,0)))
#     plt.show()
#
# # 打印标签
# print(' '.join('%5s' % classes[labels[j]] for j in range(4)))
# imshow(torchvision.utils.make_grid(test_image))# 实例化网络模型
net = LeNet()
# 定义相关参数
loss_function = nn.CrossEntropyLoss()  # 定义损失函数
optimizer = optim.Adam(net.parameters(), lr=0.001)  # 定义优化器, 这里使用的是Adam优化器
# 训练过程
for epoch in range(5):  # 定义循环,将训练集迭代多少轮running_loss = 0.0  # 叠加,训练过程中的损失for step,data in enumerate(train_dl,start=0):  # 遍历训练集样本inputs, labels = data   # 获取图像及其对应的标签optimizer.zero_grad()  # 将历史梯度清零;如果不清除历史梯度,就会对计算的历史梯度进行累加outputs = net(inputs)   # 将输入的图片输入到网络,进行正向传播loss = loss_function(outputs, labels)  # outputs网络预测的值, labels真实标签loss.backward()optimizer.step()running_loss += loss.item()if step % 500 == 499:with torch.no_grad():  # with 是一个上下文管理器outputs = net(test_image)  # [batch,10]predict_y = torch.max(outputs, dim=1)[1]   # 网络预测最大的那个accuracy = (predict_y == test_label).sum().item() / test_label.size(0)  # 得到的是tensor  (predict_y == test_label).sum()  要通过item()拿到数值print("[%d, %5d] train_loss: %.3f test_accuracy:%.3f" % (epoch + 1, step + 1, running_loss / 500, accuracy))running_loss = 0.0
print('Finished Training')save_path = './Lenet.pth'  # 保存模型
torch.save(net.state_dict(), save_path)  # net.state_dict() 模型字典;save_path 模型路径

3、predict.py文件(预测文件)

import torch
import torchvision.transforms as transforms
from PIL import Image
from model import LeNettransform = transforms.Compose([transforms.Resize((32, 32)),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship' , 'truck')net = LeNet()
net.load_state_dict(torch.load('Lenet.pth'))  # 加载train里面的训练好 产生的模型。im = Image.open('2.jpg')  # 载入准备好的图片
im = transform(im)  # 如果要将图片放入网络,进行正向传播,就得转换下格式   得到的结果为:[C,H,W]
im = torch.unsqueeze(im, dim=0)    # 增加一个维度;得到 [N,C,H,W],从而模拟一个批量大小为1的输入。with torch.no_grad():  # 不需要计算损失梯度outputs = net(im)predict = torch.max(outputs, dim=1)[1].data.numpy()   # outputs是一个张量;torch.max()用于找到张量在指定维度上的最大值;# torch.max()函数返回两个张量,一个包含最大值,另一个包含最大值的作用。# .data()属性用于从变量中提取底层的张量数据。直接使用.data()已经被认为是不安全的,推荐使用.detach()# .numpy() 表示将pytorch转换成numpy数组,从而使用numpy库的各种功能来操作数据。
print(classes[int(predict)])#     predict = torch.softmax(outputs,dim=1)  # 可以返回概率
# print(predict)

4、结果展示:

返回结果:预测是猫的概率为 86%。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/535601.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

分布式搜索elasticsearch(1)

1.初识elasticsearch 1.1.了解ES 1.1.1.elasticsearch的作用 elasticsearch是一款非常强大的开源搜索引擎,具备非常多强大功能,可以帮助我们从海量数据中快速找到需要的内容 例如: 在GitHub搜索代码 在电商网站搜索商品 在百度搜索答案…

识别恶意IP地址的有效方法

在互联网的环境中,恶意IP地址可能会对网络安全造成严重威胁,例如发起网络攻击、传播恶意软件等。因此,识别恶意IP地址是保护网络安全的重要一环。IP数据云将探讨一些有效的方法来识别恶意IP地址。 IP地址查询:https://www.ipdata…

RPC基础知识回顾

RPC基础知识回顾 1、先认识一下大家熟悉的HTTP 大家都了解HTTP吧。相信项目中也用过一些。 比如: JDK自带的老旧的HttpURLConnection,封装写的很累,java8之前基于HTTP1.0。在java9开始支持Http2.0Spring的其中RestTemplate都是基于HTTP/1.1的请求。最新的还有Sp…

OceanBase中binlog service 功能的试用

OBLogProxy简介 OBLogProxy即OceanBase的增量日志代理服务,它可与OceanBase建立连接并读取增量日志,从而为下游服务提供了变更数据捕获(CDC)的功能。 关于OBLogProxy的详尽介绍与具体的安装指引,您可以参考这篇官方OB…

离散系统描述模型及其转换

离散系统描述模型及其转换 常用离散模型 \textbf{常用离散模型} 常用离散模型系统传递函数模型 t f \bf{tf} tf零-极点增益模型 z p k \bf{zpk} zpk极点留数模型 r p k \bf{rpk} rpk二次分式模型 s o s \bf{sos} sos状态变量模型 s s \bf{ss} ss 例题 常用离散模型 \textbf{常用…

Arrays.asList转换为List集合后使用add方法抛出UnsupportedOperationException

问题场景: 将String[] 数组转为 List集合,后对list集合进行添加删除报UnsupportedOperationException 百度原因: Arrays.asList返回的集合不支持元素的添加和删除(不支持add、addAll、remove方法),否则抛出…

nginx实时流量拷贝ngx_http_mirror_module

参考: Module ngx_http_mirror_module Nginx流量拷贝ngx_http_mirror_module模块使用方法详解 ngx_http_mirror_module用于实时流量拷贝 请求一个接口,想实时拷贝这个请求转发到自己的服务上,可以使用ngx_http_mirror_module模块。 官网好像…

Java Web实战(四)Web后端之MyBatis-基础用法详解

文章目录 1. 使用MyBatis1-1. JDBC介绍1-2. 数据库连接池1-3. Lombook 2. mybatis 基础2-2. CURD操作2-2-1. delete 操作2-2-2. 预编译sql2-2-3. 插入语句2-2-4. XML-SQL2-2-5. insert主键回显 2-3. 查询语句 3. 动态SQL3-1. <if> MyBatis是一款优秀的 持久层 框架&#…

STM32点亮LED灯与蜂鸣器发声

STM32之GPIO GPIO在输出模式时可以控制端口输出高低电平&#xff0c;用以驱动Led蜂鸣器等外设&#xff0c;以及模拟通信协议输出时序等。 输入模式时可以读取端口的高低电平或电压&#xff0c;用于读取按键输入&#xff0c;外接模块电平信号输入&#xff0c;ADC电压采集灯 GP…

前端的数据标记协议

文章目录 数据标记协议是什么数据标记协议的作用常见的数据标记协议Open Graph protocol 开放图谱协议基本元数据协议可选元数据结构化属性 —— 元数据的属性多个相同的元数据标签类型元数据的使用方法全局类型使用自定义类型使用对象类型使用歌曲对象类型视频对象类型文章对象…

C++学习路线

C学习路线思维导图&#xff0c;肝了一个星期终于搞定&#xff0c;这么硬核求个赞不过分吧&#xff1f; 思维导图的内容&#xff0c;也是本文的内容框架&#xff0c;坐稳扶好&#xff0c; C 高速快车要发车了&#xff01; 内容我会持续更新&#xff0c;点赞收藏&#xff0c;…

用户视角的比特币和以太坊外围技术整理

1. 引言 要点&#xff1a; 比特币L2基本强调交易内容的隐蔽性&#xff0c;P2P交易&#xff08;尤其是支付&#xff09;成为主流&#xff0c;给用户带来一定负担&#xff08;闪电网络&#xff09;在以太坊 L2 中&#xff0c;一定程度上减少了交易的隐蔽性&#xff0c;主流是实…