Noisy Student(CVPR 2020)论文解读

paper:Self-training with Noisy Student improves ImageNet classification

official implementation:https://github.com/google-research/noisystudent

本文的创新点

本文提出了一种新的半监督方法Noisy Student Training,主要包括三步:(1)在有标签数据上训练一个教师模型(2)利用教师模型在无标签数据上生成伪标签(3)结合有标签的图片和伪标签的图片训练学生模型。重复迭代这个过程,将学生作为教师重新生成伪标签,然后再训练一个新的学生模型。

Noisy Student Training从两个方面提高自训练和蒸馏的能力。首先,它使用的学生模型比教师模型更大(或至少相等),这样学生可以更好的从一个更大的数据集中学习。其次,它给学生模型增加的噪声,这迫使学生模型更努力的从伪标签中学习。

方法介绍

算法1概述了Noisy Student Training的过程。算法的输入包括有标签和无标签的图片。我们首先在有标签图片上用交叉熵损失训练一个教师模型。然后用教师模型在无标签图片上生成伪标签。伪标签可以是soft(连续分布)或hard(one-hot分布)。然后我们训练一个学生模型,在标签图片和无标签图片上最小化交叉熵损失。最后我们迭代这个过程,将训练好的学生模型作为教师模型在无标签数据上生成新的伪标签,并训练一个新的学生模型。算法的过程如图1所示

 

本文方法的关键改进在于给学生模型增加噪声,并使用不小于教师模型的学生模型。这使得该方法不同于知识蒸馏:1)蒸馏中通常不适用噪声 2)蒸馏中通常使用比教师小的学生模型。我们可以把本文的方法看作一种knowledge expansion,通过给学生模型足够的容量和更困难的学习环境(添加噪声),我们希望学生模型比教师更好。

作者在实验中使用了两种类型的噪声:输入噪声和模型噪声。输入噪声使用了数据增强方法RandAugment(具体见RandAugment(NeurIPS 2020)论文速读-CSDN博客),模型噪声使用了dropout和stochastic depth(见Stochastic Depth 原理与代码解析-CSDN博客)。

当应用于无标签数据时,噪声有一个重要的好处,即enforce决策函数在标签数据和无标签数据上的不变性。首先,数据增强是Noisy Student Training中一种重要的噪声方法,因为它迫使学生在同一张图片的不同增强变形上的预测一致性。具体来说,教师在干净的图片上生成高质量的伪标签,而学生则需要用增强后的图片来复制这些标签,即学生必须确保增强后的图片和原始图片是相同的类别。其次,当使用dropout和stochastic depth作为噪声时,教师模型在推理时(生成伪标签时)表现的像一个集成模型,而学生则像一个单一的模型。换句话说,学生被迫模仿一个更强大的集成模型。

Other Techniques 作者还使用了额外的trick:data filtering和balancing,这让Noisy Student Training的效果更进一步。具体来说,我们过滤掉教师模型推理时置信度低的图片,因为这些通常是out-of-doman图片。为了确保无标签图片的分布和训练集的分布相匹配,我们还需要平衡无标签图片中每个类别的数量,因为ImageNet中所有类别都有相似数量的标签图片。具体做法是,对于数量不够的类别进行复制,对数量太多的类别,只用置信度最高的图片。

最后作者还强调了下在实验中使用hard或soft伪标签Noisy Student的效果都很好。对于out-of-doman的无标签数据,soft伪标签的效果稍好些。

实验结果

实验细节

数据集. 有标签数据选用ImageNet数据集。无标签数据选用JFT数据集,并从中过滤掉ImageNet验证集中的图片。然后在剩下的图片上进行data filtering和balancing。首先用在ImageNet上训练的EfficientNet-B0模型为JFT中的每张图片预测一个标签,然后选择置信度大于0.3的图片。每个类别,选择置信度最高的130K张图片,如果本身不足130K的随机复制进行补齐。这样用于训练学生模型的图片共有130M张(实际只有81M,其它都是复制的)。

模型结构. 作者选择EfficientNet作为baseline模型,并进一步增大EfficientNet-B7得到EfficientNet-L2,后者并前者更深更宽但用了更小的输入分辨率。

训练细节. 对于有标签数据,batch size默认选择2048。有标签数据的训练时长和学习率根据batch size进行调整,具体来说,对于大于EfficientNet-B4的学生模型训练350个epoch,对于更小的学生模型训练700个epoch。训练350个epoch且batch size为2048时,初始学习率为0.128每个2.4个epoch衰减0.97,训练700个epoch时每个4.8个epoch进行衰减。

噪声. 对EfficientNet-B7和L2选择相同的噪声超参。其中最后一层stochastic depth的概率为0.8,其它层遵循linear decay rule。最后一层的dropout rate为0.5。对于RandAugment,选择两种随机变换,强度设置为27。

迭代训练. 作者实验得到的最好模型是将学生作为教师模型迭代三次得到的。首先在ImageNet上训练一个EfficientNet-B7作为教师模型。然后在无标签数据上训练一个EfficientNet-L2的学生模型,无标签数据的batch size是有标签数据的14倍。然后将这个EfficientNet-L2的学生模型作为教师模型再训练一个新的EfficientNet-L2的学生模型。最后再迭代一次,此时无标签的batch size设置有标签的28倍。

