「AI模型瘦身术」——知识蒸馏技术综述

使用KD原因

遇到问题:从产业发展的角度来看工业化将逐渐过渡到智能化,边缘计算逐渐兴起预示着 AI 将逐渐与小型化智能化的设备深度融合,这也要求模型更加的便捷、高效、轻量以适应这些设备的部署。

解决方案:知识蒸馏技术

知识蒸馏的关键点

如果回归机器学习最最基础的理论,我们可以很清楚地意识到一点(而这一点往往在我们深入研究机器学习之后被忽略): 机器学习最根本的目的在于训练出在某个问题上泛化能力强的模型。

泛化能力强: 在某问题的所有数据上都能很好地反应输入和输出之间的关系,无论是训练数据,还是测试数据,还是任何属于该问题的未知数据。

而现实中,由于我们不可能收集到某问题的所有数据来作为训练数据,并且新数据总是在源源不断的产生,因此我们只能退而求其次,训练目标变成在已有的训练数据集上建模输入和输出之间的关系。由于训练数据集是对真实数据分布情况的采样,训练数据集上的最优解往往会多少偏离真正的最优解(这里的讨论不考虑模型容量)。

而在知识蒸馏时,由于我们已经有了一个泛化能力较强的Net-T,我们在利用Net-T来蒸馏训练Net-S时,可以直接让Net-S去学习Net-T的泛化能力。

一个很直白且高效的迁移泛化能力的方法就是:使用softmax层输出的类别的概率来作为“soft target”。

KD的训练过程和传统的训练过程的对比

传统training过程(hard targets): 对ground truth求极大似然

KD的training过程(soft targets): 用large model的class probabilities作为soft targets

KD的训练过程为什么更有效?

softmax层的输出,除了正例之外,负标签也带有大量的信息,比如某些负标签对应的概率远远大于其他负标签。而在传统的训练过程(hard target)中,所有负标签都被统一对待。也就是说,KD的训练方式使得每个样本给Net-S带来的信息量大于传统的训练方式。

【举个例子】

在手写体数字识别任务MNIST中,输出类别有10个。

假设某个输入的“2”更加形似"3",softmax的输出值中"3"对应的概率为0.1,而其他负标签对应的值都很小,而另一个"2"更加形似"7","7"对应的概率为0.1。这两个"2"对应的hard target的值是相同的,但是它们的soft target却是不同的,由此我们可见soft target蕴含着比hard target多的信息。并且soft target分布的熵相对高时,其soft target蕴含的知识就更丰富。

这就解释了为什么通过蒸馏的方法训练出的Net-S相比使用完全相同的模型结构和训练数据只使用hard target的训练方法得到的模型,拥有更好的泛化能力。 下图为知识蒸馏的通用形式。

知识传递形式

原始知识蒸馏(Vanilla Knowledge Distillation)仅仅是从教师模型输出的软目标中学习出轻量级的学生模型。

然而,当教师模型变得更深时,仅仅学习软目标是不够的。

因此,我们不仅需要获取教师模型输出的知识,还需要学习隐含在教师模型中的其它知识,比如有输出特征知识、中间特征知识、关系特征知识和结构特征知识。

标签知识是神经网络对样本数据最终的预测输出中包含的潜在信息,这也是目前蒸馏过程中最简单、应用最多的方式。

标签知识(输出特征知识)通常指的是教师模型的最后一层特征,主要包括逻辑单元和软目标的知识。标签知识(输出特征知识)知识蒸馏的主要思想是促使学生能够学习到教师模型的最终预测,以达到和教师模型一样的预测性能。

原始知识蒸馏是针对分类任务来提出的仅包含类间相似性的软目标知识,然而其它任务(如目标检测)网络最后一层特征输出中还可能包含有目标定位的信息。

换句话说,不同任务教师模型的最后一层输出特征是不一样的。因此,本文根据任务的不同对输 出特征知识分别进行归纳和分析,如表 1 所示。

