BN体系理解——类封装复现

 

 

 

 

 

from pathlib import Path
from typing import Optionalimport torch
import torch.nn as nn
from torch import Tensorclass BN(nn.Module):def __init__(self,num_features,momentum=0.1,eps=1e-8):##num_features是通道数"""初始化方法:param num_features:特征属性的数量,也就是通道数目C"""super(BN, self).__init__()##register_buffer:将属性当成parameter进行处理,唯一的区别就是不参与反向传播的梯度求解self.register_buffer('running_mean', torch.zeros(1, num_features, 1, 1))self.register_buffer('running_var', torch.zeros(1, num_features, 1, 1))self.running_mean: Optional[Tensor]self.running_var: Optional[Tensor]self.running_mean=torch.zeros([1,num_features,1,1])self.running_var=torch.zeros([1,num_features,1,1])self.gamma=nn.Parameter(torch.ones([1,num_features,1,1]))self.beta=nn.Parameter(torch.zeros(1,num_features,1,1))self.eps=epsself.momentum=momentumdef forward(self,x):"""前向过程output=(x-μ)/α*γ+β:param x: [N,C,H,W]:return: [N,C,H,W]"""if self.training:#训练阶段--》使用当前批次的数据_mean=torch.mean(x,dim=(0,2,3),keepdim=True)_var = torch.var(x, dim=(0, 2, 3), keepdim=True)#将训练过程中的均值和方差保存下来--方便推理的时候使用--》滑动平均self.running_mean=self.momentum*self.running_mean+(1.0-self.momentum)*_meanself.running_var=self.momentum*self.running_var+(1.0-self.momentum)*_varelse:#推理阶段-->使用的是训练过程中的累积数据_mean=self.running_mean_var=self.running_varz=(x-_mean)/torch.sqrt(_var+self.eps)*self.gamma+self.betareturn zif __name__ == '__main__':torch.manual_seed(28)path_dir=Path("./output/models")path_dir.mkdir(parents=True,exist_ok=True)device=torch.device("cuda" if torch.cuda.is_available() else "cpu")bn=BN(num_features=12)bn.to(device)#只针对子模块和参数进行转换#模拟训练过程bn.train()xs=[torch.randn(8,12,32,32).to(device) for _ in range(10)]for _x in xs:bn(_x)print(bn.running_mean.view(-1))print(bn.running_var.view(-1))#模拟推理过程bn.eval()_r=bn(xs[0])print(_r.shape)bn=bn.cpu()#保存都是以cpu保存,恢复再自己转回GPU上#模拟模型保存torch.save(bn,str(path_dir/'bn_model.pkl'))#state_dict:获取当前模块的所有参数(Parameter+register_buffer)torch.save(bn.state_dict(),str(path_dir/"bn_params.pkl"))#pt结构的保存traced_script_module=torch.jit.trace(bn.eval(),xs[0].cpu())traced_script_module.save("./output/bn_model.pt")#模拟模型恢复bn_model=torch.load(str(path_dir/"bn_model.pkl"),map_location='cpu')bn_params=torch.load(str(path_dir/"bn_params.pkl"),map_location='cpu')print(len(bn_params))

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/132576.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

为Yolov7环境安装Cuba匹配的Pytorch

1. 查看Cuba版本 方法一 nvidia-smi 找到CUDA Version 方法二 Nvidia Control Panel > 系统信息 > 组件 > 2. 安装Cuba匹配版本的PyTorch https://pytorch.org/get-started/locally/这里使用conda安装 conda install pytorch torchvision torchaudio pytorch-cu…

LeetCode416 分割等和子集

题目: 、 分析: 因为分割的子数组,不连续;所以双指针、栈,一般不适用,分析起来很像是DP问题。 思路: https://www.imooc.com/article/300277 代码: //TODO 这题有难度

如何在C++项目中用C#运行程序调试C++ DLL

