知识蒸馏Matching logits与RocketQAv2

知识蒸馏Matching logits

公式推导

刚开始的\frac{\partial L}{\partial z_i}=q_i-p_i怎么来,可以转看下面证明梯度等于输出值-标签y

C是一个交叉熵,我们要求解的是这个交叉熵对z_i的这个梯度。z_i就是你可以理解成第i个类别的得分。z_i就是student model,被蒸馏的模型,它所输出的logits。

p_i是什么?是target probability对吧。q_i是什么?q_i认为就是这个distilled model的输出的那个probability。所以就是说这两个概率相减,再乘以这个T分之一T是什么?T是一个温度。

我们现在假定是说我们是用teacher model输出的这个label,然后去训练student model,或者说去训练distilled model。我们对这个第i个类别的梯度,就等于\frac{1}{T}{(q_i-p_i)},然后呢,q_ip_i可以做一个化简。

q_ip_i进行展开,概率都是用softmax算出来的,就可以得到这个式子。

通过e^x\approx 1+x来进行化简,这个式子在x比较小的时候是成立的。

在这里,当T足够大的时(相比z的logits,即z),\frac{z_i}{T}就足够的小,接近于0,此时e^{\frac{z_i}{T}}\approx 1+\frac{z_i}{T}

\sum_j e^{\frac{z_i}{T}}\approx \sum_j{1+\frac{z_i}{T}}=N+\sum_j{\frac{z_i}{T}} 

z_j的这个累加,它就等于零。这个v_j的这个累加也等于零,即\sum_j z_j=\sum_j v_j=0,所以这两个分母直接就变成了N。

\frac{1}{T}({\frac{1+z_i/T}{N}}-{\frac{1+v_i/T}{N}})=\frac{1}{TN}{\frac{z_i-v_i}{T}}

则所求梯度

想说明的事情

它其实就想说明这样一个事情。我们试图用一个teacher model,或者说我们想用VI对应的那个概率叫p_iz_i对应的概率叫q_i。如果我们想用这个p_i作为label去用交叉商去训练q_i去用这个soft label的交叉商去训练q_i,那么其实我们可能不需要套用交叉商这个东西了,我们也不需要什么softmax的label的交叉商,然后去做这个事了。因为这个东西在我们的这样一通推导下就会发现,其实就等于均方误差,右边这一项其实就是什么均方误差的求导,它就是均方误差求导之后的结果,你可以这样认为。

我们就会发现说,原来对于交叉商对于这个知识蒸馏的这个交叉商,然后我们对他求导求出来的梯度其实是近似等同于我们直接用MSE去训练,然后得到的梯度的。那么既然这样,我们为什么不直接用MSE?

它的推导就告诉我们说我们对于两个模型,两个多分类模型来说,我们要用a模型去交B模型做蒸馏。我们没有必要让这两个模型生成分别生成什么label,然后再生成预测的概率,然后再加上去优化了。

我们直接让这两个多分类模型的这个logic,然后直接做MSE就可以了,就可以做到一种就是一种这种MSE就是一种什么蒸馏的特殊形式。就是蒸馏的一个最早期的雏形,其实在这个时候都还没有考虑用这个什么KL散度来做,就只是提出最简单的一个思想是什么,就是用MSE来做就够了。

我们一直即便到今天,我们做很多知识正溜的实验,我们依然会发现MIC可能有的时候都会比K要好。虽然大家都说自己用什么KL散度用什么JS散度,但是就是否现在就最优,还真不一定有的时候就是MSE效果好。

注:MSE = \frac{1}{N} \sum_{i=1}^{N} (y_i - \hat{y}_i)^2

需要注意的事

公式的推导基于两个假设:

1.T得足够大的(相比z的logits,即z_i)

2.模型输出的logic是零均值的(即均值为0),因为模型输出的logic是零均值的,这个z_j的这个累加,它就等于零。这个v_j的这个累加也等于零,即\sum_j z_j=\sum_j v_j=0

 证明梯度等于输出值-标签y

softmax函数

归一化,使其输出的概率和为1

S_i=\frac{e^{z_i}}{\sum_ke^{z_k}}

S_i代表的是第i个神经元的输出。

神经元的输出,一个神经元如下图:

z_i=\sum_jw_{ij}x_{ij}+b

其中w_{ij}是第i个神经元的第j个权重,b是偏移值。z_i表示该网络的第i个输出。

给这个输出加上一个softmax函数,得a_i=\frac{e^{z_i}}{\sum_ke^{z_k}}

a_i代表softmax的第i个输出值

交叉熵损失函数 loss function

L=-\sum_i{y_i}{lna_i}

其中y_i表示真实的分类结果。

证明梯度等于输出值-标签y

loss对于神经元输出z_i的梯度为\frac{\partial L}{\partial z_i}=\frac{\partial L}{\partial a_j}\frac{\partial a_j}{\partial z_i}

由于softmax公式的特性,它的分母包含了所有神经元的输出,对于不等于i的其他输出里面,也包含着z_i,所有的a都要纳入到计算范围中,并且后面的计算可以看到需要分为i=ji \ne j两种情况求导。

