CNN(六):ResNeXt-50实战

  •  🍨 本文为🔗365天深度学习训练营中的学习记录博客
  • 🍖 原作者:K同学啊|接辅导、项目定制

        ResNeXt是有何凯明团队在2017年CVPR会议上提出来的新型图像分类网络。它是ResNet的升级版,在ResNet的基础上,引入了cardinality的概念,类似于ResNet。ResNeXt也有ResNeXt-50,ResNeXt-101版本。

1 模型结构

        在ResNeXt的论文中,作者提出了当时普遍存在的一个问题,如果要提高模型的准确率,往往采取加深网络或者加宽网络的方法。虽然这种方法有效,但随之而来的,是网络设计的难度和计算开销的增加。为了一点精度的提升往往需要付出更大的代价,因此需要一个更好的策略,在不额外增加计算代价的情况下,提升网络的精度。由此,何等人提出了cardinality的概念。

        下面是ResNet(左)与ResNeXt(右)block的差异。在ResNet中,输入的具有256个通道的特征经过1x1卷积压缩4倍到64个通道,之后3x3的卷积核用于处理特征,经1x1卷积扩大通道数与原特征残差连接后输出。ResNeXt也是相同的处理策略,但在ResNeXt中,输入的具有256个通道的特征被分为32个组,每组被压缩64倍到4个通道后进行处理。32个组相加后与原特征残差连接后输出。这里cardinatity指的是一个block中所具有的相同分支的数目。

图1 ResNet和ResNeXt​​​​

2 分组卷积 

        ResNeXt中采用的分组卷积简单来说是将特征图分为不同的组,再对每组特征图分别进行卷积,这个操作可以有效的降低计算量。

        在分组卷积中,每个卷积核只处理部分通道,比如下图中,红色卷积核只处理红色的通道,绿色卷积核只处理绿色的通道,黄色卷积核只处理黄色通道。此时每个卷积核有2个通道,每个卷积核生成一张特征图。

图2 分组卷积示意图

        分组卷积的优势在于其参数开销,图3是其对比效果。

图3 标准卷积和分组卷积参数量对比

 3 代码实现

3.1 开发环境

电脑系统:ubuntu16.04

编译器:Jupter Lab

语言环境:Python 3.7

深度学习环境:tensorflow

 3.2 前期准备

3.2.1 设置GPU

import tensorflow as tfgpus = tf.config.list_physical_devices("GPU")if gpus:tf.config.experimental.set_memory_growth(gpus[0], True) # 设置GPU显存用量按需使用tf.config.set_visible_devices([gpus[0]], "GPU")

3.2.2 导入数据

import matplotlib.pyplot as plt
# 支持中文
plt.rcParams['font.sans-serif'] = ['SimHei'] # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False # 用来正常显示负号import os, PIL, pathlib
import numpy as npfrom tensorflow import keras
from tensorflow.keras import layers,modelsdata_dir = "../data/bird_photos"
data_dir = pathlib.Path(data_dir)image_count = len(list(data_dir.glob('*/*')))
print("图片总数为:", image_count)

3.2.3 加载数据

batch_size = 8
img_height = 224
img_width = 224train_ds = tf.keras.preprocessing.image_dataset_from_directory(data_dir,validation_split=0.2,subset="training",seed=123,image_size=(img_height, img_width),batch_size=batch_size)val_ds = tf.keras.preprocessing.image_dataset_from_directory(data_dir,validation_split=0.2,subset="validation",seed=123,image_size=(img_height, img_width),batch_size=batch_size)class_Names = train_ds.class_names
print("class_Names:",class_Names)

3.2.4 可视化数据

plt.figure(figsize=(10, 5)) # 图形的宽为10,高为5
plt.suptitle("imshow data")for images,labels in train_ds.take(1):for i in range(8):ax = plt.subplot(2, 4, i+1)plt.imshow(images[i].numpy().astype("uint8"))plt.title(class_Names[labels[i]])plt.axis("off")

3.2.5 检查数据

for image_batch, lables_batch in train_ds:print(image_batch.shape)print(lables_batch.shape)break

3.2.6 配置数据集

AUTOTUNE = tf.data.AUTOTUNEtrain_ds = train_ds.cache().shuffle(1000).prefetch(buffer_size=AUTOTUNE)
val_ds = val_ds.cache().prefetch(buffer_size=AUTOTUNE)

3.3 ResNeXt-50代码实现

3.3.1 分组卷积模块

