pytorch保存、加载和解析模型权重

1、模型保存和加载

         主要有两种情况:一是仅保存参数,二是保存参数及模型结构。

保存参数:

         torch.save(net.state_dict())

加载参数(加载参数前需要先实例化模型):

         param = torch.load('param.pth')

         net.load_state_dict(param)

保存模型结构和参数:

         torch.save(net)

加载模型:

         net = torch.load('model.pt')

2、解析模型权重文件

         当加载某个模型文件后,如果需要查看模型中的算子和参数,可以将模型解析为字典,然后逐一打印。

以lent5为例,将lenet5模型保存为权重文件,然后重新加载权重文件并解析其中每一层的参数。

参考代码:

def pytorch_params(pth_file):par_dict = torch.load(pth_file, map_location='cpu')for name in par_dict:parameter = par_dict[name]print(name, parameter.numpy().shape)

        以上代码是加载的权重文件,文件只有参数,没有模型结构,如果加载的是包含模型结构的权重文件,可以做如下修改:

def pytorch_params(pt_file):net = torch.load(pt_file, map_location='cpu')par_dict = net.state_dict()for name in par_dict:parameter = par_dict[name]print(name, parameter.numpy().shape)

解析结果:

3、加载自定义参数

        某些情况下可能需要对某个算子进行单独调试,如加载特定参数进行推理计算,用来确定输出结果符合预期。以Conv2d算子为例进行测试,首先设定卷积层输入为3,输出为3,卷积核为3*3,偏置bias为False。通过numpy随机一个3*3*3*3的矩阵作为自定义参数,将参数转换为Tensor以后,添加到dict中,然后通过load_state_dict将参数加载进网络。

参考脚本:

 

import torch
import torch.nn as nn
import numpy as np
net = nn.Conv2d(3, 3, kernel_size=(3, 3), padding=1, bias=False)
param = np.random.random((3, 3, 3, 3))
param = param.astype(np.float32)
torch_param = {'weight': torch.Tensor(param)}
net.load_state_dict(torch_param)
net.eval()
data = np.random.random((1, 3, 16, 16))
data = data.astype(np.float32)
result = net(torch.Tensor(data))
print(result)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/26530.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

2023最新版本Activiti7系列-身份服务

身份服务 在流程定义中在任务结点的 assignee 固定设置任务负责人,在流程定义时将参与者固定设置在.bpmn 文件中,如果临时任务负责人变更则需要修改流程定义,系统可扩展性差。针对这种情况可以给任务设置多个候选人或者候选人组,可…

vue-next-admin vue3.x版本,table自定义

vue3.x版本&#xff0c;将table进行了封装。使用起来更方便了。但是&#xff0c;有时候我们需要将一组信息显示到一列中。所以我将其进行了简单的二次改造。支持table-column自定义。 table改造代码 <template><div class"table-container"><el-tabl…

【Ajax】笔记-POST请求(原生)

POST请求 html <!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8"><meta name"viewport" content"widthdevice-width, initial-scale1.0"><title>AJAX POST 请求</title><…

VScode——NPM脚本窗口找不到

一、问题描述&#xff08;NPM终端在任务栏左侧找不到&#xff09; VScode&#xff08;Visual Studio Code&#xff09;版本&#xff1a;1.79.2 二、解决办法 第一步&#xff1a;通过设置/用户设置/扩展/MPM更改NPM默认配置&#xff0c;如下图所示&#xff1a; 第二步&#xff…

springboot实现全局异常捕获

导言&#xff1a; 为什么要做异常处理&#xff1a; 原因有三&#xff1a; 1、将系统产生的全部异常统一捕获处理。 2、自定义异常需要由全局异常来捕获。 3、JSR303规范的validator参数校验器、参数校验不通过、本身无法使用try…catch 其实对于前后端分离的项目做异常处理…

分布式应用之Zookeeper和Kafka

分布式应用之Zookeeper和Kafka 一、Zookeeper 1.定义 分布式系统管理框架&#xff0c;主要用来解决分布式集群中应用系统的一致性问题 相当于各种分布式应用服务的 注册中心 文件系统 通知机制2.特点 &#xff08;1&#xff09;Zookeeper&#xff1a;一个领导者&#…

【Java基础教程】(十五)面向对象篇 · 第九讲:抽象类和接口——定义、限制与应用的细节,初窥模板设计模式、工厂设计模式与代理设计模式~

Java基础教程之面向对象 第九讲 本节学习目标1️⃣ 抽象类1.1 抽象类定义1.2 抽象类的相关限制1.3 抽象类应用——模板设计模式 2️⃣ 接口2.1 接口定义2.2 接口的应用——标准2.3 接口的应用——工厂设计模式 (Factory)2.4 接口的应用——代理设计模式 (Proxy) 3️⃣ 抽象类与…

数据库的扩展策略

了解不同的数据库扩展技术可以帮助我们选择适合我们需求和目的的合适策略。 因此&#xff0c;在本文中&#xff0c;我们将展示不同的解决方案和技术&#xff0c;用于扩展数据库服务器。它们分为读取和写入策略。 读取/加载 有时我们的应用程序承受着巨大的负载。为了解决这个…

【VSCode | 使用技巧集锦】中文插件突然失效、配置单个工程(工作区)编码

目录 ✨技巧一&#xff1a;中文插件失效的解决办法✨技巧二&#xff1a;配置单个工程(工作区)编码 ✨技巧一&#xff1a;中文插件失效的解决办法 问题描述&#xff1a;VSCode之前安装了中文插件&#xff0c;可以正常汉化&#xff0c;用了一段时间都没问题&#xff0c;今天打开v…

51单片机的智能交通控制系统【含仿真+程序+演示视频带原理讲解】

51单片机的智能交通控制系统【含仿真程序演示视频带原理讲解】 1、系统概述2、核心功能3、仿真运行及功能演示4、程序代码 1、系统概述 该系统由AT89C51单片机、LED灯组、数码管组成。通过Protues对十字路口红绿灯控制逻辑进行了仿真。 每个路口包含了左转、右转、直行三条车道…

【UE4 塔防游戏系列】08-敌人到达终点对玩家造成伤害

目录 效果 步骤 一、敌人到终点时扣除玩家生命值 二、显示玩家生命值 效果 可以看到敌人进入终点后&#xff0c;左上角的玩家生命值会减少。 步骤 一、敌人到终点时扣除玩家生命值 新建一个Actor蓝图类&#xff0c;命名为“BP_EnemyEndPlace”&#xff0c;用来表示终点…

Pytest测试框架搭建需求及实现方案

目录 框架需求及实现方案 框架需求 实现方案 支持接口自动化、Web UI自动化及App自动化# 可以批量运行用例并生成测试报告 测试完成发送邮件 提供灵活的运行方式&#xff0c;如按功能模块运行、按脚本运行、按用例等级运行等等 提供运行日志方便定位问题 支持切换环境 …