旋转位置编码原理及代码
旋转位置编码
- RoPE(Rotary Positional Encoding)
- 当位置发生偏移时,只需要旋转角度
- 外推性,指大模型输入长度超过预训练文本长度时,输出表现变化情况。
- 使用绝对位置编码具有外推性上的限制,旋转位置编码则没有
- 旋转位置编码还具有一个相对位置编码的优点
- 两个token之间如果具有的相对位置,无论两个token存在句子的哪个位置都会有相同的表示,从下图中可以看出,角度即为相对位置偏移量,token在句中的位置不同可以在坐标系上体现
- 在二维数据中可以看出是在qk上乘了一个旋转矩阵
- 旋转矩阵的特质
- R(a)的转置 = R(-a)(正反两个方向的转动)
- R(a)R(b) = R(a+b)(转动了a+b的角度)
旋转位置编码的核心是找到对应的旋转矩阵
LLaMA中旋转矩阵相关代码
def precompute_freqs_cis(dim: int, seq_len: int, theta: float = 10000.0):# 计算词向量元素两两分组之后,每组元素对应的旋转角度freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))# 生成 token 序列索引 t = [0, 1,..., seq_len-1]t = torch.arange(seq_len, device=freqs.device)# freqs.shape = [seq_len, dim // 2] freqs = torch.outer(t, freqs).float()# torch.polar 的文档# https://pytorch.org/docs/stable/generated/torch.polar.html# 计算结果是个复数向量# 假设 freqs = [x, y]# 则 freqs_cis = [cos(x) + sin(x)i, cos(y) + sin(y)i]freqs_cis = torch.polar(torch.ones_like(freqs), freqs)return freqs_cisdef apply_rotary_emb(xq: torch.Tensor,xk: torch.Tensor,freqs_cis: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:# xq.shape = [batch_size, seq_len, dim]# xq_.shape = [batch_size, seq_len, dim // 2, 2]xq_ = xq.float().reshape(*xq.shape[:-1], -1, 2)xk_ = xk.float().reshape(*xk.shape[:-1], -1, 2)# 转为复数域xq_ = torch.view_as_complex(xq_)xk_ = torch.view_as_complex(xk_)# 应用旋转操作,然后将结果转回实数域# xq_out.shape = [batch_size, seq_len, dim]xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(2)xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(2)return xq_out.type_as(xq), xk_out.type_as(xk)class Attention(nn.Module):def __init__(self, args: ModelArgs):super().__init__()self.wq = Linear(...)self.wk = Linear(...)self.wv = Linear(...)self.freqs_cis = precompute_freqs_cis(dim, max_seq_len * 2)def forward(self, x: torch.Tensor):bsz, seqlen, _ = x.shapexq, xk, xv = self.wq(x), self.wk(x), self.wv(x)xq = xq.view(batch_size, seq_len, dim)xk = xk.view(batch_size, seq_len, dim)xv = xv.view(batch_size, seq_len, dim)# attention 操作之前,应用旋转位置编码xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)# scores.shape = (bs, seqlen, seqlen)scores = torch.matmul(xq, xk.transpose(1, 2)) / math.sqrt(dim)scores = F.softmax(scores.float(), dim=-1)output = torch.matmul(scores, xv) # (batch_size, seq_len, dim)