Conv2Former:一种transformer风格的卷积特征提取方式

一、前言

昨天读到了一篇有意思的文章,文章提出通过利用卷积调制操作来简化self-attention。还证明了这种简单的方法可以更好地利用卷积层中嵌套的大核(≥7 × 7)。我们都知道ViTs推动了设计识别模型的发展,近几年使用的也相当的多,通常就是CNN网络引入注意力机制,往往可以获得不错的性能,因为相比较与卷积cnn,self-attention能够模拟全局成对依赖关系,这是一种更有效的空间信息编码方式。

自注意机制与所提出的卷积调制运算的比较 

作者注意到有研究者通过对标准ResNet进行现代化改进,采用与Transformers相似的设计和训练方法,ConvNets的性能在某些情况下甚至能够超越一些热门的ViTs。由此,引入了一种名为Conv2Former的新型网络结构,旨在更有效地利用空间卷积。与以往工作不同,Conv2Former采用了一种称为卷积调制的操作,通过将大内核卷积的输出与值表示进行Hadamard乘积,模拟了自注意力的输出计算过程。与传统的ViTs使用自注意力不同,Conv2Former完全是一个卷积网络,其计算复杂度随分辨率线性增加,而不是像Transformers那样呈二次增加。

二、Conv2Former的网络架构

在每两个阶段间,引入了一个Patch Embed来降低分辨率,通常是一个2 × 2的卷积,步幅为2。不同的阶段有不同数量的卷积块。

构建了五个Conv2Former变体,分别是Conv2Former- n、Conv2Former- t、Conv2Former- s、Conv2Former- b、Conv2Former- l。参数由下图所示:

残差块、自注意块和所提出的调制块之间的对比图

与自注意相比,(d)方法利用卷积来建立关系,这比自注意更节省内存,特别是在处理高分辨率图像时。与经典的残差块方法(a)相比,(d)还可以通过调制操作来适应输入内容。

在VGG,Resnet等网络提出后,3x3卷积几乎就是构建卷积网络的标准选择,但深度可分离卷积的出现改变了这一格局。ConvNeXt的研究显示,将卷积核从3扩大到7可以提升性能,然而,进一步增大卷积核几乎不再带来性能提升,反而增加了计算负担。

作者认为ConvNeXt未能从大于7×7的更大卷积核中获得更大性能提升的原因在于对空间卷积的使用方式。相比之下,Conv2Former通过观察发现,在卷积核大小从5×5增加到21×21的范围内,性能持续提升。这一现象在不同规模的Conv2Former模型中均得到验证,为更好地利用空间卷积提供了重要的设计指导。

考虑到模型的效率,作者选择将卷积核大小默认设置为11×11。这一研究发现对于ConvNet设计,更大的卷积核可以通过正确的使用方式获得持续的性能提升,为未来的视觉识别模型提供了有益的启示。

在Conv2Former的标准化和激活层方面,标准化层使用Layer Normalization,而不是常用的批标准化,激活层采用GELU。作者发现这种Layer Normalization和GELU的组合能够带来0.1%-0.2%的性能提升。

三、卷积调制操作

卷积核大小影响性能

实验证明,在Conv2Former中,增加深度卷积核大小(从5×5到21×21)可以显著提升模型性能,与传统结论不同,这表明使用卷积特征作为权重能更有效地利用大核,相较于传统方法无法带来明显性能提升。

Hadamard乘积优于求和

对比Hadamard乘积和元素求和两种融合策略,实验证明在Conv2Former中,使用深度卷积特征通过Hadamard乘积调制权重表现更好,尤其对于小型模型。这突显了卷积调制在编码空间信息方面的高效性。

权重策略分析

通过尝试不同的特征图融合方式,包括添加Sigmoid函数、应用L1标准化以及线性归一化等,结果显示Hadamard乘积依然是最优选择。有趣的是,将A的值调整为正值反而导致性能下降,与传统注意机制的做法有所不同,为未来研究提供了新的问题。

四、论文复现

作者也是给出了代码地址:HVision-NKU/Conv2Former

但我看了一下,里面只是给出了一些核心的代码,所以复现还是要靠咱自己。

