NiNNet

目录

一、网络介绍

1、全连接层存在的问题

2、NiN的解决方案(NiN块)

3、NiN架构

4、总结

二、代码实现

1、定义NiN卷积块

2、NiN模型

3、训练模型


一、网络介绍

       NiN(Network in Network)是一种用于图像识别任务的卷积神经网络模型。它由谷歌研究员Min Lin、Qiang Chen和Shouyuan Chen于2013年提出。NiN的设计理念是通过引入“网络中的网络”结构来增强模型的表示能力。

1、全连接层存在的问题

       在之前的网络(比如AlexNet和VGGNet)后面都用了几个比较大的全连接层,全连接层中的参数相比于卷积层多得多,一个网络的参数大多都在全连接层,并且可以认为主要分布在卷积层之后的第一个全连接层。因此全连接层最大的问题是可能造成过拟合。

2、NiN的解决方案(NiN块)

       NiN的核心思想是使用1x1卷积层替代传统的全连接层。传统的卷积神经网络通常使用卷积层提取特征,然后通过全连接层进行分类。而NiN则在卷积层中引入了一种称为“1x1卷积”的操作,这个操作可以看作是在每个像素点上进行的全连接操作。通过使用1x1卷积,NiN能够在卷积层中引入非线性,增加模型的表达能力,并且减少了参数的数量。

       和VGG一样,NiN也有自己的块(NiN块),每一个NiN块其实就相当于一个小的神经网络(因为它具有卷积层和类似于全连接层的 $1 \times 1$ 卷积层),因此叫网络中的网络。NiN块首先有一个卷积层,然后后跟两个 $1 \times 1$ 的卷积层($1 \times 1$ 的卷积层等价于全连接层)。

3、NiN架构

全局池化层:池化层的高和宽等于输入的高和宽,一个通道得出一个值,用这个值当作对类别的预测。

4、总结

二、代码实现

       NiN的想法是将空间维度中的每个像素视为单个样本,将通道维度视为不同特征(feature)。下图说明了VGG和NiN及它们的块之间主要架构差异。NiN块以一个普通卷积层开始,后面是两个 $1 \times 1$ 的卷积层。NiN块第一层的卷积窗口形状通常由用户设置。随后的卷积窗口形状固定为 $1 \times 1$

1、定义NiN卷积块

import torch
from torch import nn
from d2l import torch as d2ldef nin_block(in_channels, out_channels, kernel_size, strides, padding):return nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size, strides, padding),nn.ReLU(),nn.Conv2d(out_channels, out_channels, kernel_size=1), nn.ReLU(),nn.Conv2d(out_channels, out_channels, kernel_size=1), nn.ReLU())

2、NiN模型

       最初的NiN网络是在AlexNet后不久提出的,显然从中得到了一些启示。NiN使用窗口形状为$11\times 11$$5\times 5$ 和 $3\times 3$ 的卷积层,输出通道数量与AlexNet中的相同。每个NiN块后有一个最大池化层,池化窗口形状为 $3\times 3$,步幅为2。

       NiN和AlexNet之间的一个显著区别是NiN完全取消了全连接层。相反,NiN使用一个个NiN块,最后一个NiN块的输出通道数等于标签类别的数量。最后放一个全局平均池化层(global average pooling layer),生成一个对数几率(logits)。NiN设计的一个优点是,它显著减少了模型所需参数的数量。然而,在实践中,这种设计有时会增加训练模型的时间。

net = nn.Sequential(nin_block(1, 96, kernel_size=11, strides=4, padding=0),nn.MaxPool2d(3, stride=2),nin_block(96, 256, kernel_size=5, strides=1, padding=2),nn.MaxPool2d(3, stride=2),nin_block(256, 384, kernel_size=3, strides=1, padding=1),nn.MaxPool2d(3, stride=2),nn.Dropout(0.5),# 标签类别数是10nin_block(384, 10, kernel_size=3, strides=1, padding=1),    # 通道数先增加后减少:1->96->256->384->10nn.AdaptiveAvgPool2d((1, 1)),   # 注意这里的(1, 1)不是kernel_size,而是output_size# 将四维的输出转成二维的输出,其形状为(批量大小, 10)nn.Flatten())   # Flatten会把channel、height和width展平成一行

       我们创建一个数据样本来查看每个块的输出形状。

X = torch.rand(size=(1, 1, 224, 224))
for layer in net:X = layer(X)print(layer.__class__.__name__,'output shape:\t', X.shape)
Sequential output shape:	 torch.Size([1, 96, 54, 54])
MaxPool2d output shape:	 torch.Size([1, 96, 26, 26])
Sequential output shape:	 torch.Size([1, 256, 26, 26])
MaxPool2d output shape:	 torch.Size([1, 256, 12, 12])
Sequential output shape:	 torch.Size([1, 384, 12, 12])
MaxPool2d output shape:	 torch.Size([1, 384, 5, 5])
Dropout output shape:	 torch.Size([1, 384, 5, 5])
Sequential output shape:	 torch.Size([1, 10, 5, 5])
AdaptiveAvgPool2d output shape:	 torch.Size([1, 10, 1, 1])
Flatten output shape:	 torch.Size([1, 10])

3、训练模型

       我们使用Fashion-MNIST来训练模型。训练NiN与训练AlexNet、VGG时相似。

lr, num_epochs, batch_size = 0.1, 10, 128
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size, resize=224) # 调节图片尺寸为224
d2l.train_ch6(net, train_iter, test_iter, num_epochs, lr, d2l.try_gpu())
loss 0.563, train acc 0.786, test acc 0.790
3087.6 examples/sec on cuda:0

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/294458.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Linux学习小结

