Self Distillation 自蒸馏论文解读

paper:Be Your Own Teacher: Improve the Performance of Convolutional Neural Networks via Self Distillation

official implementation: https://github.com/luanyunteng/pytorch-be-your-own-teacher

前言

知识蒸馏作为一种流行的压缩方法,通过让参数较少的学生模型学习参数量更大的教师模型的知识,可以有效提高学生模型的性能,甚至比教师模型更好,在实际应用中用学生模型替代教师模型从而实现压缩和加速的效果。

但是存在两个问题,一是知识传递的效率较低,学生模型很难学习到教师模型的所有知识,通过蒸馏后性能优于教师模型的情况仍是极少数。二是如何设计和训练合适的教师模型仍是一个难题,现有的蒸馏方法需要大量的实验来找到教师模型的最优架构,非常耗时。

本文的创新点

为了克服传统蒸馏的缺点,本文提出了一种新的自蒸馏架构。和传统蒸馏需要两个步骤即首先训练一个教师模型,然后将知识从教师模型蒸馏到学生模型的方法不同,本文提出的方法只需要一步,训练点直指学生模型,大大减少了训练时间(比如在CIFAR100上,从26.98个小时到5.87个小时,速度快了4.6倍),同时获得了更高的精度(比如ResNet50从传统蒸馏的79.33%的精度提升至81.04%)。

方法介绍

完整的架构如下图所示

以 ResNet50为例,根据深度将其分为四个部分,在每部分后接一个分类器,这个分类器由一个bottleneck、一个全连接层、一个softmax层构成,该分类器只在训练时使用,推理时可以去掉。bottleneck的作用是为了减轻每个浅层分类器之间的影响,并与hints(即特征图)之间计算L2损失。在训练阶段,每个浅层的分类器可以当做学生模型,深层的当做教师模型,从而实现知识的蒸馏。

训练过程中一共有三种损失:

  • 标签之间的交叉熵损失。不仅是最深层即原本模型最终的分类输出,每个浅层分类器的softmax输出也与标签计算CE损失,通过这种方式,隐含在数据集中的知识直接从标签引入到所有的分类器中。

  • KL散度损失。计算学生和教师softmax之间的散度损失。注意教师只有一个,即最深层的输出

  • 和hints之间的L2损失。通过计算最深层分类器和每个浅层分类器特征之间的L2损失,引入feature map中的implicit knowledge,使得每个浅层分类器的bottleneck中的特征图都去拟合最深层分类器bottleneck中的特征图。

完整的损失如下所示

其中 \(\lambda\) 和 \(\alpha\) 是平衡各项损失的权重超参,对于最深层分类器 \(\lambda\) 和 \(\alpha\) 都为0。

实验结果

Compared with Standard Training

在CIFAR100和ImageNet上的结果分别如表1、2所示,其中集成结果通过对各个分类器输出加权求和得到。

从结果可以看出

  1.  通过自蒸馏,所有网络的精度都得到了提升。CIFAR100上评价提升了2.65%,ImageNet上平均提升了2.0.%。
  2. 网络越深,性能提升越大。比如ResNet101提升了4.05%,ResNet18提升了2.58%。
  3. 一般来说集成结果在CIFAR100上提升较大,在ImageNet上提升较小,这可能是由于浅层分类器的精度损失较大。
  4. 分类器的深度在ImageNet中起着更重要的作用,这表明对于复杂任务网络的冗余较小。

Compared with Distillation

与其他蒸馏方法的对比如表3所示,可以看出本文提出的自蒸馏获得了最高的精度。同时如图1所示,因为不用通过实验选择合适的教师模型以及训练教师模型,整个训练时长也大大减小。

Compared with Deeply Supervised Net

深度监督网络和自蒸馏的最大区别是,自蒸馏不仅是用标签训练浅层分类器,还以深层分类器作为教师模型进行知识的蒸馏。结果对比如表4所示,可以看出,自蒸馏在每个分类器的结果都优于深度监督。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/219198.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