Hinton 等人最早提出的知识蒸馏方法就属于目标分类的标签知识(输出特征知识)。由于经过“蒸馏温度”调节后的软标签中具有很多不确定信息,通常的研究认为这其中反映了样本间的相似度或干扰性、样本预测的难度,因此标签知识又被称为“暗知识”。

  • 为了有效地解决基于聚类的算法中的伪标签噪声的问题,Ge等人[45]利用“同步平均教学”的蒸馏框架进行伪标签优化,核心思想是利用更为鲁棒的“软”标签对伪标签进行在线优化。

  • MLP[46]提出了基于元学习(Meta - learning)自适应生成目标分布的方法,用于教师和学生模型的伪标签学习过程.利用一个筛选网络从目标检测模型预测的伪标签中区分出正例和负例,将正例用于下一阶段的半监督自训练过程,可以有效提升标签数据的利用率[43]。

  • Xie等人[4]利用有监督训练学生模型自身,在自蒸馏训练中额外地引入无标签噪声数据产生伪标签,将ImageNet的Top-1识别结果提高了约1%.对于标签知识蒸馏方法本身,已经有非常多的变体和应用,主要是从改进蒸馏过程、挖掘标签信息、去除干扰等方面,提升学生模型的性能.

  • Gao等人[47]实现了一种简单的逐阶段的标签蒸馏训练过程,在梯度下降训练过程中,每次只更新学生网络的一个模块,从前至后直到全部更新完成。

  • 根据Mirzadeh等人[48]的研究发现,并不是教师模型性能越高对于学生模型的学习越有利,当教师-学生模型之间的差距过大时,会导致学生难以从教师模型获得提升.为此,他们提出使用辅助教师策略来逐渐缩小教师和学生之间的学习差距,取得更好的蒸馏效果.

  • 同样是为了缩小教师 - 学生之间的学习差距,Yang等人[49]则提出利用教师模型在每个训练周期更新的中间模型产生的标签知识指导学生模型.为了充分挖掘标签信息、去除干扰,Müller等人[50]采用了子类别蒸馏方法,将原标签分组合并参与软标签蒸馏学习;

  • 文献[51]则研究了蒸馏损失函数对犔2范数和归一化的软标签的作用,提出使用球面空间度量蒸馏的方法去除范数的影响;

  • Zhang等人[52]关注了样本权重的影响,通过预测不确定性自适应分配样本权重,改善蒸馏过程;

  • Wu等人[53]提出了同伴协同蒸馏,通过训练多个分支网络并将其他训练较强教师的 logits 知识转移给同伴,有利于模型的稳定和提高蒸馏的质量。

最早使用教师模型中间特征知识的是 FitNets[27],其主要思想是促使学生的隐含层能预测出与教师隐含层相近的输出。

知识传递方式中有同构蒸馏和异构蒸馏,主要就是区分 是否:教师和学生模型的架构相似或属于同一系列的、层与层(Layer -to - Layer)或块与块(Block - to - Block)之间一一对应;不过通过这几年的实验来看,这并没有什么区别

不同知识传递形式的效果

如图所示,不同的知识传递形式,相比是有差异的,使用经典的KD标签知识是还不错的;使用特征间的,有较多都不如开山鼻祖KD;不过近期又有更多优化,比如使用互信息与对比学习的方法;

温度的特点

在回答这个问题之前,先讨论一下温度T的特点

  1. 原始的softmax函数是 𝑇=1 时的特例, 𝑇<1 时,概率分布比原始更“陡峭”, 𝑇1 时,概率分布比原始更“平缓”。

  2. 温度越高,softmax上各个值的分布就越平均(思考极端情况: (i) 𝑇=∞ , 此时softmax的值是平均分布的;(ii) 𝑇→0,此时softmax的值就相当于 𝑎𝑟𝑔𝑚𝑎𝑥 , 即最大的概率处的值趋近于1,而其他值趋近于0)

  3. 不管温度T怎么取值,Soft target都有忽略相对较小的 𝑝𝑖 携带的信息的倾向

温度代表了什么,如何选取合适的温度?

温度的高低改变的是Net-S训练过程中对负标签的关注程度: 温度较低时,对负标签的关注,尤其是那些显著低于平均值的负标签的关注较少;而温度较高时,负标签相关的值会相对增大,Net-S会相对多地关注到负标签。

实际上,负标签中包含一定的信息,尤其是那些值显著高于平均值的负标签。但由于Net-T的训练过程决定了负标签部分比较noisy,并且负标签的值越低,其信息就越不可靠。因此温度的选取比较empirical,本质上就是在下面两件事之中取舍:

  1. 从有部分信息量的负标签中学习 --> 温度要高一些

  2. 防止受负标签中噪声的影响 -->温度要低一些

总的来说,T的选择和Net-S的大小有关,Net-S参数量比较小的时候,相对比较低的温度就可以了(因为参数量小的模型不能capture all knowledge,所以可以适当忽略掉一些负标签的信息)

