论文笔记--TinyBERT: Distilling BERT for Natural Language Understanding

论文笔记--TinyBERT: Distilling BERT for Natural Language Understanding

  • 1. 文章简介
  • 2. 文章概括
  • 3 文章重点技术
    • 3.1 Transformer Distillation
    • 3.2 两阶段蒸馏
  • 4. 数值实验
  • 5. 文章亮点
  • 5. 原文传送门
  • 6. References

1. 文章简介

  • 标题:TinyBERT: Distilling BERT for Natural Language Understanding
  • 作者:Xiaoqi Jiao, Yichun Yin, Lifeng Shang, Xin Jiang, Xiao Chen, Linlin Li, Fang Wang, Qun Liu
  • 日期:2019
  • 期刊:arxiv preprint

2. 文章概括

  文章提出了一种两阶段的BERT蒸馏模型TinyBERT。TinyBERT在GLUE上击败了所有当前的SOTA蒸馏BERT模型[1],且参数量仅为SOTA的38%,推理时间仅为SOTA的31%。此外TinyBERT在所有GLUE任务中平均表现约为96.8%,几乎完美还原BERT的能力。
  TinyBERT的整体学习步骤如下
整体架构

3 文章重点技术

3.1 Transformer Distillation

  所谓Transformer Distillation(TD),即对Transformer架构的蒸馏。假设教师模型和学生模型的层数分别为 N N N M M M,则首先定义一个映射函数 n = g ( m ) n=g(m) n=g(m)表示用学生模型的第 m m m层去学习教师模型的第 n = g ( m ) n=g(m) n=g(m)层的信息。文章通过数值实验选用了 g ( m ) = 3 m g(m)=3m g(m)=3m。定义第 0 0 0层为嵌入层,第 M + 1 M+1 M+1层为预测层,则我们可以将模型的损失函数写作 L m o d e l = ∑ x ∈ X ∑ m = 0 M + 1 λ m L l a y e r ( f m S ( x ) , f g ( m ) T ( x ) ) (1) \mathcal{L}_{model} = \sum_{x\in\mathcal{X}} \sum_{m=0}^{M+1} \lambda_m \mathcal{L}_{layer} (f_m^S(x), f_{g(m)}^T(x)) \tag{1} Lmodel=xXm=0M+1λmLlayer(fmS(x),fg(m)T(x))(1),其中 L l a y e r \mathcal{L}_{layer} Llayer表示 l a y e r layer layer层的损失函数, f m S ( x ) , f g ( m ) T ( x ) f_m^S(x), f_{g(m)}^T(x) fmS(x),fg(m)T(x)分别表示学生和教师模型在第 m m m g ( m ) g(m) g(m)层的函数, λ m \lambda_m λm为超参数,表示第 m m m层的重要性。下面为针对不同层的蒸馏方式

  • Transformer-layer Distillation:
    Transformer-layer Distillation
    如上图所示,Transformer-layer Distillation包含以下两种蒸馏方法
    • Attention based distillation:蒸馏注意力机制矩阵,损失函数为 L a t t n = 1 h ∑ i = 1 h M S E ( A i S , A i T ) (2) \mathcal{L}_{attn} = \frac 1h \sum_{i=1}^h MSE(A_i^S, A_i^T) \tag{2} Lattn=h1i=1hMSE(AiS,AiT)(2),其中 h h h为多头注意力机制的head数目, M S E MSE MSE表示Mean Squared Error, A i S , A i T A_i^S, A_i^T AiS,AiT分别表示学生模型和教师模型的注意力矩阵。
    • hidden tsates based distillation:蒸馏隐藏层(即FFN的输出层)状态,蒸馏的损失函数为 L h i d n = M S E ( H S W h , H T ) (3) \mathcal{L}_{hidn} = MSE(H^SW_h, H^T) \tag{3} Lhidn=MSE(HSWh,HT)(3),其中 H S , H T H^S, H^T HS,HT分别表示学生模型和教师模型的隐藏层状态, W h W_h Wh为可学习的参数,旨在将学生模型的隐藏向量映射到和教师模型隐藏状态相同的高维空间
  • Embedding-layer Distillation:对嵌入层进行蒸馏,损失函数为 L e m b d = M S E ( E S W e , E T ) (4) \mathcal{L}_{embd} = MSE(E^SW_e, E^T) \tag{4} Lembd=MSE(ESWe,ET)(4),其中 E S , E T E^S, E^T ES,ET分别表示学生模型和教师模型的嵌入层向量, W e W_e We和上述 W h W_h Wh作用相同,旨在将学生模型的嵌入向量映射到和教师模型嵌入向量相同的高维空间
  • Prediction-layer Distillation:采用损失函数 L p r e d = C E ( z T / t , z S / t ) (5) \mathcal{L}_{pred} =CE(z^T/t, z^S/t) \tag{5} Lpred=CE(zT/t,zS/t)(5),其中 z S , z T z^S, z^T zS,zT分别表示学生模型和教师模型的输出logits, t t t表示蒸馏的温度。此设置参考原始蒸馏论文中的设置。
      最后,将上述所有损失函数进行统一,得到 ( 1 ) (1) (1)式中的损失函数可表示为 L l a y e r = { L e m b d , m = 0 L h i d n + L a t t n , M ≥ m > 0 L p r e d , m = M + 1 \mathcal{L}_{layer} = \begin{cases}\mathcal{L}_{embd}, &m = 0\\\mathcal{L}_{hidn} + \mathcal{L}_{attn}, &M\ge m >0\\\mathcal{L}_{pred}, &m=M+1\end{cases} Llayer= Lembd,Lhidn+Lattn,Lpred,m=0Mm>0m=M+1

3.2 两阶段蒸馏

  TinyBERT采用两阶段蒸馏:general distillation和task-specific distillation,每一步骤通过上节介绍的蒸馏方式进行蒸馏

  • General Distillation:使用原始的BERT模型作为教师模型在大量无标注文本语料库上蒸馏得到General TinyBERT
  • Task-specific Distillation:通过数据增强构造一个下游任务的数据集,使用微调后的BERT在增强后的数据集上对general TinyBERT进行蒸馏,得到TinyBERT模型,这里相当于使用general TinyBERT作为第二次蒸馏的初始模型。具体来说,文章采用的数据增强方法为:首先使用BERT/GloVe预测随机掩码掉的单词,然后使用最相近的单词代替掩码位置,并随机将其增强入数据集。具体算法如下
    data-augumentation

4. 数值实验

  文章用BERT[1]原文训练方法训练了和TinyBERT模型大小相同的 BERT TINY \text{BERT}_{\text{TINY}} BERTTINY模型,对比 BERT TINY \text{BERT}_{\text{TINY}} BERTTINY,TinyBERT, BERT BASE \text{BERT}_{\text{BASE}} BERTBASE,DistilBERT[2]等先进的BERT蒸馏模型,得到以下实验结果
- BERT TINY \text{BERT}_{\text{TINY}} BERTTINY相比于 BERT BASE \text{BERT}_{\text{BASE}} BERTBASE 性能下降很多

  • TinyBERT相比于 BERT TINY \text{BERT}_{\text{TINY}} BERTTINY有大幅的性能提升,说明文章提出的KD算法是有效的
  • TinyBERT和当前的SOTA蒸馏模型(DistilBERT)等相比参数量降低28%,推理速度快3.1倍,且模型表现提升了4.4%
  • TinyBERT相比于KaTeX parse error: Expected '}', got 'EOF' at end of input: …T}_{\text{BASE}参数量降低7.5倍,速度快9.4倍,效果为BERT的96.8%,基本还原BERT能力

5. 文章亮点

  文章提出了对Transformer的两阶段蒸馏方法,相比于当前的SOTA蒸馏模型速度更快、参数量更小、表现更加出色。TinyBERT基本完美还原BERT在GLUE任务上的分析能力,可在对存储、运行效率要求更高的场景,如移动设备,作为BERT的替代模型。

5. 原文传送门

TinyBERT: Distilling BERT for Natural Language Understanding

6. References

[1] 论文笔记–BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding
[2] 论文笔记–DistilBERT, a distilled version of BERT: smaller, faster, cheaper and lighter

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/20359.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

GPON MAC SFP ONU模块介绍与应用

伴随着网络通讯技术的发展,pon无源光网络正逐步走进人们的视野;在这之前你是否仅知道以太网接入?相比与以太网接入,pon作为一种点到多点网络,具有运维成本低、服务范围广、资源占用少等优势;我们最为熟知的…

FlinkCDC第四部分-同步mysql到mysql,ctrl就完事~(flink版本1.17.1)

本文介绍了不同源单表-单表同步,不同源多表-单表同步。 注:此版本支持火焰图 Flink版本:1.17.1 环境:Linux CentOS 7.0、jdk1.8 基础文件: flink-1.17.1-bin-scala_2.12.tgz、 flink-connector-jdbc-3.0.0-1.16.…

centos7根分区、文件系统扩容

1、 输入lsblk,查看到新硬盘sde,根目录现71G. 2、 创建分区fidisk /dev/sde 3、 刷新分区 partprobe /dev/sde,并创建物理卷 pvcreate /dev/sde1 4、 查看卷组名 vgdisplay 5、 将物理卷扩展到卷组 vgextend centos /dev/sde1 6、 查看逻辑巻…

拉丁语翻译器有哪些?一分钟快速分享

拉丁语翻译器有哪些?拉丁语是一种古老的语言,现在已经不再作为主要的交流工具使用。然而,在某些学术领域和文化传承中,拉丁语仍然具有重要地位。因此,当我们需要翻译拉丁语时,使用翻译器可以提高效率和准确…

立式oled拼接屏有哪些产品优点?

葫芦岛oled拼接屏是一种高清晰度的显示屏,由多个oled屏幕拼接而成。它可以用于广告牌、展览、演示、会议等场合,具有高亮度、高对比度、高色彩饱和度、高刷新率等优点,能够吸引人们的眼球,提高信息传递效果。 葫芦岛oled拼接屏的优…

Nodejs快速搭建简单的HTTP服务器,并发布公网远程访问

文章目录 前言1.安装Node.js环境2.创建node.js服务3. 访问node.js 服务4.内网穿透4.1 安装配置cpolar内网穿透4.2 创建隧道映射本地端口 5.固定公网地址 前言 Node.js 是能够在服务器端运行 JavaScript 的开放源代码、跨平台运行环境。Node.js 由 OpenJS Foundation&#xff0…

ens33没有inet地址

1)切换到根用户 su - root 按提示输入密码(不切换到根用户没有权限修改文件) (2)输入cd /etc/sysconfig/network-scripts/ (3)输入vi ifcfg-ens33 ifcfg-ens33 (4)光标移…

