深度学习基础之参数量(3)

一般的CNN网络的参数量估计代码

class ResidualBlock(nn.Module):def __init__(self, in_planes, planes, norm_fn='group', stride=1):super(ResidualBlock, self).__init__()print(in_planes, planes, norm_fn, stride)self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, stride=stride)self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1)self.relu = nn.ReLU(inplace=True)num_groups = planes // 8if norm_fn == 'group':self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)if not stride == 1:self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)elif norm_fn == 'batch':self.norm1 = nn.BatchNorm2d(planes)self.norm2 = nn.BatchNorm2d(planes)if not stride == 1:self.norm3 = nn.BatchNorm2d(planes)elif norm_fn == 'instance':self.norm1 = nn.InstanceNorm2d(planes)self.norm2 = nn.InstanceNorm2d(planes)if not stride == 1:self.norm3 = nn.InstanceNorm2d(planes)elif norm_fn == 'none':self.norm1 = nn.Sequential()self.norm2 = nn.Sequential()if not stride == 1:self.norm3 = nn.Sequential()if stride == 1:self.downsample = Noneelse:self.downsample = nn.Sequential(nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3)def forward(self, x):print(x.shape)#exit()y = xy = self.relu(self.norm1(self.conv1(y)))y = self.relu(self.norm2(self.conv2(y)))if self.downsample is not None:x = self.downsample(x)return self.relu(x + y)R=ResidualBlock(384, 384, norm_fn='instance', stride=1)
summary(R.to("cuda" if torch.cuda.is_available() else "cpu"), (384, 32, 32))

transformer结构的参数量的估计结果

import torch
import torch.nn as nn
from thop import profile
from torchsummary import summary# 定义一个简单的Transformer模型
class Transformer(nn.Module):def __init__(self, input_dim, hidden_dim, num_heads, num_layers):super(Transformer, self).__init__()self.embedding = nn.Embedding(input_dim, hidden_dim)self.transformer_layers = nn.Transformer(d_model=hidden_dim,nhead=num_heads,num_encoder_layers=num_layers,num_decoder_layers=num_layers)self.fc = nn.Linear(hidden_dim, input_dim)def forward(self, src, tgt):src = self.embedding(src)tgt = self.embedding(tgt)output = self.transformer_layers(src, tgt)output = self.fc(output)return output# 创建Transformer模型实例
model2 = Transformer(input_dim=512, hidden_dim=512, num_heads=8, num_layers=6)# 使用thop进行FLOPS估算
flops, params = profile(model2, inputs=(torch.randint(0, 512, (128,)), torch.randint(0, 512, (64,))))
print(f"FLOPS: {flops / 1e9} G FLOPS")  # 打印FLOPS,以十亿FLOPS(GFLOPS)为单位# 计算参数量并打印
num_params = sum(p.numel() for p in model2.parameters() if p.requires_grad)
print(f"Total number of trainable parameters: {num_params}")

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/125613.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Python综合案例:学生管理系统

目录 需求说明: 功能: 创建入口函数: 实现菜单函数: 实现增删查操作: 1. 新增学生 2. 展示学生 3. 查找学生 4. 删除学生 加入存档读档: 1. 约定存档格式 2. 实现存档函数 3. 实现读档函数 打…

mysql双主互从通过KeepAlived虚拟IP实现高可用

mysql双主互从通过KeepAlived虚拟IP实现高可用 在mysql 双主互从的基础上, 架构图: Keepalived有两个主要的功能: 提供虚拟IP,实现双机热备通过LVS,实现负载均衡 安装 # 安装 yum -y install keepalived # 卸载 …

全志ARM926 Melis2.0系统的开发指引⑥

全志ARM926 Melis2.0系统的开发指引⑥ 编写目的9. 系统启动流程9.1. Shell 部分9.2.Orange 和 desktop 部分9.3. app_root 加载部分9.4. home 加载部分 10. 显示相关知识概述10.1. 总体结构10.2. 显示过程10.3. 显示宽高参数关系 -. 全志相关工具和资源-.1 全志固件镜像修改工具…

【开发篇】十五、Spring Task实现定时任务

文章目录 1、使用示例2、相关配置3、Scheduled注解4、Spring Task单线程下的阻塞坑5、Spring Task阻塞问题的处理思路6、Spring Task在分布式环境中 上一篇用Quartz来实现了定时任务,但相对来说,这个框架还是比较繁琐。Spring Boot默认在无任何第三方依赖…

minikube如何设置阿里云镜像以及如何解决dashboard无法打开的解决方案_已设置图床

minikube如何设置阿里云镜像以及如何解决dashboard无法打开的解决方案 minikube dashboard报错 considerconsider-Dell-G15-5511:~$ minikube dashboard 🤔 正在验证 dashboard 运行情况 ... 🚀 正在启动代理... 🤔 正在验证 proxy 运行…

2023/9/27 -- ARM

【汇编语言相关语法】 1.汇编语言的组成部分 1.伪操作:不参与程序的执行,但是用于告诉编译器程序该怎么编译 .text .global .end .if .else .endif .data2.汇编指令 编译器将一条汇编指令编译成一条机器码,在内存里一条指令占4字节内…

【C++设计模式之原型模式:创建型】分析及示例

简介 原型模式(Prototype Pattern)是一种创建型设计模式,它允许通过复制已有对象来生成新的对象,而无需再次使用构造函数。 描述 原型模式通过复制现有对象来创建新的对象,而无需显式地调用构造函数或暴露对象的创建…

013-第二代上位机开发环境搭建

第二代上位机开发环境搭建 文章目录 第二代上位机开发环境搭建项目介绍虚拟机安装Debian 10文件传输远程调试VNCrsync下载安装验证 配置远程调试环境配置远程设备配置 kitsCompilers配置Qtversions配置kits 测试 总结一下 关键字: Qt、 Qml、 关键字3、 关键字4…

Python常用功能的标准代码

后台运行并保存log 1 2 3 4 5 6 7 8 9 nohup python -u test.py > test.log 2>&1 & #最后的&表示后台运行 #2 输出错误信息到提示符窗口 #1 表示输出信息到提示符窗口, 1前面的&注意添加, 否则还会创建一个名为1的文件 #最后会把日志文件输出到test.log文…

八大排序算法汇总(C语言实现)

本专栏内容为:八大排序汇总 通过本专栏的深入学习,你可以了解并掌握八大排序以及相关的排序算法。 💓博主csdn个人主页:小小unicorn ⏩专栏分类:八大排序汇总 🚚代码仓库:小小unicorn的代码仓库…

智慧公厕:将科技融入日常生活的创新之举

智慧公厕是当今社会中一项备受关注的创新项目。通过将科技融入公厕设计和管理中,这些公厕不仅能够提供更便利、更卫生的使用体验,还能够极大地提升城市形象和居民生活质量。本文将以智慧公厕领先厂家广州中期科技有限公司,大量的精品案例项目…

React项目部署 - Nginx配置

写在前面:博主是一只经过实战开发历练后投身培训事业的“小山猪”,昵称取自动画片《狮子王》中的“彭彭”,总是以乐观、积极的心态对待周边的事物。本人的技术路线从Java全栈工程师一路奔向大数据开发、数据挖掘领域,如今终有小成…