实现pytorch版的mobileNetV1

mobileNet具体细节,在前面已做了分析记录:轻量化网络-MobileNet系列-CSDN博客

这里是根据网络结构,搭建模型,用于图像分类任务。

1. 网络结构和基本组件

2. 搭建组件

(1)普通的卷积组件:CBL = Conv2d + BN + ReLU6;

(2)深度可分离卷积:DwCBL  = Conv dw+ Conv dp;

Conv dw+ Conv dp = {Conv2d(3x3) + BN + ReLU6 }  + {Conv2d(1x1) + BN + ReLU6};

Conv dw是3x3的深度卷积,通过步长控制是否进行下采样;

Conv dp是1x1的逐点卷积,通过控制输出通道数,控制通道维度的变化;

# 普通卷积
class CBN(nn.Module):def __init__(self, in_c, out_c, stride=1):super(CBN, self).__init__()self.conv = nn.Conv2d(in_c, out_c, 3, stride, padding=1, bias=False)self.bn = nn.BatchNorm2d(out_c)self.relu = nn.ReLU6(inplace=True)def forward(self, x):x = self.conv(x)x = self.bn(x)x = self.relu(x)return x
# 深度可分离卷积: 深度卷积(3x3x1) + 逐点卷积(1x1xc卷积)
class DwCBN(nn.Module):def __init__(self, in_c, out_c, stride=1):super(DwCBN, self).__init__()# conv3x3x1, 深度卷积,通过步长,只控制是否缩小特征hwself.conv3x3 = nn.Conv2d(in_c, in_c, 3, stride, padding=1, groups=in_c, bias=False)self.bn1 = nn.BatchNorm2d(in_c)self.relu1 = nn.ReLU6(inplace=True)# conv1x1xc, 逐点卷积,通过控制输出通道数,控制通道维度的变化self.conv1x1 = nn.Conv2d(in_c, out_c, 1, stride=1, padding=0, bias=False)self.bn2 = nn.BatchNorm2d(out_c)self.relu2 = nn.ReLU6(inplace=True)def forward(self, x):x = self.conv3x3(x)x = self.bn1(x)x = self.relu1(x)x = self.conv1x1(x)x = self.bn2(x)x = self.relu2(x)return x

3. 搭建网络

