Pytorch从零开始实战16

Pytorch从零开始实战——ResNeXt-50算法的思考

本系列来源于365天深度学习训练营

原作者K同学

对于上次ResNeXt-50算法,我们同样有基于TensorFlow的实现。具体代码如下。

引入头文件

import numpy as np
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Input, Dense, Dropout, Conv2D, MaxPool2D, Flatten, GlobalAvgPool2D, concatenate, \
BatchNormalization, Activation, Add, ZeroPadding2D, Lambda
from tensorflow.keras.layers import ReLU
from tensorflow.keras.optimizers import Adam
import matplotlib.pyplot as plt
from tensorflow.keras.callbacks import LearningRateScheduler
from tensorflow.keras.models import Model

分组卷积模块

# 定义分组卷积
def grouped_convolution_block(init_x, strides, groups, g_channels):group_list = []# 分组进行卷积for c in range(groups):# 分组取出数据x = Lambda(lambda x: x[:, :, :, c * g_channels:(c + 1) * g_channels])(init_x)# 分组进行卷积x = Conv2D(filters=g_channels, kernel_size=(3, 3),strides=strides, padding='same', use_bias=False)(x)# 存入listgroup_list.append(x)# 合并list中的数据group_merage = concatenate(group_list, axis=3)x = BatchNormalization(epsilon=1.001e-5)(group_merage)x = ReLU()(x)return x

残差单元

# 定义残差单元
def block(x, filters, strides=1, groups=32, conv_shortcut=True):if conv_shortcut:shortcut = Conv2D(filters * 2, kernel_size=(1, 1), strides=strides, padding='same', use_bias=False)(x)# epsilon为BN公式中防止分母为零的值shortcut = BatchNormalization(epsilon=1.001e-5)(shortcut)else:# identity_shortcutshortcut = x# 三层卷积层x = Conv2D(filters=filters, kernel_size=(1, 1), strides=1, padding='same', use_bias=False)(x)x = BatchNormalization(epsilon=1.001e-5)(x)x = ReLU()(x)# 计算每组的通道数g_channels = int(filters / groups)# 进行分组卷积x = grouped_convolution_block(x, strides, groups, g_channels)x = Conv2D(filters=filters * 2, kernel_size=(1, 1), strides=1, padding='same', use_bias=False)(x)x = BatchNormalization(epsilon=1.001e-5)(x)x = Add()([x, shortcut])x = ReLU()(x)return x

堆叠残差单元

# 堆叠残差单元
def stack(x, filters, blocks, strides, groups=32):# 每个stack的第一个block的残差连接都需要使用1*1卷积升维x = block(x, filters, strides=strides, groups=groups)for i in range(blocks):x = block(x, filters, groups=groups, conv_shortcut=False)return x

网络搭建

# 定义ResNext50(32*4d)网络
def ResNext50(input_shape, num_classes):inputs = Input(shape=input_shape)# 填充3圈0,[224,224,3]->[230,230,3]x = ZeroPadding2D((3, 3))(inputs)x = Conv2D(filters=64, kernel_size=(7, 7), strides=2, padding='valid')(x)x = BatchNormalization(epsilon=1.001e-5)(x)x = ReLU()(x)# 填充1圈0x = ZeroPadding2D((1, 1))(x)x = MaxPool2D(pool_size=(3, 3), strides=2, padding='valid')(x)# 堆叠残差结构x = stack(x, filters=128, blocks=2, strides=1)x = stack(x, filters=256, blocks=3, strides=2)x = stack(x, filters=512, blocks=5, strides=2)x = stack(x, filters=1024, blocks=2, strides=2)# 根据特征图大小进行全局平均池化x = GlobalAvgPool2D()(x)x = Dense(num_classes, activation='softmax')(x)# 定义模型model = Model(inputs=inputs, outputs=x)return model

对于残差单元中的代码,提出一个问题:当conv_shortcut=False的时候,在执行Add操作时,理论上通道数不一致,为什么代码不报错?
在这里插入图片描述
答:这主要是跟下面堆叠残差单元的代码有关系,每个stack第一轮总会令conv_shortcut为True,使得x通道数进行扩展,而后面循环的时候传入的filters还是这个函数的实参,没有发生变化,但由于conv_shortcut为False,此时shortcut的通道数是与上面的x一致,所以在Add的时候,代码不会报错。

def stack(x, filters, blocks, strides, groups=32):# 每个stack的第一个block的残差连接都需要使用1*1卷积升维x = block(x, filters, strides=strides, groups=groups)for i in range(blocks):x = block(x, filters, groups=groups, conv_shortcut=False)return x

本文只是对ResNeXt-50算法的部分代码进行思考,学习过程中需要积极思考与探索,以提高能力和解决问题。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/337982.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【昕宝爸爸定制】如何将集合变成线程安全的?