目录结构 tree -L 1 / # /root #root用户的家目录 /home #存储普通用户家目录 lostfound #这个目录平时是空的,存储系统非正常关机而留下“无家可归”的文件 /usr #系统文件,相当于C:\Windows /usr/local #软件安装的目录,相当于C:\Progra…

跨境电商独立站深度分析演示网站

对于跨境电商卖家来说,多平台、多站点的布局是非常重要的战略。这样做可以规避”鸡蛋放在同一个篮子里”的风险也能够追求更高的销售额和利润。同时,市场的变化也带来了新的发展机会,因此很多出海企业都希望抓住独立站的新机遇,抢…

【华为数据之道学习笔记】6-4 打造数据供应的“三个1”

数据服务改变了传统的数据集成方式,所有数据都通过服务对外提供,用户不再直接集成数据,而是通过服务获取。因此,数据服务应该拉动数据供应链条的各个节点,以方便用户能准确地获取数据为重要目标。 数据供应到消费的完整…

【Linux笔记】文件和目录操作

🍎个人博客:个人主页 🏆个人专栏:Linux学习 ⛳️ 功不唐捐,玉汝于成 目录 前言 命令 ls (List): pwd (Print Working Directory): cp (Copy): mv (Move): rm (Remove): 结语 我的其他博客 前言 学习Linux命令…

开发知识点-HTML/JavaScript

HTML/JavaScript xlinksvgviewBoxuse基础预热与语法基础知识js 如何运行页面适用js 及输出 面向对象抽奖功能 json 支持 字符串转数组数组转字符串数组元素删除长度0位添加一个元素// 表示在下标为1处添加一项tttarray.splice(1,0,ttt)//[123,ttt,456]// 数组是否包含某个元素a…

PIC单片机项目(8)——基于PIC16F877A的温度光照检测装置的protues仿真

1.功能设计 使用PIC16F877A单片机,进行温度检测、光照检测。温度使用的是DS18B20,光照检测直接利用的AD转换。 光照太暗就开灯,温度太高就开风扇。温度阈值和光照阈值都实时显示在LCD1602屏幕上面。 完成了protues仿真。文件里面包含代码和仿…

blender径向渐变材质-着色编辑器

要点: 1、用纹理坐标中的物体输出连接映射中的矢量输入 2、物体选择一个空坐标,将空坐标延z轴上移一段距离 3、空坐标的大小要缩放到和要添加材质的物体大小保持一致

微前端样式隔离、sessionStorage、localStorage隔离

1、样式隔离 前端样式不隔离,会产生样式冲突的问题,这个点在qiankun也存在 子应用1修改一个样式 button {background: red!important; }其它应用也会受到影响 qiankun的css隔离方案(shadow dom) shadow …

华为 1+X 网络系统运维与建设中级实操模拟题

目 实验拓扑 配置中的注意事项:(针对新手) 实验目的 实验要求 实验步骤 一、搭建实验拓扑 二、配置主机名称 三、配置链路聚合 四、VLAN 配置 五、配置 RSTP 协议 六、配置 IP 地址 七、配置 VRRP 协议。 八、配置 OSPF 协议 九…

C语言中常用的sscanf函数

文章目录 1. 接受全部参数:2、分辨数字和字符3. 数字和字符一起会默认是字符4. 同时接收多个变量5. 指定长度的集合操作6. 排除部分字符 sscanf()定义于头文件stdio.h。sscanf()会将参数str的字符串根据参数format字符串来转换并格式化数据。格式转换形式请参考scan…

计算机网络——计算机网络的概述(一)

前言: 面对马上的期末考试,也为了以后找工作,需要掌握更多的知识,而且我们现实生活中也已经离不开计算机,更离不开计算机网络,今天开始我们就对计算机网络的知识进行一个简单的学习与记录。 目录 一、什么…

OpenCV4 工业缺陷检测的六种方法

文章目录 机器视觉缺陷检测工业上常见缺陷检测方法方法一:基于简单二值图像分析实现划痕提取,效果如下:方法二:复杂背景下的图像缺陷分析,基于频域增强的方法实现缺陷检测,运行截图:方法三&…