问题描述 在C#项目中调用C DLL时报错或者运行结果不符,此时需要运行C#项目并在C中加入断点进行调试 项目准备 项目一:C#项目(该项目调用C DLL)项目二:C项目(生成C DLL) 这两个项目不需要在同…

【web实现右侧弹窗】JS+CSS如何实现右侧缓慢弹窗动态效果『附完整源码下载』

文章目录 写在前面涉及知识点页面效果1、页面DOM创建1.1创建底层操作dom节点1.2 创建存放弹窗dom节点 2、页面联动功能实现(关闭与弹出)2.1 点击非右侧区域实现关闭2.2 点击叉叉及关闭按钮实现关闭功能 3、完整源码包下载3.1百度网盘3.2 123云盘3.3邮箱留…

基于若依ruoyi-nbcio支持flowable流程增加自定义业务表单(二)

更多ruoyi-nbcio功能请看演示系统 gitee源代码地址 前后端代码: https://gitee.com/nbacheng/ruoyi-nbcio 演示地址:RuoYi-Nbcio后台管理系统 之前讲了自定义业务表单,现在讲如何与流程进行关联 1、后端部分 WfCustomFormMapper.xml &…

【SoC FPGA】HPS启动过程

SoC HPS启动流程 Boot ROMPreloaderBoot Loader HPS的启动是一个多阶段的过程,每一个阶段都会完成对应的工作并且将下一个阶段的执行代码引导起来。每个阶段均负责加载下一个阶段。第一个软件阶段是引导 ROM,引导 ROM 代码查找并且执行称为预加载器的第 …

数据挖掘实战(3):如何对比特币走势进行预测?

⭐️⭐️⭐️⭐️⭐️欢迎来到我的博客⭐️⭐️⭐️⭐️⭐️ 🐴作者:秋无之地 🐴简介:CSDN爬虫、后端、大数据领域创作者。目前从事python爬虫、后端和大数据等相关工作,主要擅长领域有:爬虫、后端、大数据…

203、RabbitMQ 之 使用 direct 类型的 Exchange 实现 消息路由 (RoutingKey)

目录 ★ 使用direct实现消息路由代码演示这个情况二ConstantUtil 常量工具类ConnectionUtil 连接RabbitMQ的工具类Publisher 消息生产者测试消息生产者 Consumer01 消息消费者01测试消费者结果: Consumer02 消息消费者02测试消费者结果: 完整代码&#x…

机器学习(22)---信息熵、纯度、条件熵、信息增益

文章目录 1、信息熵2、信息增益3、例题分析 1、信息熵 1. 信息熵(information entropy)是度量样本集合纯度最常用的一种指标。信息的混乱程度越大,不确定性越大,信息熵越大;对于纯度,就是信息熵越大,纯度越低。 2. 纯度…

Text-to-SQL小白入门(八)RLAIF论文:AI代替人类反馈的强化学习

学习RLAIF论文前,可以先学习一下基于人类反馈的强化学习RLHF,相关的微调方法(比如强化学习系列RLHF、RRHF、RLTF、RRTF)的论文、数据集、代码等汇总都可以参考GitHub项目:GitHub - eosphoros-ai/Awesome-Text2SQL: Cur…

论文阅读/写作扫盲

第一节:期刊科普 JCR分区和中科院分区是用于对期刊进行分类和评估的两种常见方法。它们的存在是为了帮助学术界和研究人员更好地了解期刊的学术质量、影响力和地位。 JCR分区(Journal Citation Reports):JCR分区是由Clarivate Ana…

android U广播详解(一)

概念介绍 进程队列 BroadcastQueueModernImpl 的设计围绕着为设备上的每个潜在进程维护一个单独的 BroadcastProcessQueue 实例。表明用于传送到特定进程的Pending {link BroadcastRecord} 条目队列。整个类都标记为 {code NotThreadSafe},因为调用者有责任始终与…