Factor Transfer(NeurIPS 2018)

paper:Paraphrasing Complex Network: Network Compression via Factor Transfer

official implementation:https://github.com/Jangho-Kim/Factor-Transfer-pytorch

背景

尽管现有的知识蒸馏方法如KD、FitNet等带来了性能的改善,但直接传递教师的输出忽略了教师和学生之间的内在差异,如网络结构、通道数量、初始条件等。因此,我们需要重新解释教师网络的输出来解决这些差异。例如,从老师和学生的角度来看,对于一个问题,直接给出老师的知识而不做任何解释对教学生来说是不够的。换句话说,在教孩子时,老师不应该使用她自己的术语,因为孩子不能理解它。另一方面,如果老师把自己的术语翻译成更简单的术语,孩子会更容易理解。

本文的创新点

本文提出了一种知识蒸馏方法,使得教师和学生都能生成更容易传递的知识,文中称为“factor”。和传统的方法不一样,该方法不是仅仅直接比较网络的输出,而是训练一个神经网络可以提取好的factor并匹配这些factor。从教师网络中提取factor的网络称为paraphraser,从学生网络提取factor的网络称为translator。paraphraser以无监督的方式训练,期望它提取不同于有监督损失可以获得的知识。translator和学生网络一同训练用来吸收paraphraser从教师网络提取的factor。

方法介绍

Paraphraser

paraphraser通过几个卷积层来得到教师的factor \(F_{T}\),并在训练阶段被一些转置的卷积层进一步处理。大多数卷积自动编码器autoencoder的设计都进行了降采样,从而增加感受野。相反,paraphraser在调整factor通道数量的同时保持空间维度大小,因为它使用了最后一组的特征图,它的特征图已经足够小了。如果教师网络产生m个特征图,我们将factor通道的数量调整为m×k。我们把超参k称为paraphraser rate。

无监督训练paraphraser的reconstruction loss如下

其中paraphraser网络 \(P(\cdot)\) 以 \(x\) 为输入。

Translator

如图1所示,在训练学生网络时,在学生网络最后一组卷积层的后面插入translator,并与学生网络一起训练。这里translator起到一个buffer的作用,通过重新表述学生网络的特征图,减轻了学生网络直接学习教师网络输出的负担。

学生网络的训练包含两个损失,分类loss和factor transfer loss,如下

其中 \(F_{T},F_{S}\) 分别表示教师和学生的factor。因为使用 \(l_{1}(p=1)\) 损失和 \(l_{2}(p=2)\) 损失的差异不大,后续所有实验都是用 \(l_{1}\) 损失。\(\beta\) 是权重参数,\(C(S(I_{x}),y)\) 表示ground truth \(y\) 和网络softmax输出 \(S(I_{x})\) 之间的交叉熵损失。

代码解析

paraphraser的代码如下,mode=0为无监督训练模式,包含三个卷积和三个转置卷积,步长都为1。

class Paraphraser(nn.Module):def __init__(self, in_planes, planes, stride=1):super(Paraphraser, self).__init__()self.leakyrelu = nn.LeakyReLU(0.1)# self.bn0 = nn.BatchNorm2d(in_planes)self.conv0 = nn.Conv2d(in_planes, in_planes, kernel_size=3, stride=1, padding=1, bias=True)# self.bn1 = nn.BatchNorm2d(planes)self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=1, padding=1, bias=True)# self.bn2 = nn.BatchNorm2d(planes)self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=True)# self.bn0_de = nn.BatchNorm2d(planes)self.deconv0 = nn.ConvTranspose2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=True)# self.bn1_de = nn.BatchNorm2d(in_planes)self.deconv1 = nn.ConvTranspose2d(planes, in_planes, kernel_size=3, stride=1, padding=1, bias=True)# self.bn2_de = nn.BatchNorm2d(in_planes)self.deconv2 = nn.ConvTranspose2d(in_planes, in_planes, kernel_size=3, stride=1, padding=1, bias=True)# ### Mode 0 - throw encoder and decoder (reconstruction)# ### Mode 1 - extracting teacher factorsdef forward(self, x, mode):if mode == 0:# encoderout = self.leakyrelu((self.conv0(x)))out = self.leakyrelu((self.conv1(out)))out = self.leakyrelu((self.conv2(out)))# decoderout = self.leakyrelu((self.deconv0(out)))out = self.leakyrelu((self.deconv1(out)))out = self.leakyrelu((self.deconv2(out)))if mode == 1:out = self.leakyrelu((self.conv0(x)))out = self.leakyrelu((self.conv1(out)))out = self.leakyrelu((self.conv2(out)))# only throw decoderif mode == 2:out = self.leakyrelu((self.deconv0(x)))out = self.leakyrelu((self.deconv1(out)))out = self.leakyrelu((self.deconv2(out)))return out

训练paraphraser时前面的网络和教师网络是一致的,在最后一组卷积后面添加paraphraser。训练的代码如下,其中model就是教师网络,outputs[2]为教师网络最后一组卷积的输出,module为paraphraser,criterion为nn.L1Loss()。

outputs = model(inputs)
# reconstructed feature maps (Mode 0; see FeatureProjection.py)
output_p = module(outputs[2], 0)
loss = criterion(output_p, outputs[2].detach())

translator的代码如下,只包含三个卷积

