pytorch -- CIFAR10 完整的模型训练套路

  1. 网络结构
    在这里插入图片描述
  2. 代码
# CIFAR 10
'''
完整的模型训练套路:'''
import torch.optim
import torchvision
from torch import nn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriterfrom model import *# 1. 准备数据集
train_data = torchvision.datasets.CIFAR10('data',train=True,transform=torchvision.transforms.ToTensor(),download=True)
test_data = torchvision.datasets.CIFAR10('data',train=False,transform=torchvision.transforms.ToTensor(),download=True)
# 数据集大小
train_data_size = len(train_data)
test_data_size = len(test_data)
print('训练数据集的长度为{}'.format(train_data_size))
print('测试数据集的长度为{}'.format(test_data_size))# 2 利用DataLoader加载数据集
train_dataloader = DataLoader(train_data,batch_size=64)
test_dataloader = DataLoader(test_data,batch_size=64)# 3 搭建神经网络
# 4 创建网络模型
tudui = Tudui()# 5 损失函数
loss_fn = nn.CrossEntropyLoss()# 6 优化器 1e-2=1x10^(-2)
learning_rate = 0.01
optimizer = torch.optim.SGD(tudui.parameters(),lr=learning_rate)# 7 设置训练网络的一些参数
total_train_step = 0 # 记录训练次数
total_test_step = 0 # 记录测试次数
epoch = 10 #训练轮数
# 添加tensorboard
writer = SummaryWriter('logs_model')for i in range(epoch):print('-----------第{}轮训练开始-----------'.format(i+1))# 训练开始# 训练步骤开始 dropout batchNorm仅对某些层次有作用tudui.train()for data in train_dataloader:imgs, targets = dataoutput = tudui(imgs) #训练模型的预测输出loss = loss_fn(output,targets)# 优化器优化模型optimizer.zero_grad()loss.backward()optimizer.step()total_train_step += 1if total_train_step % 100 == 0:print('训练次数是{}时,loss是{}'.format(total_train_step,loss.item()))# 加了item() tensor变成了数字writer.add_scalar('train_loss',loss.item(),total_train_step)# 训练完一轮,看是否训练好,有没有达到想要的需求,测试数据集中跑一篇看准确率或者损失# 测试步骤开始tudui.eval()total_test_loss = 0total_accuracy = 0# 测试不需要对梯度进行调整with torch.no_grad():for data in test_dataloader:imgs,targets = dataoutputs = tudui(imgs)loss = loss_fn(outputs,targets)total_test_loss += loss.item()# accuracy 正确预测的样本数量accuracy = (outputs.argmax(1) == targets).sum()total_accuracy += accuracyprint('整体测试集上的loss是{}'.format(total_test_loss))print('整体测试集上的正确率是{}'.format(total_accuracy/test_data_size))writer.add_scalar('test_loss',total_test_loss,total_test_step)writer.add_scalar('test_accuracy', total_accuracy, total_test_step)total_test_step+=1torch.save(tudui,'tudui_{}.pth'.format(i))print('模型已保存')writer.close()
# model.py
import torch
from torch import nn# 3 搭建神经网络
class Tudui(nn.Module):def __init__(self):super().__init__()self.model = nn.Sequential(nn.Conv2d(3,32,5,1,2),nn.MaxPool2d(2),nn.Conv2d(32, 32, 5, 1, 2),nn.MaxPool2d(2),nn.Conv2d(32, 64, 5, 1, 2),nn.MaxPool2d(2),nn.Flatten(),nn.Linear(1024,64),nn.Linear(64, 10))def forward(self,x):x = self.model(x)return xif __name__ == '__main__':tudui = Tudui()# 验证一下输入输出尺寸input = torch.ones((64,3,32,32))output = tudui(input)print(output.shape)

运行结果:
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/492707.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Nest创建神经元,并显示电压变化曲线

