Pytorch-ResNet50-MINIST Classify 网络实现流程

分两个文件讲解:1、train.py训练文件     2、test.py测试文件.

1、train.py训练文件

1)从主函数入口开始,设置相关参数

# 主函数入口
if __name__ == '__main__':# ----------------------------##   是否使用Cuda#   没有GPU可以设置成Fasle# ----------------------------#cuda = True# ----------------------------##   是否使用预训练模型# ----------------------------#pre_train = True# ----------------------------##   是否使用余弦退火学习率# ----------------------------#CosineLR = True# ----------------------------##   超参数设置#   lr:学习率#   Batch_size:batchsize大小# ----------------------------#lr = 1e-3Batch_size = 2Init_Epoch = 0Fin_Epoch = 100

 2)创建模型

# 创建模型
model = ResNet(Bottleneck, [3, 4, 6, 3], num_classes=10)#判断是否需要预训练模型,在1)已经设置pre_train=True,这里会加载预训练模型,
#为"logs/resnet50-mnist.pth"。
#这里加载的是预训练模型的权重参数,实例化到本地模型ResNet上
if pre_train:model_path = 'logs/resnet50-mnist.pth'model.load_state_dict(torch.load(model_path))#判断cuda是否可用,如果cuda可用,模型将调用GPU,否则将调用CPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)

3)创建数据集

# ----------------------------#
root='data/' :路径
train=True   :训练设置为True
transform=transforms.ToTensor() :转化成Tensor
download=True :下载
# ----------------------------#
train_dataset = datasets.MNIST(root='data/', train=True,transform=transforms.ToTensor(), download=True)
#这里train = False, download=False,此时下载验证集
test_dataset = datasets.MNIST(root='data/', train=False,transform=transforms.ToTensor(), download=False)

4)加载数据集

# ----------------------------#
#DataLoader加载数据集
batch_size=Batch_size 批量输入
shuffle=True 打乱数据
num_workers=0 单个工作进程
# ----------------------------#
gen = DataLoader(dataset=train_dataset, batch_size=Batch_size, shuffle=True, num_workers=0)
gen_test = DataLoader(dataset=test_dataset, batch_size=Batch_size // 2, shuffle=True, num_workers=0)

5)设置损失函数和优化器

#损失函数为交叉熵损失
softmax_loss = torch.nn.CrossEntropyLoss()
#优化器选择Adams
optimizer = torch.optim.Adam(model.parameters(), lr)

6)设置学习率

#如果CosineLR = True,学习率为CosineAnnealingLR,否则为StepLR
if CosineLR:lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=5, eta_min=1e-10)
else:lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.92)

7)训练

# ----------------------------#
epoch_size 训练集一次加载多少个batch
epoch_size_val 验证集一次加载多少个batch
# ----------------------------#
epoch_size = len(gen) 
epoch_size_val = len(gen_test)# ----------------------------#
Init_Epoch 起始训练为0
Fin_Epoch  终止训练为100次
fit_one_epoch()函数进行训练数据
lr_scheduler.step()一次训练结束后,学习率进行更新
# ----------------------------#
for epoch in range(Init_Epoch, Fin_Epoch):fit_one_epoch(net=model, softmaxloss=softmax_loss, epoch=epoch, epoch_size=epoch_size,epoch_size_val=epoch_size_val, gen=gen, gen_test=gen_test, Epoch=Fin_Epoch, cuda=cuda)lr_scheduler.step()

2、test.py测试文件

展示运行结果

1)整段讲解

 

import torch
from nets.resnet50 import ResNet,Bottleneck
import os
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torch.autograd import Variable
import torchvision
import cv2
import time# 设置权重文件路径
PATH = './logs/resnet50-mnist.pth'
# 谁知手动输入单次识别字数
Batch_Size = int(input('每次预测手写字体图片个数:'))
# 加载模型
model = ResNet(Bottleneck, [3, 4, 6, 3], num_classes=10)
model.load_state_dict(torch.load(PATH))
model = model.cuda()# 进入测试程序
model.eval()
# 设置测试数据集并加载
test_dataset = datasets.MNIST(root='data/', train=False,transform=transforms.ToTensor(), download=False)
gen_test = DataLoader(dataset=test_dataset, batch_size=Batch_Size, shuffle=True)# 进入循环
while True:# 获取图片和标签images, lables = next(iter(gen_test))img = torchvision.utils.make_grid(images, nrow=Batch_Size)img_array = img.numpy().transpose(1, 2, 0)# 获取开始时间start_time = time.time()# 输出预测结果outputs = model(images.cuda())_, id = torch.max(outputs.data, 1)end_time = time.time()# 打印用时和预测结果,由于输出的id为tensor,这里必须转换为numpyprint('预测用时:', end_time-start_time)print('预测结果为', id.data.cpu().numpy())# 展示图片cv2.imshow('img', img_array)cv2.waitKey(0)

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/15013.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

