分组卷积的思想神了

大家好啊,我是董董灿。

最近,分组卷积帮我解决了一个大忙,事情是这样的。

这几天遇到一个头疼的问题,就是要在某一芯片上完成一个神经网络的适配,这个神经网络中卷积居多,并且有一些卷积的通道数很大,比如2048个输入通道。

问题是,该芯片是专用芯片,所对应的硬件模块无法直接支持这种通道数很大的卷积运算。

于是开始了头脑风暴,因为芯片中有良好的向量指令集来支持内积运算,因此第一反应便是拿内积运算来拼凑出卷积。

但实验结果表明,利用内积指令来拼凑的卷积效果不如人意,主要在于内积指令调用次数过多,导致神经网络的整体性能太差。

就在一筹莫展时,一个声音传过来,“我们改图吧”。

改图,指的是改神经网络的结构,很多推理框架都具备这个能力,比如pytorch,tvm等。

这些推理框架可以针对性的适配某些专用AI芯片加速器,为此魔改一些神经网络结构,通过增加一些优化节点(pass),来使原本不支持的运算变为可支持的运算。

比如这个卷积的例子,可以将一个大卷积(指的是通道数很大),魔改为两个或多个小卷积,分别计算,计算完成后再将结果合并。

这就要提一下将卷积在通道维度分成多个卷积计算的操作——分组卷积(Group Convolution)了。

1、 什么是分组卷积

网上有很多关于分组卷积的资料。说的简单点,分组卷积是将卷积在channel 维度分组来计算,以达到将一个大卷积分成多个小卷积的目的。

为了清晰,我们将卷积操作简化为一次最简单的乘累加运算,channel维度只有2个数据,如下图。

图片

正常的卷积操作,A和B的乘累加,计算的是 1x3 + 2x4 = 11。

而如果将其在channel维度分组(例子中channel维度只有两个数据,我们就分成两组),那么会是这样

图片

第一组只计算channel 维度的前半部分,第二组只计算channel维度的后半部分。

分组的过程是不是很好理解。

2、为什么需要分组卷积

分组卷积最早由Alex等人在2012年的ImageNet图像分类竞赛中提出并使用,提出的初衷是为了解决卷积神经网络训练期间的计算和内存开销问题。

2012年的GPU不像现在的GPU内存那么大,当初GPU内存还很有限,一个channel通道数很大的卷积直接计算,放在整个网络中,是很耗费内存的。

于是,他们把大卷积在通道方向拆分成多个小卷积来分别计算,这样可以让拆分出来的多个小卷积分别运行在多张GPU卡上,达到一个模型多卡并行计算的目的,从而提高训练性能。

需要说明的是,论文中的分组卷积不仅将输入通道进行了分组,同时将输出通道进行了分组。

由此而来的分组卷积,在计算量上变为原来的1/G,G为分组的组数。

3、分组卷积和原始卷积在数学上等价吗?

细心的小伙伴可能会问这个问题。

如果仅仅说分组卷积,那么结果肯定和原始大卷积不等价,因为把channel维度给拆开了。

并且如果不做处理,还会影响最终的推理精度,对于这个问题,有个很好的解决办法。

我们知道,卷积算法的核心是特征提取和融合:5分钟理解什么是卷积的特征提取。

如果不进行其他操作,那么分组卷积仅仅进行了组内小卷积的特征融合,而缺少了分组间的特征融合,这样对于最终的训练推理结果会有影响。

为了解决这个问题,往往在分组卷积前在channle维度进行 shuffle 操作,也就是洗牌,使得特征可以随机的分配到每一个组内,能够更好的完成组间的特征融合。

4、“分组卷积“”的思想神了

回到上面我遇到的问题,我们需要在自己的需求下,利用分组卷积的思想,魔改大卷积运算。

如下示意图:一个输入channel 为 ci 的卷积,通过 split 在输入 channel 维度拆成两个 ci/2 的卷积,然后“分组”进行卷积操作,然后通过加法进行相加。

图片

这里并没有对输出channel 进行分组,因为我们解决的问题不一样,内存对我们来说不是问题,问题仅仅在于输入channel太大。

而通过上面的魔改变换,便可以使得最终的结果和原始卷积计算一致,借用“分组卷积”的思想,可以很好的解决我遇到的问题。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/150829.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

gitlab查看、修改用户和邮箱,gitlab生成密钥