"""
Copyright (c) 2023, Auorui.
All rights reserved.reference <https://arxiv.org/pdf/2211.11943.pdf> (Conv2Former: A Simple Transformer-Style ConvNet for Visual Recognition)
Time:2023.12.31, Complete before the end of 2023.
"""
import torch
import torch.nn as nnfrom pyzjr.Models import DropPathC = {'n': [64, 128, 256, 512],'t': [72, 144, 288, 576],'s': [72, 144, 288, 576],'b': [96, 192, 384, 768],'l': [128, 256, 512, 1024],} #  reference <https://arxiv.org/pdf/2211.11943.pdf> Table 1
L = {'n': [2, 2, 8, 2],'t': [3, 3, 12, 3],'s': [4, 4, 32, 4],'b': [4, 4, 34, 4],'l': [4, 4, 48, 4],} #  reference <https://arxiv.org/pdf/2211.11943.pdf> Table 1class MLP(nn.Module):def __init__(self, dim, mlp_ratio=4.):super().__init__()self.norm = nn.LayerNorm(dim, eps=1e-6)self.fc1 = nn.Conv2d(dim, dim * mlp_ratio, 1)self.pos = nn.Conv2d(dim * mlp_ratio, dim * mlp_ratio, 3, padding=1, groups=dim * mlp_ratio)self.fc2 = nn.Conv2d(dim * mlp_ratio, dim, 1)self.act = nn.GELU()def forward(self, x):x = self.norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)x = self.fc1(x)x = self.act(x)x = x + self.act(self.pos(x))x = self.fc2(x)return xclass ConvMod(nn.Module):def __init__(self, dim):super().__init__()self.norm = nn.LayerNorm(dim, eps=1e-6)self.a = nn.Sequential(nn.Conv2d(dim, dim, 1),nn.GELU(),nn.Conv2d(dim, dim, 11, padding=5, groups=dim))self.v = nn.Conv2d(dim, dim, 1)self.proj = nn.Conv2d(dim, dim, 1)def forward(self, x):x = self.norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)a = self.a(x)x = a * self.v(x)x = self.proj(x)return xclass Block(nn.Module):def __init__(self, dim, mlp_ratio=4.0, drop_path=0.):super().__init__()self.attn = ConvMod(dim)self.mlp = MLP(dim, mlp_ratio)layer_scale_init_value = 1e-6self.layer_scale_1 = nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True)self.layer_scale_2 = nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True)self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()def forward(self, x):x = x + self.drop_path(self.layer_scale_1.unsqueeze(-1).unsqueeze(-1) * self.attn(x))x = x + self.drop_path(self.layer_scale_2.unsqueeze(-1).unsqueeze(-1) * self.mlp(x))return xclass BaseLayer(nn.Module):def __init__(self, dim, depth, mlp_ratio=4., drop_path=None, downsample=True):super().__init__()self.dim = dimself.drop_path = drop_pathself.blocks = nn.ModuleList([Block(dim=self.dim,mlp_ratio=mlp_ratio,drop_path=drop_path[i],)for i in range(depth)])# patch merging layerif downsample:self.downsample = nn.Sequential(nn.GroupNorm(num_groups=1, num_channels=dim),nn.Conv2d(dim, dim * 2, kernel_size=2, stride=2,bias=False))else:self.downsample = Nonedef forward(self, x):for blk in self.blocks:x = blk(x)if self.downsample is not None:x = self.downsample(x)return xclass Conv2Former(nn.Module):def __init__(self, num_classes=10, depths=(2,2,8,2), dim=(64,128,256,512), mlp_ratio=2.,drop_rate=0.,drop_path_rate=0.15, **kwargs):super().__init__()norm_layer = nn.LayerNormself.num_classes = num_classesself.num_layers = len(depths)self.dim = dimself.mlp_ratio = mlp_ratioself.pos_drop = nn.Dropout(p=drop_rate)# stochastic depth decay ruledpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]# build layersself.layers = nn.ModuleList()for i_layer in range(self.num_layers):layer = BaseLayer(dim[i_layer],depth=depths[i_layer],mlp_ratio=self.mlp_ratio,drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],downsample=(i_layer < self.num_layers - 1),)self.layers.append(layer)self.fc1 = nn.Conv2d(3, dim[0], 1)self.norm = norm_layer(dim[-1], eps=1e-6,)self.avgpool = nn.AdaptiveAvgPool2d(1)self.head = nn.Linear(dim[-1], num_classes) \if num_classes > 0 else nn.Identity()self.apply(self._init_weights)def _init_weights(self, m):if isinstance(m, nn.Linear):nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')if m.bias is not None:nn.init.constant_(m.bias, 0.)elif isinstance(m, (nn.Conv1d, nn.Conv2d)):nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')if m.bias is not None:nn.init.constant_(m.bias, 0.)elif isinstance(m, (nn.LayerNorm, nn.GroupNorm)):nn.init.constant_(m.bias, 0.)nn.init.constant_(m.weight, 1.)def forward_features(self, x):x = self.fc1(x)x = self.pos_drop(x)for layer in self.layers:x = layer(x)x = self.norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)x = self.avgpool(x)x = torch.flatten(x, 1)return xdef forward(self, x):x = self.forward_features(x)x = self.head(x)return xif __name__ == '__main__':model = Conv2Former(num_classes=10, depths=L["b"], dim=C["b"], mlp_ratio=2, drop_path_rate=0.1)input_tensor = torch.randn(1, 3, 224, 224)output = model(input_tensor)print("Output shape:", output.shape)

暂时还未进行测试,如果有空再进行实验, 你可以从这里找到我的源码:pyzjr/pyzjr/Models/backbone/Conv2Former_2.py at main · Auorui/pyzjr (github.com)

总结

Conv2Former,这是一种新型的卷积神经网络架构,其核心是卷积调制操作,通过卷积和Hadamard乘积简化了自注意力机制。在ImageNet分类、目标检测和语义分割任务中,Conv2Former相对于先前的CNN和Transformer模型表现更优。作者强调了对大内核卷积的更有效利用,如何将提出的卷积调制块与Transformer结合起来,这是将来的研究方向。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/313158.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

