LoRA学习笔记

Background

  1. 全参微调
    在这里插入图片描述
    全量微调指的是,在下游任务的训练中,对预训练模型的每一个参数都做更新。例如图中,给出了Transformer的Q/K/V矩阵的全量微调示例,对每个矩阵来说,在微调时,其d*d个参数,都必须参与更新。
  • 全量微调的显著缺点是,训练代价昂贵。例如GPT3的参数量有175B,我等单卡贵族只能望而却步,更不要提在微调中发现有bug时的覆水难收。同时,由于模型在预训练阶段已经吃了足够多的数据,收获了足够的经验。
  • 因此我只要想办法给模型增加一个额外知识模块,让这个小模块去适配我的下游任务,模型主体保持不变(freeze)即可。
  1. 局部微调办法

Adapter Tuning:
在这里插入图片描述

  • 图例中的左边是一层Transformer Layer结构,其中的Adapter就是我们说的“额外知识模块”;右边是Adatper的具体结构。在微调时,除了Adapter的部分,其余的参数都是被冻住的(freeze),这样我们就能有效降低训练的代价。

但这样的设计架构存在一个显著劣势:添加了Adapter后,模型整体的层数变深,会增加训练速度和推理速度,原因是:

  • 需要耗费额外的运算量在Adapter上
  • 当我们采用并行训练时(例如Transformer架构常用的张量模型并行),Adapter层会产生额外的通讯量,增加通讯时间

Prefix Tuning

在这里插入图片描述

通过对输入数据增加前缀(prefix)来做微调。当然,prefix也可以不止加载输入层,还可以加在Transformer Layer输出的中间层。

对于GPT这样的生成式模型,在输入序列的最前面加入prefix token,图例中加入2个prefix token,在实际应用中,prefix token的个数是个超参,可以根据模型实际微调效果进行调整。

对于BART这样的Encoder-Decoder架构模型,则在x和y的前面同时添加prefix token。在后续微调中,我们只需要冻住模型其余部分,单独训练prefix token相关的参数即可,每个下游任务都可以单独训练一套prefix token。


  • 那么prefix的含义是什么呢?

prefix的作用是引导模型提取x相关的信息,进而更好地生成y。
例如,我们要做一个summarization的任务,那么经过微调后,prefix就能领悟到当前要做的是个“总结形式”的任务,然后引导模型去x中提炼关键信息;
如果我们要做一个情感分类的任务,prefix就能引导模型去提炼出x中和情感相关的语义信息,以此类推。这样的解释可能不那么严谨,但大家可以大致体会一下prefix的作用。


Prefix Tuning虽然看起来方便,但也存在以下两个显著劣势;

  1. 较难训练,且模型的效果并不严格随prefix参数量的增加而上升,这点在原始论文中也有指出
  2. 会使得输入层有效信息长度减少。为了节省计算量和显存,我们一般会固定输入数据长度。增加了prefix之后,留给原始文字数据的空间就少了,因此可能会降低原始文字中prompt的表达能力。

LoRA

全参数微调太贵,Adapter Tuning存在训练和推理延迟,Prefix Tuning难训且会减少原始训练数据中的有效文字长度,那是否有一种微调办法,能改善这些不足呢?

  • 在这样动机的驱动下,作者提出了LoRA(Low-Rank Adaptation,低秩适配器)这样一种微调方法。

在这里插入图片描述
在这里插入图片描述

核心思想 - SVD

在这里插入图片描述
在这里插入图片描述

  • 小小的总结一下:W矩阵SVD分解(近似1),然后取三个分解矩阵的top r行(近似2)= W最重要的特征

SVD Code

