【论文阅读】Uncertainty-aware Self-training for Text Classification with Few Label

论文下载
GitHub
bib:

@INPROCEEDINGS{mukherjee-awadallah-2020-ust,title 		= "Uncertainty-aware Self-training for Few-shot Text Classification",author 		= "Subhabrata Mukherjee and Ahmed Hassan Awadallah",booktitle 	= "NeurIPS",year 		= "2020",pages      	= {21199--21212}
}

Notice:
这篇论文在arXiv上面的标题为 《Uncertainty-aware Self-training for Text Classification with Few Labels》,推测是发表后更改的。

1. 摘要

Recent success of pre-trained language models crucially hinges on fine-tuning them on large amounts of labeled data for the downstream task, that are typically expensive to acquire or difficult to access for many applications. We study selftraining as one of the earliest semi-supervised learning approaches to reduce the annotation bottleneck by making use of large-scale unlabeled data for the target task. Standard self-training mechanism randomly samples instances from the unlabeled pool to generate pseudo-labels and augment labeled data. We propose an approach to improve self-training by incorporating uncertainty estimates of the underlying neural network leveraging recent advances in Bayesian deep learning. Specifically, we propose (i) acquisition functions to select instances from the unlabeled pool leveraging Monte Carlo (MC) Dropout, and (ii) learning mechanism leveraging model confidence for self-training. As an application, we focus on text classification with five benchmark datasets. We show our methods leveraging only 20-30 labeled samples per class for each task for training and for validation perform within 3% of fully supervised pre-trained language models fine-tuned on thousands of labels with an aggregate accuracy of 91% and improvement of up to 12% over baselines.

预训练语言模型最近的成功关键取决于对下游任务的大量标记数据进行微调,这些数据通常获取成本昂贵或对于许多应用程序来说难以访问。我们将自我训练研究为最早的半监督学习方法之一,通过利用大规模未标记数据来完成目标任务,从而减少注释瓶颈。标准的自训练机制从未标记池中随机采样实例以生成伪标签并增强标记数据。我们提出了一种利用贝叶斯深度学习的最新进展,结合底层神经网络的不确定性估计来改进自我训练的方法。具体来说,我们提出(i)获取函数利用蒙特卡罗(MC)Dropout从未标记池中选择实例,以及(ii)利用模型置信度进行自我训练的学习机制。作为一个应用程序,我们专注于使用五个基准数据集进行文本分类。我们展示了我们的方法,每个任务仅利用每类 20-30 个标记样本进行训练和验证,其性能在完全监督的预训练语言模型的 3% 以内,该语言模型在数千个标签上进行了微调,总体准确度为 91%,并且改进了比基线高出 12%。

UncertainSelf-training的结合

2. 算法描述

Self-training process:
min ⁡ W E x l , y l ∈ D l [ − log ⁡ p ( y l ∣ x l ; W ) ] + λ E x u ∈ S u , S u ⊂ D u E y ∼ p ( y ∣ x u ; W ∗ ) [ − log ⁡ p ( y ∣ x u ; W ) ] (1) \begin{split} & \min_W{\mathbb{E}_{x_l,y_l \in D_l}[-\log{p(y_l|x_l;W)}]} \\ &+ \lambda \mathbb{E}_{x_u \in S_u, S_u \subset D_u} \mathbb{E}_{y \sim p(y|x _u;W^*)}[-\log p(y|x_u;W)] \end{split}\tag{1} WminExl,ylDl[logp(ylxl;W)]+λExuSu,SuDuEyp(yxu;W)[logp(yxu;W)](1)

Uncertain-aware Self-training process:
min ⁡ W , θ E x l , y l ∈ D l [ − log ⁡ p ( y l ∣ x l ; W ) ] + λ E x u ∈ S u , S u ⊂ D u E W ~ ∼ q θ ( W ∗ ) E y ∼ p ( y ∣ f W ~ ( x u ) ) [ − log ⁡ p ( y ∣ f W ( x u ) ) ] (2) \begin{split} & \min_{W, \theta}{\mathbb{E}_{x_l,y_l \in D_l}[-\log{p(y_l|x_l;W)}]} \\ &+ \lambda \mathbb{E}_{x_u \in S_u, S_u \subset D_u} \mathbb{E}_{\widetilde{W} \sim q_\theta(W^*)}\mathbb{E}_{y \sim p(y|f^{\widetilde{W}}(x_u))}[-\log p(y|f^{W}(x_u))] \end{split}\tag{2} W,θminExl,ylDl[logp(ylxl;W)]+λExuSu,SuDuEW qθ(W)Eyp(yfW (xu))[logp(yfW(xu))](2)

