完整模型的训练套路

从心所欲

不逾矩

天大地大

皆可去

一、官方模型的初使用

使用VGG16模型

 VGG模型使用代码示例:

import torchvision.models
from torch import nndataset = torchvision.datasets.CIFAR10('/cifar10', False, transform=torchvision.transforms.ToTensor())vgg16_true = torchvision.models.vgg16(pretrained=True)
vgg16_false = torchvision.models.vgg16(pretrained=False)
print(vgg16_false)# 改造VGG,增加一层
vgg16_true.classifier.add_module('add_linear', nn.Linear(1000, 10))
print(vgg16_true)# 改造vgg,修改一层
vgg16_false.classifier[6] = nn.Linear(4096, 10)
print(vgg16_false)

说明:

  1. pretrained=True:当设置为True时,模型将加载在大规模图像数据集(如ImageNet)上预训练的权重。这些预训练的权重经过了在大量图像上的训练,可以捕捉到通用的图像特征。通过加载预训练权重,可以将VGG模型初始化为在ImageNet上训练得到的状态,并且这些权重可以作为初始参数用于特定任务的微调或迁移学习。

  2. pretrained=False:当设置为False时,模型将使用随机初始化的权重。这意味着模型的权重没有经过预训练,需要从头开始进行训练。在这种情况下,模型将不会具备捕捉通用图像特征的能力,而是需要根据特定任务的数据进行训练。

pretrained=Truepretrained=False区别在于是否加载预训练的权重。如果你想要在特定任务上使用VGG模型,并且你的任务与图像分类或特征提取相关,那么通常建议使用pretrained=True,以便利用预训练权重的优势。如果你的任务与图像分类或特征提取无关,或者你希望从头开始训练模型以适应特定数据集,那么可以选择pretrained=False

二、模型的保存与加载

模型的保存:

两种保存模式,官方推荐第二种,只保存参数,不保存模型

import torch
import torchvision.modelsvgg16 = torchvision.models.vgg16(pretrained=False)# 保存方式1: 既保存模型结构,也保存参数
torch.save(vgg16, 'vgg16_model1.pth')# 保存方式2:把参数保存成字典,不保存结构(官方推荐)
torch.save(vgg16.state_dict(), 'vgg16_model2.pth')print("end")

模型的加载:
 

import torch
import torchvision.models# 加载方式1 - 保存方式1
model = torch.load('vgg16_model1.pth')# 加载方式2 - 保存方式2
vgg16 = torchvision.models.vgg16(pretrained=False)
vgg16.load_state_dict(torch.load('vgg16_model2.pth'))

三、完整的模型训练套路

以CIFAR10数据集来一个完整的模型训练。

代码示例:

model.py

from torch import nn# 搭建神经网络
class Lh(nn.Module):def __init__(self):super(Lh, self).__init__()self.model = nn.Sequential(nn.Conv2d(3, 32, 5, 1, 2),nn.MaxPool2d(2),nn.Conv2d(32, 32, 5, 1, 2),nn.MaxPool2d(2),nn.Conv2d(32, 64, 5, 1, 2),nn.MaxPool2d(2),nn.Flatten(),nn.Linear(64 * 4 * 4, 64),nn.Linear(64, 10))def forward(self, x):x = self.model(x)return x

train.py

import torch
import torchvision.datasets
from torch import nn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriterfrom model import Lh# 准备数据集
train_data = torchvision.datasets.CIFAR10('./cifar10', train=True, transform=torchvision.transforms.ToTensor(), download=True)
test_data = torchvision.datasets.CIFAR10('./cifar10', train=False, transform=torchvision.transforms.ToTensor(), download=True)
train_data_size = len(train_data)
test_data_size = len(test_data)
print("训练数据集的长度为:{}".format(train_data_size))
print("测试数据集的长度为:{}".format(test_data_size))# 利用DataLoader来加载数据集
train_dataloader = DataLoader(train_data, batch_size=64)
test_dataloader = DataLoader(test_data, batch_size=64)# 搭建神经网络 - 10分类
lh = Lh()# 损失函数
loss_fn = nn.CrossEntropyLoss()# 优化器
learning_rate = 1e-2
optimizer = torch.optim.SGD(lh.parameters(), lr=learning_rate)# 设置训练网络的一些参数
# 记录训练的次数
total_train_step = 0
# 记录测试的次数
total_test_step = 0
# 训练轮数
epoch = 10# 添加tensorboard
writer = SummaryWriter("train_logs")for i in range(epoch):print("-----第{}轮训练开始了-----".format(i + 1))# 训练步骤开始for data in train_dataloader:imgs, tragets = dataoutput = lh(imgs)loss = loss_fn(output, tragets)optimizer.zero_grad()loss.backward()optimizer.step()total_train_step += 1if total_train_step % 100 == 0:print("训练次数:{},Loss:{}".format(total_train_step, loss.item()))writer.add_scalar("train_loss", loss.item(), total_train_step)# 测试步骤开始total_test_loss = 0total_accuracy = 0with torch.no_grad():for data in test_dataloader:imgs, tragets = dataoutput = lh(imgs)loss = loss_fn(output, tragets)total_test_loss += lossaccuracy = (output.argmax(1) - - tragets).sum()total_accuracy += accuracyprint("整体测试机上误差:{}".format(total_test_loss))print("整体测试机上的正确率:{}".format(total_accuracy / test_data_size))writer.add_scalar("test_loss", total_test_loss, total_test_step)writer.add_scalar("test_accuracy", total_accuracy / total_test_step)total_test_step += 1# torch.save(lh, "lhy_{}.pth".format(i))# print("模型已保存")writer.close()