CRD 对比学习

首先 CRD是2020年提出的新模式的蒸馏方法,使用对比学习,在这年对比了12个KD方法都是最好的,其中,CRD+KD两个方法合在一起更好,相当于两个维度的知识传递的监督,在2023年有基于CRD实现的CRCD,效果好一点,方案是差不多的;

知识提炼(KD)将知识从一个深度学习模型(教师)转移到另一个深度学习模型(学生)。Hinton等人(2015)最初提出的目标是将教师和学生输出之间的KL差异最小化。当输出是一个分布,例如类上的概率质量函数时,该公式具有直观意义。然而,我们通常希望传递有关representation的知识。例如,在“跨模态蒸馏”问题中,我们可能希望将图像处理网络的表示转移到声音(Aytar等人,2016)或深度(Gupta等人,2016)处理网络,这样图像的深度特征和相关的声音或深度特征高度相关。在这种情况下,KL发散是不确定的。

表征知识是结构化的——维度表现出复杂的相互依赖性。最初的KD目标(Hinton等人,2015年)将所有维度视为独立的,以输入为条件。让yT成为老师的输出,yS成为学生的输出。那么原始的KD目标函数ψ具有全因子形式:. 这种带因素的目标不足以传递结构知识,即输出维度i和j之间的依赖关系。这与图像生成中的情况类似,在图像生成中,由于输出维度之间的独立性假设,L2目标会产生模糊的结果。

为了克服这个问题,我们想要一个目标,捕捉相关性和高阶输出依赖性。为了实现这一点,在本文中,我们利用了对比目标家族(Gutmann&Hyvärinen,2010;Oord等人,2018;Arora等人,2019;Hjelm等人,2018)。近年来,这些目标函数已成功地用于密度估计和表征学习,尤其是在自我监督环境中。在这里,我们让他们适应从一个深层网络到另一个深层网络的知识蒸馏任务。我们表明,致力于研究表现空间很重要,类似于最近的工作,如Zagoruyko和Komodakis(2016a);Remero等人(2014年)。然而,请注意,这些工作中使用的损失函数并没有明确尝试捕捉表征空间中的相关性或高阶相关性。

图1:我们考虑的三种提取设置:(a)压缩模型,(b)将知识从一种模式(例如RGB)转移到另一种模式(例如深度),(c)将网络集合提取到单个网络中。对比目标鼓励教师和学生将相同的输入映射到接近的表示(在某些度量空间中),并将不同的输入映射到遥远的表示,如阴影圈所示。

我们的目标是最大化教师和学生之间的互信息的下限。我们发现,这会在多个知识转移任务中产生更好的表现。我们推测,这是因为对比目标能更好地传递教师表征中的所有信息,而不仅仅是传递关于条件独立输出类概率的知识。有些令人惊讶的是,对比目标甚至改善了最初提出的提取类概率知识的任务的结果,例如,将大型CIFAR100网络压缩为性能几乎相同的较小网络。我们认为这是因为不同类别概率之间的相关性包含有用的信息,可以规范学习问题。我们的论文在两个主要独立发展的文献之间建立了联系:知识蒸馏和表征学习。这种联系使我们能够利用表征学习的强大方法,显著提高知识蒸馏的SOTA。

我们的贡献是:

1.基于对比的目标,用于在深度网络之间传递知识。

2.模型压缩、跨模态传输和整体蒸馏的应用。

3.对标12种最新蒸馏方法;CRD优于所有其他方法,例如,与原始KD相比,平均相对改善57%(Hinton等人,2015),令人惊讶的是,后者的表现次之。

这是近几年的得分,有使用crd结合其他损失的,可以在一些任务中得到较好表现,不同任务表现不一致,

多教师蒸馏

多教师蒸馏(Multi-Teacher Distillation)是一种知识蒸馏的方法,它通过同时蒸馏多个教师网络的知识来提升学生网络的性能。相比于传统的单一教师蒸馏,多教师蒸馏可以利用不同教师网络的多样性和丰富性,从而获得更全面的知识传递。

在多教师蒸馏中,通常包括一个学生网络(Student Network)和多个教师网络(Teacher Networks)。每个教师网络都是一个独立的模型,具有不同的架构或参数初始化。学生网络通过同时学习多个教师网络的知识来提高自己的性能。