SQL力扣练习(六)

目录 1. 部门工资前三高的所有员工(185) 题解一(dense_rank()窗口函数) 题解二(自定义函数) 2.删除重复的电子邮箱(196) 题解一 题解二(官方解析) 3.上升的温度(197) 解法一(DATEDIFF())…

阿里云AliYun物联网平台使用-设备添加以及模拟设备端上云

一、前言 上一篇文章提到,我们已经申请了免费的阿里云平台,下面需要将我们的设备在阿里云上进行注册和申请,以便于我们的数据上云。 二、步骤 注册产品(设备模型) 在产品页面,点击 "创建产品" 。…

windows下使用arp 协议

/ //自动扫描局域网存活主机 本程序是利用arp协议去获取局域网中的存活主机 arp协议概述 地址解析协议,即ARP(Address Resolution Protocol),是根据IP地址获取物理地址的一个TCP/IP协议。主机发送信息时将包含目标IP地址的ARP请…

python散记

"""字符串格式化的两种方法"""name"sans" age18 math_score90.56 english_score88.8print(f"这个学生的名字叫{name},年龄{age},数学分数是{math_score},总分是{math_scoreenglish_score}") print("这个学生的名字叫%s…

克服 ClickHouse 运维难题:ByteHouse 水平扩容功能上线

前言 对于分析型数据库产品,通过增加服务节点实现集群水平扩容,并提升集群性能和容量,是运维的必要手段。 但是对于熟悉 ClickHouse 的工程师而言,听到“扩容”二字一定会头疼不已。开源 ClickHouse 的 MPP 架构导致扩容成本高&…