四、IDEA创建项目时,Maven Archetype模板工程说明

什么是Maven Archetype Archetype是一个Maven项目的模板工具包,它定义了一类项目的基本架构。Archetype为开发人员提供了创建Maven项目的模板,同时它也可以根据已有的Maven项目生成参数化的模板。 官方文档:https://maven.apache.org/archet…

pytorch导出rot90算子至onnx

如何导出rot90算子至onnx 1 背景描述2 等价替换2.1 rot90替换(NCHW)2.2 rot180替换(NCHW)2.3 rot270替换(NCHW) 3 rot导出ONNX 1 背景描述 在部署模型时,如果某些模型中或者前后处理中含有rot90算子,但又希望一起和模型导出onnx时,可能会遇到…

【二叉树】oj题

在处理oj题之前我们需要先处理一下之前遗留的问题 在二叉树中寻找为x的节点 BTNode* BinaryTreeFind(BTNode* root, int x) {if (root NULL)return NULL;if (root->data x)return root;BTNode* ret1 BinaryTreeFind(root->left, x);BTNode* ret2 BinaryTreeFind(ro…

【云原生】什么是 Kubernetes ?

什么是 Kubernetes ? Kubernetes 是一个开源容器编排平台,管理着一系列的 主机 或者 服务器,它们被称作是 节点(Node)。 每一个节点运行了若干个相互独立的 Pod。 Pod 是 Kubernetes 中可以部署的 最小执行单元&#x…

机器学习【03】在本地浏览器使用远程服务器的Jupyter Notebook【conda环境】

1.激活虚拟环境 conda activate 虚拟环境名字2.虚拟环境下安装jupyter notebook pip install jupyter3.配置 jupyter 文件 在 Jupyter Notebook 的配置目录中生成一个配置文件 jupyter_notebook_config.py jupyter notebook --generate-config3.设置密码 jupyter notebook …

性能压测工具:wrk

一般我们压测的时候,需要了解衡量系统性能的一些参数指标,比如。 1、性能指标简介 1.1 延迟 简单易懂。green:一般指响应时间 95线:P95。平均100%的请求中95%已经响应的时间 99线:P99。平均100%的请求中99%已经响应的时间 平…

JVM字节码文件的相关概述解读

Java全能学习面试指南:https://javaxiaobear.cn 1、字节码文件 从下面这个图就可以看出,字节码文件是可以跨平台使用的 想要让一个Java程序正确地运行在JVM中,Java源码就必须要被编译为符合JVM规范的字节码。 https://docs.oracle.com/java…

如何使用nginx部署静态资源

Nginx可以作为静态web服务器来部署静态资源,这个静态资源是指在服务端真实存在,并且能够直接展示的一些文件数据,比如常见的静态资源有html页面、css文件、js文件、图片、视频、音频等资源相对于Tomcat服务器来说,Nginx处理静态资…

Flutter桌面应用开发之毛玻璃效果

目录 效果实现方案依赖库支持平台实现步骤注意事项话题扩展 毛玻璃效果:毛玻璃效果是一种模糊化的视觉效果,常用于图像处理和界面设计中。它可以通过在图像或界面元素上应用高斯模糊来实现。使用毛玻璃效果可以增加图像或界面元素的柔和感,同…

Elasticsearch集群部署,配置head监控插件

Elasticsearch是一个开源搜索引擎,基于Lucene搜索库构建,被广泛应用于全文搜索、地理位置搜索、日志处理、商业分析等领域。它采用分布式架构,可以处理大规模数据集和支持高并发访问。Elasticsearch提供了一个简单而强大的API,可以…

超级详细的 Maven 教程(基础+高级)

1. Maven 是什么 Maven 是 Apache 软件基金会组织维护的一款专门为 Java 项目提供构建和依赖管理支持的工具。 一个 Maven 工程有约定的目录结构,约定的目录结构对于 Maven 实现自动化构建而言是必不可少的一环,就拿自动编译来说,Maven 必须…