由于\frac{\partial (-\sum_{k\ne j}y_{k}ln a_k)}{\partial a_j}=0

\frac{\partial C}{\partial a_j}=\frac{\partial (-\sum_jy_jln a_j)}{\partial a_j}=-\sum_jy_j\frac{1}{a_J}

如果i=j

\frac{\partial a_i}{\partial z_i}=\frac{\partial (\frac{e^{z_i}}{\sum_ke^{z_k}})}{\partial z_i}=\frac{\partial (\frac{e^{z_i}}{\sum_{k\ne i}e^{z_k}+e^{z_i}})}{\partial z_i}=\frac{\sum_ke^{z_k}e^{z_i}-(e^{z_i})^2}{\sum_k(e^{z_k})^2}
=(\frac{e^{z_i}}{\sum_ke^{z_k}})(1-\frac{e^{z_i}}{\sum_ke^{z_k}})=a_i(1-a_i)

这里\sum_ke^{z_k}=\sum_{k\ne i}e^{z_k}+e^{z_i}

如果i \ne j

这里\sum_ke^{z_k}=\sum_{k\ne j}e^{z_k}+e^{z_j}

\frac{\partial a_i}{\partial z_i}=\frac{\partial (\frac{e^{z_i}}{\sum_ke^{z_k}})}{\partial z_i}=-e^{z_j}(\frac{1}{\sum_ke^{z_k}})e^{z_i}=-a_ia_j

综上

\frac{\partial L}{\partial z_i}=\frac{\partial L}{\partial a_j}\frac{\partial a_j}{\partial z_i}=(-\sum_jy_j\frac{1}{a_j})\frac{\partial a_j}{\partial z_i}=-\frac{y_i}{a_i}a_i(1-a_i)+\sum_{j\neq i}\frac{y_i}{a_j}a_ia_j
=-y_i+y_ia_i+\sum_{j\neq i}{y_ia_i}=-y_i+a_i\sum_{j}y_j

最后,针对分类问题,我们给定的结果y_i最终只会有一个类别是1,其他非标签类别都是0,因此,对于分类问题,这个梯度等于

\frac{\partial L}{\partial z_i}=a_i-y_i

知识蒸馏RocketQAv2

https://arxiv.org/pdf/2110.07367.pdf

这个模型有两部分组成一个retriever和一个ranker。这个做的事情就是说用label去监督re-ranker,然后用ranker去监督retriever。用KL散度去约束它约束,用这个K散路去让这个re-ranker的分布和retriever的分布对齐。

要注意就是说。这里就是他们就没有用MSE,就是说如果用MSE怎么做,就是说对应的这个直接相减,就对应位置直接相减,然后分MSE就行。这里用的是KL散度。

KL散度的定义,你可以认为是这样的,让这两个概率分别相除,除完了之后都要再取对数,然后再乘以这个概率。

DE,这个teacher model的概率乘以teacher model的概率乘以log,teacher model的概率除以student model的概率。然后把这么多概率给它都累加起来。

在这里,假定这里的是retriever给出来的一个概率分布假如说是十个候选,ranker也给了这样一个概率分布,那么就是十个概率分布对应的一项一项的去算这个KL度,即概率除概率,然后再取对数,然后再乘上ranker这个概率。

然后再把这十项给它累加起来,然后就是一个KL散度,这样的话,这个K散度其实是现在就是接受最多的一种损失函数。

因为KL散度就是天生的,可以捕获这个分布和分布之间的距离。像MSE缺点是什么?MSE的缺点是它没有整体的那种距离衡量的能力。MSE其实是对于细节的这种距离的衡量很强。如果MSE来的话,每一个每一项,这十项每一项的重要性对于MIC来说都是一样的。但是这个KL散度可能就会更在乎一个整体的一个分布上的一个区别了,就而不是说就在乎一些细节上的一些差别,因为有可能就是说。你某一些细节差距虽然大一些,但是你整体差距不大,所以KL散度也可以比较小。

实际上一切可以衡量两个分布之间距离的指标都可以用来做知识蒸馏,所以其实wasserstein距离也可以用来作为蒸馏的损失函数:

https://openaccess.thecvf.com/content/CVPR2021/papers/Chen_Wasserstein_Contrastive_Representation_Distillation_CVPR_2021_paper.pdf

为什么知识蒸馏会有效?

1. teacher model可以生成soft label,相比于原始数据的hard label,包含了更多信息量。

所以很多时候你与其说直接用一个数据集去训练一个模型,你还不如用这个数据集先训练一个大a模型比a模型要大的模型。再让大a模型去教会a模型去做,有可能效果就更好。就是因为大a模型这个teacher model可以生成soft label相比于原始数据的hard label,可以包含更多的信息量,从而就天然的有一种去燥的一种功能。

2. teacher model可以为大量的无标签数据打上label,然后为student提供一个大规模的训练集。然后从而可以给student提供一个更大尺度的训练集,然后防止student的一个过拟合,然后提高student model的一个泛化能力。也就是说,teacher model可以把自己的泛化能力交给student model