多教师蒸馏的核心思想是将不同教师网络的预测结果作为辅助目标来训练学生网络。具体而言,多教师蒸馏包括以下步骤:

1、教师网络的训练:针对不同的教师网络,使用标准的监督学习方法进行训练,以获得具有丰富知识的教师模型。

2、教师网络的预测:使用已训练好的教师网络对输入样本进行预测,得到多个教师网络的预测结果。

3、学生网络的训练:将教师网络的预测结果作为辅助目标,与真实标签一起用于训练学生网络。通过最小化学生网络的预测与教师网络预测之间的差异,将教师网络的知识传递给学生网络。

4、蒸馏损失函数的定义:通常使用交叉熵损失函数来衡量学生网络的分类性能。同时,为了传递教师网络的知识,可以定义额外的辅助目标损失,如平均软标签损失(Mean Soft Labels Loss)或特定的蒸馏损失函数。

通过多教师蒸馏,学生网络能够从多个教师网络中获得更丰富的知识,并综合各个教师网络的预测结果来提高自己的性能。多教师蒸馏可以增强模型的泛化能力,减少过拟合问题,并在复杂任务中取得更好的性能表现。

好,接下来我们从源码分析;

蒸馏算法源码分析

KD

链接:https://arxiv.org/pdf/1503.02531.pd3f

发表:NIPS14

class DistillKL(nn.Module):"""Distilling the Knowledge in a Neural Network"""def __init__(self, T):super(DistillKL, self).__init__()self.T = T #教师模型指导学生模型的程度(蒸馏温度),值越大,指导程度越高def forward(self, y_s, y_t):p_s = F.log_softmax(y_s/self.T, dim=1)p_t = F.softmax(y_t/self.T, dim=1)#下面就是对两个模型的预测值,做KL散度的分布分析,如果偏差越大,则kl散度算出来的值越大。#p_t表示教师模型的目标值#p_s表示学生模型的预测值loss = F.kl_div(p_s, p_t, size_average=False) * (self.T**2) / y_s.shape[0]return loss

核心就是一个kl_div函数,用于计算学生网络和教师网络的分布差异。输入为学生和教师模型的分类输出,经过温度可控的软化之后进行KL散度计算,简单直接粗暴有效;

FitNet

全称:Fitnets: hints for thin deep nets

链接:https://arxiv.org/pdf/1412.6550.pdf

发表:ICLR 15 Poster

很容易理解,方法使用特征间信息,对中间层进行蒸馏的开山之作,通过将学生网络的feature map扩展到与教师网络的feature map相同尺寸以后,使用均方误差MSE Loss来衡量两者差异

(1)大模型训练,小模型随机初始化

(2)将大模型特征提取器的第H层作为hint,从第一层到第H层的参数对应图(a)中Whint,,选择小模型特征提取器的第G层作为guided,从第一层到第G层对应图(a)中Wguided

(3)两者feature map大小可能不匹配,引入卷积层调整器(Wr)对guided层进行调整,对应图(b)

(4)优化均方损失函数

(5)对预训练好的小模型进行进一步知识蒸馏,对应图

 
class HintLoss(nn.Module):"""Fitnets: hints for thin deep nets, ICLR 2015"""def __init__(self):super(HintLoss, self).__init__()self.crit = nn.MSELoss()  # 在这个类中,初始化函数中使用了nn.MSELoss(),即均方误差损失函数,
用于度量学生网络和教师网络之间的均方误差'''
在前向传播函数中,接收学生网络的中间层表示f_s和教师网络的中间层表示f_t作为输入。
然后使用均方误差损失函数计算它们之间的差异,得到"hint"损失。
'''def forward(self, f_s, f_t):loss = self.crit(f_s, f_t)return loss
class ConvReg(nn.Module):"""Convolutional regression for FitNet 用来对齐T-S某层feature map的特征尺寸 可学"""def __init__(self, s_shape, t_shape, use_relu=True):super(ConvReg, self).__init__()self.use_relu = use_relus_N, s_C, s_H, s_W = s_shapet_N, t_C, t_H, t_W = t_shapeif s_H == 2 * t_H:self.conv = nn.Conv2d(s_C, t_C, kernel_size=3, stride=2, padding=1)elif s_H * 2 == t_H:self.conv = nn.ConvTranspose2d(s_C, t_C, kernel_size=4, stride=2, padding=1)elif s_H >= t_H:self.conv = nn.Conv2d(s_C, t_C, kernel_size=(1+s_H-t_H, 1+s_W-t_W))else:raise NotImplemented('student size {}, teacher size {}'.format(s_H, t_H))self.bn = nn.BatchNorm2d(t_C)self.relu = nn.ReLU(inplace=True)def forward(self, x):x = self.conv(x)if self.use_relu:return self.relu(self.bn(x))else:return self.bn(x)

