【深度学习实验】卷积神经网络(七):实现深度残差神经网络ResNet

目录

一、实验介绍

二、实验环境

1. 配置虚拟环境

2. 库版本介绍

三、实验内容

0. 导入必要的工具包

1. Residual(残差连接)

__init__(初始化)

forward(前向传播)

2. resnet_block(残差网络块)

3. ResNet(网络模型)

__init__(初始化)

forward(前向传播)

4. 代码整合


一、实验介绍

        本实验实现了实现深度残差神经网络ResNet

        残差网络(ResNet)是一种深度神经网络架构,用于解决深层网络训练过程中的梯度消失和梯度爆炸问题。通过引入残差连接(residual connection)来构建网络层与层之间的跳跃连接,使得网络可以更好地优化深层结构。

        残差网络的一个重要应用是在图像识别任务中,特别是在深度卷积神经网络(CNN)中。通过使用残差模块,可以构建非常深的网络,例如ResNet,其在ILSVRC 2015图像分类挑战赛中取得了非常出色的成绩。

        在ResNet中,每个残差块由一个或多个卷积层组成,其中包含了跳跃连接。跳跃连接将输入直接添加到残差块的输出中,从而使得网络可以学习残差函数,即残差块只需学习将输入的变化部分映射到输出,而不需要学习完整的映射关系。这种设计有助于减轻梯度消失问题,使得网络可以更深地进行训练。

二、实验环境

        本系列实验使用了PyTorch深度学习框架,相关操作如下:

1. 配置虚拟环境

conda create -n DL python=3.7 
conda activate DL
pip install torch==1.8.1+cu102 torchvision==0.9.1+cu102 torchaudio==0.8.1 -f https://download.pytorch.org/whl/torch_stable.html
conda install matplotlib
 conda install scikit-learn

2. 库版本介绍

软件包本实验版本目前最新版
matplotlib3.5.33.8.0
numpy1.21.61.26.0
python3.7.16
scikit-learn0.22.11.3.0
torch1.8.1+cu1022.0.1
torchaudio0.8.12.0.2
torchvision0.9.1+cu1020.15.2

三、实验内容

0. 导入必要的工具包

from torch import nn
import torch.nn.functional as F

1. Residual(残差连接)

class Residual(nn.Module):def __init__(self, input_channels, num_channels, use_1x1conv=False, strides=1):super().__init__()self.conv1 = nn.Conv2d(input_channels, num_channels, kernel_size=3, padding=1, stride=strides)self.conv2 = nn.Conv2d(num_channels, num_channels, kernel_size=3, padding=1)if use_1x1conv:self.conv3 = nn.Conv2d(input_channels, num_channels, kernel_size=1, stride=strides)else:self.conv3 = None# 批量归一化层,将会在第7章讲到self.bn1 = nn.BatchNorm2d(num_channels)self.bn2 = nn.BatchNorm2d(num_channels)def forward(self, X):Y = F.relu(self.bn1(self.conv1(X)))Y = self.bn2(self.conv2(Y))if self.conv3:X = self.conv3(X)Y += Xreturn F.relu(Y)

__init__(初始化)

  • 参数:
    • 输入通道数`input_channels`
    • 输出通道数`num_channels`
    • 是否使用1x1卷积`use_1x1conv`
    • 步幅`strides`
  • 在初始化过程中,创建了两个卷积层`conv1`和`conv2`,分别使用不同的输入和输出通道数,并指定了卷积核的大小、填充和步幅。
  • 如果`use_1x1conv`为True,则创建一个1x1卷积层`conv3`,用于进行维度匹配;
  • 否则,将`conv3`设为None。
  • 创建两个批量归一化层`bn1`和`bn2`,用于对卷积层的输出进行批量归一化操作。

