快速入门Torch构建自己的网络模型

真有用构建自己的网络模型

    • 读前必看
    • 刚学完Alex网络感觉很厉害的样子,我也要搭建一个
    • 可以看着网络结构实现上面的代码你已经很强了,千万不要再想实现VGG等网络!!!90%你能了解到的模型大佬早已实现好,直接调用就OK
    • 下面是源码用nn.Module实现的AlexNet,和我们实现的区别并不大,将模型print出来能看懂就可以
    • 不忘初心,构建自己的网络模型,将AlexNet输入改为单通道图片:
    • Tips

读前必看

  • 如何用框架复现论文中的模型不重要,重要的是明白网络模型原理!!!
  • 如何用框架复现论文中的模型不重要,重要的是明白网络模型原理!!!
  • 如何用框架复现论文中的模型不重要,重要的是明白网络模型原理!!!

刚学完Alex网络感觉很厉害的样子,我也要搭建一个

在这里插入图片描述

回想一下torch构建网络的几种方法

  • nn.Sequential直接顺序实现
  • nn.Module继承基类构建自定义模型
feature = nn.Sequential(nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2),nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=3, stride=2),nn.Conv2d(64, 192, kernel_size=5, padding=2),nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=3, stride=2),nn.Conv2d(192, 384, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.Conv2d(384, 256, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.Conv2d(256, 256, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=3, stride=2),
)

现在需要计算卷积后图像的维度,根据公式 image_shape = (image_shape - kernel_size + 2 * padding) / stride + 1计算

in_shape= 224
conv_size = [11, 5, 3, 3, 3]
padding_size = [2, 2, 1, 1, 1]
stride_size = [4, 1, 1, 1, 1]
# image_shape = (image_shape - kernel_size + 2 * padding) / stride + 1
for i in range(len(conv_size)):in_shape = (in_shape - conv_size[i] + 2 * padding_size[i]) / stride_size[i] + 1in_shape = math.floor(in_shape)if i in [0, 1, 4]:in_shape = (in_shape - 3 + 2 * 0) / 2 + 1in_shape = math.floor(in_shape)
print(in_shape)

计算结果是6,输出通道是256,所以特征有25666个,将下面代码添加到Sequential中完成自定义AlexNet构建

nn.Linear(256 * 6 * 6, 4096),
nn.ReLU(inplace=True),
nn.Dropout(p=dropout),
nn.Linear(4096, 4096),
nn.ReLU(inplace=True),
nn.Linear(4096, num_classes)

可以看着网络结构实现上面的代码你已经很强了,千万不要再想实现VGG等网络!!!90%你能了解到的模型大佬早已实现好,直接调用就OK

下面是源码用nn.Module实现的AlexNet,和我们实现的区别并不大,将模型print出来能看懂就可以

class AlexNet(nn.Module):def __init__(self, num_classes: int = 1000, dropout: float = 0.5) -> None:super().__init__()# _log_api_usage_once(self)self.features = nn.Sequential(nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2),nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=3, stride=2),nn.Conv2d(64, 192, kernel_size=5, padding=2),nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=3, stride=2),nn.Conv2d(192, 384, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.Conv2d(384, 256, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.Conv2d(256, 256, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=3, stride=2),)self.avgpool = nn.AdaptiveAvgPool2d((6, 6))self.classifier = nn.Sequential(nn.Dropout(p=dropout),nn.Linear(256 * 6 * 6, 4096),nn.ReLU(inplace=True),nn.Dropout(p=dropout),nn.Linear(4096, 4096),nn.ReLU(inplace=True),nn.Linear(4096, num_classes),)def forward(self, x: torch.Tensor) -> torch.Tensor:x = self.features(x)x = self.avgpool(x)x = torch.flatten(x, 1)x = self.classifier(x)return x

不忘初心,构建自己的网络模型,将AlexNet输入改为单通道图片:

model = AlexNet()
model.features[0] = nn.Conv2d(1, 64, kernel_size=11, stride=4, padding=2)
print(model)

Tips

Q1: padding是卷积之后还是卷积之前还是卷积之后实现的?
padding是在卷积之前补0,如果愿意的话,可以通过使用torch.nn.Functional.pad来补非0的内容。

Q2:padding补0的默认策略是什么?
四周都补!如果pad输入是一个tuple的话,则第一个参数表示高度上面的padding,第2个参数表示宽度上面的

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/412328.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Failed to load class org.slf4j.impl.StaticLoggerBinder

Failed to load class org.slf4j.impl.StaticLoggerBinder 问题描述问题分析解决方案1解决方案2 问题描述 在使用Slf4J的时候发现报错了,日志一直都是使用了slf4j-api、slf4j-log4j12、log4j这三个包结合起来使用,新搭建了一个项目,然后创建了…

adb 常用命令汇总

目录 adb 常用命令 1、显示已连接的设备列表 2、进入设备 3、安装 APK 文件到设备 4、卸载指定包名的应用 5、从设备中复制文件到本地 6、将本地文件复制到设备 7、查看设备日志信息 8、重启设备 9、截取设备屏幕截图 10、屏幕分辨率 11、屏幕密度 12、显示设备的…

vuex前端开发,getters是什么?怎么调用?简单的案例操作

vuex前端开发,getters是什么?怎么调用?简单的案例操作! 下面通过一些简单的案例,来了解一下,vuex当中的getters到底是什么意思,有哪些实际的操作案例。 Vuex的getters主要用于对store中的state进行计算或过…

IntelliJ IDEA - 快速去除 mapper.xml 告警线和背景(三步走)

1、去掉 No data sources configure 警告 Settings(Ctrl Alt S) ⇒ Editor ⇒ Inspections ⇒ SQL ⇒ No data sources configure 2、去掉 SQL dialect is not configured 警告 Settings(Ctrl Alt S) ⇒ Editor ⇒ Inspecti…

macOS向ntfs格式的移动硬盘写数据

最近想把日常拍摄的照片从SD存储卡中转存到闲置的移动硬盘中,但是转存的时候发现,mac只能读我硬盘里的东西,无法将数据写入到移动硬盘中,也无法删除移动硬盘的数据。后来在网上查了许久资料,终于可实现mac对移动硬盘写…

Spring中动态注册和销毁对象

1. 使用说明 通常我们项目中想要往spring容器中注入一个bean可以在项目初始化的时候结合Bean注解实现。但是该方法适合项目初始化时候使用,如果后续想要继续注入对象则无可奈何。本文主要描述一种在后续往spring容器注入bean的方法。 2. 实现 2.1 说明 2.1.1 注册…

三棋先手必胜证明

目录 创作原因 游戏规则 初始状态图 证明过程 先手必胜的证明 失败的博弈树(三个多小时的成果) 创作原因 这个棋不是网上流行的成三棋,我也不知道这个棋叫什么。由于这个棋是(横竖斜)连成三个就获胜,…

Dubbo负载均衡解析

Dubbo负载均衡四件套 相比Ribbon负载均衡策略里的十八般兵器,Dubbo就显得低调的多了,它只提供了负载均衡四件套,让我们先来简单了解一下: 负载均衡策略底层算法RandomLoadBalance基于权重算法的负载均衡策略LeastActiveLoadBalance基于最少…

PXE——高效批量网络装机

目录 部署PXE远程安装服务 1.PXE概述 2.实现过程 3.实验操作 3.1安装dhcp、vsftpd、tftp-server.x86_64、syslinux服务 3.2修改配置文件——DHCP 3.3修改配置文件——TFTP 3.4kickstart——无人值守安装 3.4.1选择程序 3.4.2修改基础配置 3.4.3修改安装方法 3.4.4…

【多线程】认识Thread类及其常用方法

📄前言: 本文是对以往多线程学习中 Thread类 的介绍,以及对其中的部分细节问题进行总结。 文章目录 一. 线程的 创建和启动🍆1. 通过继承 Thread 类创建线程🍅2. 通过实现 Runnable 接口创建线程🥦3. 其他方…

【Python数据可视化】matplotlib之绘制常用图形:折线图、柱状图(条形图)、饼图和直方图

文章传送门 Python 数据可视化matplotlib之绘制常用图形:折线图、柱状图(条形图)、饼图和直方图matplotlib之设置坐标:添加坐标轴名字、设置坐标范围、设置主次刻度、坐标轴文字旋转并标出坐标值matplotlib之增加图形内容&#x…

vue实现 marquee(走马灯)

样式 代码 <div class"marquee-prompt"><div class"list-prompt" refboxPrompt><span v-for"item in listPrompt" :title"item" class"prompt">{{item}}</span></div> </div>data() {…