6-pytorch - 网络的保存和提取

前言

我们训练好的网络,怎么保存和提取呢?
总不可以一直不关闭电脑吧,训练到一半,想结束到明天再来训练,这就需要进行网络的保存和提取了。
本文以前面博客3-pytorch搭建一个简单的前馈全连接层网络(回归问题)的网络进行网络的保存和提取,建议先看完上面博客再来看本博客。

一、生成训练数据

import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np# 生成数据(fake data)
x = torch.linspace(-1,1,100).reshape(-1,1)
# 加上点噪声
y = x.pow(2) + 0.2*torch.rand(x.shape)# 可视化一下数据
plt.scatter(x.data.numpy(),y.data.numpy())
plt.show()

输出:
在这里插入图片描述

二、网络保存

def save():net1 = torch.nn.Sequential(torch.nn.Linear(1,10),torch.nn.ReLU(), torch.nn.Linear(10,1))optimizer = torch.optim.SGD(net1.parameters(),lr=0.5)loss_func = torch.nn.MSELoss()for t in range(100):prediction = net1(x)loss = loss_func(prediction,y)optimizer.zero_grad()loss.backward()optimizer.step()# 下面介绍两种不同的保存方法,方法二可能运行速度要快点# 保存整个网络的所有torch.save(net1, 'net.pkl')     # 保存好网络的参数torch.save(net1.state_dict(),'net_params.pkl')# plot resultplt.figure(1,figsize=(10,3))plt.subplot(131)plt.title('Net1')plt.scatter(x.data.numpy(),y.data.numpy())plt.plot(x.data.numpy(),prediction.data.numpy(),'r-',lw=5)

【注】:保存整个网络还是保存网络参数,个人建议仅保存参数,这个速度更快。

三、网络提取

def restore_net():net2 = torch.load('net.pkl')prediction = net2(x)# plot resultplt.subplot(132)plt.title('Net1')plt.scatter(x.data.numpy(),y.data.numpy())plt.plot(x.data.numpy(),prediction.data.numpy(),'r-',lw=5)def restore_params():# 如果只是保留参数的情况,提取时需要再次定义相同网络才行net3 = torch.nn.Sequential(torch.nn.Linear(1,10),torch.nn.ReLU(),torch.nn.Linear(10,1))net3.load_state_dict(torch.load('net_params.pkl'))prediction = net3(x)# plot resultplt.subplot(133)plt.title('Net1')plt.scatter(x.data.numpy(),y.data.numpy())plt.plot(x.data.numpy(),prediction.data.numpy(),'r-',lw=5)

四、对保存网络提取进行结果展示

save()
restore_net()
restore_params()

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/625344.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

详解UART通信协议以及FPGA实现

文章目录 一、UART概述二、UART协议帧格式2.1 波特率2.2 奇校验ODD2.3 偶校验EVEN 三、UART接收器设计3.1 接收时序图3.2 Verilog代码3.3 仿真文件测试3.4 仿真结果3.5 上版测试 四、UART发送器设计4.1 发送时序图4.2 Verilog代码4.3 仿真文件测试4.4 仿真结果4.5 上板测试 五、…

CommunityToolkit.Mvvm笔记---AsyncRelayCommand

AsyncRelayCommand 是 CommunityToolkit.Mvvm 中的一个功能,专门设计用来处理异步操作。它是 RelayCommand 的一个变体,提供了对异步任务的支持,允许开发者在 MVVM(Model-View-ViewModel)模式中方便地实现异步命令。使…

RK3568笔记二十二:基于TACO的垃圾检测和识别

若该文为原创文章,转载请注明原文出处。 基于TACO数据集,使用YOLOv8分割模型进行垃圾检测和识别,并在ATK-RK3568上部署运行。 一、环境 1、测试训练环境:AutoDL. 2、平台:rk3568 3、开发板: ATK-RK3568正点原子板子…

气膜滑冰馆:打破冰雪运动地域限制-轻空间

冰雪运动资源受地域与气候环境的限制,仅能在特定区域和时间段内进行,但随着科技的进步,气膜滑冰馆的出现打破了这一限制,成为了冰雪运动领域的新宠。这种全新的体育场馆让滑冰运动摆脱了地域限制,为运动爱好者提供了更…

[html]一个动态js倒计时小组件

先看效果 代码 <style>.alert-sec-circle {stroke-dasharray: 735;transition: stroke-dashoffset 1s linear;} </style><div style"width: 110px; height: 110px; float: left;"><svg style"width:110px;height:110px;"><cir…

【C++】set 类 和 map 类

1. 关联式容器 关联式容器也是用来存储数据的&#xff0c;与序列式容器不同的是&#xff0c;其里面存储的是<key, value>结构的 键值对&#xff0c;在数据检索时比序列式容器效率更高 2. 键值对 用来表示具有一一对应关系的一种结构&#xff0c;该结构中一般只包含…

Linux-时间同步服务器

1. (问答题) 一.配置server主机要求如下&#xff1a; 1.server主机的主机名称为 ntp_server.example.com 编写脚本文件 #!/bin/bash hostnamectl hostname ntp_server.example.com cd /etc/NetworkManager/system-connections/ rm -fr * cat > eth0.nmconnection <&…

【STL详解 —— priority_queue的使用与模拟实现】

STL详解 —— priority_queue的使用与模拟实现 priority_queue的使用priority_queue的介绍priority_queue的定义方式priority_queue各个接口的使用 priority_queue的模拟实现仿函数priority_queue的模拟实现 priority_queue的使用 priority_queue的介绍 std::priority_queue 是…

Java面试八股文(JVM篇)(❤❤)

Java面试八股文_JVM篇 1、知识点汇总2、知识点详解&#xff1a;3、说说类加载与卸载11、说说Java对象创建过程12、知道类的生命周期吗&#xff1f;14、如何判断对象可以被回收&#xff1f;17、调优命令有哪些&#xff1f;18、常见调优工具有哪些20、你知道哪些JVM性能调优参数&…

[Algorithm][双指针][查找总价格为目标值的两个商品][三数之和][四数之和]详细解读 + 代码实现

目录 1.查找总价格为目标值的两个商品1.题目链接2.算法原理讲解3.代码实现 2.三数之和1.题目链接2.算法原理讲解3.代码实现 3.四数之和1.题目链接2.算法原理讲解3.代码实现 1.查找总价格为目标值的两个商品 1.题目链接 题目链接 2.算法原理讲解 由于本题数据有序&#xff0c…

【前端】1. HTML【万字长文】

HTML 基础 HTML 结构 认识 HTML 标签 HTML 代码是由 “标签” 构成的. 形如: <body>hello</body>标签名 (body) 放到 < > 中大部分标签成对出现. <body> 为开始标签, </body> 为结束标签.少数标签只有开始标签, 称为 “单标签”.开始标签和…

Transformer的Decoder的输入输出都是什么

目录 1 疑问&#xff1a;Transformer的Decoder的输入输出都是什么 2 推理时Transformer的Decoder的输入输出 2.1 推理过程中的Decoder输入输出 2.2 整体右移一位 3 训练时Decoder的输入 参考文献&#xff1a; 1 疑问&#xff1a;Transformer的Decoder的输入输出都是什么 …