你逛过凌晨四点的校园吗?--大四毕业生的年终总结

前言&#xff1a; Hello大家好&#xff0c;我是Dream。 又是一年的年终总结&#xff0c;我也迎来了自己的毕业季&#xff0c;没错&#xff0c;我马上要毕业啦&#xff01;不知道大家是什么时候认识我的呢&#xff0c;又或者是第一次发现我~这一年&#xff0c;迎接过朝阳、拍下过…

Linux:apache优化(1)—— 长链接/保持连接

系统:CentOS 7.9 apache版本为&#xff1a;2.4.25 需要使用源码包进行安装才能够使用这些扩展模块 在使用这些扩展模块前要先下载zlib-devel 安装--enable-deflate选项需要的网页压缩传输的软件包 yum -y install zlib-devel 在配置编译安装时需要使用扩展配置 ./config…

使用ChatGLM3自定义工具实现大模型查询MySQL数据库

ChatGLM3-6B 采用了全新设计的 Prompt 格式&#xff0c;除正常的多轮对话外。同时原生支持工具调用&#xff08;Function Call&#xff09;、代码执行&#xff08;Code Interpreter&#xff09;和 Agent 任务等复杂场景。 什么是工具调用 大模型虽然强大&#xff0c;但是由于…

java struts2教务管理系统Myeclipse开发mysql数据库struts2结构java编程计算机网页项目

一、源码特点 java struts2 教务管理系统 是一套完善的web设计系统&#xff0c;对理解JSP java编程开发语言有帮助 struts2 框架开发&#xff0c;系统具有完整的源代码和数据库&#xff0c;系统主要采用B/S模式开发。开发环境 为TOMCAT7.0,Myeclipse8.5开发&#xff0c;数据库…

【数据结构-单链表】(C语言版本)

今天分享的是数据结构有关单链表的操作和实践&#xff08;图解法&#xff0c;图变化更利于理解&#xff09; 记录宗旨&#x1f4dd;&#xff1a; 眼&#xff08;脑&#xff09;过千遍&#xff0c;不如手过一遍。 我们都知道单链表是一种常见的链表数据结构&#xff0c;由一系列…

泽攸科技PECVD设备助力开发新型石墨烯生物传感器

近日&#xff0c;松山湖材料实验室许智团队与清华大学符汪洋合作在纳米领域头部期刊《Small》上发表了一项引人注目的研究成果&#xff0c;题为“Ultrasensitive biochemical sensing platform enabled by directly grown graphene on insulator”&#xff08;硅晶圆上直接生长…

34--JDK8新特性

1. Java版本迭代概述 1.1 发布特点&#xff08;小步快跑&#xff0c;快速迭代&#xff09; 发行版本 发行时间 备注 Java 1.0 1996.01.23 Sun公司发布了Java的第一个开发工具包 Java 5.0 2004.09.30 ①版本号从1.4直接更新至5.0&#xff1b;②平台更名为JavaSE、JavaE…

[每周一更]-(第49期):一名成熟Go开发需储备的知识点(答案篇)- 2

答案篇 1、Go语言基础知识 什么是Go语言&#xff1f;它有哪些特点&#xff1f; Go语言&#xff08;也称为Golang&#xff09;是一种由Google开发的开源编程语言。它于2007年首次公开发布&#xff0c;并在2012年正式推出了稳定版本。Go语言旨在提供简单、高效、可靠的编程解决…

【Spring】AOP的AspectJ开发

AOP基础不了解可以阅读&#xff1a;【Spring】AOP原来如此-CSDN博客 AspectJ是一个居于JAVA开发的AOP框架 基于XML的声明式AspectJ 基于XML的声明式AspectJ是通过XML文件来定义切面&#xff0c;切入点及通知&#xff0c;所有的切面、切入点和通知必须定义在内&#xff0c; 元…

webpack的深入学习与实战(持续更新)

一、何为Webpack Webpack是 一个开源的JavaScript模块打包工具&#xff0c;其最核心的功能是解决模块之间的依赖&#xff0c;把各个模块按照特定的规则和顺序组织在一起&#xff0c;最终合并为一个JS文件或多个。 二、带宽的换算 目前我们的云服务器带宽为5M 三 、bundle 体…

Qt高质量的开源项目合集

文章目录 1.Qt官网下载/文档2.第三方开源 1.Qt官网下载/文档 Qt Downloads Qt 清华大学开源软件镜像站 Qt 官方博客 2.第三方开源 记录了平常项目开发中用到的第三方库&#xff0c;以及一些值得参考的项目&#xff01; Qt AV 基于Qt和FFmpeg的跨平台高性能音视频播放框…

Javascript 循环结构while do while for实例讲解

Javascript 循环结构while do while for实例讲解 目录 Javascript 循环结构while do while for实例讲解 一、while语句 二、do…while语句 三、for循环 疑难解答 我们从上一节课知道&#xff0c;JavaScript循环结构总有3种&#xff1a; &#xff08;1&#xff09;while语…