import torch
import numpy as np
torch.manual_seed(0)# ------------------------------------
# n:输入数据维度
# m:输出数据维度
# ------------------------------------
n = 10
m = 10# ------------------------------------
# 随机初始化权重W
# 之所以这样初始化,是为了让W不要满秩,
# 这样才有低秩分解的意义
# ------------------------------------
nr = 10
mr = 2
W = torch.randn(nr,mr)@torch.randn(mr,nr)# ------------------------------------
# 随机初始化输入数据x
# ------------------------------------
x = torch.randn(n)# ------------------------------------
# 计算Wx
# ------------------------------------
y = W@x
print("原始权重W计算出的y值为:\n", y)# ------------------------------------
# 计算W的秩
# ------------------------------------
r= np.linalg.matrix_rank(W)
print("W的秩为: ", r)# ------------------------------------
# 对W做SVD分解
# ------------------------------------
U, S, V = torch.svd(W)# ------------------------------------
# 根据SVD分解结果,
# 计算低秩矩阵A和B
# ------------------------------------
U_r = U[:, :r]
S_r = torch.diag(S[:r])
V_r = V[:,:r].t()B = U_r@S_r # shape = (d, r)
A = V_r     # shape = (r, d)# ------------------------------------
# 计算y_prime = BAx
# ------------------------------------
y_prime = B@A@xprint("SVD分解W后计算出的y值为:\n", y)print("原始权重W的参数量为: ", W.shape[0]*W.shape[1])
print("低秩适配后权重B和A的参数量为: ", A.shape[0]*A.shape[1] + B.shape[0]*B.shape[1])
  • 输出的结果不变,参数量减小很多
原始权重W计算出的y值为:tensor([ 3.3896,  1.0296,  1.5606, -2.3891, -0.4213, -2.4668, -4.4379, -0.0375,-3.2790, -2.9361])
W的秩为:  2
SVD分解W后计算出的y值为:tensor([ 3.3896,  1.0296,  1.5606, -2.3891, -0.4213, -2.4668, -4.4379, -0.0375,-3.2790, -2.9361])
原始权重W的参数量为:  100
低秩适配后权重B和A的参数量为:  40

很有意思的自相矛盾

在这里插入图片描述

超参数 α \alpha α

在这里插入图片描述

实验验证
尽管理论上我们可以在模型的任意一层嵌入低秩适配器(比如Embedding, Attention,MLP等),但LoRA中只选咋在Attention层嵌入,并做了相关实验

在这里插入图片描述

LoRA使用

下游任务的example

LoRA源码

