register_parameter和register_buffer 详解

在参考yolo系列代码或其他开源代码,经常看到register_buffer register_parameter的使用,接下来将详细对他们进行介绍。

1. 前沿

在搭建网络时,我们 自定义的参数,往往不会保存到模型权重文件中,或者成为模型可学习的参数。即我们通过 net.named_parameters() (模型可学习参数)或 net.state_dict().items()(保存模型权重值)方法都无法遍历输出。那如何解决呢,这就需要用到本文讲的register_parameterregister_buffer方法。

2. register_parameter

register_parameter() 是 torch.nn.Module 类中的一个方法。

2.1 主要作用

  • 用于定义可学习参数
  • 定义的参数可被保存到网络对象的参数中,可使用 net.parameters()net.named_parameters() 查看
  • 定义的参数可用 net.state_dict() 转换到字典中,进而 保存到网络文件 / 网络参数文件

2.2 函数说明

register_parameter(name,param)

参数:

  • name:参数名称

  • param:参数张量, 须是torch.nn.Parameter()对象 或 None ,否则报错如下
    TypeError: cannot assign 'torch.FloatTensor' object to parameter 'xx' (torch.nn.Parameter or None required)

2.3 举例说明

(1)自定义的参数未使用register_parameter

import torch
import torch.nn as nnclass MyModule(nn.Module):def __init__(self):super(MyModule, self).__init__()self.conv1 = nn.Conv2d(in_channels=3, out_channels=6, kernel_size=3, stride=1, padding=1, bias=False)self.conv2 = nn.Conv2d(in_channels=6, out_channels=9, kernel_size=3, stride=1, padding=1, bias=False)self.weight = torch.ones(10,10)self.bias = torch.zeros(10)def forward(self, x):x = self.conv1(x)x = self.conv2(x)x = x * self.weight + self.biasreturn xnet = MyModule()print('\n', '*'*30+"net.named_parameters"+'*'*30, '\n')
for name, param in net.named_parameters():print(name, param.shape)print('\n', '*'*30+"net.state_dict"+'*'*30, '\n')
for key, val in net.state_dict().items():print(key, val.shape) 

输出:
在这里插入图片描述
在网络搭建的代码中,我们自定义了self.weightself.bias参数。我们思考下2个问题:1. 我们定义的self.weightself.bias参数是否会保存到网络的参数中,是否能在优化器的作用下进行学习。2. 这些参数是否能够保存到模型文件中,从而可以利用state_dict中遍历出来。通过上面的打印信息我们发现:

  • 使用net.named_parameters()迭代网络中可学习的参数,发现输出的参数只有conv1conv2的weight参数,并没有输出我们定义的self.weightself.bias
  • 接下来使用net.state_dict()方法迭代保存的参数,同样发现self.weightself.bias参数也没有被输出出来。

(2)通过register_parameter方法来定义参数

  • 接下来我们使用register_parameter来定义weight和bias参数,看看会有啥效果。代码修改如下:
self.register_parameter('weight',torch.nn.Parameter(torch.ones(10,10)))
self.register_parameter('bias',torch.nn.Parameter(torch.zeros(10)))

完整代码

import torch
import torch.nn as nnclass MyModule(nn.Module):def __init__(self):super(MyModule, self).__init__()self.conv1 = nn.Conv2d(in_channels=3, out_channels=6, kernel_size=3, stride=1, padding=1, bias=False)self.conv2 = nn.Conv2d(in_channels=6, out_channels=9, kernel_size=3, stride=1, padding=1, bias=False)self.register_parameter('weight',torch.nn.Parameter(torch.ones(10,10)))self.register_parameter('bias',torch.nn.Parameter(torch.zeros(10)))def forward(self, x):x = self.conv1(x)x = self.conv2(x)x = x * self.weight + self.biasreturn xnet = MyModule()print('\n', '*'*30+"net.named_parameters"+'*'*30, '\n')
for name, param in net.named_parameters():print(name, param.shape)print('\n', '*'*30+"net.state_dict"+'*'*30, '\n')
for key, val in net.state_dict().items():print(key, val.shape) 

在这里插入图片描述

  • 可以看到,使用了register_parameter定义的参数weight和bias,可以通过net.named_parameters或者net.parameters迭代出来的,这说明weight和bias已经存到了网络的参数中,他们是可学习的参数
  • 同时,通过state_dict()也能将参数和值给迭代出来,就说明如果要保存模型权重或网络参数时,这两个参数时可以被保存起来的。

3 register_buffer()

register_buffer()是 torch.nn.Module() 类中的一个方法

3.1 作用

  • 用于定义不可学习的参数
  • 定义的参数不会被保存到网络对象的参数中,使用 net.parameters() 或 net.named_parameters() 查看不到
  • 定义的参数可用 net.state_dict() 转换到字典中,进而 保存到网络文件 / 网络参数文件中

register_buffer() 用于在网络实例中 注册缓冲区,存储在缓冲区中的数据,类似于参数(但不是参数),它与参数的区别为:

  • 参数:可以被优化器更新 (requires_grad=False / True)

  • buffer 中的数据 (不可学习): 不会被优化器更新

3.2、举例说明

将定义的weight和bias,通过register_buffer来定义。

self.register_buffer('weight',torch.ones(10,10))
self.register_buffer('bias',torch.zeros(10))

运行完整代码看看效果:

