论文浅尝 | 少样本学习的语言模型的持续训练

be30bd99acfe779ad2b14937ab10519a.png

笔记整理:王贵涛,东南大学硕士,研究方向为自然语言处理

链接:https://github.com/UIC-Liu-Lab/CPT

一、动机

克服灾难性遗忘(CF)是持续学习(CL)的一个主要目标。目前有许多方法,例如基于正则化的方法、基于重放的方法以及基于参数隔离的方法。从头开始训练一个大型的语言模型是非常困难且昂贵的。在领域的最终任务微调之前,使用一个大的未标记领域语料库进行后训练(Post Training),即领域自适应预训练或预微调,可以比直接微调预训练模型获得更好的结果。使用语言本身不断变化的发展,社会事件和来自不同领域的知识来逐步更新语言数据变得越来越重要。由于人类在增量学习方面非常有效,如果能够很少或不被遗忘地模仿这种人类能力,将显著推动人工智能研究的发展。

二、贡献

本文提出了利用未标记域语料库序列增量后训练语言模型,在不忘记其现有知识的情况下不断扩展语言模型的问题。其目标是提高这些领域的少镜头最终任务学习。由此产生的系统被称为CPT(持续后训练)。

三、方法

本研究提出方法CPT(Continual Post Training),是一种用于后训练的CL系统。从预训练的多模态模型开始,使用未标记的语料库对域序列的多模态进行后训练。一旦一个任务被训练好了,它的数据就不再可访问了。在任何时候,所产生的持续训练后的多模态模型都可以被训练领域中的最终任务所使用。这是在CL的任务增量学习设置中,当稍后需要使用任务的学习模型时,提供任务ID 。

CPT对多模态模型进行持续的后训练,通过插入到预训练模型的每个transformer层中的两个持续学习插件(称为CL插件)的模块来实现的。CL插件的灵感来自于适配器。虽然适配器可以隔离不同的任务,但需要为每个任务分配一个新的适配器,并且在不同任务的适配器之间不能共享任何知识。然而,CL插件是一个持续学习系统,它可以通过所有领域共享的适配器来学习一系列任务。图1给出了添加到预训练模型中的两个CL插件的CPT架构。

3b7c0d258ccf03e70955ca7b5ba710fb.png

图1 加入CL插件的CPT结构

在后训练任务中,只训练两个CL插件。原始预训练过的多模态模型的组成部分是固定的。而在最终任务的微调中,所有组件都是可训练的。CL插件是一个带有任务掩码机制的双层全连接网络。它需要两个输入:来自transformer层的前馈层的隐藏状态和任务增量学习所需的任务ID  。在一个CL插件中,任务掩码表示特定于任务的神经元,用于处理CF。由于任务掩码是可微的,所以整个CPT可以进行端到端训练。

学习新领域包括两个主要步骤:(1)学习领域  及其掩码,以供将来使用。(2)在每个旧任务的每一层应用掩码,阻止梯度流,保护旧任务的模型。

(1)学习任务掩码以克服CF。在学习每个任务  时,在CL插件中的每一层上训练一个伪二进制掩码  ,表明对该任务很重要的神经元,借用硬注意的想法,并利用任务ID嵌入来训练掩码。对于任务ID  ,其嵌入  由可微的确定性参数组成,可以与网络的其他部分一起学习。为了从  中生成任务掩码  ,使用Sigmoid作为一个伪门(掩码)函数。  的计算方法如下:

096f1c8df8f9a49eb3f844f366d66218.png

其中 τ 是一个温度变量,从1线性退回到 τ 。

在正向传递中,给定每个层的输出  ,按对应元素乘以掩码  :

298a6ee21552a60d2d24d517ee343dfb.png

CL插件中最后一层的掩蔽输出  通过跳跃连接输入到多模态预训练模型的下一层。在学习任务  之后,保存最终的  并添加到集合{  }中。

(2)应用任务掩码。在学习新任务  之前,首先在所有旧任务iprev的每一层神经元上积累并设置掩码  ,这样在反向传播中,任务  的梯度  就不会流向这些神经元。由于  是伪二进制,使用最大池化来实现积累和条件梯度:

ba73aba6a7defdddae3600b567c6b546.png

与MaxPool({  })中的1项对应的梯度被设置为0以阻止梯度流,而其他梯度保持不变。这样,旧任务中的神经元就受到了保护。

四、实验

本文使用四个未标注的领域数据集:Yelp Restaurant (Xu et al., 2019), AI Papers (Loet al., 2020), ACL Papers (Lo et al., 2020), AGNews (Zhang et al., 2015) 及其4个相应的最终任务分类数据集。

本文使用6个非持续学习方法和7个自适应的持续学习方法作为基线。

非持续学习基线包括:(1) RoBERTa;(2)Adapter,直接微调预训练模型或适配器;(3) RoBERTa-ONE;(4)Adapter-ONE;(5)Prompt-ONE,使用单独的网络为每个任务建立一个模型,没有知识转移或灾难性遗忘。(6)DEMIX,为每个任务训练一个单独的适配器,并从其之前最相似的先前任务适配器初始化适配器。