输出结果:

 在tensorboard打开

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/55321.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【逗老师的PMP学习笔记】5、项目范围管理

目录 一、规划范围管理二、收集需求1、【关键工具】头脑风暴2、【关键工具】访谈3、【关键工具】问卷调查4、【关键工具】标杆对照(对标)5、【关键工具】亲和图和思维导图6、【关键工具】质量功能展开7、【关键工具】用户故事8、【关键工具】原型法9、【…

VBA技术资料MF38:VBA_在Excel中隐藏公式

【分享成果,随喜正能量】佛祖也无能为力的四件事:第一,因果不可改,自因自果,别人是代替不了的;第二,智慧不可赐,任何人要开智慧,离不开自身的磨练;第三&#…

MySQL(1)

MySQL创建数据库和创建数据表 创建数据库 1. 连接 MySQL mysql -u root -p 2. 查看当前的数据库 show databases; 3. 创建数据库 create database 数据库名; 创建数据库 4. 创建数据库时设置字符编码 create database 数据库名 character set utf8; 5. 查看和显示…

【具身智能】前沿思考与总结(谷歌微软)

0. 总结 0.1 万字长文,当机器人拥抱大模型 只需要告诉机器人它要做的任务是什么,机器人就会理解需要做的事情,拆分任务动作,生成应用层控制指令,并根据任务过程反馈修正动作,最终完成人类交给的任务。整个…

小程序自定义tabBar+Vant weapp

1.构建npm,安装Vant weapp: 1)根目录下 ,初始化生成依赖文件package.json npm init -y 2)安装vant # 通过 npm 安装 npm i vant/weapp -S --production 3)修改 package.json 文件 开发者工具创建的项…

如何在 Ubuntu 上部署 ONLYOFFICE 协作空间社区版?

ONLYOFFICE 协作空间是一个在线协作平台,帮助您更好地与客户、业务合作伙伴、承包商及第三方进行文档协作。今天我们来介绍一下,如何在 Ubuntu 上安装协作空间的自托管版。 ONLYOFFICE 协作空间主要功能 使用 ONLYOFFICE 协作空间,您可以&am…

MySQL的关键指标及采集方法

MySQL 是个服务,所以我们可以借用 Google 四个黄金指标的思路来解决问题。 1、延迟 应用程序会向 MySQL 发起 SELECT、UPDATE 等操作,处理这些请求花费了多久,是非常关键的,甚至我们还想知道具体是哪个 SQL 最慢,这样…

用html+javascript打造公文一键排版系统15:一键删除所有空格

现在我们来实现一键删除所有空格的功能。 一、使用原有的代码来实现,测试效果并不理想 在这之前我们已经为String对象编写了一个使用正则表达式来删除所有空格的方法: //功能:删除字符串中的所有空格 //记录:20230726创建 Stri…

机器学习笔记 - 关于GPT-4的一些问题清单

一、简述 据报道,GPT-4 的系统由八个模型组成,每个模型都有 2200 亿个参数。GPT-4 的参数总数估计约为 1.76 万亿个。 近年来,得益于 GPT-4 等高级语言模型的发展,自然语言处理(NLP) 取得了长足的进步。凭借其前所未有的规模和能力,GPT-4为语言 AI​​设立了新标准,并为机…

16 - 初探Linux进程调度

---- 整理自狄泰软件唐佐林老师课程 查看所有文章链接:(更新中)Linux系统编程训练营 - 目录 文章目录 1. 初探Linux进程调度1.1 Linux系统调度1.2 进程调度原理1.3 Linux系统调度策略1.4 进程调度实验设计1.4.1 实验目标1.4.2 实验设计 1.5 实…

Mock.js的基本使用方法

官网网址:Mock.js (mockjs.com) 当前端工程师需要独立于后端并行开发时,后端接口还没有完成,那么前端怎么获取数据? 这时可以考虑前端搭建web server自己模拟假数据,这里我们选第三方库mockjs用来生成随机数据&#xf…

微信小程序开发【从0到1~入门篇】2023.08

一个小程序主体部分由三个文件组成,必须放在项目的根目录,如下: 文件必须作用app.js是小程序逻辑app.json是小程序公告配置app.wxss否小程序公告样式表 3. 小程序项目结构 一个小程序页面由四个文件组成,分别是: 文…