Pytorch深度学习-----神经网络模型的保存与加载(VGG16模型)

系列文章目录

PyTorch深度学习——Anaconda和PyTorch安装
Pytorch深度学习-----数据模块Dataset类
Pytorch深度学习------TensorBoard的使用
Pytorch深度学习------Torchvision中Transforms的使用(ToTensor,Normalize,Resize ,Compose,RandomCrop)
Pytorch深度学习------torchvision中dataset数据集的使用(CIFAR10)
Pytorch深度学习-----DataLoader的用法
Pytorch深度学习-----神经网络的基本骨架-nn.Module的使用
Pytorch深度学习-----神经网络的卷积操作
Pytorch深度学习-----神经网络之卷积层用法详解
Pytorch深度学习-----神经网络之池化层用法详解及其最大池化的使用
Pytorch深度学习-----神经网络之非线性激活的使用(ReLu、Sigmoid)
Pytorch深度学习-----神经网络之线性层用法
Pytorch深度学习-----神经网络之Sequential的详细使用及实战详解
Pytorch深度学习-----损失函数(L1Loss、MSELoss、CrossEntropyLoss)
Pytorch深度学习-----优化器详解(SGD、Adam、RMSprop)
Pytorch深度学习-----现有网络模型的使用及修改(VGG16模型)


文章目录

  • 系列文章目录
  • 一、网络模型的保存
    • 1.方法一
    • 2.方法二
  • 二、网络模型的加载
    • 1.方法一
    • 2.方法二
  • 三、总结


一、网络模型的保存

1.方法一

保存整个模型,包括其相关的所有参数

torch.save(obj, f, pickle_protocol=DEFAULT_PROTOCOL)

参数说明:

obj: 要保存的对象,可以是模型、张量、字典等。
f: 要保存到的文件路径或文件对象。
pickle_protocol: 序列化协议的版本,默认为DEFAULT_PROTOCOL。

代码如下:

import torch
import torchvision.models as models
from torch import nnvgg16_true = models.vgg16(weights=True)
vgg16_false = models.vgg16(weights=False)torch.save(vgg16_true, "vgg16_model_true.pth")

其中.pth是后缀标志。

在这里插入图片描述

2.方法二

只保存模型参数,在原有vgg16对象中使用.state_dict()方法即可。

代码如下:

import torch
import torchvision.models as models
from torch import nnvgg16_true = models.vgg16(weights=True)
vgg16_false = models.vgg16(weights=False)torch.save(vgg16_true.state_dict(), "vgg16_model_true_2.pth")

在这里插入图片描述

二、网络模型的加载

1.方法一

对应于上述中保存模型的方法1进行加载。

相关函数如下:

torch.load(f, map_location=None, pickle_module=pickle, **pickle_load_args)

参数说明:

f: 要加载的文件路径或文件对象。
map_location: 可选参数,用于指定在哪个设备上加载模型。如果不提供该参数,默认会加载到当前设备。
pickle_module: 可选参数,用于指定用于反序列化的模块。默认为pickle。
pickle_load_args: 其他可选的用于反序列化的参数。

代码如下:

import torch
import torchvision.models as models
from torch import nnmodel1 = torch.load("vgg16_model_true.pth")  # 因为vgg16_model_true.pth是使用方法一保存的,故输出后是整个模型网络结构
print(model1)
model2 = torch.load("vgg16_model_true_2.pth")  # 因为vgg16_model_true_2.pth是使用方法二保存的,只保留模型参数,故输出后是整个字典类型
print(model2)

vgg16_model_true.pth加载结果