class LoRALayer():def __init__(self, r: int, # 矩阵的秩lora_alpha: int, # 超参数alora_dropout: float,merge_weights: bool,):self.r = rself.lora_alpha = lora_alpha# Optional dropoutif lora_dropout > 0.:self.lora_dropout = nn.Dropout(p=lora_dropout)else:self.lora_dropout = lambda x: x# Mark the weight as unmergedself.merged = Falseself.merge_weights = merge_weights

Embedding层

class Embedding(nn.Embedding, LoRALayer):# LoRA implemented in a dense layerdef __init__(self,num_embeddings: int,embedding_dim: int,r: int = 0,lora_alpha: int = 1,merge_weights: bool = True,**kwargs):nn.Embedding.__init__(self, num_embeddings, embedding_dim, **kwargs)LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=0,merge_weights=merge_weights)# Actual trainable parametersif r > 0:self.lora_A = nn.Parameter(self.weight.new_zeros((r, num_embeddings)))self.lora_B = nn.Parameter(self.weight.new_zeros((embedding_dim, r)))self.scaling = self.lora_alpha / self.r# Freezing the pre-trained weight matrixself.weight.requires_grad = Falseself.reset_parameters()def reset_parameters(self):nn.Embedding.reset_parameters(self)if hasattr(self, 'lora_A'):# initialize A the same way as the default for nn.Linear and B to zeronn.init.zeros_(self.lora_A)nn.init.normal_(self.lora_B)def train(self, mode: bool = True):nn.Embedding.train(self, mode)if mode:if self.merge_weights and self.merged:# Make sure that the weights are not mergedif self.r > 0:self.weight.data -= (self.lora_B @ self.lora_A).transpose(0, 1) * self.scalingself.merged = Falseelse:if self.merge_weights and not self.merged:# Merge the weights and mark itif self.r > 0:self.weight.data += (self.lora_B @ self.lora_A).transpose(0, 1) * self.scalingself.merged = Truedef forward(self, x: torch.Tensor):if self.r > 0 and not self.merged:result = nn.Embedding.forward(self, x)after_A = F.embedding(x, self.lora_A.transpose(0, 1), self.padding_idx, self.max_norm,self.norm_type, self.scale_grad_by_freq, self.sparse)result += (after_A @ self.lora_B.transpose(0, 1)) * self.scalingreturn resultelse:return nn.Embedding.forward(self, x)

Linear层实现

class Linear(nn.Linear, LoRALayer):# LoRA implemented in a dense layerdef __init__(self, in_features: int, out_features: int, r: int = 0, lora_alpha: int = 1, lora_dropout: float = 0.,fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out)merge_weights: bool = True,**kwargs):nn.Linear.__init__(self, in_features, out_features, **kwargs)LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout,merge_weights=merge_weights)self.fan_in_fan_out = fan_in_fan_out# Actual trainable parametersif r > 0:self.lora_A = nn.Parameter(self.weight.new_zeros((r, in_features)))self.lora_B = nn.Parameter(self.weight.new_zeros((out_features, r)))self.scaling = self.lora_alpha / self.r# Freezing the pre-trained weight matrixself.weight.requires_grad = Falseself.reset_parameters()if fan_in_fan_out:self.weight.data = self.weight.data.transpose(0, 1)def reset_parameters(self):nn.Linear.reset_parameters(self)if hasattr(self, 'lora_A'):# initialize A the same way as the default for nn.Linear and B to zeronn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))nn.init.zeros_(self.lora_B)def train(self, mode: bool = True):def T(w):return w.transpose(0, 1) if self.fan_in_fan_out else wnn.Linear.train(self, mode)if mode:if self.merge_weights and self.merged:# Make sure that the weights are not mergedif self.r > 0:self.weight.data -= T(self.lora_B @ self.lora_A) * self.scalingself.merged = Falseelse:if self.merge_weights and not self.merged:# Merge the weights and mark itif self.r > 0:self.weight.data += T(self.lora_B @ self.lora_A) * self.scalingself.merged = True       def forward(self, x: torch.Tensor):def T(w):return w.transpose(0, 1) if self.fan_in_fan_out else wif self.r > 0 and not self.merged:result = F.linear(x, T(self.weight), bias=self.bias)            result += (self.lora_dropout(x) @ self.lora_A.transpose(0, 1) @ self.lora_B.transpose(0, 1)) * self.scalingreturn resultelse:return F.linear(x, T(self.weight), bias=self.bias)class MergedLinear(nn.Linear, LoRALayer):# LoRA implemented in a dense layerdef __init__(self, in_features: int, out_features: int, r: int = 0, lora_alpha: int = 1, lora_dropout: float = 0.,enable_lora: List[bool] = [False],fan_in_fan_out: bool = False,merge_weights: bool = True,**kwargs):nn.Linear.__init__(self, in_features, out_features, **kwargs)LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout,merge_weights=merge_weights)assert out_features % len(enable_lora) == 0, \'The length of enable_lora must divide out_features'self.enable_lora = enable_loraself.fan_in_fan_out = fan_in_fan_out# Actual trainable parametersif r > 0 and any(enable_lora):self.lora_A = nn.Parameter(self.weight.new_zeros((r * sum(enable_lora), in_features)))self.lora_B = nn.Parameter(self.weight.new_zeros((out_features // len(enable_lora) * sum(enable_lora), r))) # weights for Conv1D with groups=sum(enable_lora)self.scaling = self.lora_alpha / self.r# Freezing the pre-trained weight matrixself.weight.requires_grad = False# Compute the indicesself.lora_ind = self.weight.new_zeros((out_features, ), dtype=torch.bool).view(len(enable_lora), -1)self.lora_ind[enable_lora, :] = Trueself.lora_ind = self.lora_ind.view(-1)self.reset_parameters()if fan_in_fan_out:self.weight.data = self.weight.data.transpose(0, 1)def reset_parameters(self):nn.Linear.reset_parameters(self)if hasattr(self, 'lora_A'):# initialize A the same way as the default for nn.Linear and B to zeronn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))nn.init.zeros_(self.lora_B)def zero_pad(self, x):result = x.new_zeros((len(self.lora_ind), *x.shape[1:]))result[self.lora_ind] = xreturn result

卷积层

class ConvLoRA(nn.Module, LoRALayer):def __init__(self, conv_module, in_channels, out_channels, kernel_size, r=0, lora_alpha=1, lora_dropout=0., merge_weights=True, **kwargs):super(ConvLoRA, self).__init__()self.conv = conv_module(in_channels, out_channels, kernel_size, **kwargs)LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=merge_weights)assert isinstance(kernel_size, int)# Actual trainable parametersif r > 0:self.lora_A = nn.Parameter(self.conv.weight.new_zeros((r * kernel_size, in_channels * kernel_size)))self.lora_B = nn.Parameter(self.conv.weight.new_zeros((out_channels//self.conv.groups*kernel_size, r*kernel_size)))self.scaling = self.lora_alpha / self.r# Freezing the pre-trained weight matrixself.conv.weight.requires_grad = Falseself.reset_parameters()self.merged = Falsedef reset_parameters(self):self.conv.reset_parameters()if hasattr(self, 'lora_A'):# initialize A the same way as the default for nn.Linear and B to zeronn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))nn.init.zeros_(self.lora_B)def train(self, mode=True):super(ConvLoRA, self).train(mode)if mode:if self.merge_weights and self.merged:if self.r > 0:# Make sure that the weights are not mergedself.conv.weight.data -= (self.lora_B @ self.lora_A).view(self.conv.weight.shape) * self.scalingself.merged = Falseelse:if self.merge_weights and not self.merged:if self.r > 0:# Merge the weights and mark itself.conv.weight.data += (self.lora_B @ self.lora_A).view(self.conv.weight.shape) * self.scalingself.merged = Truedef forward(self, x):if self.r > 0 and not self.merged:return self.conv._conv_forward(x, self.conv.weight + (self.lora_B @ self.lora_A).view(self.conv.weight.shape) * self.scaling,self.conv.bias)return self.conv(x)class Conv2d(ConvLoRA):def __init__(self, *args, **kwargs):super(Conv2d, self).__init__(nn.Conv2d, *args, **kwargs)class Conv1d(ConvLoRA):def __init__(self, *args, **kwargs):super(Conv1d, self).__init__(nn.Conv1d, *args, **kwargs)# Can Extend to other ones like thisclass Conv3d(ConvLoRA):def __init__(self, *args, **kwargs):super(Conv3d, self).__init__(nn.Conv3d, *args, **kwargs)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.hqwc.cn/news/88629.html

