多头自注意力(Multi-Head Self-Attention, MHSA)是 Transformer 结构的核心模块之一,其时间复杂度主要受输入序列长度 ( L ) 和隐藏维度 ( d ) 影响。下面我们详细分析其计算复杂度。
1. 多头自注意力计算流程
假设输入张量为:
- 输入: ( X \in \mathbb{R}^{L \times d} ),其中 ( L ) 是序列长度,( d ) 是特征维度。
- 头数: ( h ),每个头的维度为 ( d_h = d / h )。
- 权重矩阵:
- 查询矩阵 ( W_Q \in \mathbb{R}^{d \times d} )
- 键矩阵 ( W_K \in \mathbb{R}^{d \times d} )
- 值矩阵 ( W_V \in \mathbb{R}^{d \times d} )
- 输出变换矩阵 ( W_O \in \mathbb{R}^{d \times d} )
整个计算流程如下:
-
计算 ( Q, K, V ) 矩阵:
[
Q = X W_Q, \quad K = X W_K, \quad V = X W_V
]- 每个运算的时间复杂度:( O(L d^2) )。
-
计算注意力权重(点积注意力):
[
A = \text{softmax} \left( \frac{Q K^T}{\sqrt{d_h}} \right)
]- 矩阵乘法 ( Q K^T ) 复杂度:( O(L^2 d) )(每个头的计算量是 ( O(L^2 d_h) ),总共 ( h ) 个头,总体积累后仍为 ( O(L^2 d) ))。
- softmax 操作复杂度:( O(L^2) ),可忽略不计。
-
计算注意力输出:
[
Z = A V
]- 矩阵乘法 ( A V ) 复杂度:( O(L^2 d) )。
-
投影回输出空间:
[
Y = Z W_O
]- 复杂度:( O(L d^2) )。
2. 计算时间复杂度
综合上述步骤,各部分时间复杂度如下:
计算步骤 | 复杂度 |
---|---|
计算 ( Q, K, V ) | ( O(L d^2) ) |
计算注意力权重 ( Q K^T ) | ( O(L^2 d) ) |
计算注意力加权 ( A V ) | ( O(L^2 d) ) |
输出投影 ( Z W_O ) | ( O(L d^2) ) |
总复杂度 | ( O(L d^2 + L^2 d) ) |
通常,在 Transformer 中:
- 如果 ( L ) 较小,( O(L d^2) ) 主导计算;
- 如果 ( L ) 较大,( O(L^2 d) ) 变成主要计算瓶颈。
在标准 Transformer 结构中,( d ) 通常是固定的,比如 512 或 1024,而 ( L ) 可能较大(如 512、1024、2048 等),所以 时间复杂度通常由 ( O(L^2 d) ) 主导,导致 MHSA 在长序列任务中计算开销较大。
3. 复杂度优化方法
为了降低计算复杂度,研究者提出了多种优化方法:
- Sparse Attention(稀疏注意力):仅计算部分 ( QK^T )(如 Longformer、BigBird),将复杂度降至 ( O(L \log L) ) 或 ( O(L) )。
- Low-rank Approximation(低秩近似):使用因式分解(如 Linformer)降低复杂度到 ( O(L d) )。
- Performer / Linformer(线性注意力):用核技巧近似计算 ( QK^T ),降低复杂度到 ( O(L d) )。
- Sliding Window Attention(滑动窗口注意力):限制注意力范围(如 Transformer-XL)。
4. 总结
- 标准多头自注意力的时间复杂度是 ( O(L^2 d) ),主要受 ( QK^T ) 计算 影响。
- 当序列长度 ( L ) 变大时,计算瓶颈主要来自 ( O(L^2 d) )。
- 许多优化方法(如稀疏注意力、低秩分解)可以降低复杂度,使 Transformer 适用于长序列任务。
如果你对某个优化方法的具体实现感兴趣,可以告诉我,我可以详细讲解或提供代码示例! 😊