pytorch代码实现之动态卷积模块ODConv

ODConv动态卷积模块

ODConv可以视作CondConv的延续,将CondConv中一个维度上的动态特性进行了扩展,同时了考虑了空域、输入通道、输出通道等维度上的动态性,故称之为全维度动态卷积。ODConv通过并行策略采用多维注意力机制沿核空间的四个维度学习互补性注意力。作为一种“即插即用”的操作,它可以轻易的嵌入到现有CNN网络中。ImageNet分类与COCO检测任务上的实验验证了所提ODConv的优异性:即可提升大模型的性能,又可提升轻量型模型的性能,实乃万金油是也!值得一提的是,受益于其改进的特征提取能力,ODConv搭配一个卷积核时仍可取得与现有多核动态卷积相当甚至更优的性能。

原文地址:Omni-Dimensional Dynamic Convolution

ODConv结构图
代码实现:

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.autograd
from models.common import Conv, autopadclass Attention(nn.Module):def __init__(self, in_planes, out_planes, kernel_size, groups=1, reduction=0.0625, kernel_num=4, min_channel=16):super(Attention, self).__init__()attention_channel = max(int(in_planes * reduction), min_channel)self.kernel_size = kernel_sizeself.kernel_num = kernel_numself.temperature = 1.0self.avgpool = nn.AdaptiveAvgPool2d(1)self.fc = Conv(in_planes, attention_channel, act=nn.ReLU(inplace=True))self.channel_fc = nn.Conv2d(attention_channel, in_planes, 1, bias=True)self.func_channel = self.get_channel_attentionif in_planes == groups and in_planes == out_planes:  # depth-wise convolutionself.func_filter = self.skipelse:self.filter_fc = nn.Conv2d(attention_channel, out_planes, 1, bias=True)self.func_filter = self.get_filter_attentionif kernel_size == 1:  # point-wise convolutionself.func_spatial = self.skipelse:self.spatial_fc = nn.Conv2d(attention_channel, kernel_size * kernel_size, 1, bias=True)self.func_spatial = self.get_spatial_attentionif kernel_num == 1:self.func_kernel = self.skipelse:self.kernel_fc = nn.Conv2d(attention_channel, kernel_num, 1, bias=True)self.func_kernel = self.get_kernel_attentionself._initialize_weights()def _initialize_weights(self):for m in self.modules():if isinstance(m, nn.Conv2d):nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')if m.bias is not None:nn.init.constant_(m.bias, 0)if isinstance(m, nn.BatchNorm2d):nn.init.constant_(m.weight, 1)nn.init.constant_(m.bias, 0)def update_temperature(self, temperature):self.temperature = temperature@staticmethoddef skip(_):return 1.0def get_channel_attention(self, x):channel_attention = torch.sigmoid(self.channel_fc(x).view(x.size(0), -1, 1, 1) / self.temperature)return channel_attentiondef get_filter_attention(self, x):filter_attention = torch.sigmoid(self.filter_fc(x).view(x.size(0), -1, 1, 1) / self.temperature)return filter_attentiondef get_spatial_attention(self, x):spatial_attention = self.spatial_fc(x).view(x.size(0), 1, 1, 1, self.kernel_size, self.kernel_size)spatial_attention = torch.sigmoid(spatial_attention / self.temperature)return spatial_attentiondef get_kernel_attention(self, x):kernel_attention = self.kernel_fc(x).view(x.size(0), -1, 1, 1, 1, 1)kernel_attention = F.softmax(kernel_attention / self.temperature, dim=1)return kernel_attentiondef forward(self, x):x = self.avgpool(x)x = self.fc(x)return self.func_channel(x), self.func_filter(x), self.func_spatial(x), self.func_kernel(x)class ODConv2d(nn.Module):def __init__(self, in_planes, out_planes, k, s=1, p=None, g=1, act=True, d=1,reduction=0.0625, kernel_num=1):super(ODConv2d, self).__init__()self.in_planes = in_planesself.out_planes = out_planesself.kernel_size = kself.stride = sself.padding = autopad(k, p)self.dilation = dself.groups = gself.kernel_num = kernel_numself.attention = Attention(in_planes, out_planes, k, groups=g,reduction=reduction, kernel_num=kernel_num)self.weight = nn.Parameter(torch.randn(kernel_num, out_planes, in_planes//g, k, k),requires_grad=True)self._initialize_weights()self.bn = nn.BatchNorm2d(out_planes)self.act = nn.SiLU() if act is True else (act if isinstance(act, nn.Module) else nn.Identity())if self.kernel_size == 1 and self.kernel_num == 1:self._forward_impl = self._forward_impl_pw1xelse:self._forward_impl = self._forward_impl_commondef _initialize_weights(self):for i in range(self.kernel_num):nn.init.kaiming_normal_(self.weight[i], mode='fan_out', nonlinearity='relu')def update_temperature(self, temperature):self.attention.update_temperature(temperature)def _forward_impl_common(self, x):# Multiplying channel attention (or filter attention) to weights and feature maps are equivalent,# while we observe that when using the latter method the models will run faster with less gpu memory cost.channel_attention, filter_attention, spatial_attention, kernel_attention = self.attention(x)batch_size, in_planes, height, width = x.size()x = x * channel_attentionx = x.reshape(1, -1, height, width)aggregate_weight = spatial_attention * kernel_attention * self.weight.unsqueeze(dim=0)aggregate_weight = torch.sum(aggregate_weight, dim=1).view([-1, self.in_planes // self.groups, self.kernel_size, self.kernel_size])output = F.conv2d(x, weight=aggregate_weight, bias=None, stride=self.stride, padding=self.padding,dilation=self.dilation, groups=self.groups * batch_size)output = output.view(batch_size, self.out_planes, output.size(-2), output.size(-1))output = output * filter_attentionreturn outputdef _forward_impl_pw1x(self, x):channel_attention, filter_attention, spatial_attention, kernel_attention = self.attention(x)x = x * channel_attentionoutput = F.conv2d(x, weight=self.weight.squeeze(dim=0), bias=None, stride=self.stride, padding=self.padding,dilation=self.dilation, groups=self.groups)output = output * filter_attentionreturn outputdef forward(self, x):return self.act(self.bn(self._forward_impl(x)))

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/110299.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【JavaSE笔记】抽象类与接口