如若内容造成侵权/违法违规/事实不符,请联系编程知识网进行投诉反馈email:809451989@qq.com,一经查实,立即删除!

相关文章

虚拟化技术:云计算发展的核心驱动力

文章目录 虚拟化技术的概念和作用虚拟化技术的优势虚拟化技术对未来发展的影响结论 🎉欢迎来到AIGC人工智能专栏~虚拟化技术:云计算发展的核心驱动力 ☆* o(≧▽≦)o *☆嗨~我是IT陈寒🍹✨博客主页:IT陈寒的博客🎈该系…

Git基础教程-常用命令整理:学会Git使用方法和错误解决

目录 一、了解Git的基本概念 二、Git的安装和配置 Git的安装 Git的配置 用户信息 文本编辑器 差异分析工具 查看配置信息 三、Git的基本操作 基本原理 基本操作命令 基本操作示例 场景一:创建新仓库 场景二:拉取并编辑远程仓库 四、常见问…

docker项目实战

目录 1、使用mysql:5.6和 owncloud 镜像,构建一个个人网盘。 1)拉取mysql:5.6和owncloud镜像 2)后台运行容器 3)通过ip:端口的方式访问owncloud 2、安装搭建私有仓库 Harbor 1)首先准备所需包 2)安装h…

【PLSQL】PLSQL基础