其中:

  • q θ ( W ∗ ) q_\theta(W^*) qθ(W)表示Dropout distribution,是一种预估模型不确定性的一种方案,也叫做 Monte-Carlo Dropout
  • E \mathbb{E} E 可以看作是一种平均值,其中它的下标表示所有的可能方案。
  • 对于预测概率 p ( y ∣ f W ( x u ) ) p(y|f^{W}(x_u)) p(yfW(xu))为什么要 log ⁡ \log log计算,可能是为了方便计算,最大似然中将乘法转化为加法。

Account for the teacher uncertain for the pseudo-labels in terms of their predictive variance:

min ⁡ W , θ E x l , y l ∈ D l [ − log ⁡ p ( y l ∣ x l ; W ) ] + λ E x u ∈ S u , S u ⊂ D u E W ~ ∼ q θ ( W ∗ ) E y ∼ p ( y ∣ f W ~ ( x u ) ) [ log ⁡ p ( y ∣ f W ( x u ) ) ⋅ log ⁡ V a r ( y ) ] \begin{split} & \min_{W, \theta}{\mathbb{E}_{x_l,y_l \in D_l}[-\log{p(y_l|x_l;W)}]} \\ &+ \lambda \mathbb{E}_{x_u \in S_u, S_u \subset D_u} \mathbb{E}_{\widetilde{W} \sim q_\theta(W^*)}\mathbb{E}_{y \sim p(y|f^{\widetilde{W}}(x_u))}[\log p(y|f^{W}(x_u)) \cdot \log Var(y)] \end{split} W,θminExl,ylDl[logp(ylxl;W)]+λExuSu,SuDuEW qθ(W)Eyp(yfW (xu))[logp(yfW(xu))logVar(y)]

其中:

  • log ⁡ V a r ( y ) \log Var(y) logVar(y) 表示per-sample weight。对于单个样本 x u x_u xu的损失是 − log ⁡ p ( y ) -\log p(y) logp(y) log ⁡ 1 V a r ( y ) \log \frac{1}{Var(y)} logVar(y)1的组合。这会在老师更确定的错误分类实例(即低方差样本)上对学生模型进行更多惩罚,反之亦然。更加重视低方差样本。
  • Var(y)的定义。
    V a r ( y ) = V a r [ E ( y ∣ W , x ) ] + E [ V a r ( y ∣ W , x ) ] = V a r ( softmax ( f W ( x ) ) ) + σ 2 ≈ ( 1 T ∑ t = 1 T y t ∗ ( x ) T y t ∗ ( x ) − E ( y ) T E ( y ) ) + σ 2 \begin{aligned} Var(y) &=Var[\mathbb{E}(y|W, x)] + \mathbb{E}[Var(y|W, x)] \\ &=Var(\text{softmax}(f^W(x))) + \sigma^2\\ &\approx(\frac{1}{T}\sum_{t=1}^T y_t^*(x)^\mathsf{T}y_t^*(x) - E(y)^\mathsf{T}E(y)) + \sigma^2\\ \end{aligned} Var(y)=Var[E(yW,x)]+E[Var(yW,x)]=Var(softmax(fW(x)))+σ2(T1t=1Tyt(x)Tyt(x)E(y)TE(y))+σ2

D ( X ) = V a r ( X ) = E ( X 2 ) − [ E ( X ) ] 2 D(X) = Var(X) = E(X^2) - [E(X)]^2 D(X)=Var(X)=E(X2)[E(X)]2

注意的是,在代码实现中, σ 2 \sigma^2 σ2表示数据本身存在的噪声,这一步不在置信度考量范围,实际上也没有对此建模。
伪代码:
在这里插入图片描述
回过头来看伪代码,就很清楚了,这里还有几点想要说明一下:

  • S u S_u Su是随机采样的,这是为了节约计算资源,还有就是为了给算法带来随机性,就像是全局梯度下降与随机梯度下降一样。原文中是说使用简单样本还是探索困难样本。
  • R R R是基于BALD指标选择的,是进一步的提高伪标签的质量。

3. 总结

大厂出品必属精品。我读下来本文的核心就是将不确定性(主要是模型不确定性)融入了Self-training中,数学符号语言很丰富,值得学习。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/278609.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

浅析LDPC软解码对SSD延迟的影响-part1

