Nas-FPN(CVPR 2019)原理与代码解析

paper:NAS-FPN: Learning Scalable Feature Pyramid Architecture for Object Detection

third-party implementation:https://github.com/open-mmlab/mmdetection/tree/main/configs/nas_fpn

本文的创新点

本文采用神经网络结构搜索(Neural Architecture Search, NAS),在一个覆盖所有跨尺度连接的新型可扩展搜索空间中发现了一个新的特征金字塔结构,NAS-FPN。与原始FPN相比,NAS-FPN显著提高了目标检测的性能,并取得了更好了速度-精度的平衡。

方法介绍

考虑到其简单而高效的结构,目标检测模型采用RetinaNet,如图2所示

作者提出了merging cell作为FPN的basic building block,将任何两层的输入特征融合为一层的输出特征。如图3所示

其中Binary Op包括两种候选方案,如图4所示 

最终搜索到的NAS-FPN的完整结构如图6所示 

图7展示的搜索到的所有结构,其中(a)是原始FPN结构,(b)-(f)的精度逐渐变高,(f)是最终的NAS-FPN结构。

因为是搜索到的结构,并且图示非常清晰,这里就不过多介绍具体结构了。接下来结合代码和图(6)(7)的结构介绍一下实现细节 

代码解析

这里以mmdetection中的实现为例,实现代码在mmdet/models/necks/nas_fpn.py中,下面是完整的forward函数。其中self.fpn_stages=7是nas-fpn重复的次数,每个nas-fpn的输出是下一个nas-fpn的输入。forward最开始的输入是backbone的输出C2~C5,这里只取C3~C5通过lateral_conv得到P3~P5,然后进行下采样得到P6和P7,完整的P3~P7作为第一个nas-fpn的输入。

def forward(self, inputs: Tuple[Tensor]) -> tuple:# [(8,256,160,160),#  (8,512,80,80),#  (8,1024,40,40),#  (8,2048,20,20)]"""Forward function.Args:inputs (tuple[Tensor]): Features from the upstream network, eachis a 4D-tensor.Returns:tuple: Feature maps, each is a 4D-tensor."""# build P3-P5feats = [lateral_conv(inputs[i + self.start_level])for i, lateral_conv in enumerate(self.lateral_convs)]  # [(8,256,80,80),(8,256,40,40),(8,256,20,20)]# build P6-P7 on top of P5for downsample in self.extra_downsamples:feats.append(downsample(feats[-1]))  # [..., (8,256,10,10),(8,256,5,5)]p3, p4, p5, p6, p7 = featsfor stage in self.fpn_stages:# gp(p6, p4) -> p4_1# print(stage['gp_64_4'])p4_1 = stage['gp_64_4'](p6, p4, out_size=p4.shape[-2:])  # (8,256,40,40)# sum(p4_1, p4) -> p4_2p4_2 = stage['sum_44_4'](p4_1, p4, out_size=p4.shape[-2:])  # (8,256,40,40)# sum(p4_2, p3) -> p3_outp3 = stage['sum_43_3'](p4_2, p3, out_size=p3.shape[-2:])  # (8,256,80,80)# sum(p3_out, p4_2) -> p4_outp4 = stage['sum_34_4'](p3, p4_2, out_size=p4.shape[-2:])  # (8,256,40,40)# sum(p5, gp(p4_out, p3_out)) -> p5_outp5_tmp = stage['gp_43_5'](p4, p3, out_size=p5.shape[-2:])  # (8,256,20,20)p5 = stage['sum_55_5'](p5, p5_tmp, out_size=p5.shape[-2:])  # (8,256,20,20)# sum(p7, gp(p5_out, p4_2)) -> p7_outp7_tmp = stage['gp_54_7'](p5, p4_2, out_size=p7.shape[-2:])  # (8,256,5,5)p7 = stage['sum_77_7'](p7, p7_tmp, out_size=p7.shape[-2:])  # (8,256,5,5)# gp(p7_out, p5_out) -> p6_outp6 = stage['gp_75_6'](p7, p5, out_size=p6.shape[-2:])  # (8,256,10,10)return p3, p4, p5, p6, p7

在for循环中,从上到到下分别对应图6中从左到右按顺序所有的GP和Sum。其中GP对应GlobalPoolingCell,Sum对应SumCell,具体实现都在MMCV中。

GlobalPoolingCell的实现如下,其中self.input1_conv和self.input2_conv是空的,self._resize通过双线性插值进行上采样,通过max pooling进行下采样。

def forward(self,x1: torch.Tensor,x2: torch.Tensor,out_size: Optional[tuple] = None) -> torch.Tensor:assert x1.shape[:2] == x2.shape[:2]assert out_size is None or len(out_size) == 2if out_size is None:  # resize to larger oneout_size = max(x1.size()[2:], x2.size()[2:])x1 = self.input1_conv(x1)x2 = self.input2_conv(x2)x1 = self._resize(x1, out_size)x2 = self._resize(x2, out_size)x = self._binary_op(x1, x2)if self.with_out_conv:x = self.out_conv(x)return x

self._binary_op的实现如下,其中self.global_pool是全局平均池化。

def _binary_op(self, x1, x2):x2_att = self.global_pool(x2).sigmoid()return x2 + x2_att * x1

最后的self.out_conv是Conv+BN+ReLU的组合,对应图6中的R-C-B。注意图6中只有第一个GP和最后一个GP后有R-C-B,中间两个GP后没有,即上面代码中self.with_out_conv=False。