文章目录 一:记录类型1.语法2.代码实例 二:字符转换三:%TYPE和%ROWTYPE1.%TYPE2.%ROWTYPE 四:循环1.LOOP2.WHILE(推荐)3.数字式循环 五:游标1.游标定义及读取2.游标属性3.NO_DATA_FOUND和%NOTFO…

CSS 滚动容器与固定 Tabbar 自适应的几种方式

问题 容器高度使用 px 定高时,随着页面高度发生变化,组件展示的数量不能最大化的铺满,导致出现底部留白。容器高度使用 vw 定高时,随着页面宽度发生变化,组件展示的数量不能最大化的铺满,导致出现底部留白…

三、pikachu之文件上传

文章目录 1、文件上传概述2、客户端检测2.1 客户端检测原理及绕过方法2.2 实际操作之client check 3、服务端检测3.1 MIME type3.3.1 检测原理3.3.2 绕过方法3.3.3 实际操作之MIME type 3.2 文件内容检测3.2.1 检测原理3.2.2 绕过方式3.2.3 实际操作之getimagesize() 3.3 其他服…

MySQL binlog的几种日志录入格式以及区别

🏆作者简介,黑夜开发者,CSDN领军人物,全栈领域优质创作者✌,CSDN博客专家,阿里云社区专家博主,2023年6月CSDN上海赛道top4。 🏆数年电商行业从业经验,历任核心研发工程师…

L1-044 稳赢(Python实现) 测试点全过

题目 大家应该都会玩“锤子剪刀布”的游戏:两人同时给出手势,胜负规则如图所示: 现要求你编写一个稳赢不输的程序,根据对方的出招,给出对应的赢招。但是!为了不让对方输得太惨,你需要每隔K次就…

STM32之17.PWM脉冲宽度调制

一LED0脉冲宽度调制在TIM14_CHI&#xff0c;先将LED&#xff08;PF9&#xff09;代码配置为AF推挽输出模式&#xff0c;将PF9引脚连接到TIM14&#xff0c; #include <stm32f4xx.h>static GPIO_InitTypeDef GPIO_InitStruct;void Led_init(void) {//打开端口F的硬件时钟&a…

Android——基本控件(下)(十九)

1. 菜单&#xff1a;Menu 1.1 知识点 &#xff08;1&#xff09;掌握Android中菜单的使用&#xff1b; &#xff08;2&#xff09;掌握选项菜单&#xff08;OptionsMenu&#xff09;的使用&#xff1b; &#xff08;3&#xff09;掌握上下文菜单&#xff08;ContextMenu&am…

运维Shell脚本小试牛刀(二)

运维Shell脚本小试牛刀(一) 运维Shell脚本小试牛刀(二) 一: if---else.....fi 条件判断演示 [rootwww shelldic]# cat checkpass.sh #!/bin/bash - # # # # FILE: checkpass.sh # USAGE: ./checkpass.sh # DESCRI…

简易虚拟培训系统-UI控件的应用3

目录 Button组件的组成 Button组件方法1-在Button组件中设置OnClick()回调 Button组件方法2-在脚本中添加Button类的监听 上一篇使用了文件流读取硬盘数据并显示在Text组件中&#xff0c;本篇增加使用按钮来控制显示哪一篇文字信息。 Button组件的组成 1. 新建Button&#…