损失计算时,就先使用guided 网络处理完,送进fitloss算一次mse即可;

Fitloss 使用的特征维度做监督,效果没有kd好,可能是由于mse或者特征的提取选择不好,可以考虑多使用几个维度的特征监督;

PKT:Probabilistic Knowledge Transfer

全称:Probabilistic Knowledge Transfer for deep representation learning

链接:https://arxiv.org/abs/1803.10837

发表:CoRR18

提出一种概率知识转移方法,引入了互信息来进行建模。该方法具有可跨模态知识转移、无需考虑任务类型、可将手工特征融入网络等的优点。

 

class PKT(nn.Module):"""Probabilistic Knowledge Transfer for deep representation learningCode from author: https://github.com/passalis/probabilistic_kt"""def __init__(self):super(PKT, self).__init__()def forward(self, f_s, f_t):return self.cosine_similarity_loss(f_s, f_t)@staticmethoddef cosine_similarity_loss(output_net, target_net, eps=0.0000001):# Normalize each vector by its normoutput_net_norm = torch.sqrt(torch.sum(output_net ** 2, dim=1, keepdim=True))output_net = output_net / (output_net_norm + eps)output_net[output_net != output_net] = 0target_net_norm = torch.sqrt(torch.sum(target_net ** 2, dim=1, keepdim=True))target_net = target_net / (target_net_norm + eps)target_net[target_net != target_net] = 0# Calculate the cosine similaritymodel_similarity = torch.mm(output_net, output_net.transpose(0, 1))target_similarity = torch.mm(target_net, target_net.transpose(0, 1))# Scale cosine similarity to 0..1model_similarity = (model_similarity + 1.0) / 2.0target_similarity = (target_similarity + 1.0) / 2.0# Transform them into probabilitiesmodel_similarity = model_similarity / torch.sum(model_similarity, dim=1, keepdim=True)target_similarity = target_similarity / torch.sum(target_similarity, dim=1, keepdim=True)# Calculate the KL-divergenceloss = torch.mean(target_similarity * torch.log((target_similarity + eps) / (model_similarity + eps)))return loss

这和PKT方法效果比KD好一些,主要是使用了概率传递学习先将教师和学生的网络输出进行标准化,再将输出的特征信息使用矩阵乘法、概率化方法映射到另一个空间,最后进行KL散度计算,就是在KD的基础上,将网络输出进行非线性映射成一个更简单的空间,监督这个空间下的S-T KL散度

CRD: Contrastive Representation Distillation

全称:Contrastive Representation Distillation

链接:https://arxiv.org/abs/1910.10699v2

发表:ICLR20

将对比学习引入知识蒸馏中,其目标修正为:学习一个表征,让正样本对的教师网络与学生网络尽可能接近,负样本对教师网络与学生网络尽可能远离。

构建的对比学习问题表示如下:

整体的蒸馏Loss表示如下:

实现如下:https://github.com/HobbitLong/RepDistiller

