分类神经网络2:ResNet模型复现

目录

ResNet网络架构

ResNet部分实现代码


ResNet网络架构

论文原址:https://arxiv.org/pdf/1512.03385.pdf

残差神经网络(ResNet)是由微软研究院的何恺明、张祥雨、任少卿、孙剑等人提出的,通过引入残差学习解决了深度网络训练中的退化问题,同时该网络结构特点主要表现为3点,超深的网络结构(超过1000层)、提出residual(残差结构)模块和使用Batch Normalization 加速训练(丢弃dropout)。ResNet网络的模型结构如下:

ResNet网络通过“跳跃连接”,将靠前若干层的某一层数据输出直接跳过多层引入到后面数据层的输入部分。即神经网络学习到该层的参数是冗余的时候,它可以选择直接走这条“跳接”曲线(shortcut connection),跳过这个冗余层,而不需要再去拟合参数使得H(x)=F(x)=x。同时通过这种连接方式不仅保护了信息的完整性(避免卷积层堆叠存在的信息丢失),整个网络也只需要学习输入、输出差别的部分,这克服了由于网络深度加深而产生的学习效率变低与准确率无法有效提升的问题(梯度消失或梯度爆炸)。

残差模块如下图示:

残差结构有两种,常规残差和瓶颈残差常规残差:由2个3x3卷积层堆叠而成,当输入和输出维度一致时,可以直接将输入加到输出上,这相当于简单执行了同等映射,不会产生额外的参数,也不会增加计算复杂度(随着网络深度的加深,这种残差模块在实践中并不十分有效);瓶颈残差:依次由1x1 、3x3 、1x1个卷积层构成,这里1x1卷积,能够对通道数channel起到升维或降维的作用,从而令3x3 的卷积,以相对较低维度的输入进行卷积运算,提高计算效率。

ResNet网络的具体配置信息如下:

在构建神经网络时,首先采用了步长为2的卷积层进行图像尺寸缩减,即下采样操作,紧接着是多个残差结构,在网络架构的末端,引入了一个全局平均池化层,用于整合特征信息,最后是一个包含1000个类别的全连接层,并在该层后应用了softmax激活函数以进行多分类任务。值得注意的是,通过引入残差连接模块,其最深的网络结构达到了152层,同时在50层后均使用的是瓶颈残差结构。

ResNet部分实现代码

老样子,直接上代码,建议大家阅读代码时结合网络结构理解