SumCell和GlobalPoolingCell继承自同一个基类,forward函数是一样的。区别在于SumCell中的self._binary_op就是sum操作,如下

class SumCell(BaseMergeCell):def __init__(self, in_channels: int, out_channels: int, **kwargs):super().__init__(in_channels, out_channels, **kwargs)def _binary_op(self, x1, x2):return x1 + x2

此外,5个SumCell后都有R-C-B。

实验结果

和其他SOTA模型的对比如表1所示

其中7@256表示NAS-FPN堆叠7次,通道数为256。

文中特别提到由于NAS-FPN的结构堆叠多层引入了更多的参数,需要一个合适的正则化方法来防止过拟合。本文采用DropBlock,具体介绍见DropBlock(NeurIPS 2018)论文与代码解析-CSDN博客。图10展示了DropBlock显著提升了NAS-FPN的性能。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/427275.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【JavaWeb】日程管理系统 项目搭建 第二期

文章目录 一、数据库准备二、导入依赖 与 JDBC工具类三、pojo包处理四、daodao包工具类 五、service六、controllerservlet 基类 反射 七、加密工具类 MD5八、页面文件九、业务代码9.1 注册业务处理9.2 登录业务处理 总结 一、数据库准备 创建数据库: SET NAMES …

【码农新闻】浏览器上有趣的 Console 命令,VSCode 插件 FreeWindow......

目录 【码农新闻】浏览器上有趣的 Console 命令,VSCode 插件 FreeWindow...... 浏览器上有趣的 Console 命令VSCode 插件 FreeWindow拖拽竟然还能这样玩!阮一峰 ES6 教程总结学习网站总结与整理买临期食品的年轻人,在向“吃喝内卷”低头文章所属专区 码农新闻 欢迎各位编程大…

​第20课 在Android Native开发中加入新的C++类

​这节课我们开始利用ffmpeg和opencv在Android环境下来实现一个rtmp播放器,与第2课在PC端实现播放器的思路类似,只不过在处理音视频显示和播放的细节略有不同。 1.压缩备份上节课工程文件夹并修改工程文件夹为demo20,将demo20导入到Eclipse或…

80.网游逆向分析与插件开发-背包的获取-自动化助手显示物品数据

内容参考于:易道云信息技术研究院VIP课 上一个内容:升级Notice类获得背包基址-CSDN博客 码云地址(ui显示角色数据 分支):https://gitee.com/dye_your_fingers/sro_-ex.git 码云版本号:3be017de38c50653b1…

智能泊车,再上热搜

编者按:相比于行车,低速可控场景,更有利于泊车功能快速迭代。同时,对于部分消费者来说,泊车智能化也是加分项。 智能泊车赛道,正在重新成为各路势力争夺的焦点。而上一次“高潮”,要追溯到2018年…

CSC5613C

CSC5613C是一款DC/DC同步降压IC,输入电压8V-30V,CSC5613C具有良好的瞬态响应和环路稳定性。CSC5613C外围元器件极少具有项目过流保护,过热保护功能。CSC5613C可通过调节FB电阻的比例来调节输出电压,可用于快充。CSC561…

OpenCV-Python(49):图像去噪

目标 学习使用非局部平均值去噪算法去除图像中的噪音学习函数cv2.fastNlMeansDenoising()、cv2.fastNlMeansDenoisingColored等 原理 在前面的章节中我们已经学习了很多图像平滑技术,比如高斯平滑、中值平滑等。当噪声比较小时,这些技术的效果都是很好…

k-Wave仿真例程:对圆形换能器记录的光声波场进行时间反转重建

使用 k-Wave 对圆形换能器阵列上记录的二维光声波场进行时间反转重建。 1. 模拟换能器数据 使用 kspaceFirstOrder2D 和外部图像模拟换能器数据以获取初始压力分布,图像代表脉管系统。 定义一个居中的圆形换能器 sensor,在 270 角度上放置 70 个阵元 sensor_radius = 4.5…

基于若依的ruoyi-nbcio流程管理系统一种简单的动态表单模拟测试实现(四)

更多ruoyi-nbcio功能请看演示系统 gitee源代码地址 前后端代码: https://gitee.com/nbacheng/ruoyi-nbcio 演示地址:RuoYi-Nbcio后台管理系统 更多nbcio-boot功能请看演示系统 gitee源代码地址 后端代码: https://gitee.com/nbacheng/n…

IO 专题

使用try-with-resources语句块,可以自动关闭InputStream [实践总结] FileIUtils 共通方法最佳实践 [实践总结] java 获取在不同系统下的换行符 [实践总结] StreamIUtils 共通方法最佳实践 斜杠“/“和反斜杠“\“的区别 路径中“./”、“…/”、“/”代表的含义…

SpringBoot整合ElasticSearch实现基础的CRUD操作

本文来说下SpringBoot整合ES实现CRUD操作 文章目录 概述spring-boot-starter-data-elasticsearch项目搭建ES简单的crud操作保存数据修改数据查看数据删除数据 本文小结 概述 SpringBoot支持两种技术和es交互。一种的jest,还有一种就是SpringData-ElasticSearch。根据…

Stable Diffusion学习

参考 Stable Diffusion原理详解_stable diffusion csdn-CSDN博客 Stable Diffusion是stability.ai开源的图像生成模型,可以说Stable Diffusion的发布将AI图像生成提高到了全新高度,其效果和影响不亚于Open AI发布ChatGPT。 图像生成的发展 在Stable D…