大模型增量预训练新技巧:解决灾难性遗忘

大家好,目前不少开源模型在通用领域具有不错的效果,但由于缺乏领域数据,往往在一些垂直领域中表现不理想,这时就需要增量预训练和微调等方法来提高模型的领域能力。

但在领域数据增量预训练或微调时,很容易出现灾难性遗忘现象,也就是学会了垂直领域知识,但忘记了通用领域知识,之前介绍过增量预训练以及领域大模型训练技巧。

今天给大家带来一篇增量预训练方法-Llama-Pro,对LLMs进行Transformer块扩展后,增量预训练过程中仅对新增块进行训练,有效地进行模型知识注入,并且极大程度地避免灾难性遗忘。

图片

LLaMA Pro: Progressive LLaMA with Block Expansion

LLaMA Pro: Progressive LLaMA with Block Expansion
Paper: https://arxiv.org/abs/2401.02415
Github: https://github.com/TencentARC/LLaMA-Pro

文章目录

    • 技术交流群
    • 用通俗易懂方式讲解系列
    • 块扩展方法
    • 实验细节
    • 讨论分析
    • 写在最后

技术交流群

前沿技术资讯、算法交流、求职内推、算法竞赛、面试交流(校招、社招、实习)等、与 10000+来自港科大、北大、清华、中科院、CMU、腾讯、百度等名校名企开发者互动交流~

我们建了大模型面试与技术交流群, 想要进交流群、需要源码&资料、提升技术的同学,可以直接加微信号:mlc2060。加的时候备注一下:研究方向 +学校/公司+CSDN,即可。然后就可以拉你进群了。

方式①、微信搜索公众号:机器学习社区,后台回复:加群
方式②、添加微信号:mlc2060,备注:技术交流

资料1
在这里插入图片描述

用通俗易懂方式讲解系列

  • 用通俗易懂的方式讲解:自然语言处理初学者指南(附1000页的PPT讲解)
  • 用通俗易懂的方式讲解:1.6万字全面掌握 BERT
  • 用通俗易懂的方式讲解:NLP 这样学习才是正确路线
  • 用通俗易懂的方式讲解:28张图全解深度学习知识!
  • 用通俗易懂的方式讲解:不用再找了,这就是 NLP 方向最全面试题库
  • 用通俗易懂的方式讲解:实体关系抽取入门教程
  • 用通俗易懂的方式讲解:灵魂 20 问帮你彻底搞定Transformer
  • 用通俗易懂的方式讲解:图解 Transformer 架构
  • 用通俗易懂的方式讲解:大模型算法面经指南(附答案)
  • 用通俗易懂的方式讲解:十分钟部署清华 ChatGLM-6B,实测效果超预期
  • 用通俗易懂的方式讲解:内容讲解+代码案例,轻松掌握大模型应用框架 LangChain
  • 用通俗易懂的方式讲解:如何用大语言模型构建一个知识问答系统
  • 用通俗易懂的方式讲解:最全的大模型 RAG 技术概览
  • 用通俗易懂的方式讲解:利用 LangChain 和 Neo4j 向量索引,构建一个RAG应用程序
  • 用通俗易懂的方式讲解:使用 Neo4j 和 LangChain 集成非结构化知识图增强 QA
  • 用通俗易懂的方式讲解:面了 5 家知名企业的NLP算法岗(大模型方向),被考倒了。。。。。
  • 用通俗易懂的方式讲解:NLP 算法实习岗,对我后续找工作太重要了!。
  • 用通俗易懂的方式讲解:理想汽车大模型算法工程师面试,被问的瑟瑟发抖。。。。
  • 用通俗易懂的方式讲解:基于 Langchain-Chatchat,我搭建了一个本地知识库问答系统
  • 面试了字节大模型算法岗(实习),快被问哭了。。。。

块扩展方法

块扩展,顾名思义,就是在原始模型中每个Transformer块或者某几个Transformer块后增加一个Transformer块,但为了保持扩展后的模型输出保持不变,需要增加的块为恒等块(输入输出相同),如下图所示。

图片

在构建恒等块过程中,主要是将多头注意力层和FFN层中的最后一个线性层(Linear)权重置为0变成Zero-Linear,即可保持经过该块的输入输出一致。