一、抽象类 1、概念 在面向对象的概念中,所有的对象都是通过类来描绘的,但是反过来,并不是所有的类都是用来描绘对象的,如果一个类中没有包含足够的信息来描绘一个具体的对象,这样的类就是抽象类。 package demo2…

气传导耳机哪个好?值得推荐的气传导耳机分享

​随着生活节奏的加快,人们越来越关注听力健康。气传导耳机以其独特的传导方式和舒适的佩戴感受,逐渐成为耳机市场的新宠。气传导耳机不入耳设计听音,让你在享受音乐的同时,也能保护你的听力安全。今天我们就一起来看看几款值得大…

饲料添加剂 微生物 屎肠球菌

声明 本文是学习GB 7300.503-2023 饲料添加剂 第5部分:微生物 屎肠球菌. 而整理的学习笔记,分享出来希望更多人受益,如果存在侵权请及时联系我们 1 范围 本文件规定了饲料添加剂屎肠球菌的技术要求、采样、检验规则、标签、包装、运输、贮存和保质 期&#xff0…

【vue】vue 中插槽的三种类型:

文章目录 一、匿名插槽&#xff1a;二、具名插槽&#xff1a;三、作用域插槽 一、匿名插槽&#xff1a;<slot></slot> 1.没有为插槽指定名称 2.通过slot标签可以添加匿名插槽 3.在使用组件的时候&#xff0c;组件中的内容会填充到所有匿名插槽的位置&#xff0c;所…

数据结构——散列函数、散列表

文章目录 前言一、散列表的基本概念二、散列函数的构造方法三、处理冲突的方法1. 开放定址法&#xff1a;2. 拉链法 四、散列查找及性能分析总结 前言 散列表的基本概念散列函数的构造方法处理冲突的方法散列查找及性能分析 提示&#xff1a;以下是本篇文章正文内容&#xff0…

七天学会C语言-第一天(C语言基本语句)

一、固定格式 这个是C程序的基本框架&#xff0c;需要记住&#xff01;&#xff01;&#xff01; #include<stdio.h>int main(){return 0; }二、printf 语句 简单输出一句C程序&#xff1a; #include<stdio.h> int main(){printf("大家好&#xff0c;&quo…

S7-1200PLC和LED电子看板通信(TCP/IP)

S7-200SMART PLC和LED电子看板通信应用,请查看下面文章链接: SMART 200 PLC UDP通讯应用LED看板_RXXW_Dor的博客-CSDN博客开放式用户通信 (OUC) 库:数据解析:https://rxxw-control.blog.csdn.net/article/details/121424897这篇博客我们主要介绍S7-1200PLC和LED电子看板通…

PowerDesigner 逆向工程以及IDEA中UML插件

1、MySQL数据库连接&#xff08;JDBC方式&#xff09; 1.1 新建一个pdm&#xff0c;dbms选择mysql 1.2 Database - Connect 选择数据库连接 1.3 配置连接信息 数据库连接这里是通过一个配置文件来获取连接信息的&#xff0c;首次的话因为没有&#xff0c;所以我们需要选择…

Michael.W基于Foundry精读Openzeppelin第34期——MerkleProof.sol

Michael.W基于Foundry精读Openzeppelin第34期——MerkleProof.sol 0. 版本0.1 MerkleProof.sol 1. 目标合约2. 代码精读2.1 processProof(bytes32[] memory proof, bytes32 leaf) && processProofCalldata(bytes32[] calldata proof, bytes32 leaf)2.2 verify(bytes32[…

人工智能现在可以从文本中生成具有CD音质的音乐,而且只会越来越好

想象一下&#xff0c;键入“戏剧性的介绍音乐”并听到一首飙升的交响乐&#xff0c;或者编写“令人毛骨悚然的脚步声”并获得高质量的音效。这是稳定音频的承诺&#xff0c;一个文本到音频的人工智能模型周三宣布由能合成立体声的稳定人工智能44.1千赫来自文字描述的音乐或声音…

进化算法、遗传编程和学习

一、说明 进化算法是一系列搜索算法&#xff0c;其灵感来自自然界&#xff08;达尔文主义&#xff09;进化过程。所有不同家庭成员的共同点是&#xff0c;通过应用受自然遗传学和自然选择启发的 算子&#xff0c;通过进化出最初 随机的候选解决方案群体来解决问题&#…

C++之哈希表、哈希桶的实现

哈希表、哈希桶的实现 哈希概念哈希冲突哈希函数哈希冲突解决闭散列哈希表闭散列实现哈希表的结构哈希表的插入哈希表的查找哈希表的删除 开散列开散列概念哈希表的结构哈希表的插入哈希表的查找哈希表的删除 哈希概念 顺序结构以及平衡树中&#xff0c;元素关键码与其存储位置…