import numpy as np
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Input, Dense, Dropout, Conv2D, MaxPool2D, Flatten, GlobalAvgPool2D, concatenate, \
BatchNormalization, Activation, Add, ZeroPadding2D, Lambda
from tensorflow.keras.layers import ReLU
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import LearningRateScheduler
from tensorflow.keras.models import Model# 定义分组卷积
def grouped_convolution_block(init_x, strides, groups, g_channels):group_list = []# 分组进行卷积for c in range(groups):# 分组取出数据x = Lambda(lambda x: x[:, :, :, c*g_channels:(c+1)*g_channels])(init_x)# 分组进行卷积x = Conv2D(filters=g_channels, kernel_size=(3,3), strides=strides, padding='same', use_bias=False)(x)# 存入listgroup_list.append(x)# 合并list中的数据group_merge = concatenate(group_list, axis=3)x = BatchNormalization(epsilon=1.001e-5)(group_merge)x = ReLU()(x)return x

3.3.2 残差单元

# 定义残差单元
def block(x, filters, strides=1, groups=32, conv_shortcut=True):if conv_shortcut:shortcut = Conv2D(filters*2, kernel_size=(1,1), strides=strides, padding='same', use_bias=False)(x)# epsilon位BN公式中防止分母为零的值shortcut = BatchNormalization(epsilon=1.001e-5)(shortcut)else:# identity_shortcutshortcut = x# 三层卷积层x = Conv2D(filters=filters, kernel_size=(1,1), strides=1, padding='same', use_bias=False)(x)x = BatchNormalization(epsilon=1.001e-5)(x)x = ReLU()(x)# 计算每组的通道数g_channels = int(filters / groups)# 进行分组卷积x = grouped_convolution_block(x, strides, groups, g_channels)x = Conv2D(filters=filters * 2, kernel_size=(1,1), strides=1, padding='same', use_bias=False)(x)x = BatchNormalization(epsilon=1.001e-5)(x)x = Add()([x, shortcut])x = ReLU()(x)return x

3.3.3 堆叠残差单元

        每个stack的第一个block的输入和输出shape是不一致的,所以残差连接都需要使用1x1卷积升维后才能进行Add操作。而其他block的输入和输出的shape是一致的,所以可以直接执行Add操作。

# 堆叠残差单元
def stack(x, filters, blocks, strides, groups=32):# 每个stack的第一个block的残差连接都需要使用1*1卷积升维x = block(x, filters, strides=strides, groups=groups)for i in range(blocks):x = block(x, filters, groups=groups, conv_shortcut=False)return x

3.3.4 搭建ResNeXt-50网络

# 定义ResNext50(32*4d)网络
def ResNext50(input_shape, num_classes):inputs = Input(shape=input_shape)# 填充3圈0,[224, 224, 3] -> [230, 230, 3]x = ZeroPadding2D((3,3))(inputs)x = Conv2D(filters=64, kernel_size=(7,7), strides=2, padding='valid')(x)x = BatchNormalization(epsilon=1.001e-5)(x)x = ReLU()(x)# 填充1圈0x = ZeroPadding2D((1, 1))(x)x = MaxPool2D(pool_size=(3,3), strides=2, padding='valid')(x)# 堆叠残差结构x = stack(x, filters=128, blocks=2, strides=1)x = stack(x, filters=256, blocks=3, strides=2)x = stack(x, filters=512, blocks=5, strides=2)x = stack(x, filters=1024, blocks=2, strides=2)# 根据特征图大小进行全局平均池化x = GlobalAvgPool2D()(x)x = Dense(num_classes, activation='softmax')(x)# 定义模型model = Model(inputs=inputs, outputs=x)return modelmodel = ResNext50(input_shape=(224,224,3), num_classes=4)
model.summary()

     结果显示如下(由于结果内容较多,只展示前后部分内容):

 (中间内容省略)

3.4 正式训练

# 设置优化器
opt = tf.keras.optimizers.Adam(learning_rate=1e-4)model.compile(optimizer="adam",loss='sparse_categorical_crossentropy',metrics=['accuracy'])epochs = 10history = model.fit(train_ds,validation_data=val_ds,epochs=epochs)

     结果如下图所示:

3.5 模型评估

acc = history.history['accuracy']
val_acc = history.history['val_accuracy']loss = history.history['loss']
val_loss = history.history['val_loss']epochs_range = range(epochs)plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.suptitle("ResNeXt-50 test")plt.plot(epochs_range, acc, label='Training Accuracy')
plt.plot(epochs_range, val_acc, label='Validation Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')plt.subplot(1, 2, 2)
plt.plot(epochs_range, loss, label='Training Loss')
plt.plot(epochs_range, val_loss, label='Validation loss')
plt.legend(loc='upper right')
plt.title('Training and Validation loss')
plt.show()

      结果如下图所示: 