在这个知识蒸馏的过程当中,这也是为什么说很多大公司里边现在线上的模型都是蒸馏出来的小模型就是因为我们与其说直接训练小模型。还不如说就用这个蒸馏去蒸馏一个小模型反而泛化能力会更强一些

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/539921.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

vue iview 级联选择器遇到的坑

我们PC项目用到的前端技术栈是vue+iview,最近有个需求,要做个级联选择器,并且是懒加载动态加载后端返回的数据。效果如下: 如下图所示,在我们封装的公共组件form-box.vue里有我们级联选择器: 代码如下: <!--级联选择器--><template v-else-if="item.type…

蓝桥杯 EDA 组 2021-2022 省赛真题+模拟题原理图解析

本文解析了标题内的原理图蓝桥杯EDA组真题&#xff0c;为方便阅读2023年真题/模拟和国赛部分放到其他章节解析。下文中重复或者是简单的电路节约篇幅不在赘述。 其中需要补充和计算原理图的题目解析都放在最下面 一、2021第十二届真题第一场 1.1 AMS1117 线性稳压器 最常见的1…

PyTorch搭建AlexNet训练集

本次项目是使用AlexNet实现5种花类的识别。 训练集搭建与LeNet大致代码差不多&#xff0c;但是也有许多新的内容和知识点。 1.导包&#xff0c;不必多说。 import torch import torch.nn as nn from torchvision import transforms, datasets, utils import matplotlib.pypl…

【STM32学习】基本定时器,输出比较模式,基本参数

1、概述 此项功能是用来控制一个输出波形&#xff0c;或者指示一段给定的的时间已经到时。 如输出PWM信号时&#xff0c;可用这个模式。 2、输出比较初始化函数&#xff0c;基本参数 以上函数是用来配置输出比较模块的&#xff0c;每个函数对应一个定时器的通道&#xff0c;配…

LVGL移植到ARM开发板(GEC6818开发板)

LVGL移植到ARM开发板&#xff08;GEC6818开发板&#xff09; 一、LVGL概述 LVGL&#xff08;Light and Versatile Graphics Library&#xff09;是一个开源的图形用户界面库&#xff0c;旨在提供轻量级、可移植、灵活和易于使用的图形用户界面解决方案。 它适用于嵌入式系统…

自然语言处理实验2 字符级RNN分类实验

实验2 字符级RNN分类实验 必做题&#xff1a; &#xff08;1&#xff09;数据准备&#xff1a;academy_titles.txt为“考硕考博”板块的帖子标题&#xff0c;job_titles.txt为“招聘信息”板块的帖子标题&#xff0c;将上述两个txt进行划分&#xff0c;其中训练集为70%&#xf…

概率论与数理统计(随机事件与概率)

1随机事件与概率 1.1随机事件及其运算规律 1.1.1运算 交换律结合律分配律德摩根律 1.2概率的定义及其确定方法 1.2.1概率的统计定义 频率 设在 n 次试验中&#xff0c;事件 A 发生了(A)次&#xff0c;则称为事件 A 发生的频率。 1.2.2概率的统计定义 在一组恒定不变的条…

GPT-SoVITS开源音色克隆框架的训练与调试

GPT-SoVITS开源框架的报错与调试 遇到的问题解决办法 GPT-SoVITS是一款创新的跨语言音色克隆工具&#xff0c;同时也是一个非常棒的少样本中文声音克隆项目。 它是是一个开源的TTS项目&#xff0c;只需要1分钟的音频文件就可以克隆声音&#xff0c;支持将汉语、英语、日语三种…

vscode 导入前端项目

vscode 导入前端项目 导入安装依赖 运行 参考vscode 下载 导入 安装依赖 运行 在前端项目的终端中输入npm run serve

KKVIEW: 远程控制软件哪个好用

远程控制软件哪个好用 随着科技的发展和工作方式的改变&#xff0c;远程控制软件越来越受到人们的关注和需求。无论是在家中远程办公&#xff0c;还是技术支持人员为远程用户提供帮助&#xff0c;选择一款高效稳定的远程控制软件至关重要。在众多选择中&#xff0c;有几款远程…

【数学建模】线性规划

针对未来可能的数学建模比赛内容&#xff0c;我对学习的内容做了一些调整&#xff0c;所以先跳过灰色关联分析和模糊综合评价的代码&#xff0c;今天先来了解一下运筹规划类——线性规划模型。 背景&#xff1a; 某数学建模游戏有三种题型&#xff0c;分别是A&#xff0c;B&am…

【AI论文阅读笔记】ResNet残差网络

论文地址&#xff1a;https://arxiv.org/abs/1512.03385 摘要 重新定义了网络的学习方式 让网络直接学习输入信息与输出信息的差异(即残差) 比赛第一名1 介绍 不同级别的特征可以通过网络堆叠的方式来进行丰富 梯度爆炸、梯度消失解决办法&#xff1a;1.网络参数的初始标准化…