PS:论文附录A中写了大段的推导公式来证明,在此不做过多介绍。

块的增加方式是,对原始模型的 个Transformer块分成 组,每组中包含 个Transformer块,对于每组后添加 个恒等块。代码实现具体如下:

model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.float16)
ckpt = model.state_dict()# original_layers是模型原始层数,layers是模型最后达到层数
split = int(original_layers / (layers - original_layers))layer_cnt = 0output = {}
for i in range(original_layers):for k in ckpt:if ('layers.' + str(i) + '.') in k:output[k.replace(('layers.' + str(i) + '.'), ('layers.' + str(layer_cnt) + '.'))] = ckpt[k]layer_cnt += 1if (i+1) % split == 0:for k in ckpt:if ('layers.' + str(i) + '.') in k:if 'down_proj' in k or 'o_proj' in k:output[k.replace(('layers.' + str(i) + '.'), ('layers.' + str(layer_cnt) + '.'))] = torch.zeros_like(ckpt[k])else:output[k.replace(('layers.' + str(i) + '.'), ('layers.' + str(layer_cnt) + '.'))] = ckpt[k]layer_cnt += 1assert layer_cnt==layers
for k in ckpt:if not 'layers' in k:output[k] = ckpt[k]torch.save(output, output_path)

实验细节

数据由代码和数学组成,其中代码数据采用The-Stack-Dedup数据集中Python语言部分共22B Token,数学数据采用Proof-Pile-2数据集中AlgebraicStack、OpenWebMath和ArXiv部分共55B,详细如下表所示。

图片

数据分布

基础模型为LLaMA2-7B模型,通过块扩展方法将32层模型扩展到40层,其中 、 、 ,每个组从4个Transformer块扩展到5个Transformer块。

对于代码和数学数据进行增量预训练,批量大小为1024,序列最大长度为4096,预热比率为6%,学习率为2e-4,采用余弦学习率调度器,BF16混合精度训练,权重衰减为0.1。使用16个NVIDIA H800 GPU进行了15900个步骤的训练,大约耗费2830个GPU/小时。

在ARC、HellaSwag、MMLU、TruthfulQA、Winogrande、GSM8K、GSM8K-PoT、HumanEval、MBPP等多个评测数据集中进行评测,可以看出,在保持通用任务能力不下降的情况下,数学和代码能力较原始LLaMA2-7B模型有很大提升。

图片

图片

讨论分析

对比块扩展方法与正常训练和Lora方法之间的区别,采用TRACE基准利用总体性能(OP)和逆向转移(BWT)指标进行评估。,如下表所示,块扩展方法整体提升较大。

图片

对比块个数对块扩展方法的影响,进行了不同个数块的实验,并且对比了MoE的方法,训练损失如下,MoE方法的损失下降程度与添加四个块相当。

图片

在代码和法律(16.7B)领域数据下进行增量预训练,在通用任务以及领域任务上比较不同个数块之间的差异,同时比较扩展块全部添加到模型底部或顶部之间的差别,如下所示。可以发现块个数为8时效果最佳,并且不能直接将扩展块全部堆积在头部或尾部,需要分开插入。

图片

写在最后

该方法主要通过增加恒定块扩展模型层数,使模型在增量训练过程中仅训练新增层、冻结原始层,保持模型原有能力,防止模型出现灾难性遗忘现象。

但有两点存疑:

  • 目前来说mistral要好于llama,为啥不用mistral进行实验

  • 不用恒定块,性能会差多少

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/451353.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

c++重载、隐藏和覆盖

重载 函数名字相同&#xff0c;但参数列表或者返回值不同一组函数要重载必须处在同一作用域中 class Base { public:Base(int data0):m_a(data){}void show(){cout<<"Base::show()"<<endl;}void show(int){cout<<"Base:show(int)"<&l…

STM32--SPI通信协议(2)W25Q64简介

一、W25Q64简介 1、W25Qxx中的xx是不同的数字&#xff0c;表示了这个芯片不同的存储容量&#xff1b; 2、存储器分为易失性与非易失性&#xff0c;主要区别是存储的数据是否是掉电不丢失&#xff1a; 易失性存储器&#xff1a;SRAM、DRAM&#xff1b; 非易失性存储器&#xff…