在ImageNet验证集上的结果如表2所示,用Noisy Student训练的EfficientNet-L2达到了88.4%的准确率,远超于之前EfficientNet系列的85.0%的准确率。提升的3.4%中0.5%来源于更大的模型,2.9%来源于Noisy Student Training,这表明Noisy Student Training对模型精度的影响比调整网络结构大得多。

此外,88.4%的精度还超过了之前采用FixRes ResNeXt-101 WSL的86.4%的SOTA,且后者使用了35亿张有标签的Instagram图片。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/624173.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

L2-3 完全二叉树的层序遍历

完全二叉树的层序遍历 一个二叉树,如果每一个层的结点数都达到最大值,则这个二叉树就是完美二叉树。对于深度为 D 的,有 N 个结点的二叉树,若其结点对应于相同深度完美二叉树的层序遍历的前 N 个结点,这样的树就是完全…

防止企业数据泄密的四种有效措施

防止企业数据泄密的四种有效措施 泄密大案每天都在上演,受害者既有几十人的小型企业,也有上万人的世界500强,为什么这些企业都难逃数据泄露的噩梦呢?我们应该采取什么措施来防止信息泄密呢? 首先我们来看看数据泄密的…

Slf4j+Log4j简单使用

Slf4jLog4j简单使用 文章目录 Slf4jLog4j简单使用一、引入依赖二、配置 log4j2.xml2.1 配置结构2.2 配置文件 三、使用四、使用MDC完成日志ID4.1 程序入口处4.2 配置文件配置打印4.3 多线程日志ID传递配置 五. 官网 一、引入依赖 <dependencies><dependency><g…

C++项目 -- 负载均衡OJ(一)comm

C项目 – 负载均衡OJ&#xff08;一&#xff09;comm 文章目录 C项目 -- 负载均衡OJ&#xff08;一&#xff09;comm一、项目宏观结构1.项目功能2.项目结构 二、comm公共模块1.util.hpp2.log.hpp 一、项目宏观结构 1.项目功能 本项目的功能为一个在线的OJ&#xff0c;实现类似…

leetcode1448.统计二叉树中的好节点数目

1. 题目描述 题目链接 2. 解题思路 首先看一下题目的“核心”&#xff0c;什么是好节点&#xff1a;从根到该节点 X 所经过的节点中&#xff0c;没有任何节点的值大于 X 的值。也就是说&#xff0c;我们只要知道了从根节点到该节点的所有的值&#xff0c;就可以判断该节点是…

三个晚上!给干废了!MINI2440 挂载 NFS

虚拟机执行&#xff1a;sudo ifconfig tap0 10.10.10.1 up qemu 开发板&#xff1a; set bootargs noinitrd root/dev/nfs rw nfsroot10.10.10.1:/nfsroot ip10.10.10.10:10.10.10.1 ::255.255.255.0 consolettySAC0,115200 Hit any key to stop autoboot: 0 MINI2440 # set…

P5730 【深基5.例10】显示屏

思路&#xff1a; 此题只需要两层循环&#xff0c;通过数组映射即可求出答案 AC代码&#xff1a; #include<iostream>using namespace std;typedef long long ll; const int N 10; int a[N];int main() {ll n,m;cin >> n >> m;for(ll in;i<m;i){ll nu…

Level protection and deep learning

1.模拟生成的数据 import randomdef generate_data(level, num_samples):if level not in [2, 3, 4]:return Nonedata_list []for _ in range(num_samples):# 构建指定等级的数据data str(level)for _ in range(321):data str(random.randint(0, 9))data_list.append(data)…

自定义类型: 结构体 (详解)

本文索引 一. 结构体类型的声明1. 结构体的声明和初始化2. 结构体的特殊声明3. 结构体的自引用 二. 结构体内存对齐1. 对齐规则2. 为啥存在对齐?3. 修改默认对齐值 三. 结构体传参四. 结构体实现位段1. 什么是位段?2. 位段的内存分配3. 位段的应用4. 位段的注意事项 ​ 前言:…

yolov5 MMCV依赖库 报错(2个错误)

1.报错内容 错误1&#xff1a; ImportError: libtorch_cuda_cu.so: cannot open shared object file: No such file or directory 错误2&#xff1a;ModuleNotFoundError: No module named ‘mmcv._ext’ 2. 原因分析 python 、torch、mmcv版本兼容问题 3.解决方法 首先运行…

linux进阶篇:文件查找的利器——grep命令+管道操作详解

Linux文件查找的利器——grep命令管道操作详解 1 grep简介 grep (global search regular expression(RE) and print out the line,全面搜索正则表达式并把行打印出来)是一种强大的文本搜索工具&#xff0c;它能使用正则表达式搜索文本&#xff0c;并把匹配的行打印出来。 Uni…

BackTrader 中文文档(十八)

原文&#xff1a;www.backtrader.com/ OCO 订单 原文&#xff1a;www.backtrader.com/blog/posts/2017-03-19-oco/oco/ 版本 1.9.34.116 添加了OCO&#xff08;又称一次取消其他&#xff09;到回测工具中。 注意 这只在回测中实现&#xff0c;尚未实现对实时经纪商的实现 注…