class MobileNetV1(nn.Module):def __init__(self, class_num=1000):super(MobileNetV1, self).__init__()self.stage1 = torch.nn.Sequential(CBN(3, 32, 2),  # 下采样/2DwCBN(32, 64, 1))self.stage2 = torch.nn.Sequential(DwCBN(64, 128, 2),  # 下采样/4DwCBN(128, 128, 1))self.stage3 = torch.nn.Sequential(DwCBN(128, 256, 2),  # 下采样/8DwCBN(256, 256, 1))self.stage4 = torch.nn.Sequential(DwCBN(256, 512, 2),  # 下采样/16DwCBN(512, 512, 1),  # 5个DwCBN(512, 512, 1),DwCBN(512, 512, 1),DwCBN(512, 512, 1),DwCBN(512, 512, 1),)self.stage5 = torch.nn.Sequential(DwCBN(512, 1024, 2),  # 下采样/32DwCBN(1024, 1024, 1))# classifierself.avg_pooling = torch.nn.AdaptiveAvgPool2d((1, 1))self.fc = torch.nn.Linear(1024, class_num, bias=True)# self.classifier = torch.nn.Softmax()  # 原始的softmax值# torch.log_softmax 首先计算 softmax 然后再取对数,因此在数值上更加稳定。# 在分类网络在训练过程中,通常使用交叉熵损失函数(Cross-Entropy Loss)。# torch.nn.CrossEntropyLoss 会在内部进行 softmax 操作,因此在网络的最后一层不需要手动加上 softmax 操作。def forward(self, x):scale1 = self.stage1(x)  # /2scale2 = self.stage2(scale1)scale3 = self.stage3(scale2)scale4 = self.stage4(scale3)scale5 = self.stage5(scale4)  # /32. 7x7x = self.avg_pooling(scale5)  # (b,1024,7,7)->(b,1024,1,1)x = torch.flatten(x, 1)  # (b,1024,1,1)->(b,1024,)x = self.fc(x)  # (b,1024,)  -> (b,1000,)return xif __name__ == '__main__':m1 = MobileNetV1(class_num=1000)input_data = torch.randn(64, 3, 224, 224)output = m1.forward(input_data)print(output.shape)

待续。。。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/326593.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

通天星CMSV6车载视频监控平台 SQL注入漏洞复现

0x01 产品简介 通天星CMSV6车载视频监控平台是东莞市通天星软件科技有限公司研发的监控平台,通天星CMSV6产品覆盖车载录像机、单兵录像机、网络监控摄像机、行驶记录仪等产品的视频综合平台。通天星科技应用于公交车车载、校车车载、大巴车车载、物流车载、油品运输车载、警车…

构建自己的私人GPT

创作不易,请大家多鼓励支持。 在现实生活中,很多人的资料是不愿意公布在互联网上的,但是我们又要使用人工智能的能力帮我们处理文件、做决策、执行命令那怎么办呢?于是我们构建自己或公司的私人GPT变得非常重要。 一、本地部署…

Java8内置四大核心函数式接口

先来看几个例子,主要练习策略模式: 用策略模式的做法 定义个接口 其实像这样的接口并不需要我们自己创建 java8推出的Lambda表达式主要就是为了简化开发,而Lambda表达式 的应用主要是针对与函数式接口,自然也推出了对应的一些接口 /*** Java8 内置的四大核心函数式接口** C…

Opencv(C++)学习之cv::calcHist 任意bin数量进行直方图计算

**背景:**当前网上常见的直方图使用方法都是默认使用256的范围,而对于使用特定范围的直方图方法讲的不够清楚。仔细研究后总结如下: 1、常见使用方法,直接对灰度图按256个Bin进行计算。 Mat mHistUn; int channels[1] { 0 }; {…

HarmonyOS 应用开发学习笔记 状态管理概述

移动端开发,最重要的一点就是数据的处理,并且正确的显示渲染UI。 变量在页面和组件、组件和组件之间有时候并不能实时共享,而有时候,又不需要太多的作用域(节省资源),作用就需要根据不同场景&am…

【linux】更改infiniband卡在Debian系统的网络接口名

在Debian或任何其他基于Linux的系统中,网络接口的名称由udev系统管理。通过创建udev规则,可以修改网络接口名称。以下是更改InfiniBand卡接口名称的一般步骤: 1. 找到网络接口的属性,以编写匹配的udev规则 可以使用udevadm命令查…

​三子棋(c语言)

前言: 三子棋是一种民间传统游戏,又叫九宫棋、圈圈叉叉棋、一条龙、井字棋等。游戏规则是双方对战,双方依次在9宫格棋盘上摆放棋子,率先将自己的三个棋子走成一条线就视为胜利。但因棋盘太小,三子棋在很多时候会出现和…

STM32 JLINK SWD调试器手动复位才能烧写的问题

STM32 JLINK SWD调试器手动复位才能烧写的问题 Chapter1 STM32 JLINK SWD调试器手动复位才能烧写的问题 Chapter1 STM32 JLINK SWD调试器手动复位才能烧写的问题 原文链接:https://blog.csdn.net/denghuajing/article/details/121649667 问题 只有手动复位的情况下…

github 好项目 之 reference

github项目地址 网页网址 点进去以后你可以看到很多关于技术前沿的东西的简单笔记,一些实践的代码,或者是一些快捷键的命令 我个人比较喜欢 latex 的数学公式笔记 以及关于 vim 的一些命令 还有我最喜欢的git命令

安徽省暨合肥市“希望工程·梦想计划”小盖茨机器人捐赠启动仪式举行

1月5日,安徽省暨合肥市“希望工程梦想计划”小盖茨机器人捐赠启动仪式在合肥市一六八玫瑰园学校东校区举行。共青团安徽省委副书记叶征,北京儒布特教育科技有限公司董事牛俊明,北京儒布特教育科技有限公司市场总监高进,安徽省青基…

vue3.x 的生命周期和钩子函数

什么是生命周期 Vue 是组件化编程,从一个组件诞生到消亡,会经历很多过程,这些过程就叫做生命周期。 你理解了什么是生命周期,你还了解一个概念“钩子函数”。 钩子函数: 伴随着生命周期,给用户使用的函数&…

CMU15-445-Spring-2023-Project #1 - 前置知识(lec01-04)

Lecture #01_ Relational Model & Relational Algebra Databases 数据库是相互关联的数据的有组织集合,对现实世界的某些方面进行建模。区别于DBMS(MySQL、Oracle)。 Flat File Strawman 数据库以CSV文件的形式存储,并由D…