7个适应的持续学习基线包括(7) RoBERTa-NCL和(8)Adapter-NCL,一个接一个对领域进行后训练,没有处理灾难性遗忘和转移的机制。其他的是最先进的持续学习基线,调整以适应持续的后训练。

实验结果如1表所示:

表1 实验结果

e93fb1ff4aad0900995fb68f47274101.png

五、总结

本文提出了利用未标记域语料库连续对具有域序列的语言模型进行后训练。并提出了一种有效的计算方法(CPT)。来自任何领域后训练的最终任务都可以微调生成的语言模型。实验结果证明了CPT的有效性。


OpenKG

OpenKG(中文开放知识图谱)旨在推动以中文为核心的知识图谱数据的开放、互联及众包,并促进知识图谱算法、工具及平台的开源开放。

57d496adc2d3790c598805c35c02c86d.png

点击阅读原文,进入 OpenKG 网站。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/19520.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

SGD原理及Pytorch实现

🎏目录 🎈1 SGD       🎄1.1 原理       🎄1.2 构造       🎄1.3 参数详解——momentum ✨1 SGD 损失函数是用来度量模型输出和真实值的偏差,损失函数越小,说明我们的模型效…

804. n的阶乘

链接: https://www.acwing.com/problem/content/806/ 题目: 输入一个整数 nn,请你编写一个函数,int fact(int n),计算并输出 nn 的阶乘。 输入格式 共一行,包含一个整数 nn。 输出格式 共一行,包…

深度学习笔记之Transformer(八)Transformer模型架构基本介绍

机器学习笔记之Transformer——Transformer模型架构基本介绍 引言回顾:简单理解: Seq2seq \text{Seq2seq} Seq2seq模型架构与自编码器自注意力机制 Transformer \text{Transformer} Transformer架构关于架构的简单认识多头注意力机制包含掩码的多头注意力…

40.RocketMQ之高频面试题大全

消息中间件如何选型 RabbitMQ erlang开发,对消息堆积的支持并不好,当大量消息积压的时候,会导致 RabbitMQ 的性能急剧下降。每秒钟可以处理几万到十几万条消息。 RocketMQ java开发,面向互联网集群化功能丰富,对在线业…

MySQL物理文件----日志文件(错误日志、通用查询日志、二进制日志、慢查询日志)

文章目录 MYSQL5.7/8.0支持的几种日志文件1、错误日志(Error log)2、一般或通用查询日志(General query log)3、二进制日志(Binary log)3、1 查看是否开启二进制日志3、2二进制日志开启3、3查看二进制文件位…

简单爬虫项目练习

爬虫项目练习 前言任务基本爬虫框架URL管理器Html 下载器HTML 解析器数据存储器爬虫调度器效果分析 前言 自学,参考书籍为 Python爬虫开发与项目实战 ,具体参考了该书的第六章。过程中出现两个问题: 在 Pycharm 上实现时发现有些库名更改及…

自定义程序包不存在的解决方法

方案一&#xff1a; 在pom文件中加入以下代码 <plugin><groupId>org.apache.maven.plugins</groupId><artifactId>maven-surefire-plugin</artifactId><version>2.4.2</version><configuration><skipTests>true</sk…

java并发编程原理-----线程

目录 上下文切换 java代码创建线程的两种方式 线程的五个状态 线程join方法 多线程之间的影响 上下文切换 CPU的每一个核心同一时刻只能执行一个线程&#xff0c;但是我们会发现电脑同一时刻现实会进行几千个线程&#xff0c;这就是cpu在快速的切换执行线程&#xff0c;由…

最早做「行业化」安全托管MSS的厂商,现在怎么样了?

科技云报道原创。 1家三甲医院&#xff0c;4个院区&#xff0c;9万台终端&#xff0c;数千台服务器&#xff0c;只配备1个安全运营人员&#xff0c;换做任何一家企事业单位都不敢想象&#xff0c;但武汉某医院却实现了7*24小时的自动化监测响应&#xff0c;所有威胁均可在1小时…

基础篇--初识STM32

初识STM32 STM32是什么 ST&#xff1a;意法半导体 M&#xff1a;MCU/MPU32:32位 ST累计推出了&#xff1a;5大类、18个系列、1000多个型号的Cortex内核微控制器 STM32芯片分类 ST中文社区网&#xff1a;https://www.stmcu.org.cn/ ST官网&#xff1a;https://www.st.com …

【从零开始学习CSS | 第一篇】选择器介绍

目录 前言&#xff1a; 选择器介绍&#xff1a; 各类选择器&#xff1a; 总结&#xff1a; 前言&#xff1a; 本文以及后续几篇文章我们将会集中介绍CSS中的常见选择器&#xff0c;选择器的出现可以让我们实现对具体的元素标签进行定制&#xff0c;因此我们要掌握好各类选择…

python+unittest+requests+HTMLRunner搭建接口测试框架,执行用例请求多个不同请求方式的接口

问题描述&#xff1a; 搭建接口测试框架&#xff0c;执行用例请求多个不同请求方式的接口 实现步骤&#xff1a; ① 创建配置文件config.ini&#xff0c;写入部分公用参数&#xff0c;如接口的基本url、测试报告文件路径、测试数据文件路径等配置项 1 [DATABASE] 2 data_addre…