forward(前向传播)

  • 将输入`X`通过`conv1`进行卷积操作,然后经过批量归一化层`bn1`和ReLU激活函数。
  • 将输出通过`conv2`进行卷积操作,再经过批量归一化层`bn2`。
  • 如果`conv3`不为None,则将输入`X`通过`conv3`进行卷积操作,用于进行维度匹配。
  • 最后,将经过卷积和批量归一化的结果与输入相加,得到残差连接的输出。
  • 通过ReLU激活函数处理输出,并返回结果。

2. resnet_block(残差网络块)

        生成由多个残差块组成的残差网络块。

def resnet_block(input_channels, num_channels, num_residuals, first_block=False):blk = []for i in range(num_residuals):if i == 0 and not first_block:blk.append(Residual(input_channels, num_channels,use_1x1conv=True, strides=2))else:blk.append(Residual(num_channels, num_channels))return blk
  • 参数
    • input_channels:输入通道数,即每个残差块的输入的通道数。
    • num_channels:每个残差块中卷积层的输出通道数,也是每个残差块内部卷积层的通道数。
    • num_residuals:残差块的数量。
    • first_block:一个布尔值,表示是否为整个 ResNet 中的第一个残差块。
  • 创建一个空列表 blk,用于存储构建的残差块。
  • 通过一个循环迭代 num_residuals 次,每次迭代都构建一个残差块并将其添加到 blk 列表中。
    • 在每个迭代中,首先检查是否为第一个残差块且 first_block 为 False。
      • 如果是,则创建一个具有下采样(strides=2)的残差块,并将其添加到 blk 列表中。这是为了在整个 ResNet 中的第一个残差块中进行下采样。
      • 如果不是第一个残差块或者 first_block 为 True,则创建一个普通的残差块,并将其添加到 blk 列表中。
  • 返回构建好的残差块列表 blk

3. ResNet(网络模型

        ResNet 网络模型,包含了多个残差块,用于实现图像分类任务。

class ResNet(nn.Module):def __init__(self, num_classes):super().__init__()self.b1 = nn.Sequential(nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3),nn.BatchNorm2d(64), nn.ReLU(),nn.MaxPool2d(kernel_size=3, stride=2, padding=1))self.b2 = nn.Sequential(*resnet_block(64, 64, 2, first_block=True))self.b3 = nn.Sequential(*resnet_block(64, 128, 2))self.b4 = nn.Sequential(*resnet_block(128, 256, 2))self.b5 = nn.Sequential(*resnet_block(256, 512, 2))self.head = nn.Sequential(nn.AdaptiveAvgPool2d((1, 1)), nn.Flatten(), nn.Linear(512, num_classes))def forward(self, x):net = nn.Sequential(self.b1, self.b2, self.b3, self.b4, self.b5, self.head)return net(x)