VGG((features): Sequential((0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(1): ReLU(inplace=True)(2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(3): ReLU(inplace=True)(4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)(5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(6): ReLU(inplace=True)(7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(8): ReLU(inplace=True)(9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)(10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(11): ReLU(inplace=True)(12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(13): ReLU(inplace=True)(14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(15): ReLU(inplace=True)(16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)(17): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(18): ReLU(inplace=True)(19): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(20): ReLU(inplace=True)(21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(22): ReLU(inplace=True)(23): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)(24): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(25): ReLU(inplace=True)(26): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(27): ReLU(inplace=True)(28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(29): ReLU(inplace=True)(30): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False))(avgpool): AdaptiveAvgPool2d(output_size=(7, 7))(classifier): Sequential((0): Linear(in_features=25088, out_features=4096, bias=True)(1): ReLU(inplace=True)(2): Dropout(p=0.5, inplace=False)(3): Linear(in_features=4096, out_features=4096, bias=True)(4): ReLU(inplace=True)(5): Dropout(p=0.5, inplace=False)(6): Linear(in_features=4096, out_features=1000, bias=True))
)

vgg16_model_true_2.pth加载结果

OrderedDict([('features.0.weight', tensor([[[[-5.5373e-01,  1.4270e-01,  5.2896e-01],[-5.8312e-01,  3.5655e-01,  7.6566e-01],[-6.9022e-01, -4.8019e-02,  4.8409e-01]],[[ 1.7548e-01,  9.8630e-03, -8.1413e-02],[ 4.4089e-02, -7.0323e-02, -2.6035e-01],[ 1.3239e-01, -1.7279e-01, -1.3226e-01]],[[ 3.1303e-01, -1.6591e-01, -4.2752e-01],[ 4.7519e-01, -8.2677e-02, -4.8700e-01],[ 6.3203e-01,  1.9308e-02, -2.7753e-01]]],[[[ 2.3254e-01,  1.2666e-01,  1.8605e-01],[-4.2805e-01, -2.4349e-01,  2.4628e-01],[-2.5066e-01,  1.4177e-01, -5.4864e-03]],[[-1.4076e-01, -2.1903e-01,  1.5041e-01],[-8.4127e-01, -3.5176e-01,  5.6398e-01],[-2.4194e-01,  5.1928e-01,  5.3915e-01]],[[-3.1432e-01, -3.7048e-01, -1.3094e-01],[-4.7144e-01, -1.5503e-01,  3.4589e-01],[ 5.4384e-02,  5.8683e-01,  4.9580e-01]]],[[[ 1.7715e-01,  5.2149e-01,  9.8740e-03],[-2.7185e-01, -7.1709e-01,  3.1292e-01],[-7.5753e-02, -2.2079e-01,  3.3455e-01]],[[ 3.0924e-01,  6.7071e-01,  2.0546e-02],[-4.6607e-01, -1.0697e+00,  3.3501e-01],[-8.0284e-02, -3.0522e-01,  5.4460e-01]],[[ 3.1572e-01,  4.2335e-01, -3.4976e-01],[ 8.6354e-02, -4.6457e-01,  1.1803e-02],[ 1.0483e-01, -1.4584e-01, -1.5765e-02]]],...,

2.方法二

import torch
import torchvision.models as models
from torch import nnvgg16_true = models.vgg16(weights=True)vgg16_true.load_state_dict(torch.load("vgg16_model_true_2.pth"))  # 针对第二种加载参数的情况,使其显示完整的网络结构
print(vgg16_true)
VGG((features): Sequential((0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(1): ReLU(inplace=True)(2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(3): ReLU(inplace=True)(4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)(5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(6): ReLU(inplace=True)(7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(8): ReLU(inplace=True)(9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)(10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(11): ReLU(inplace=True)(12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(13): ReLU(inplace=True)(14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(15): ReLU(inplace=True)(16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)(17): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(18): ReLU(inplace=True)(19): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(20): ReLU(inplace=True)(21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(22): ReLU(inplace=True)(23): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)(24): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(25): ReLU(inplace=True)(26): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(27): ReLU(inplace=True)(28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(29): ReLU(inplace=True)(30): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False))(avgpool): AdaptiveAvgPool2d(output_size=(7, 7))(classifier): Sequential((0): Linear(in_features=25088, out_features=4096, bias=True)(1): ReLU(inplace=True)(2): Dropout(p=0.5, inplace=False)(3): Linear(in_features=4096, out_features=4096, bias=True)(4): ReLU(inplace=True)(5): Dropout(p=0.5, inplace=False)(6): Linear(in_features=4096, out_features=1000, bias=True))
)

注意: 加载模型时,要确保当前代码中使用的模型类与之前保存的模型类相同。

三、总结

torch.load()是PyTorch中用于加载保存的对象的函数,可以加载之前使用torch.save()保存的模型、张量、字典等。可以指定要加载的文件路径或文件对象,并可选地指定加载到的设备、反序列化模块等参数。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/59959.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

大数据技术之Hadoop:HDFS集群安装篇(三)

目录 分布式文件系统HDFS安装篇 一、为什么海量数据需要分布式存储 二、 分布式的基础架构分析 三、 HDFS的基础架构 四 HDFS集群环境部署 4.1 下载安装包 4.2 集群规划 4.3 上传解压 4.4 配置HDFS集群 4.5 准备数据目录 4.6 分发hadoop到其他服务器 4.7 配置环境变…

Linux驱动之设备树添加蜂鸣器驱动

目录 一、蜂鸣器简介 二、硬件原理分析 三、蜂鸣器驱动原理 四、开发环境 五、修改设备树文件 1、添加 pinctrl 节点 2、添加 BEEP 设备节点 3、检查 PIN 是否被其他外设使用 六、蜂鸣器驱动程序编写 七、测试程序编写 八、运行验证 在 I.MX6U-ALPHA 开发板上有一个有源…

《合成孔径雷达成像算法与实现》Figure3.6

代码复现如下: clc clear all close all%参数设置 TBP 100; %时间带宽积 T 10e-6; %脉冲持续时间%参数计算 B TBP/T; …

zadig安装驱动潜在风险与解决策略

zadig安装驱动潜在风险与解决策略 ✨没事不要闲着乱打驱动,能正常使用的情况下,不要轻易或随意去乱打驱动,可能会导致新的驱动对已有的设备不兼容的问题。✨🔰特别说明:本文介绍的方法,并不能包治百病&…

Java EE 突击 9 - Spring Boot 日志文件

Spring Boot 日志文件 学习目标一 . 日志有什么用1.1 日志格式说明 二 . 自定义日志打印2.1 得到日志对象2.2 使用日志对象提供的方法 , 输出自定义的日志内容2.3 日志的级别 三 . 日志持久化3.1 在配置文件里面设置日志名称3.2 设置日志的保存目录 四 . 日志级别的设置五 . 简…

虹科方案 | 成都大运会进行时,保障大型活动无线电安全需要…

成都大运会 7月28日,备受关注的第31届世界大学生夏季运动会在成都正式开幕。据悉,这是全球首个5G加持的智慧大运会,也是众多成熟信息技术的综合“应用场”。使用基于5G三千兆、云网、8K超高清视频等技术,在比赛现场搭建多路8K摄像…

【算法|数组】双指针

算法|数组——双指针 引入 给你一个按 非递减顺序 排序的整数数组 nums,返回 每个数字的平方 组成的新数组,要求也按 非递减顺序 排序。 示例 1: 输入:nums [-4,-1,0,3,10] 输出:[0,1,9,16,100] 解释:…

爬虫014_文件操作_打开关闭_读写_序列化_反序列化---python工作笔记033

报错,没有指定路径,没有指定路径无法创建文件 这样可以在当前目录下创建一个可写的文件 可以看到找到刚才生成的文件,看看内容

企业服务器数据库中了devos勒索病毒怎么办如何解决预防勒索病毒攻击

随着科学技术的不断发展,计算机可以帮助我们完成很多重要的工作,但是随之而来的网络威胁也不断提升。近期,我们收到很多企业的求助,企业的服务器数据库遭到了devos勒索病毒攻击,导致系统内部的许多重要数据被加密无法正…

1.RuoYi-Vue前后端分离版本启动

1.代码下载 去若依官网,选择RuoYI前后端分离版: 下载地址:https://gitee.com/y_project/RuoYi-Vue 2.依赖环境的部署 1.Mysql 2.Redis安装部署 : https://blog.csdn.net/qq_27860623/article/details/132168382 3.Idea打开后端服务 用Idea打开后端的…

【Linux命令详解 | grep命令】 grep命令用于在文件中搜索指定模式的文本,功能强大且常用

文章标题 简介一,参数列表二,使用介绍1. 使用基本模式搜索2. 忽略大小写匹配3. 反向匹配4. 递归搜索目录5. 显示文件名6. 显示行号7. 显示上下文行8. 启用扩展正则表达式9. 将模式视为固定字符串10. 使用颜色高亮显示匹配文本 总结 简介 在Linux系统中&…

界面控件DevExpress WPF Chart组件——拥有超快的数据可视化库!

DevExpress WPF Chart组件拥有超大的可视化数据集,并提供交互式仪表板与高性能WPF图表库。DevExpress Charts提供了全面的2D / 3D图形集合,包括数十个UI定制和数据分析/数据挖掘选项。 PS:DevExpress WPF拥有120个控件和库,将帮助…