使用 TensorFlow 创建 DenseNet 121

一、说明

本篇示意DenseNet如何在tensorflow上实现,DenseNet与ResNet有类似的地方,都有层与层的“短路”方式,但两者对层的短路后处理有所不同,本文遵照原始论文的技术路线,完整复原了DenseNet的全部网络。

图1:DenseNet中的各种块和层(来源:原始DenseNet论文)

      

二、DenseNet综述

        DenseNet(密集卷积网络)是一种架构,专注于使深度学习网络更深入,但同时通过在层之间使用更短的连接来提高它们的训练效率。DenseNet 是一个卷积神经网络,其中每一层都连接到网络中更深的所有其他层,即第一层连接到第 2、3、4 层等,第二层连接到第 3、4、5 层等。这样做是为了在网络各层之间实现最大的信息流。为了保持前馈特性,每一层从前面的所有层获取输入,并将自己的特征图传递给它将要到达的所有层。与 Resnets 不同,它不是通过求和来组合特征,而是通过连接它们来组合特征。因此,“ith”层具有“i”输入,并且由其所有先前卷积块的特征图组成。它自己的特征图被传递到所有下一个“I-i”层。这在网络中引入了“(I(I+1)))/2”连接,而不是像传统深度学习架构中那样只是“I”连接。因此,与传统的卷积神经网络相比,它需要的参数更少,因为不需要学习不重要的特征图。

        DenseNet由两个重要的块组成,而不是基本的卷积层和池化层。它们是密集块和过渡层。

        接下来,我们看看所有这些块和层的外观,以及如何在 python 中实现它们。

图2:DenseNet121框架(来源:DenseNet原始论文,由作者编辑)

        DenseNet从基本的卷积和池化层开始。然后有一个密集块,然后是一个过渡层,另一个密集块后跟一个过渡层,另一个密集块后跟一个过渡层,最后是一个密集块,然后是一个分类层。

        第一个卷积块有 64 个大小为 7x7 的过滤器,步幅为 2。接下来是最大池化为 3x3 且步幅为 2 的 MaxPooling 层。这两行可以在 python 中用以下代码表示。

input = Input (input_shape)
x = Conv2D(64, 7, strides = 2, padding = 'same')(input)
x = MaxPool2D(3, strides = 2, padding = 'same')(x)

2.1 定义卷积块 —

        输入后的每个卷积块具有以下顺序:批处理归一化,然后是 ReLU 激活,然后是实际的 Conv2D 层。为了实现这一点,我们可以编写以下函数。

#batch norm + relu + conv
def bn_rl_conv(x,filters,kernel=1,strides=1):x = BatchNormalization()(x)x = ReLU()(x)x = Conv2D(filters, kernel, strides=strides,padding = 'same')(x)return x

图3.密集块(来源:DenseNet论文-作者编辑)

2.2 定义密集块 —

        如图 3 所示,每个密集块都有两个卷积,具有 1x1 和 3x3 大小的内核。在密集块 1 中,重复 6 次,在密集块 2 中重复 12 次,在密集块 3 中重复 24 次,最后在密集块 4 中重复 16 次。

在密集块中,每个 1x1 卷积都有 4 倍的滤波器数量。所以我们使用 4*过滤器,但 3x3 过滤器只存在一次。此外,我们必须将输入与输出张量连接起来。

每个块分别运行 6、12、24、16 次重复,使用 'for 循环'。

def dense_block(x, repetition):for _ in range(repetition):y = bn_rl_conv(x, 4*filters)y = bn_rl_conv(y, filters, 3)x = concatenate([y,x])return x

图4:过渡层(来源:DenseNet论文,作者编辑)

2.3 定义过渡层 

        — 在过渡层中,我们将通道数减少到现有通道的一半。有一个 1x1 卷积层和一个 2x2 平均池化层,步幅为 2。bn_rl_conv,函数中已经设置了 1x1 的内核大小,因此我们不需要明确地再次定义它。

        在过渡层中,我们必须将通道删除到现有通道的一半。我们有输入张量x,我们想找到有多少个通道,我们需要得到其中的一半。因此,我们可以使用 Keras 后端 (K) 获取张量 x 并返回一个维度为 x 的元组。而且,我们只需要该形状的最后一个数字,即过滤器的数量。所以我们加上 [-1]。最后,我们可以将这个数量的过滤器除以 2 以获得所需的结果。