4 总结

         总而言之,ResNeXt是在ResNet的网络架构上,使用类似于Inception的分治思想,即split-tranform-merge策略,将模块中的网络拆开分组,与Inception不同,每组的卷积核大小一致,这样其感受野一致,但由于每组的卷积核参数不同,提取的特征自然不同。然后将每组得到的特征进行concat操作后,再与原输入特征x或者经过卷积等处理(即进行非线性变换)的特征进行Add操作。这样做的好处是,在不增加参数复杂度的前提下提高准确率,同时还能提高超参数的数量。

        另外,cardinality是基的意思,将数个通道特征进行分组,不同的特征组之间可以看作是由不同基组成的子空间,每个组的核虽然一样,但参数不同,在各自的子空间中学到的特征就多种多样,这点跟transformer中的Multi-head attention不谋而合(Multi-head attention allows the model to jointly attend to information from different representation subspaces.)而且分组进行特征提取,使得学到的特征冗余度降低,获取能起到正则化的作用。(参考ResNeXt的分类效果为什么比Resnet好? - 知乎)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/90646.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

入门vue——创建vue脚手架项目 以及 用tomcat和nginx分别部署vue项目(vue2)

入门vue——创建vue脚手架项目 以及 用tomcat和nginx分别部署vue项目(vue2) 1. 安装npm2. 安装 Vue CLI3. 创建 vue_demo1 项目(官网)3.1 创建 vue_demo1 项目3.1.1 创建项目3.1.2 解决 sudo 问题 3.2 查看创建的 vue_demo1 项目3…

Mysql高阶语句 (一)

一、常用查询 (增、删、改、查) 对 MySQL 数据库的查询,除了基本的查询外,有时候需要对查询的结果集进行处理。 例如只取 10 条数据、对查询结果进行排序或分组等等 1、按关键字排序 PS:类比于windows 任务管理器 使用 SELECT 语句…

QT基础教程之七Qt消息机制和事件

QT基础教程之七Qt消息机制和事件 事件 事件(event)是由系统或者 Qt 本身在不同的时刻发出的。当用户按下鼠标、敲下键盘,或者是窗口需要重新绘制的时候,都会发出一个相应的事件。一些事件在对用户操作做出响应时发出&#xff0c…

windows 中pycharm中venv无法激活

1.用管理员身份打开Windows PowerShell 2.进入项目的:venv\Scripts 如:D: (1): cd .\project\venv\Scripts\ (2): 执行命令: Set-ExecutionPolicy RemoteSigned (3): 选择:Y (4): .\activate

小程序隐私保护授权处理方式之弹窗组件

欢迎点击关注-前端面试进阶指南:前端登顶之巅-最全面的前端知识点梳理总结 *分享一个使用比较久的🪜 小程序隐私保护授权弹窗组件 调用wx.getUserProfile进行授权时,返回错误信息:{errMsg: “getUserProfile:fail api scope is…

创作纪念日-我的第1024天

机缘 不知不觉已经成为创作者的第1024天啦… … 刚开始接触博客的初衷就是为了记笔记📒、记总结📝,或许对于当时就等同于是为了找工作。坚持学习并持续输出博客一年后,这时我发现再写博客,不在是为了找一份工作&…

如何利用SFTP协议远程实现更安全的文件传输 ——【内网穿透】

🎬 鸽芷咕:个人主页 🔥 个人专栏: 《高效编程技巧》《cpolar》 ⛺️生活的理想,就是为了理想的生活! 文章目录 1. 安装openSSH1.1 安装SSH1.2 启动ssh 2. 安装cpolar2.1 配置termux服务 3. 远程SFTP连接配置3.1 查看生成的随机公…

机器学习:争取被遗忘的权利

随着越来越多的人意识到他们通过他们经常访问的无数应用程序和网站共享了多少个人信息,数据保护和隐私一直在不断讨论。看到您与朋友谈论的产品或您在 Google 上搜索的音乐会迅速作为广告出现在您的社交媒体提要中,这不再那么令人惊讶。这让很多人感到担…

【HSPCIE仿真】输入网表文件(3)子电路描述语句

子电路描述语句 1. 子电路的定义定义子电路的基本语法子电路终止语句子电路的调用语句全局节点(.gloab)示例 2. 基于子电路执行多次分析 HSPICE 允许用户在程序执行过程中调用由各种 HSPICE 元件和器件构成的子电路,即电路结构的层次化描述。 子电路是以 .SUBCKT 或…

MongoDB实验——MongoDB配置用户的访问控制

MongoDB 配置用户的访问控制 一、 实验原理 理解admin数据库:安装MongoDB时,会自动创建admin数据库,这是一个特殊数据库,提供了普通数据库没有的功能,例如,有些账户角色赋予用户操作多个数据库的权限&…

大数据之Maven

一、Maven的作用 作用一:下载对应的jar包 避免jar包重复下载配置,保证多个工程共用一份jar包。Maven有一个本地仓库,可以通过pom.xml文件来记录jar所在的位置。Maven会自动从远程仓库下载jar包,并且会下载所依赖的其他jar包&…