class ContrastLoss(nn.Module):"""contrastive loss, corresponding to Eq (18)"""def __init__(self, n_data):super(ContrastLoss, self).__init__()self.n_data = n_datadef forward(self, x):bsz = x.shape[0]m = x.size(1) - 1# noise distributionPn = 1 / float(self.n_data)# loss for positive pairP_pos = x.select(1, 0)log_D1 = torch.div(P_pos, P_pos.add(m * Pn + eps)).log_()# loss for K negative pairP_neg = x.narrow(1, 1, m)log_D0 = torch.div(P_neg.clone().fill_(m * Pn), P_neg.add(m * Pn + eps)).log_()loss = - (log_D1.sum(0) + log_D0.view(-1, 1).sum(0)) / bszreturn lossclass CRDLoss(nn.Module):"""CRD Loss functionincludes two symmetric parts:(a) using teacher as anchor, choose positive and negatives over the student side(b) using student as anchor, choose positive and negatives over the teacher sideArgs:opt.s_dim: the dimension of student's featureopt.t_dim: the dimension of teacher's featureopt.feat_dim: the dimension of the projection spaceopt.nce_k: number of negatives paired with each positiveopt.nce_t: the temperatureopt.nce_m: the momentum for updating the memory bufferopt.n_data: the number of samples in the training set, therefor the memory buffer is: opt.n_data x opt.feat_dim"""def __init__(self, opt):super(CRDLoss, self).__init__()self.embed_s = Embed(opt.s_dim, opt.feat_dim)self.embed_t = Embed(opt.t_dim, opt.feat_dim)self.contrast = ContrastMemory(opt.feat_dim, opt.n_data, opt.nce_k, opt.nce_t, opt.nce_m)self.criterion_t = ContrastLoss(opt.n_data)self.criterion_s = ContrastLoss(opt.n_data)def forward(self, f_s, f_t, idx, contrast_idx=None):"""Args:f_s: the feature of student network, size [batch_size, s_dim]f_t: the feature of teacher network, size [batch_size, t_dim]idx: the indices of these positive samples in the dataset, size [batch_size]contrast_idx: the indices of negative samples, size [batch_size, nce_k]Returns:The contrastive loss"""f_s = self.embed_s(f_s)f_t = self.embed_t(f_t)out_s, out_t = self.contrast(f_s, f_t, idx, contrast_idx)s_loss = self.criterion_s(out_s)t_loss = self.criterion_t(out_t)loss = s_loss + t_lossreturn loss
 

他会在训练过程中,使用contrast-memory 来记忆网络的负样本,在网络训练中互信息监督;效果不错;

超分等生成任务与蒸馏

众所周知,图像/视频超分 (SR) 是工业界非常具有应用场景的应用,但能够生产具有良好视觉效果的重建图像的SR模型的参数量和运算量都非常巨大,比如业界公认的优秀baseline模型EDSR,EDVR等的算力需求高达几百,几千GFLOPs。而业界真正需求的轻量化模型,尤其是可以部署于移动端设备的实时模型,其算力限制可能严苛到小于10GFlops。

在high-level CV tasks上得到广泛应用和验证的模型剪枝、c馏方法应用到超分任务上,即将一个训练好的大模型进行裁剪,或者用性能较强的教师大模型蒸馏原本较弱的学生小模型,使裁剪/蒸馏后的小模型能够取得相比普通训练方式更好,甚至接近原先大模型的性能。这里的challenge在于,直接的迁移应用这些算法,在超分任务上无法得到有效的性能提升,甚至可能导致非常严重的performance degradation.

  • SRKD:它将最基本的知识蒸馏直接应用到图像超分中,整体思想分类网络中的蒸馏方式基本一致,整体来看属于应用形式;

  • FAKD:它在常规知识蒸馏的基础上引入了特征关联机制,进一步提升被蒸馏所得学生网络的性能,相比直接应用有了一定程度的提升;

  • PISR:它则是利用了广义蒸馏的思想进行超分网络的蒸馏,通过充分利用训练过程中HR信息的可获取性进一步提升学生网络的性能。

上图给出了SRKD的蒸馏示意图,它采用了最基本的知识蒸馏思想对老师网络与学生网络的不同阶段特征进行蒸馏。考虑到老师网络与学生网络的通道数可能是不相同的,SRKD则是对中间特征的统计信息进行监督。该文考虑了如下四种统计信息:

owards Compact Single Image Super-Resolution via Contrastive Self-distillation

链接:

code:GitHub - Booooooooooo/CSD: Towards Compact Single Image Super-Resolution via Contrastive Self-distillation, IJCAI21

发表:IJCAI21

团队:Yonsei University

1.背景

卷积神经网络在超分任务上取得了很好的成果,但是依然存在着参数繁重、显存占用大、计算量大的问题,为了解决这些问题,作者提出利用对比自蒸馏实现超分模型的压缩和加速。

我们的目标是同时压缩和加速SR模型。我们提出了一个简单的自蒸馏框架,其中学生网络通过在每层使用教师的部分通道从教师(目标)网络中分离出来。我们将这种学生网络称为信道分割超分辨率网络(CSSRNet)。教师网络和学生网络共同训练,形成两个计算方式不同的SR模型。根据设备中计算资源的不同,我们可以动态分配这两种模型,即在资源有限的设备中,如果超过所需的计算开销,则选择CSSR-Net,否则选择教师模型.

主要贡献

作者提出的对比自蒸馏(CSD)框架可以作为一种通用的方法来同时压缩和加速超分网络,在落地应用中的运行时间也十分友好。

