深入解析 ResNet:实现与原理

news/2024/11/20 16:40:55/文章来源:https://www.cnblogs.com/crazypigf/p/18558694

ResNet(Residual Network,残差网络)是深度学习领域中的重要突破之一,由 Kaiming He 等人在 2015 年提出。其核心思想是通过引入残差连接(skip connections)来缓解深层网络中的梯度消失问题,使得网络可以更高效地训练,同时显著提升了深度网络的性能。

本文以一个 ResNet 的简单实现为例,详细解析其工作原理、代码结构和设计思想,并介绍 ResNet 的发展背景和改进版本。


背景与动机

随着网络深度的增加,传统深层神经网络面临以下问题:

  1. 梯度消失与梯度爆炸: 在网络传播过程中,梯度逐层衰减或爆炸,使得深层网络难以有效训练。
  2. 退化问题: 增加网络深度并不一定带来更高的准确率,反而可能导致训练误差增大。

为了应对这些挑战,ResNet 提出了残差学习框架,通过学习输入与输出之间的残差来简化优化过程。


残差块 (Residual Block)

设计思想

在 ResNet 中,一个基本的单元是残差块。假设希望拟合一个目标映射H(x),ResNet 将其重新表述为:

\[H(x) = F(x) + x \]

其中:

  • F(x) 是要学习的残差函数。
  • x 是输入,直接通过快捷连接(shortcut connection)传递到输出。

这种设计可以让网络更容易优化,因为相比直接学习 H(x),学习 F(x)通常更容易。


代码实现

以下是一个标准的残差块实现:

class ResBlock(nn.Module):def __init__(self, in_channels, out_channels, stride=1):super().__init__()self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1)self.bn1 = nn.BatchNorm2d(out_channels)self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)self.bn2 = nn.BatchNorm2d(out_channels)self.relu = nn.ReLU(inplace=True)self.downsample = None# 当输入和输出维度不匹配时,添加一个卷积层以调整维度if in_channels != out_channels or stride != 1:self.downsample = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride)self.downsample_bn = nn.BatchNorm2d(out_channels)def forward(self, x):residual = xout = self.conv1(x)out = self.bn1(out)out = self.relu(out)out = self.conv2(out)out = self.bn2(out)if self.downsample is not None:residual = self.downsample(x)residual = self.downsample_bn(residual)out += residualout = self.relu(out)return out

核心部分解析:

  1. 卷积操作:
    • 使用两个3 \(\times\) 3的卷积核,提取特征。
    • 通过批归一化 (BatchNorm) 稳定训练。
  2. 残差连接:
    • 当输入和输出通道数一致时,直接加和。
    • 若通道数或尺寸不同,则通过1 \(\times\) 1卷积调整形状。
  3. 激活函数:
    • 使用 ReLU 函数,增加非线性。

ResNet 网络结构

ResNet 由多个残差块堆叠而成,不同版本的 ResNet 使用的块数和通道数不同。以下是一个简化的 ResNet 实现:

class ResNet(nn.Module):def __init__(self, num_classes=10):super().__init__()# 输入图像尺寸为 28 x 28self.block1 = ResBlock(3, 64)# 输出 28 x 28self.block2 = ResBlock(64, 128, stride=2)# 输出 14 x 14self.block3 = ResBlock(128, 256, stride=2)# 输出 7 x 7self.block4 = ResBlock(256, 512, stride=2)# 输出 4 x 4self.block5 = ResBlock(512, 1024, stride=2)# 输出 2 x 2self.block6 = ResBlock(1024, 2048, stride=2)# 输出 1 x 1self.fc = nn.Linear(2048, num_classes)def forward(self, x):x = self.block1(x)x = self.block2(x)x = self.block3(x)x = self.block4(x)x = self.block5(x)x = self.block6(x)x = x.view(x.size(0), -1)x = self.fc(x)return x

网络结构说明:

  1. 输入为 28 \(\times\) 28 的图像,通过 6 个残差块提取特征。
  2. 每次通过残差块,通道数增加,空间尺寸减少一半。
  3. 最后通过全连接层实现分类。

ResNet 的优势

  1. 解决梯度问题: 残差连接使得梯度能够直接传递到前层,有效缓解了梯度消失问题。
  2. 更深的网络: ResNet-50 和 ResNet-152 等深度版本大大提升了性能,广泛用于图像分类、目标检测等任务。
  3. 模块化设计: 残差块设计简单,可扩展性强。