import torch
import torch.nn as nnclass MyModule(nn.Module):def __init__(self):super(MyModule, self).__init__()self.conv1 = nn.Conv2d(in_channels=3, out_channels=6, kernel_size=3, stride=1, padding=1, bias=False)self.conv2 = nn.Conv2d(in_channels=6, out_channels=9, kernel_size=3, stride=1, padding=1, bias=False)self.register_buffer('weight',torch.ones(10,10))self.register_buffer('bias',torch.zeros(10))def forward(self, x):x = self.conv1(x)x = self.conv2(x)x = x * self.weight + self.biasreturn xnet = MyModule()zprint('\n', '*'*30+"net.named_parameters"+'*'*30, '\n')
for name, param in net.named_parameters():print(name, param.shape)print('\n', '*'*30+"net.state_dict"+'*'*30, '\n')
for key, val in net.state_dict().items():print(key, val.shape) 

在这里插入图片描述
我们可以看到:

  • 通过register_buffer定义的参数weight和bias,它是没有被named_parameter给迭代出来的,也就是说weight和bias不是网络的可学习参数,无法通过优化器来迭代更新,我们把它叫做buffer,而不是参数
  • 然而我们使用net.state_dict去迭代的话,weight和bias事可以被迭代出来的,这就说明使用register_buffer定义的数据,可以保持到模型或者权重文件中。

注意:

  • 在使用register_parameter定义参数时,必须定义为可学习的参数,因此需要通过torch.nn.Parameter去定义为一个可学习的参数
  • 而我们使用register_buffer定义参数时,是不需要通过torch.nn.Parameter去定义为可学习的参数的

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/161117.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

“轻松管理收支明细,随时筛选借款信息,财务清晰无忧“

作为现代人,我们每天都在与金钱打交道。无论是个人还是企业,都需要对收支情况进行详细的管理和分析。然而,繁琐的财务数据往往让人头疼。现在,我们为您推荐一款强大的财务管理工具,让您轻松管理收支明细,随…

超越 GLIP! | RegionSpot: 识别一切区域,多模态融合的开放世界物体识别新方法

本文的主题是多模态融合和图文理解,文中提出了一种名为RegionSpot的新颖区域识别架构,旨在解决计算机视觉中的一个关键问题:理解无约束图像中的各个区域或patch的语义。这在开放世界目标检测等领域是一个具有挑战性的任务。 关于这一块&…

Android可绘制资源概览(背景、图形等)

关于作者:CSDN内容合伙人、技术专家, 从零开始做日活千万级APP。 专注于分享各领域原创系列文章 ,擅长java后端、移动开发、商业变现、人工智能等,希望大家多多支持。 目录 一、导读二、概览三、drawable 分类3.1 Bitmap fileXML …

速学数据结构 | 链表实现队列究竟有什么优势?

🎬 鸽芷咕:个人主页 🔥 个人专栏:《速学数据结构》 《C语言进阶篇》 ⛺️生活的理想,就是为了理想的生活! 📋 前言 🌈hello! 各位宝子们大家好啊,栈区的实现我们前面已经讲了&#…

shell学习脚本05(小滴课堂)

可以对海量的数据进行提取。 -v对提取的内容进行取反。 -n显示出行号。 -w精确匹配: -i:忽略大小写: -E正则匹配: cut命令: -d指定分隔符,-f指定截取区域: 截取第一列到第三列: 截取第二列到最…

趋势:实时的stable diffusion

视频中使用了实时模型:只需2~4 个步骤甚至一步即可生成768 x 768分辨率图像。 这项技术可以把任意的stable diffusion模型转为实时模型。 潜在一致性模型 LCM LCM 只需 4,000 个训练步骤(约 32 个 A100 GPU 一小时)即可从任何预训练的SD模型中…

在markdown中怎么画表格

2023年11月5日,周日上午 下面是一种常用的方式来编写表格: | 列1标题 | 列2标题 | 列3标题 | |:------:|:------:|:------:| | 内容 | 内容 | 内容 | | 内容 | 内容 | 内容 |在这个示例中,第一行用于定义表格的列标…

论文阅读—— UniDetector(cvpr2023)

arxiv:https://arxiv.org/abs/2303.11749 github:https://github.com/zhenyuw16/UniDetector 一、介绍 通用目标检测旨在检测场景那种的一切目标。现有的检测器依赖于大量数据集 通用的目标检测器应该有两个能力:1、可以利用多种来…

asp.net人事管理信息系统VS开发sqlserver数据库web结构c#编程Microsoft Visual Studio

一、源码特点 asp.net 人事管理信息系统是一套完善的web设计管理系统,系统具有完整的源代码和数据库,系统主要采用B/S模式开发。开发环境为vs2010,数据库为sqlserver2008,使用c#语言 开发 asp.net 人事管理系统1 应用技术…

宠物医院服务预约小程序的效果如何

随着养宠家庭增多及对爱宠的照顾加深,除了食品、服饰外,宠物医院近些年也迎来了较高发展,部分城市甚至聚集着众多品牌,以单店或多店品牌的方式拓展市场。 对宠物医院来说,一般都是拓展同市客户,或者多门店…

强化学习的动态规划三

一、策略的改进 假设新的贪婪策略π0与旧的策略π效果相当,但并不优于π。由此得出vπvπ0,且根据之前的推导可以得出:对于所有的s∈S 这与贝尔曼最优方程相同,因此,vπ0是v∗,π和π0是最佳策略。因此&…

类和对象 下

构造函数 初始化列表 构造函数是在构造类的时候给对象赋初值,并不是给类的成员函数初始化,赋值可以赋多次,而初始化只能初始化一次,这里我们引入初始化列表来对成员函数进行初始化 初始化列表,以冒号 :开…