自蒸馏被引用进超分领域来实现模型的加速和压缩,同时作者提出利用对比学习进行有效的知识迁移,从而 进一步的提高学生网络的模型性能。

在Urban100数据集上,加速后的EDSR+可以实现4倍的压缩比例和1.77倍的速度提高,带来的性能损失仅为0.13 dB PSNR。

2.方法

我们的CSD包括两个部分:CSSR-Net和对比损失(CL)。首先,我们描述了CSSR-Net。然后,我们给出了构造CSSR-Net的上界和下界的正则表达式。

最后,给出了CSD方案的总体损失函数,并用一种新的优化策略对其进行了求解。

总结

回顾

近年来,知识蒸馏(Knowledge Distillation)方法在深度学习领域中备受关注,它是一种模型压缩技术,旨在将一个复杂的模型(通常被称为教师模型)的知识转移到一个简化的模型(通常被称为学生模型)中,从而使学生模型能够在保持性能的同时具有更小的模型尺寸和计算成本。

一些近年来的知识蒸馏方法和拓展包括:

  1. Teacher-Student Architecture: 最常见的知识蒸馏方法之一是使用教师模型和学生模型之间的监督信号。教师模型通常是一个大型、复杂的模型,而学生模型则是一个较小、简化的模型。通过让学生模型学习教师模型的输出,学生模型可以在学习到教师模型的知识的同时获得更好的泛化性能。

  2. Soft Target Training: 传统的监督学习使用的是硬标签(one-hot编码),即只有正确类别的概率为1,其余为0。而软目标训练则使用教师模型的输出概率分布作为目标。这种方法能够提供更丰富的信息,使得学生模型可以学习到更多的知识。

  3. Attention Mechanisms: 在知识蒸馏中引入注意力机制可以帮助学生模型更好地关注教师模型的重要信息,从而提高模型性能。

  4. Self-Distillation: 自蒸馏是一种方法,其中学生模型在训练过程中不仅要学习来自教师模型的知识,还要学习自身的输出。这种方法可以进一步提高学生模型的性能,同时减少对教师模型的依赖。

  5. Multi-Teacher Distillation: 多教师蒸馏是一种将多个教师模型的知识融合到学生模型中的方法。每个教师模型可能具有不同的视角或专长,通过结合它们的知识,学生模型可以获得更全面和鲁棒的学习。

未来

随着深度学习模型的不断发展和复杂化,未来的知识蒸馏方法可能会涉及更复杂的模型结构。这可能包括对于更深、更宽的神经网络架构的探索,以及对于更复杂的模型组合和蒸馏技术的研究。例如,结合Transformer模型的自注意力机制与知识蒸馏技术可能会带来更加高效的模型压缩和知识传递方式。

其次,未来的知识蒸馏方法可能会更加注重模型的智能化和个性化。这意味着,蒸馏过程将更加关注于学生模型的个性化需求和特征提取,以及对于不同学习任务和场景的适应性。这可能会涉及到更加精细的目标函数设计、更加智能化的蒸馏策略以及更加灵活的模型结构。

目前有的蒸馏方法效果提升不大,知识蒸馏还有很大提升空间,因为网络中有大量的参数,而实际使用到的很少,所以可以在蒸馏方法上优化,将特征提取和知识传递做得更通用,或者更准确,甚至像大模型的预训练与微调一样,或者是自监督蒸馏,或者是自动地结合上剪枝量化,感知量化等等方法。

reference

1、crd https://arxiv.org/abs/1910.1069

2、crd code https://github.com/HobbitLong/RepDistiller

3、cls kd https://blog.csdn.net/akaweige/article/details/131520764

4、sr kd https://zhuanlan.zhihu.com/p/346422123

5、cls kd https://zhuanlan.zhihu.com/p/102038521

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/706878.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

【全开源】云界旅游微信小程序(源码搭建/上线/运营/售后/维护更新)

开启您的云端旅行新体验 一、引言 在快节奏的现代生活中&#xff0c;旅行成为了人们放松身心、探索世界的重要方式。让您的旅行更加便捷、高效&#xff0c;打造了云界旅游小程序&#xff0c;带您领略云端旅行的无限魅力。 二、小程序功能概览 云界旅游小程序集成了丰富的旅游…

短视频创作者的9个免费实用的视频素材网站