查看用户、邮箱 git config user.name git config user.email 修改用户、邮箱 git config --global user.name “xxx” git config --global user.email “xxxxxx.com” 生成ssh密钥 ssh-keygen -t rsa -C “xxxxxx.com” 查看SSH秘钥 cat ~/.ssh/id_rsa.pub 将秘钥复制&…

玩转ChatGPT:批量下载Alphafold的蛋白pdb文件

一、写在前面 突发奇想,想批量下载Alphafold网站的蛋白pdb文件,后续再做个分子对接用。又不想手动下载,来求助CSDN和GPT。 二、CSDN白嫖基础代码 CSDN大神多,这不,找到一个:Alphafold批量下载蛋白的pdb文…

mysql-面试50题-2

一、查询数据 学生表 Student create table Student(SId varchar(10),Sname varchar(10),Sage datetime,Ssex varchar(10)); insert into Student values(01 , 赵雷 , 1990-01-01 , 男); insert into Student values(02 , 钱电 , 1990-12-21 , 男); insert into Student v…

hadoop集群搭建

hadoop有三种部署方式 1、Local (Standalone) Mode(单机模式) 数据存储在本地 2、Pseudo-Distributed Mode(伪集群模式) 数据存储在HDFS 3、Fully-Distributed Mode(集群模式) 集群部署,数据存储…

网络扫描与网络监听

前言:前文给大家介绍了网络安全相关方面的基础知识体系,以及什么是黑客,本篇文章笔者就给大家带来“黑客攻击五部曲”中的网络扫描和网络监听 目录 黑客攻击五部曲 网络扫描 按扫描策略分类 按照扫描方式分类 被动式策略 系统用户扫描 …

Matter.js 插件:matter-wrap(世界是圆的)

本文简介 点赞 关注 收藏 学会了 记得以前看爆笑校园里有一集讲到,一个人对着前面开了一枪,过了一阵子弹打中他自己的后脑勺。作者想通过这个冷笑话告诉大家一件事:地球是圆的。 在 Matter.js 世界里,默认是没有边界的&#…

MSQL系列(八) Mysql实战-SQL存储引擎

Mysql实战-SQL存储引擎 前面我们讲解了索引的存储结构,BTree的索引结构,我们一般都知道Mysql的存储引擎有两种,MyISAM和InnoDB,今天我们来详细讲解下Mysql的存储引擎 文章目录 Mysql实战-SQL存储引擎1.存储引擎2.MyISAM的特点3. InnoDB的特…

11 结构型模式- 代理模式

结构性模式一共包括七种: 代理模式、桥接模式、装饰者模式、适配器模式、门面(外观)模式、组合模式、和享元模式。 1 代理模式介绍 软件开发中的代理: 代理模式中引入了一个新的代理对象,代理对象在客户端对象和目标对象之间起到了中介的作用,它去掉客…

Linux系列讲解 —— VIM配置与美化

目录 1. Vim基本配置2. 插件管理器Vundle2.1 下载Vundle2.2 在vimrc中添加Vundle的配置 3. Vundle的使用3.1 安装常用插件3.1.1 NERDTree 3.2 卸载插件 1. Vim基本配置 1.1 配置文件 vim的配置文件有两处,请根据实际情况选择修改哪个。 (1) 全局配置文件&#xff…

RT-Thread 7. RT-Thread Studio ENV修改MCU型号

1. 修改MCU型号 2.在ENV界面输入 scons -c scons --dist3. dist下为更新后完整源代码 4.导入RT-Thread Studio 发现GD32F330已经生效了。 5. 自己编写startup_gd32f3x0.S,准确性待验证 ;/* ; * Copyright (c) 2006-2021, RT-Thread Development Team ; * ; * SPD…

javaEE -10(11000字详解5层重要协议)

一:应用层重点协议 1.1: DNS DNS,即Domain Name System,域名系统。DNS是一整套从域名映射到IP的系统。 TCP/IP中使用IP地址来确定网络上的一台主机,但是IP地址不方便记忆,且不能表达地址组织信息&#x…

c++ qt连接操作sqlite

qt客户端编程,用到数据库的场景不多,但是部分项目还是需要数据库来保存同步数据,客户端用到的数据库,一般是sqlite。 Qt提供了数据库模块,但是qt本身的数据库模块并不好用,会有各种问题, 建议大家不要,可以自己封装数据库的操作。本篇博客介绍qt连接操作sqlite。 sqlit…