SPEC CPU 2006 在 CentOS 5.0 x86_64 古老系统测试【2】

上一篇 SPEC CPU 2006 在 CentOS 5.0 x86_64 古老系统测试_hkNaruto的博客-CSDN博客 虚拟机时间,一天后获得结果 由于ssh版本太低,采用nc把文件拷贝出来 结果 SPEC CFP2006 Result Copyright 2006-2023 Standard Performance Evaluation Corporatio…

SpringBoot集成Quartz集群模式

<!-- quartz定时任务 --><dependency><groupId>org.springframework.boot</groupId><artifactId>spring-boot-starter-quartz</artifactId></dependency> 单机版本&#xff1a; SpringBoot集成Quartz动态定时任务_jobgroupname_小…

09_Linux内核定时器

目录 Linux时间管理和内核定时器简介 内核定时器简介 Linux内核短延时函数 定时器驱动程序编写 编写测试APP 运行测试 Linux时间管理和内核定时器简介 学习过UCOS或FreeRTOS的同学应该知道, UCOS或FreeRTOS是需要一个硬件定时器提供系统时钟,一般使用Systick作为系统时钟…

人工智能(pytorch)搭建模型17-pytorch搭建ReitnNet模型,加载数据进行模型训练与预测

大家好&#xff0c;我是微学AI&#xff0c;今天给大家介绍一下人工智能(pytorch)搭建模型17-pytorch搭建ReitnNet模型&#xff0c;加载数据进行模型训练与预测&#xff0c;RetinaNet 是一种用于目标检测任务的深度学习模型&#xff0c;旨在解决目标检测中存在的困难样本和不平衡…

Maven学习

1.配置环境变量 1.M2_HOME Maven的安装目录 2.修改Path %M2_HOME%\bin2.配置IDEA 配置文件的地址 本地仓库的地址 修改配置文件的路径 修改本地仓库的目录 注意&#xff0c;这里的路径的分隔符必须是/ 配置镜像 <mirror><id>aliyunmaven</id><mi…

在idea中使用Git技术

1.配置git环境 打开idea,点击file->setting->搜索git&#xff0c; 将git的安装路径填写进去 2.去gitee创建一个远程仓库 3.拉入一个.gitignore文件&#xff0c;过滤掉不需要管理的文件 4.在idea进行如下操作 5.选择要提交的内容 目前只是保存在了本地仓库 6.推送到远端…

SEGA: Semantic Guided Attention on Visual Prototype for Few-Shot Learning

方法比较简单&#xff0c;利用语义改进prototype&#xff0c;能促进性提升

十二、Docker Compose 介绍与安装

学习参考&#xff1a;尚硅谷Docker实战教程、Docker官网、其他优秀博客(参考过的在文章最后列出) 目录 前言一、docker compose介绍二、docker compose能干嘛三、docker compose安装与卸载3.1 docker-compose安装3.2 docker-compose卸载 总结 前言 在使用k8s之前&#xff0c;随…

LNMP架构及部署、skyuc电影网站部署

目录 一、安装nginx 1、关闭防火墙 2.创建管理nginx用户 3.配置nginx 4.命令优化 5.创建nginx脚本 二、安装mysql数据库 三、安装PHP 1.上传php安装包 2.上传 zend-loader-hph5.6 3.创建用户 四、LNMP平台中部署skyuc电影网站 1.解压 SKYUC.v3.4.2.srouce 2.创建数据…

光场1.0——非聚焦型光场相机

本文概要 本文讲主要从光场硬件结构设计以及软件处理方式的层面来介绍一下光场的相关内容&#xff0c;关于光场的优势和具体应用点并不在本文的主要范围内。 光场1.0 1. 结构原理说明 首先来介绍一下光场相机&#xff0c;那么什么是光场相机呢&#xff0c;光场相机经历了两…

SPEC CPU 2006 在 CentOS 5.0 x86_64 古老系统测试

下载镜像 CentOS 2 3 4 5 6 等历史老版本下载地址 国内镜像地址_hkNaruto的博客-CSDN博客 下载CentOS 5.0 1-7 ISO文件 注意&#xff1a;尝试过下载DVD版本&#xff0c;速度太慢了。还是通过国内镜像下载这几个iso快。 安装虚拟机 VirtualBox 挂载第一个iso&#xff0c;启动…

突破数据边界,开启探索之旅!隐语开源Meetup一周年专场7月22日上海见

小伙伴们&#xff0c;&#x1f4e2;「隐语开源一周年 Meetup 」即将来袭&#xff01;&#x1f389;在一周年 Meetup 上&#xff0c;不仅会对隐语 1.0 版本进行详解&#xff0c;还有新鲜出炉的隐语 MVP 部署体验包&#xff0c;让你秒变高手&#xff01;更有机会与隐私计算行业的…