如何将集合变成线程安全的? ✅典型解析🟢拓展知识仓☑️Java中都有哪些线程安全的集合?🟠线程安全集合类的优缺点是什么🟡如何选择合适的线程安全集合类☑️如何解决线程安全集合类并发冲突问题✔️乐观锁实现方式 (具体步骤)。✅…

2024年了,Layui再战三年有问题不?

v2.9.3 2023-12-31 2023 收官。 form 优化 input 组件圆角时后缀存在方框的问题 #1467 bxjt123优化 select 搜索面板打开逻辑,以适配文字直接粘贴触发搜索的情况 #1498 Sight-wcgtable 修复非常规列设置 field 表头选项时,导出 excel 出现合计行错位的…

[Python办公自动化 – 数据预处理和数据校验

Python办公自动化 – 以下是往期的文章目录,需要可以查看哦。 Python办公自动化 – Excel和Word的操作运用 Python办公自动化 – Python发送电子邮件和Outlook的集成 Python办公自动化 – 对PDF文档和PPT文档的处理 Python办公自动化 – 对Excel文档和数据库的操作…

【C++】STL 算法 ③ ( 函数对象中存储状态 | 函数对象作为参数传递时值传递问题 | for_each 算法的 函数对象 参数是值传递 )

文章目录 一、函数对象中存储状态1、函数对象中存储状态简介2、示例分析 二、函数对象作为参数传递时值传递问题1、for_each 算法的 函数对象 参数是值传递2、代码示例 - for_each 函数的 函数对象 参数在外部不保留状态3、代码示例 - for_each 函数的 函数对象 返回值 一、函数…

中小企业低成本如何进行推广?媒介盒子解答

数字化时代下,用户想要购买产品的第一步都是在网上查询相关信息,因此对于中小企业来说,怎么把自己曝光到网上,怎么可以让用户搜到类似关键词时名列其中是企业需要考虑的问题,今天媒介盒子就来和大家聊聊:中…

五、Spring AOP面向切面编程(基于XML方式实现)

本章概要 Spring AOP基于XML方式实现(了解)Spring AOP对获取Bean的影响理解 根据类型装配 bean使用总结 5.6 Spring AOP基于XML方式实现(了解) 准备工作 加入依赖 <!-- spring-aspects会帮我们传递过来aspectjweaver --> <dependency><groupId>org.spr…

代理IP连接不上/网速过慢?如何应对?

当您使用代理时&#xff0c;您可能会遇到不同的代理错误代码显示代理IP连不通、访问失败、网速过慢等种种问题。 在本文中中&#xff0c;我们将讨论您在使用代理IP时可能遇到的常见错误、发生这些错误的原因以及解决方法。 一、常见代理服务器错误 当您尝试访问网站时&#…

10个最容易被忽视的 FastAPI 实用功能

FastAPI是一种现代、高性能的Python Web框架&#xff0c;用于构建Web应用程序和API。 它基于Python的异步编程库asyncio和await语法&#xff0c;以及类型注解和自动文档生成等特性&#xff0c;提供了快速、易用和可靠的开发体验&#xff0c;接下来本文将介绍10项被忽视的FastA…

Qt/QML编程学习之心得:hicar手机投屏到车机中控的实现(32)

hicar,是华为推出的一款手机APP,有百度地图、华为音乐,更多应用中还有很多对应手机上装在的其他APP,都可以在这个里面打开使用,对开车的司机非常友好。但它不仅仅是用在手机上,它还可以投屏到车机中控上,这是比较神奇的一点。 HiCar本质上是一套智能投屏系统,理论上所有…

React Native 桥接原生常量

一、编写并注册原生常量方法 在 SmallDaysAppModule 这个模块中有一个方法 getConstans &#xff0c;重载这个方法就可将自定义的常量返回&#xff0c;系统会自行调用该方法并返回定义的常量将其直接注入到 JS 层&#xff0c;在 JS 层直接获取即可。 二、JS 层获取原生常量&am…

Vulnhub靶机:Corrosion1

一、介绍 运行环境&#xff1a;Virtualbox 攻击机&#xff1a;kali&#xff08;10.0.2.15&#xff09; 靶机&#xff1a;corrosion:1&#xff08;10.0.2.12&#xff09; 目标&#xff1a;获取靶机root权限和flag 靶机下载地址&#xff1a;https://www.vulnhub.com/entry/c…

MongoDB快速实战与基本原理

MongoDB 介绍 什么是 MongoDB MongoDB 是一个文档数据库&#xff08;以 JSON 为数据模型&#xff09;&#xff0c;由 C 语言编写&#xff0c;旨在为 WEB 应用提供可扩展的高性能数据存储解决方案。文档来自于“JSON Document”&#xff0c;并非我们一般理解的 PDF、WORD 文档…