此前,存储随笔有发布一篇关于SSD QoS相关问题,文章中有从以下方面做了全景的分析: 扩展阅读: 全景解析SSD IO QoS性能优化 SSD基础架构与NAND IO并发问题探讨 本文主要在之前文章的基础上,再做个补充,本…

【Java】5分钟读懂Java虚拟机架构

5分钟读懂Java虚拟机架构 Java虚拟机(JVM)架构JVM是如何工作的?1. 类加载器子系统2. 运行时数据区3. 执行引擎 相关资料 本文阐述了JVM的构成和组件。每个Java开发人员都知道字节码经由JRE(Java运行时环境)执行。但他们…

PythonGame图形绘制函数详解

文章目录 五种图形矩形圆形 五种图形 除了直线之外,pygame中提供了多种图形绘制函数,除了必要的绘图窗口、颜色以及放在最后的线条宽度之外,它们的参数如下表所示 函数图形参数/类型rect矩形Rectellipse椭圆Rectarc椭圆弧Rect, st, edcircl…

拼多多买家页面批量导出订单excel

拼多多买家页面批量导出订单excel 由于拼多多不支持订单导出excel清算起来很麻烦,就自己写了一个页面批量导出脚本代码。 首先打开拼多多手机端网站:https://mobile.pinduoduo.com/ 登录后点击我的订单打开f12审查元素 在控制台引入jquery,引…

Python等比例缩放图片并修改对应的Labelme标注文件(v2.0)

Python等比例缩放图片并修改对应的Labelme标注文件(v2.0) 前言前提条件相关介绍实验环境Python等比例缩放图片并修改对应的Labelme标注文件Json文件代码实现输出结果 前言 此版代码,相较于Python等比例缩放图片并修改对应的Labelme标注文件&a…

巴贝拉葡萄酒是单一品种还是混合品种制成的?

大多数巴贝拉葡萄酒都是由单一的巴贝拉葡萄品种制成的,许多意大利葡萄酒商开始尝试在巴贝拉葡萄酒中加入其它葡萄品种,其中两个最受欢迎的意大利品种是皮埃蒙特的巴贝拉德阿尔巴和达斯蒂。和朋友在一家意大利餐厅吃饭,被酒单吓到了&#xff1…

YOLOv8改进 | Conv篇 | 轻量级下采样方法ContextGuided(涨点幅度)

一、本文介绍 本文给大家带来的是改进机制是一种替换Conv的模块Context Guided Block (CG block) ,其是在CGNet论文中提出的一种模块,其基本原理是模拟人类视觉系统依赖上下文信息来理解场景。CG block 用于捕获局部特征、周围上下文和全局上下文&#…

Linux环境下安装JDK

本文将介绍在Linux环境下,如何安装JDK 1.用yum方式安装(无需配置环境变量) 检索yum中有没有java1.8的包:yum list java-1.8*安装:yum install java-1.8.0-openjdk* -y检查是否安装合适 2. 用JDK安装包安装 查看是否已经安装JDK&#xff1…

RabbitMQ插件详解:rabbitmq_message_timestamp【Rabbitmq 五】

欢迎来到我的博客,代码的世界里,每一行都是一个故事 RabbitMQ时空之旅:rabbitmq_message_timestamp的奇妙世界 前言什么是rabbitmq_message_timestamprabbitmq_message_timestamp 的定义与作用:如何在 RabbitMQ 中启用消息时间戳&…

寒冷冬天,再次撕下了新能源汽车的续航遮羞布,北方真不适合

随着懂车帝的冬测,新能源汽车的冬天续航成为关注焦点,电池性能在冬天里缩减众所周知,不过从来没有机构告诉消费者,到底冬天电池的续航会减少多少,如今这一切显然暴露在人们眼前了。 懂车帝的冬测显示除了特别优秀的新能…

YOLOv8 | 代码逐行解析(一) | 项目目录构造分析

一、本文介绍 Hello,大家好这次给大家带来的不是改进,是整个YOLOv8项目的分析,整个系列大概会更新7-10篇左右的文章,从项目的目录到每一个功能代码的都会进行详细的讲解,同时YOLOv8改进系列也突破了三十篇文章&#x…

探索Linux服务器配置信息的命令

目录 前言1 uname2 lscpu3 free4 df5 lspci6 lsusb7 lshw结语 前言 Linux系统提供了许多命令,用于获取和查看服务器的软硬件配置信息。这些命令可以帮助管理员和用户了解系统的状态、资源使用情况以及硬件设备的相关信息。以下是一些常用的命令以及它们的作用、使用…