在视频剪辑的过程中&#xff0c;找到高质量、无水印且可商用的视频素材是每个创作者的梦想。下面为大家推荐9个无水印素材网站&#xff0c;助你轻松获取所需的视频素材。 1. 蛙学府 - 提供丰富的高清视频素材&#xff0c;涵盖风景、人物、科技等类别。所有素材高清且可商用&…

网络工程师----第三十一天

DNS&#xff1a; DNS含义&#xff1a;DNS 是 Domain Name System&#xff08;域名解析系统&#xff09; 端口号&#xff1a;DNS为53&#xff08;UDP&#xff09; 域名的层次结构&#xff1a; 域名的分级&#xff1a; 域名服务器&#xff1a; 域名解析过程&#xff1a; 递归查…

网站开发初学者指南:2024年最新解读

在信息交流迅速的时代&#xff0c;网页承载着大量的信息&#xff0c;无论你知道还是不知道&#xff0c;所以你知道什么是网站开发吗&#xff1f;学习网站开发需要什么基本技能&#xff1f;本文将从网站开发阶段、网站开发技能、网站开发类型等角度进行分析&#xff0c;帮助您更…

MATLAB车辆动力学建模 ——《控制系统现代开发技术》

引言 在上这门课之前&#xff0c;我已经用过CasADi 去做过最优化的相关实践&#xff0c;其中每一步迭代主要就是由&#xff1a;对象系统优化求解两部分组成的。这里我们重点介绍 “对象系统”如何去描述 &#xff0c;因为它是每一步迭代中重要的一环——“优化求解”会获得控制…

花花省V6淘宝客APP社交电商自营商城聚合优惠券系统功能介绍

花花省V6淘宝客APP的社交电商自营商城聚合优惠券系统具有多种功能&#xff0c;以满足用户的不同需求。以下是其主要功能的介绍&#xff1a; 首页功能&#xff1a;首页设计包含广告位、淘口令识别、微信登录、淘宝登录等。此外&#xff0c;还有淘宝返佣、拼多多返佣、京东返佣、…

Danfoss丹佛斯S90泵比例放大器

S90R042、S90R055、S90R075、S90R100、S90R130、S90R180、S90R250电气排量控制变量泵比例阀放大器&#xff0c;电气排量控制为高增益控制方式&#xff1a;通过微小变化的输入电流控制信号即可推动伺服阀主阀芯至全开口位置&#xff0c;进而将最大流量的控制油引入到伺服油缸。伺…

沉钒废水回收钒

沉钒废水处理与钒回收的重要性 沉钒废水是含钒元素的特殊废水&#xff0c;钒在工业生产中广泛应用&#xff0c;但其排放造成资源浪费与环境威胁。为实现钒的有效回收&#xff0c;研究和实践了多种处理技术。 沉钒废水处理技术 1. 化学沉淀法&#xff1a;添加沉淀剂&#xff…

02 VUE学习:模板语法

模板语法 Vue 使用一种基于 HTML 的模板语法&#xff0c;使我们能够声明式地将其组件实例的数据绑定到呈现的 DOM 上。所有的 Vue 模板都是语法层面合法的 HTML&#xff0c;可以被符合规范的浏览器和 HTML 解析器解析。 在底层机制中&#xff0c;Vue 会将模板编译成高度优化的…

Python专题:八、为整数增加小数点

1、题目 虽说很多人讨厌小数点&#xff0c;但是有时候小数点是必不可少的一项&#xff0c;请你使用强制类型转换为输入的整数增加小数点&#xff0c;并输出改变类型后的变量类型。 2、代码 import sysa float(int(input())) print(f"(a:.lf)",type(a),sep"\…

CentOS 的常见命令

CentOS 是一种广泛使用的 Linux 发行版&#xff0c;特别在服务器环境中。本文将详细介绍 CentOS 中常见的命令&#xff0c;以便帮助用户在操作系统中有效地进行各种操作。下面介绍一下文件和目录操作、用户和权限管理、系统信息查看、软件包管理以及网络配置等方面的命令。 一…

【CTF Web】NSSCTF 3863 [LitCTF 2023]导弹迷踪 Writeup(JS分析+源码泄漏+信息收集)

[LitCTF 2023]导弹迷踪 你是一颗导弹&#xff0c;你需要&#xff0c;飞到最后&#xff01;&#xff08;通过6道关卡就能拿到flag哦~ Flag形式 NSSCTF{} 出题人 探姬 解法 查看网页源代码。 flag 应该在这些文件里面。 <!-- Game files --><script type"applicat…