def transition_layer(x):x = bn_rl_conv(x, K.int_shape(x)[-1] //2 )x = AvgPool2D(2, strides = 2, padding = 'same')(x)return x

        因此,我们完成了定义密集块和过渡层的工作。现在我们需要将密集块和过渡层堆叠在一起。所以我们写了一个 for 循环来运行 6,12,24,16 次重复。因此,循环运行 4 次,每次使用 6、12、24 或 16 中的值之一。这样就完成了 4 个密集块和过渡层。

for repetition in [6,12,24,16]:d = dense_block(x, repetition)x = transition_layer(d)

        最后,是GlobalAveragePooling,然后是最终的输出层。正如我们在上面的代码块中看到的,密集块由“d”定义,而在最后一层,在密集块 4 之后,没有过渡层 4,而是直接进入分类层。因此,“d”是应用GlobalAveragePooling的连接,而不是“x”。另一种选择是从上面的代码中删除“for”循环,并将层一个接一个地堆叠,而不使用最终的过渡层。

x = GlobalAveragePooling2D()(d)
output = Dense(n_classes, activation = 'softmax')(x)

现在我们已经将所有块放在一起,让我们将它们合并以查看整个DenseNet架构。

三、完整的 DenseNet 121 架构 

import tensorflow as tf
from tensorflow.keras.layers import Input, Conv2D, BatchNormalization, Dense
from tensorflow.keras.layers import AvgPool2D, GlobalAveragePooling2D, MaxPool2D
from tensorflow.keras.models import Model
from tensorflow.keras.layers import ReLU, concatenate
import tensorflow.keras.backend as K
# Creating Densenet121
def densenet(input_shape, n_classes, filters = 32):#batch norm + relu + convdef bn_rl_conv(x,filters,kernel=1,strides=1):x = BatchNormalization()(x)x = ReLU()(x)x = Conv2D(filters, kernel, strides=strides,padding = 'same')(x)return xdef dense_block(x, repetition):for _ in range(repetition):y = bn_rl_conv(x, 4*filters)y = bn_rl_conv(y, filters, 3)x = concatenate([y,x])return xdef transition_layer(x):x = bn_rl_conv(x, K.int_shape(x)[-1] //2 )x = AvgPool2D(2, strides = 2, padding = 'same')(x)return xinput = Input (input_shape)x = Conv2D(64, 7, strides = 2, padding = 'same')(input)x = MaxPool2D(3, strides = 2, padding = 'same')(x)for repetition in [6,12,24,16]:d = dense_block(x, repetition)x = transition_layer(d)x = GlobalAveragePooling2D()(d)output = Dense(n_classes, activation = 'softmax')(x)model = Model(input, output)return model
input_shape = 224, 224, 3
n_classes = 3
model = densenet(input_shape,n_classes)
model.summary()

输出:(假设 3 个最终类 — 模型摘要的最后几行)

四、 查看体系结构关系图 

        可以使用以下代码。

from tensorflow.python.keras.utils.vis_utils import model_to_dot
from IPython.display import SVG
import pydot
import graphvizSVG(model_to_dot(model, show_shapes=True, show_layer_names=True, rankdir='TB',expand_nested=False, dpi=60, subgraph=False
).create(prog='dot',format='svg'))

        输出 — 图表的前几个块

        这就是我们如何实现DenseNet 121架构。

五、引用 

  1. 黄高、刘壮、劳伦斯·范德马滕和基利安·温伯格,密集连接的卷积网络,arXiv 1608.06993 (2016)

    阿琼·萨卡尔

       2 密网论文链接:https://arxiv.org/pdf/1608.06993.pdf 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/127534.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

人体姿态标注

人体姿态标注 一 标注工具labelme1.1 安装方式1.2 界面说明 二 数据集准备以下每张图片的命名方式:状态_学号_序号.jpg (注意 一定是jpg格式) 保存到一个文件夹中,便于标注。 例如:FALL_0000_0001.jpg 站立数据(UP):不同方向&…

ASPICE标准快速掌握「2.2. 过程参考模型(Process Reference Model,PRM)」

ASPICE归纳了大量的历史经验,分门别类总结出了适用于所有项目的过程。并将所有过程依据过程类别进行分组,并根据他们所处的活动类别在过程组内进一步划分。总共有 3 个过程类别: 主要生命周期过程组织生命周期过程支持生命周期过程上面的每个过程类别都又往下细分为1-N个子过…

超自动化加速落地,助力运营效率和用户体验显著提升|爱分析报告

RPA、iPaaS、AI、低代码、BPM、流程挖掘等在帮助企业实现自动化的同时,也在构建一座座“自动化烟囱”。自动化工具尚未融为一体,协同价值没有得到释放。Gartner于2019年提出超自动化(Hyperautomation)概念,主要从技术组…

黑豹程序员-架构师学习路线图-百科:AJAX

文章目录 1、什么是AJAX2、发展历史3、工作原理4、一句话概括 1、什么是AJAX Ajax即Asynchronous(呃森可乐思) Javascript And XML(异步JavaScript和XML) 在 2005年被Jesse James Garrett(杰西詹姆斯加勒特&#xff09…

服务器文件备份

服务器上,做好跟应用程序有关的文件备份(一般备份到远程的盘符),有助于当服务器发生硬件等故障时,可以对系统进行进行快速恢复。 下面以Windows服务器为例,记录如何做文件的备份操作。 具体操作如下&#…

Vue中如何进行网页截图与截屏

在Vue中实现网页截图与截屏功能 网页截图与截屏功能在许多Web应用程序中都非常有用。Vue.js作为一个流行的JavaScript框架,提供了许多工具和库来简化网页截图和截屏的实现。本文将介绍如何使用Vue来实现一个网页截图和截屏功能的示例,包括使用html2canv…

短视频矩阵源码开发部署---技术解析

一、短视频SEO源码搜索技术需要考虑以下几点: 1. 关键词优化:通过研究目标受众的搜索习惯,选择合适的关键词,并在标题、描述、标签等元素中进行优化,提高视频的搜索排名。 2. 内容质量:优质、有吸引力的内…

DiffusionDet:第一个用于物体检测的扩散模型(DiffusionDet: Diffusion Model for Object Detection)

提出了一种新的框架——DiffusionDet,它将目标检测定义为一个从有噪声的盒子到目标盒子的去噪扩散过程。在训练阶段,目标盒从真实值盒扩散到随机分布,模型学会了逆转这个噪声过程。 在推理中,该模型以渐进的方式将一组随机生成的框…

铭控传感亮相2023国际物联网展,聚焦“多场景物联感知方案”应用

金秋九月,聚焦IoT基石技术,荟萃最全物联感知企业,齐聚IOTE 2023第20届国际物联网展深圳站。铭控传感携智慧楼宇,数字工厂,智慧消防,智慧泵房等多场景物联感知方案及多品类无线传感器闪亮登场,现…

基于 Kettle + StarRocks + FineReport 的大数据处理分析方案

Kettle StarRocks FineReport 的大数据处理分析方案 其中 Kettle 负责数据的ETL处理,StarRocks 负责海量数据的存储及检索,FineReport 负责数据的可视化展示。整体过程如下所示: 如果多上面三个组件不了解可以先参考下下面的文章&#xff…

李沐深度学习记录5:13.Dropout

Dropout从零开始实现 import torch from torch import nn from d2l import torch as d2l# 定义Dropout函数 def dropout_layer(X, dropout):assert 0 < dropout < 1# 在本情况中&#xff0c;所有元素都被丢弃if dropout 1:return torch.zeros_like(X)# 在本情况中&…

采集网页数据保存到文本文件---爬取古诗文网站

访问古诗文网站&#xff08;https://so.gushiwen.org/mingju/&#xff09; 会显示出这个页面&#xff0c;里面包含了很多的名句&#xff0c;点击某一个名句&#xff08;比如点击无处不伤心&#xff0c;轻尘在玉琴&#xff09;就会出现完整的古诗 我们点击鼠标右键&#xff0c;点…