class Translator(nn.Module):def __init__(self, in_planes, planes, stride=1):super(Translator, self).__init__()self.leakyrelu = nn.LeakyReLU(0.1)# self.bn0 = nn.BatchNorm2d(in_planes)self.conv0 = nn.Conv2d(in_planes, in_planes, kernel_size=3, stride=1, padding=1, bias=True)# self.bn1 = nn.BatchNorm2d(planes)self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=1, padding=1, bias=True)# self.bn2 = nn.BatchNorm2d(planes)self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=True)def forward(self, x):out = self.leakyrelu((self.conv0(x)))out = self.leakyrelu((self.conv1(out)))out = self.leakyrelu((self.conv2(out)))return out

在蒸馏下训练学生网络的代码如下,teacher和student分别是教师和学生网络,输出都取最后一组卷积的输出。module_t和module_s分别是paraphraser和translator,这里mode=1,从上面paraphraser的代码可以看出mode=1时只经过前面三层卷积。最后cricriterion=nn.L1Loss(),criterion_CE=nn.CrossEntropyLoss()。

teacher_outputs = teacher(inputs)
student_outputs = student(inputs)factor_t = module_t(teacher_outputs[2], 1)
factor_s = module_s(student_outputs[2])loss = BETA * (criterion(utils.FT(factor_s), utils.FT(factor_t.detach()))) + criterion_CE(student_outputs[3], targets)

实验结果

下面是Factor Transfer与其它蒸馏方法在CIFAR-10和CIFAR-100数据集上的结果对比,可以看出当FT与原始KD结合使用时,效果可能会变差。

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/434727.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

element plus使用问题

文章目录 element plusvue.config.js注意1、有时候会报错 not a function2、使用 ElMessage 报错3、 element plus 版本过高4、警告Feature flag VUE_PROD_HYDRATION_MISMATCH_DETAILS is not explicitly defined.5、报错 ResizeObserver loop completed with undelivered noti…

如何监控两台android设备之间串口通讯的ADB日志?

如果你的目标是将设备通过 Wi-Fi 连接到计算机,可以执行以下步骤: 一.通过 USB 连接设备: adb devices 确保设备通过 USB 连接,并且可以通过 adb devices 命令正常识别。 二、将设备1和设备2都切换到 TCP/IP 模式:…

汇编led驱动的代码编写以及ubuntu下的烧录

文章目录 前言一、实验代码详解二、编译1、arm-linux-gnueabihf-gcc 编译文件2、arm-linux-gnueabihf-ld 链接文件3、arm-linux-gnueabihf-objcopy 格式转换4、arm-linux-gnueabihf-objdump 反汇编5、编写Makefile文件 三、代码烧写1、将 imxdownload 拷贝到工程根目录下2、给予…

幻兽帕鲁服务器多少钱一台?腾讯云新版报价

腾讯云幻兽帕鲁服务器4核16G、8核32G和16核64G配置可选,4核16G14M带宽66元一个月、277元3个月,8核32G22M配置115元1个月、345元3个月,16核64G35M配置580元年1个月、1740元3个月、6960元一年,腾讯云百科txybk.com分享腾讯云幻兽帕鲁…

Linux的常见指令和基本操作演绎【复习篇章一】

文章目录 前言下载安装 XShellXShell 下的复制粘贴热键操作01.ls指令tree 02.cd指令03.touch指令04.mkdir指令(重要):05.rmdir指令 && rm 指令(重要)06.组合07.man指令(重要)&#xff1…

【Linux 内核源码分析】多核调度分析

多核调度 SMP(Symmetric Multiprocessing,对称多处理)是一种常见的多核处理器架构。它将多个处理器集成到一个计算机系统中,并通过共享系统总线和内存子系统来实现处理器之间的通信。 首先,SMP架构将一组处理器集中在…

Unity 光照

光照烘培 光照模式切换为 Baked 或 Mixed,Baked 模式完全使用光照贴图模拟光照,运行时修改光照颜色不生效,Mixed 模式也使用光照贴图,并且进行一些实时运算,运行时修改光照颜色会生效 受光照影响的物体勾选 Contribute…

Cesium渲染白膜数据

async DrawBaiMoFun2() {// tiles 矩阵变换let changePostion = (tileSet, tx, ty, tz, rx, ry, rz, scale, center) => {if (!center) return;const m = Cesium.Transforms.eastNorthUpToFixedFrame(center);const surface =center ||Cesium.Cartesian3.fromRadians(cartog…

【深度学习:开源BERT】 用于自然语言处理的最先进的预训练

【深度学习:开源BERT】 用于自然语言处理的最先进的预训练 是什么让 BERT 与众不同?双向性的优势使用云 TPU 进行训练BERT 结果让 BERT 为您所用 自然语言处理 (NLP) 面临的最大挑战之一是训练数据的短缺。由于 NLP 是一个具有许多…

语音生成、写作增强、论文辅助、英文学习,AI原生应用精彩推荐一箩筐!

崭新的2024年已然降临,飞桨星河社区再次涌现出诸多精彩纷呈的AI原生应用,快来一同探索,发现这些应用带来的无限惊喜与可能吧! 语音生成:10音色自由选择 应用介绍 本应用基于ERNIE SDK和语音合成工具,可以…

C++ 关于“常量”的知识整理:

目录 1 常量对象: 2 常量成员: 2.1常量数据成员: 常数据成员总结: 2.2常量成员函数(使用最多): 常成员函数总结: 3 常量引用: C中常量的值在程序运行中不允许被改…

QT+VS实现Kmeans聚类算法

1、Kmeans的定义 聚类是一个将数据集中在某些方面相似的数据成员进行分类组织的过程,聚类就是一种发现这种内在结构的技术,聚类技术经常被称为无监督学习。k均值聚类是最著名的划分聚类算法,由于简洁和效率使得他成为所有聚类算法中最广泛使…