nest 安装与介绍 NEST(神经模拟工具)最初是在 1990 年代后期开发的。它的主要目标是作为计算神经科学模拟器。它支持具有不同生物学细节水平的各种神经元和突触模型。例如,NEST 的神经元模型范围从泄漏积分和激发模型到详细的 Hodgkin-Huxle…

java基于redis实现分布式锁

文章目录 前言一、redis二、Redisson1.引入库2. 分布式锁3. 锁自动续期 总结 前言 上篇文章介绍了Java中锁的应用,在SpringBoot单体应用中完全够用,但是SpringCloud微服务集群中就力所不及了。 我的使用场景是某些微服务应用中使用spring注解的形式来完成定时任务的功能,服务集…

在使用nginx的时候快速测试配置文件,并重新启动

小技巧 Nginx修改配置文件后需要重新启动,常规操作是启动在任务管理器中关闭程序然后再次双击nginx.exe启动,但是使用命令行就可以快速的完成操作。 将cmd路径切换到nginx的安装路径 修改完成配置文件后 使用 nginx -t校验nginx 的配置文件是否出错 …

C# OpenVINO PaddleSeg实时人像抠图PP-MattingV2

目录 效果 项目 代码 下载 C# OpenVINO 百度PaddleSeg实时人像抠图PP-MattingV2 效果 项目 代码 using OpenCvSharp; using Sdcb.OpenVINO; using System; using System.Diagnostics; using System.Drawing; using System.Security.Cryptography; using System.Text; us…

医院管理系统小程序

**🍅点赞收藏关注 → 私信领取本源代码、数据库🍅 本人在Java毕业设计领域有多年的经验,陆续会更新更多优质的Java实战项目希望你能有所收获,少走一些弯路。🍅关注我不迷路🍅**一 、设计说明 1.1 研究背景…

智慧公厕是什么?智慧公厕意义何在

随着城市化进程的加速,公厕成为城市管理中不容忽视的一环。智慧公厕传统的公厕管理方式已经无法满足当今社会的需求,因此智慧公厕的出现成为解决问题的利器。什么是智慧公厕?智慧公厕是实现公共厕所信息化、数字化、智慧化全方位管理与服务的…

Vue 卸载eslint

卸载依赖 npm uninstall eslint --save 然后 进入package.json中,删除残留信息。 否则在执行卸载后,运行会报错。 之后再起项目。

gitlab上传代码

1、先在gitlab上新建一个project 2、创建一个新的project 3、自定义项目名称和组名:无组名可新创建一个组 打开git命令行工具,执行如下命令。 git init 这将会创建一个新的空白的Git仓库。 git add . 当前目录及子目录下所有修改过的文件都被添加到Gi…

前端工程化面试题 | 17.精选前端工程化高频面试题

🤍 前端开发工程师、技术日更博主、已过CET6 🍨 阿珊和她的猫_CSDN博客专家、23年度博客之星前端领域TOP1 🕠 牛客高级专题作者、打造专栏《前端面试必备》 、《2024面试高频手撕题》 🍚 蓝桥云课签约作者、上架课程《Vue.js 和 E…

使用GPTQ进行4位LLM量化

使用GPTQ进行4位LLM量化 最佳脑量化GPTQ算法步骤1:任意顺序洞察步骤2:延迟批量更新第三步:乔尔斯基重塑 用AutoGPTQ量化LLM结论References 权重量化的最新进展使我们能够在消费级硬件上运行大量大型语言模型,例如在RTX 3090 GPU上运行LLaMA-30B模型。这要归功于性能…

分布式事务,zookeeper,dubbo,rocketmq

1.1 什么是CAP理论 CAP理论是分布式领域中非常重要的一个指导理论,C(Consistency)表示强一致性,A(Availability)表示可用性,P(Partition Tolerance)表示分区容错…

Kubernetes部署及运用

Kubernetes 1. Kubernetes介绍 1.1 应用部署方式演变 在部署应用程序的方式上,主要经历了三个时代: 传统部署:互联网早期,会直接将应用程序部署在物理机上 优点:简单,不需要其它技术的参与 缺点&#xf…