YOLOv8进阶 | 如何用yolov8训练自己的数据集(以安全帽佩戴检测举例)

前言&#xff1a;Hello大家好&#xff0c;我是小哥谈。YOLOv8是一种目标检测算法&#xff0c;它是YOLO&#xff08;You Only Look Once&#xff09;系列算法的最新版本。本节课就带领大家如何基于YOLOv8来训练自己的目标检测模型&#xff0c;本次作者就以安全帽佩戴检测为案例进…

C#(C Sharp)学习笔记_前言及Visual Studio Code配置C#运行环境【一】

前言 这可以说是我第一次正式的踏入C#的学习道路&#xff0c;我真没想过我两年前是怎么跳过C#去学Unity3D游戏开发的&#xff08;当然了&#xff0c;游戏开发肯定是没有成功的&#xff0c;都是照搬代码&#xff09;。而现在&#xff0c;我真正地学习一下C#&#xff0c;就和去年…

2024亿级密码泄露事件:涉及7084万个邮箱账号

近日&#xff0c;热门漏洞通知服务HIBP所有者特洛伊・亨特&#xff08;Troy Hunt&#xff09;发布博文&#xff0c;表示在暗网上发现了超大规模的泄漏数据集&#xff0c;被称为Naz.API列表。该数据集包含7084万个电子邮件地址以及超过1亿个密码凭证&#xff0c;至少有超过40万 …

idea创建spring项目

一、环境 window10 IDEA 2022.2.3 maven-3.8.6 二、创建spring项目 1、新建Maven项目 File -> New -> Project 然后如下图选中Maven Archetype&#xff0c;在Archetype&#xff0c;选中maven-archetype-webapp&#xff0c;点击Create 2、配置maven 默认是使用IDEA内…

CCReportAdv的一个配置技巧

关于CCReportAdv CCReportAdv是我们推出的基于经典WinCC/TIA WinCC Prof.的一款报表控件。它支持导入Excel模板&#xff0c;可以灵活生成美观的数据报表。 配置示例 CCReportAdv功能非常强大。通过简单的配置就可以生成客户需要的报表。以下面这款报表为例&#xff0c;参见下面…

基于SpringBoot的家电销售展示网页的设计与实现

文章目录 项目介绍主要功能截图&#xff1a;部分代码展示设计总结项目获取方式 &#x1f345; 作者主页&#xff1a;超级无敌暴龙战士塔塔开 &#x1f345; 简介&#xff1a;Java领域优质创作者&#x1f3c6;、 简历模板、学习资料、面试题库【关注我&#xff0c;都给你】 &…

Python||五城P.M.2.5数据分析与可视化_使用复式柱状图分析各个城市的P.M.2.5月度差异情况(上)

目录 1.北京市空气质量月度差异 2.成都市空气质量月度差异 3.上海市空气质量月度差异 五城P.M.2.5数据分析与可视化_使用复式柱状图分析各个城市的P.M.2.5月度差异情况 1.北京市空气质量月度差异 import numpy as np import pandas as pd import matplotlib.pyplot as plt#读入…

小白Linux学习笔记-Linux文件系统和磁盘管理

Linux文件系统和磁盘管理 文章目录 Linux文件系统和磁盘管理文件系统资源虚拟化文件系统的概念文件系统的类型文件系统的结构文件系统的区别文件系统的简单操作dfdu 磁盘的分割、格式化与挂载分割 fdisk磁盘格式化 mkfs挂载mount 的用法mount 的查看umount /etc/fstab 将永久生…

ReactNative实现宽度变化实现的动画效果

效果如上图所示,通过修改设备宽度实现动画效果 import React, {useRef, useEffect, useState} from react; import {Animated, Text, View, Image} from react-native;const FadeInView = props => {const fadeAnim = useRef(new Animated.Value(0)).current;React.useEff…

基于YOLOv8算法的照片角度分类项目实践

目录 一、任务概述二、YOLOv8算法简介2.1 算法改进2.2 算法特点2.3 网络结构2.4 性能比较 三、工程实践3.1 安装算法框架库ultralytics3.2 库存照片预处理3.2.1 提取所有图片3.2.2 去除冗余的相同照片3.2.3 去除无车辆照片3.2.4 随机提取指定数量的图片 3.3 照片朝向分类3.3.1 …