import torch
import torch.nn as nn__all__ = ['resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152']class ConvBNReLU(nn.Module):def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1):super(ConvBNReLU, self).__init__()self.conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding)self.bn = nn.BatchNorm2d(num_features=out_channels)self.relu = nn.ReLU(inplace=True)def forward(self, x):x = self.conv(x)x = self.bn(x)x = self.relu(x)return xclass ConvBN(nn.Module):def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1):super(ConvBN, self).__init__()self.conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding)self.bn = nn.BatchNorm2d(num_features=out_channels)def forward(self, x):x = self.conv(x)x = self.bn(x)return xdef conv3x3(in_channels, out_channels, stride=1, groups=1, dilation=1):"""3x3 convolution with padding:捕捉局部特征和空间相关性,学习更复杂的特征和抽象表示"""return nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride,padding=dilation, groups=groups, bias=False, dilation=dilation)def conv1x1(in_channels, out_channels, stride=1):"""1x1 convolution:实现降维或升维,调整通道数和执行通道间的线性变换"""return nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False)class BasicBlock(nn.Module):expansion = 1def __init__(self, in_channels, out_channels, stride=1, downsample=None):super(BasicBlock, self).__init__()self.convbnrelu1 = ConvBNReLU(in_channels, out_channels, kernel_size=3, stride=stride)self.convbn1 = ConvBN(out_channels, out_channels, kernel_size=3)self.relu = nn.ReLU(inplace=True)self.downsample = downsampleself.stride = strideself.conv_down = nn.Sequential(conv1x1(in_channels, out_channels * self.expansion, self.stride),nn.BatchNorm2d(out_channels * self.expansion),)def forward(self, x):residual = xout = self.convbnrelu1(x)out = self.convbn1(out)if self.downsample:residual = self.conv_down(x)out += residualout = self.relu(out)return outclass Bottleneck(nn.Module):expansion = 4def __init__(self, in_channels, out_channels, stride=1, downsample=None):super(Bottleneck, self).__init__()groups = 1base_width = 64dilation = 1width = int(out_channels * (base_width / 64.)) * groups   # wide = out_channelsself.conv1 = conv1x1(in_channels, width)       # 降维通道数self.bn1 = nn.BatchNorm2d(width)self.conv2 = conv3x3(width, width, stride, groups, dilation)self.bn2 = nn.BatchNorm2d(width)self.conv3 = conv1x1(width, out_channels * self.expansion)   # 升维通道数self.bn3 = nn.BatchNorm2d(out_channels * self.expansion)self.relu = nn.ReLU(inplace=True)self.downsample = downsampleself.stride = strideself.conv_down = nn.Sequential(conv1x1(in_channels, out_channels * self.expansion, self.stride),nn.BatchNorm2d(out_channels * self.expansion),)def forward(self, x):residual = xout = self.conv1(x)out = self.bn1(out)out = self.relu(out)out = self.conv2(out)out = self.bn2(out)out = self.relu(out)out = self.conv3(out)out = self.bn3(out)if self.downsample:residual = self.conv_down(x)out += residualout = self.relu(out)return outclass ResNet(nn.Module):def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,groups=1, width_per_group=64):super(ResNet, self).__init__()self.inplanes = 64self.dilation = 1replace_stride_with_dilation = [False, False, False]self.groups = groupsself.base_width = width_per_groupself.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False)self.bn1 = nn.BatchNorm2d(self.inplanes)self.relu = nn.ReLU(inplace=True)self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)self.layer1 = self._make_layer(block, 64, layers[0])self.layer2 = self._make_layer(block, 128, layers[1], stride=2,dilate=replace_stride_with_dilation[0])self.layer3 = self._make_layer(block, 256, layers[2], stride=2,dilate=replace_stride_with_dilation[1])self.layer4 = self._make_layer(block, 512, layers[3], stride=2,dilate=replace_stride_with_dilation[2])self.avgpool = nn.AdaptiveAvgPool2d((1, 1))self.fc = nn.Linear(512 * block.expansion, num_classes)for m in self.modules():if isinstance(m, nn.Conv2d):nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):nn.init.constant_(m.weight, 1)nn.init.constant_(m.bias, 0)# Zero-initialize the last BN in each residual branch,# so that the residual branch starts with zeros, and each residual block behaves like an identity.# This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677if zero_init_residual:for m in self.modules():if isinstance(m, Bottleneck):nn.init.constant_(m.bn3.weight, 0)elif isinstance(m, BasicBlock):nn.init.constant_(m.bn2.weight, 0)def _make_layer(self, block, planes, blocks, stride=1, dilate=False):downsample = Falseif dilate:self.dilation *= stridestride = 1if stride != 1 or self.inplanes != planes * block.expansion:downsample = Truelayers = nn.ModuleList()layers.append(block(self.inplanes, planes, stride, downsample))self.inplanes = planes * block.expansionfor _ in range(1, blocks):layers.append(block(self.inplanes, planes))return layersdef forward(self, x):x = self.conv1(x)x = self.bn1(x)x = self.relu(x)x = self.maxpool(x)for layer in self.layer1:x = layer(x)for layer in self.layer2:x = layer(x)for layer in self.layer3:x = layer(x)for layer in self.layer4:x = layer(x)x = self.avgpool(x)x = torch.flatten(x, 1)x = self.fc(x)return xdef resnet18(num_classes, **kwargs):return ResNet(BasicBlock, [2, 2, 2, 2], num_classes=num_classes, **kwargs)def resnet34(num_classes, **kwargs):return ResNet(BasicBlock, [3, 4, 6, 3], num_classes=num_classes, **kwargs)def resnet50(num_classes, **kwargs):return ResNet(Bottleneck, [3, 4, 6, 3], num_classes=num_classes, **kwargs)def resnet101(num_classes, **kwargs):return ResNet(Bottleneck, [3, 4, 23, 3], num_classes=num_classes, **kwargs)def resnet152(num_classes, **kwargs):return ResNet(Bottleneck, [3, 8, 36, 3], num_classes=num_classes, **kwargs)if __name__=="__main__":import torchsummarydevice = 'cuda' if torch.cuda.is_available() else 'cpu'input = torch.ones(2, 3, 224, 224).to(device)net = resnet50(num_classes=4)net = net.to(device)out = net(input)print(out)print(out.shape)torchsummary.summary(net, input_size=(3, 224, 224))# Total params: 134,285,380

