7.3 详解NiN模型--首次使用多层感知机(1x1卷积核)替换掉全连接层的模型

一.前提知识

多层感知机:由一个输入层,一个或多个隐藏层和一个输出层组成。(至少有一个隐藏层,即至少3层)

全连接层:是MLP的一种特殊情况,每个节点都与前一层的所有节点连接,全连接层可以解决线性可分问题,无法学习到非线性特征。(只有输入和输出层)

二.NiN模型特点

NiN与过去模型的区别:AlexNet和VGG对LeNet的改进在于如何扩大加深这两个模块。他们都使用了全连接层,使用全连接层就可能完全放弃表征的空间结构。
NiN放弃了使用全连接层,而是使用两个1x1卷积层(将空间维度中的每个像素视为单个样本,将通道维度视为不同特征。),相当于在每个像素的通道上分别使用多层感知机

优点:NiN去除了全连接层,可以减少过拟合,同时显著减少NiN的参数数量

三.模型架构

在这里插入图片描述

四.代码

import torch
from torch import nn
from d2l import torch as d2l
import time
def nin_block(in_channels,out_channels,kernel_size,strides,padding):return nn.Sequential(# 卷积层nn.Conv2d(in_channels,out_channels,kernel_size,strides,padding),nn.ReLU(),# 两个带有ReLU激活函数的 1x1卷积层nn.Conv2d(out_channels,out_channels,kernel_size=1),nn.ReLU(),nn.Conv2d(out_channels,out_channels,kernel_size=1),nn.ReLU())
net = nn.Sequential(nin_block(1,96,kernel_size=11,strides=4,padding=0),nn.MaxPool2d(3,stride=2),nin_block(96,256,kernel_size=5,strides=1,padding=2),nn.MaxPool2d(3,stride=2),nin_block(256,384,kernel_size=3,strides=1,padding=1),nn.MaxPool2d(3,stride=2),nn.Dropout(0.5),# 标签类别是10nin_block(384,10,kernel_size=3,strides=1,padding=1),# 二维自适应平均池化,不用指定池化窗口大小nn.AdaptiveAvgPool2d((1,1)),# 将(样本,通道,w,h) = (批量,10,1,1),四维的输出转成2维的输出,其形状为(批量大小,10)nn.Flatten()
)
X = torch.rand(size=(1,1,224,224))
for layer in net:X = layer(X)print(layer.__class__.__name__,'output shape:\t',X.shape)
Sequential output shape:	 torch.Size([1, 96, 54, 54])
MaxPool2d output shape:	 torch.Size([1, 96, 26, 26])
Sequential output shape:	 torch.Size([1, 256, 26, 26])
MaxPool2d output shape:	 torch.Size([1, 256, 12, 12])
Sequential output shape:	 torch.Size([1, 384, 12, 12])
MaxPool2d output shape:	 torch.Size([1, 384, 5, 5])
Dropout output shape:	 torch.Size([1, 384, 5, 5])
Sequential output shape:	 torch.Size([1, 10, 5, 5])
AdaptiveAvgPool2d output shape:	 torch.Size([1, 10, 1, 1])
Flatten output shape:	 torch.Size([1, 10])

六.不同参数训练结果

学习率是0.1的情况

# 训练模型
lr,num_epochs,batch_size = 0.1,10,128
train_iter,test_iter = d2l.load_data_fashion_mnist(batch_size,resize=224)
d2l.train_ch6(net,train_iter,test_iter,num_epochs,lr,d2l.try_gpu())

在这里插入图片描述

学习率是0.05的情况(提升了6个点)

'''开始计时'''
start_time = time.time()
# 训练模型
lr,num_epochs,batch_size = 0.05,10,128
train_iter,test_iter = d2l.load_data_fashion_mnist(batch_size,resize=224)
d2l.train_ch6(net,train_iter,test_iter,num_epochs,lr,d2l.try_gpu())
'''计时结束'''
end_time = time.time()
run_time = end_time - start_time
# 将输出的秒数保留两位小数
print(f'{round(run_time,2)}s')

在这里插入图片描述

学习率为0.01,批次等于30的情况(反而下降了)

在这里插入图片描述

思考

为什么NiN块中有两个1x1卷积层?

从NiN替换掉全连接层,使用多层感知机角度来说:
因为1个1x1卷基层相当于全连接层,两个1x1卷积层使输入和输出层中间有了隐藏层,才相当于多层感知机。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/66230.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

面试题:说说vue2的生命周期函数?说说vue3的生命周期函数?说说vue2和vue3的生命周期函数对比?

说说vue2的生命周期函数?说说vue3的生命周期函数?说说vue2和vue3的生命周期函数对比? 一、说说vue2的生命周期函数1.1 vue生命周期分为四个阶段、8个钩子1.1.1 beforeCreate 和 created 初始化阶段1.1.2 beforeMount 和 mounted 挂载阶段1.1.…

bigemap如何添加四维地图?

工具 Bigemap gis office地图软件 BIGEMAP GIS Office-全能版 Bigemap APP_卫星地图APP_高清卫星地图APP 打开软件,要提示需要授权和添加地图,需要授权可以联系客服处理,然后点击选择地图这个按钮,列表中有个添加按钮点进去选择…

提高 After Effects 效率的 40 个最佳快捷键

After Effects 是运动图形和视觉效果的强大工具,但它也可能让人不知所措。拥有如此多的特性和功能,很容易让人迷失在软件中。但是,有一种方法可以简化您的工作流程并提高工作效率 - 使用键盘快捷键。 After Effects素材文件巨大、占用电脑内…

快递管理系统springboot 寄件物流仓库java jsp源代码mysql

本项目为前几天收费帮学妹做的一个项目,Java EE JSP项目,在工作环境中基本使用不到,但是很多学校把这个当作编程入门的项目来做,故分享出本项目供初学者参考。 一、项目描述 快递管理系统springboot 系统有1权限:管…

android 开发中常用命令

1.反编译 命令&#xff1a;apktool d <test.apk> -o <folderdir> 其中&#xff1a;test.apk是待反编译文件的路径&#xff0c;folderdir是反编译后的文件的存储位置。 apktool d -f <test.apk> -o <folderdir> 注意&#xff1a;如果dir已经存在&am…

【Java】2021 RoboCom 机器人开发者大赛-高职组(初赛)题解

7-1 机器人打招呼 机器人小白要来 RoboCom 参赛了&#xff0c;在赛场中遇到人要打个招呼。请你帮它设置好打招呼的这句话&#xff1a;“ni ye lai can jia RoboCom a?”。 输入格式&#xff1a; 本题没有输入。 输出格式&#xff1a; 在一行中输出 ni ye lai can jia Robo…

快手商品详情数据API 抓取快手商品价格、销量、库存、sku信息

快手商品详情数据API是用来获取快手商品详情页数据的接口&#xff0c;请求参数为商品ID&#xff0c;这是每个商品唯一性的标识。返回参数有商品标题、商品标题、商品简介、价格、掌柜昵称、库存、宝贝链接、宝贝图片、商品SKU等。 接口名称&#xff1a;item_get 公共参数 名…

单例模式-java实现

介绍 单例模式的意图&#xff1a;保证某个类在系统中有且仅有一个实例。 我们可以看到下面的类图&#xff1a;一般的单例的实现&#xff0c;是属性中保持着一个自己的私有静态实例引用&#xff0c;还有一个私有的构造方法&#xff0c;然后再开放一个静态的获取实例的方法给外界…

多线程与并发编程面试题总结

多线程与并发编程 多线程 线程和进程的区别&#xff1f; 从操作系统层面上来讲&#xff1a;进程(process)在计算机里有单独的地址空间&#xff0c;而线程只有单独的堆栈和局部内存空间&#xff0c;线程之间是共享地址空间的&#xff0c;正是由于这个特性&#xff0c;对于同…

php如何对接伪原创api

在了解伪原创api的各种应用形态之后&#xff0c;我们继续探讨智能写作背后的核心技术。需要说明的是&#xff0c;智能写作和自然语言生成、自然语言理解、知识图谱、多模算法等各类人工智能算法都有紧密的关联&#xff0c;在百度的智能写作实践中&#xff0c;常根据实际需求将多…

Vue [Day7]

文章目录 自定义创建项目ESlint 代码规范vuex 概述创建仓库向仓库提供数据使用仓库中的数据通过store直接访问通过辅助函数 mapState&#xff08;简化&#xff09;mutations传参语法(同步实时输入&#xff0c;实时更新辅助函数 mapMutationsaction &#xff08;异步辅助函数map…

【electron】electron安装过慢和打包报错:Unable to load file:

文章目录 一、安装过慢问题:二、打包报错&#xff1a;Unable to load file: 一、安装过慢问题: 一直处于安装过程 【解决】 #修改npm的配置文件 npm config edit#添加配置 electron_mirrorhttps://cdn.npm.taobao.org/dist/electron/二、打包报错&#xff1a;Unable to load…