总结

本文通过代码实现和理论讲解,深入解析了 ResNet 的核心思想和设计细节。ResNet 是深度学习领域的重要里程碑,其提出的残差学习框架为训练深层网络提供了有效的方法。随着 ResNet 的不断发展,它在各种任务中依然表现强劲,是经典的深度学习模型之一。

通过理解 ResNet 的原理和实现,我们不仅可以灵活应用现有的网络架构,还能为创新和改进深度网络提供思路。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/837394.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

批量解除 此文件来自其他计算机,可能被阻止以帮助保护该计算机

下载微软工具 - Streams https://learn.microsoft.com/en-us/sysinternals/downloads/ streams -s -d D:/file留待后查,同时方便他人 联系我:renhanlinbsl@163.com

使用linq查询报错English Message : Join a needs to be the same as OrderBy it

可以使用 .Select 和 .MergeTable() 将多表结果集变成单表:这样问题就可以解决了

【淘汰9成NLP工程师的常识题】LSTM的前向计算如何进行加速?

【淘汰9成NLP工程师的常识题】LSTM的前向计算如何进行加速? 重要性:★★★ 💯 这是我常用的【淘汰9成NLP工程师的常识题】LSTM的前向计算如何进行加速? 重要性:★★★ 💯这是我常用的一个面试题。看似简单的基础题,但在面试中能准确回答的不足10% ,常识题的错误反而会…

【论文阅读笔记】多模态大语言模型必读 —— LLaVA

LLaVA (Large Language and Vision Assistant),proposed by Haotian Liu (UWM), et al.论文地址:https://arxiv.org/abs/2304.08485 代码地址:https://github.com/haotian-liu/LLaVA目录简介Visual Instruction 数据生成视觉指令微调模型架构训练 简介 人类对于世界的认知是…

接口文档和编写接口测试用例

一、熟悉接口文档和分析接口 1、发送接口文档 2、分析接口文档 3、了解需要测试接口,分析需求文档接口请求参数:接口返回参数:成功整理接口:(自己项目有哪些借款) cms项目接口:查询接口,登录接口,添加用户接口,用户管理接口,文章管理接口,删除用户接口,删除栏目接…

python代码实现RNN, LSTM, GRU

安装torch, transformers, loguru(本代码实现为下方版本,其余版本实现可比葫芦画瓢自行摸索)pip install torch==1.13.1 transformers==4.44.1 numpy==1.26.4 loguru -i https://pypi.tuna.tsinghua.edu.cn/simple/RNN:Recurrent Neural Network,网络结构如下图所示:import nu…

ChatGPT国内中文版镜像网站整理合集(2024/11/20)

ChatGPT 镜像站的用途 镜像站(Mirror Site)ChatGPT镜像网站是指通过复制原始网站内容和结构,创建的备用网站。其主要目的是在原始网站无法访问时,提供相同或类似的服务和信息。​ 一、ChatGPT中文镜像站 ① yixiaai.com 支持4o以及o1,支持MJ绘画 ② chat.lify.vip 支持通用…

鸿蒙NEXT开发案例:随机数生成

【引言】 本项目是一个简单的随机数生成器应用,用户可以通过设置随机数的范围和个数,并选择是否允许生成重复的随机数,来生成所需的随机数列表。生成的结果可以通过点击“复制”按钮复制到剪贴板。 【环境准备】 • 操作系统:Windows 10• 开发工具:DevEco Studio NEXT Be…

13、优化器_(执行计划、统计信息)_1

执行计划 一个SQL文本,经过解析,经过解析之后,oracle发现有很多种执行方案,然后oracle在这多种执行方案中,选出一种oracle认为最优的一种执行方案,来作为执行计划,然后oracle按照执行计划一步步去执行 因为oracle有多种的执行方案,但是,有的执行方案快,有的执行方案慢…

12、表的访问方式(索引)_2

表的访问方式 以t1表为例来看表的访问方式 首先创建了一个用户,建立了一张表t1,按照object_id列排序的: SQL> create user u1 identified by u1; -- 创建用户u1 User created.SQL> grant connect,resource,dba to u1; -- 给u1授权 Grant succeeded.SQL> conne…