希望对大家能够有所帮助呀!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/637760.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Nodejs - 异步I/O

异步I/O 利用单线程,远离多线程死锁,状态同步等问题,利用异步I/O, 让单线程原理阻塞,更好的使用cpu异步I/O实现现状 阻塞IO 操作系统内对于I/O只有两种方式: 阻塞和非阻塞。在调用阻塞I/O的时候,应用程序需…

能源成果3D网络三维展厅越发主流化

在这个数字化飞速发展的时代,我们为您带来了全新的展览形式——线上3D虚拟展厅。借助VR虚拟现实制作和web3d开发技术,我们能够将物品、图片、视频和图文信息等完美融合,通过计算机技术和3D建模,为您呈现一个逼真、生动的数字化展览…

Llama 3大模型发布!快速体验推理及微调

Meta,一家全球知名的科技和社交媒体巨头,在其官方网站上正式宣布了一款开源的大型预训练语言模型——Llama-3。 据了解,Llama-3模型提供了两种不同参数规模的版本,分别是80亿参数和700亿参数。这两种版本分别针对基础的预训练任务…

JDBC学习

DriverManager(驱动管理类) Drivermanager的作用有: 1.注册驱动; 2.获取数据库连接 Class.forName("com.mysql.cj.jdbc.Driver"); 这一行的作用就是注册Mysql驱动(把我们下载的jar包加载到内存里去&…

地图图源#ESRI ArcGIS XYZ Tiles系列(TMS)

目录 1、前言 2、地图图源网址 2.1、Satellite 卫星图源 2.2、Terrain 地形图源 2.3、Street 路网/标注图源 2.4、Specifity 特色设计图源 3、专业推荐”穿搭“ 4、图源配置下载及使用 图源名称图层类别特别注意谷歌 Google①地形 ②影像 ③矢量及标注 ④特色图源国内大…

【讲解下Spring Boot单元测试】

🌈个人主页: 程序员不想敲代码啊 🏆CSDN优质创作者,CSDN实力新星,CSDN博客专家 👍点赞⭐评论⭐收藏 🤝希望本文对您有所裨益,如有不足之处,欢迎在评论区提出指正,让我们共…

HotSpot JVM 中的应用程序/动态类数据共享

0.前言 本文的目的是详细讨论 HotSpot JVM 自 JDK 1.5 以来提供的一项功能,该功能可以减少启动时间,但如果在多个 JVM 之间共享相同的类数据共享 (CDS) 存档,则还可以减少内存占用。 1.类数据共享 (CDS) CDS 的想法是使用特定格式将预处理…

【LeetCode】187. 重复的DNA序列

题目链接:187. 重复的DNA序列 题目描述: 思路:首先要明白一个重要概念区别: 子串(substring): 原始字符串的一个连续子集。子序列(subsequence): 原始字符串的一个子集。 再注意题意是:长度为10的子串出现…

强固型工业电脑在码头智能化,龙门吊/流机车载电脑的行业应用

码头智能化行业应用 对码头运营来说,如何优化集装箱从船上到码头堆场到出厂区的各个流程以及达到提高效率。 降低成本的目的,是码头营运获利最重要的议题。为了让集装箱码头客户能够安心使用TOS系统来调度指挥码头上各种吊车、叉车、拖车和人员&#xf…

安居水站:理解烟瘾、探索戒烟动因及采用有效策略

摘要:本论文旨在深入探讨烟瘾的成因,分析戒烟的动因和好处,并提出一套科学的戒烟方法和具体策略。通过综合权威机构的研究文献,我们将为公众提供更全面、更深入的理解,从而推动更多人走上戒烟的道路,改善个…

鸿蒙OpenHarmony【轻量系统编写“Hello World”程序】 (基于Hi3861开发板)

编写“Hello World”程序 下方将通过修改源码的方式展示如何编写简单程序,输出“Hello world”。请在下载的源码目录中进行下述操作。 前提条件 已参考鸿蒙开发指导文档:gitee.com/li-shizhen-skin/harmony-os/blob/master/README.md点击或者复制转到…

力扣283. 移动零

Problem: 283. 移动零 文章目录 题目描述思路复杂度Code 题目描述 思路 1.定义一个int类型变量index初始化为0; 2.遍历nums当当前的元素nums[i]不为0时使nums[i]赋值给nums[index]; 3.从index开始将nums中置对应位置的元素设为0; 复杂度 时间…