__init__(初始化)

  • 参数:
    • num_classes,表示分类的类别数目
  • 调用父类的构造函数 `super().__init__()`。
  • self.b1是一个包含了卷积层、批归一化层、ReLU激活函数和最大池化层的序列。它对输入数据进行卷积操作,然后进行批归一化、ReLU激活和最大池化,用于提取输入图像的特征。
    • nn.Conv2d,使用 7x7 的卷积核对输入进行卷积操作,输出通道数为 64,步长为 2,填充为 3。
    • nn.BatchNorm2d 层,用于进行批归一化操作。
    •  ReLU 激活函数层 nn.ReLU()。
    • nn.MaxPool2d`层,使用 3x3 的池化核进行最大池化操作,步长为 2,填充为 1。
  • self.b2self.b3self.b4self.b5分别是几个残差块(resnet_block)的序列。这些残差块包含了卷积层、批归一化层和ReLU激活函数,用于进一步提取输入数据的特征。
    • self.b2使用构建了 2 个残差块,输入通道数为 64,输出通道数也为 64,并且指定 `first_block=True`,表示它是第一个残差块;
    • ……
  • self.head是一个包含自适应平均池化层(AdaptiveAvgPool2d)、展平层(Flatten)和全连接层(Linear)的序列。它将输入数据进行自适应平均池化,然后展平为一维向量,并通过全连接层将特征映射到分类的类别数目上:
    • 自适应平均池化层nn.AdaptiveAvgPool2d:将输入的特征图池化为大小为 1x1 的特征图。
    • 展平层nn.Flatten,将池化后的特征图展平成一维向量。
    • 全连接层nn.Linear,将展平后的特征映射到输出类别的数量。

forward(前向传播)

        输入数据通过上述序列模块self.b1self.b2self.b3self.b4self.b5self.head进行处理,最终输出分类结果

4. 代码整合

# 导入必要的工具包
from torch import nn
import torch.nn.functional as F#  残差连接, 输入和输出的维度有时是相同的, 有时是不同的, 所以需要 use_1x1conv来判断是否需要
class Residual(nn.Module):def __init__(self, input_channels, num_channels, use_1x1conv=False, strides=1):super().__init__()self.conv1 = nn.Conv2d(input_channels, num_channels, kernel_size=3, padding=1, stride=strides)self.conv2 = nn.Conv2d(num_channels, num_channels, kernel_size=3, padding=1)if use_1x1conv:self.conv3 = nn.Conv2d(input_channels, num_channels, kernel_size=1, stride=strides)else:self.conv3 = None# 批量归一化层,将会在第7章讲到self.bn1 = nn.BatchNorm2d(num_channels)self.bn2 = nn.BatchNorm2d(num_channels)def forward(self, X):Y = F.relu(self.bn1(self.conv1(X)))Y = self.bn2(self.conv2(Y))if self.conv3:X = self.conv3(X)Y += Xreturn F.relu(Y)# 残差网络是由几个不同的残差块组成的
def resnet_block(input_channels, num_channels, num_residuals, first_block=False):blk = []for i in range(num_residuals):if i == 0 and not first_block:blk.append(Residual(input_channels, num_channels,use_1x1conv=True, strides=2))else:blk.append(Residual(num_channels, num_channels))return blkclass ResNet(nn.Module):def __init__(self, num_classes):super().__init__()self.b1 = nn.Sequential(nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3),nn.BatchNorm2d(64), nn.ReLU(),nn.MaxPool2d(kernel_size=3, stride=2, padding=1))self.b2 = nn.Sequential(*resnet_block(64, 64, 2, first_block=True))self.b3 = nn.Sequential(*resnet_block(64, 128, 2))self.b4 = nn.Sequential(*resnet_block(128, 256, 2))self.b5 = nn.Sequential(*resnet_block(256, 512, 2))self.head = nn.Sequential(nn.AdaptiveAvgPool2d((1, 1)), nn.Flatten(), nn.Linear(512, num_classes))def forward(self, x):net = nn.Sequential(self.b1, self.b2, self.b3, self.b4, self.b5, self.head)return net(x)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/132817.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

Kotlin函数作为参数指向不同逻辑

Kotlin函数作为参数指向不同逻辑 fun sum(): (Int, Int) -> Int {return { a, b -> (a b) } }fun multiplication(): (Int, Int) -> Int {return { a, b -> (a * b) } }fun main(args: Array<String>) {var math: (Int, Int) -> Intmath sum()println(m…

Unity可视化Shader工具ASE介绍——6、通过例子说明ASE节点的连接方式

大家好&#xff0c;我是阿赵。继续介绍Unity可视化Shader编辑插件ASE的用法。上一篇已经介绍了很多ASE常用的节点。这一篇通过几个小例子&#xff0c;来看看这些节点是怎样连接使用的。   这篇的内容可能会比较长&#xff0c;最终是做了一个遮挡X光的效果&#xff0c;不过把这…

python随手小练5

1、求1-100的累加和&#xff08;终止条件 1-100&#xff09;&#xff08;while和for两种&#xff09; #while循环 count 0 index 0 while index < 100:count indexindex 1 print(count)#for循环 sum 0 for i in range(0,101):sum i print(sum)结果&#xff1a; 5050 2…

拓扑排序求最长路

P1807 最长路 - 洛谷 | 计算机科学教育新生态 (luogu.com.cn) 题目要求我们求出第1号到第n号节点之间最长的距离。 我们想到使用拓扑排序来求最长路。 正常来讲&#xff0c;我们应该把1号节点入队列&#xff0c;再出队列&#xff0c;把一号节点能到达的所有的点的入度减一&a…

oracle connect by详解

1、作用&#xff1a; 用于存在父子&#xff0c;祖孙&#xff0c;上下级等层级关系的数据表进行层级查询。 2、语法 SELECT ... FROM .... START WITH cond1 CONNECT BY cond2 WHERE cond3;2.1、说明 start with: 指定起始节点的条件 connect by: 指定父子行的条件关系 …

PyTorch 深度学习之加载数据集Dataset and DataLoader(七)

1. Revision: Manual data feed 全部Batch&#xff1a;计算速度&#xff0c;性能有问题 1 个 &#xff1a;跨越鞍点 mini-Batch:均衡速度与性能 2. Terminology: Epoch, Batch-Size, Iteration DataLoader: batch_size2, sheffleTrue 3. How to define your Dataset 两种处…

Verilog功能模块——同步FIFO

前言 FIFO功能模块分两篇文章&#xff0c;本篇为同步FIFO&#xff0c;另一篇为异步FIFO&#xff0c;传送门&#xff1a; Verilog功能模块——异步FIFO-CSDN博客 同步FIFO实现起来是异步FIFO的简化版&#xff0c;所以&#xff0c;本博文不再介绍FIFO实现原理&#xff0c;感兴趣…

Java面试题-0919

集合篇 Java面试题-集合篇HashMap底层实现原理概述javaSE进阶-哈希表 为了满足hashmap集合的不重复存储&#xff0c;为什么要重写hashcode和equals方法&#xff1f; 首先理解一下hashmap的插入元素的前提&#xff1a; hashmap会根据元素的hashcode取模进行比较&#xff0c;当…

【Java 进阶篇】创建 HTML 注册页面

在这篇博客中&#xff0c;我们将介绍如何创建一个简单的 HTML 注册页面。HTML&#xff08;Hypertext Markup Language&#xff09;是一种标记语言&#xff0c;用于构建网页的结构和内容。创建一个注册页面是网页开发的常见任务之一&#xff0c;它允许用户提供个人信息并注册成为…

Unity ToLua热更框架使用教程(1)

从本篇开始将为大家讲解ToLua在unity当中的使用教程。 Tolua的框架叫LuaFramework&#xff0c;首先附上下载链接&#xff1a; https://github.com/jarjin/LuaFramework_UGUI_V2 这个地址的是UGUI的。 下载完之后导入项目&#xff0c;首先&#xff0c;我们要先让这个项目跑起…

域渗透04-漏洞(CVE-2020-1472)

Netlogon协议&#xff1a; 想了解CVE-2020-1472&#xff0c;我们首先必须要了解Netlogon协议是什么&#xff1a; Netlogon 远程协议是 Windows 域控制器上可用的 RPC 接口。它用于与用户和计算机身份验证相关的各种任务&#xff0c;最常见的是方便用户使用 NTLM 协议登录到服务…

【数据结构】二叉树--链式结构的实现 (遍历)

目录 一 二叉树的遍历 1 构建一个二叉树 2 前序遍历 3 中序遍历 4 后续遍历 5 层序 6 二叉树销毁 二 应用(递归思想) 1 二叉树节点个数 2 叶子节点个数 3 第K层的节点个数 4 二叉树查找值为x的节点 5 判断是否是二叉树 一 二叉树的遍历 学习二叉树结构&#xff0…