背景
序列建模的挑战
理解为什么需要 Attention 机制
在 Attention 机制出现之前,序列到序列 (Seq2Seq) 模型主要依赖 循环神经网络 (RNN) 实现。但 RNN 面临根本性的瓶颈:
1
信息瓶颈 (Information Bottleneck)
整个输入序列被压缩为单一的固定大小上下文向量 c。长序列的信息大量丧失。
2
长距离依赖困难
梯度通过多个时间步传播时产生梯度消失/爆炸,难以学习长距离的语义关联。
3
线性时间依赖
RNN 必须顺序处理,难以并行化,计算效率低下。
RNN Seq2Seq 的信息瓶颈
Attention 的核心洞察:与其强制将整个输入压缩为单一向量,不如让解码器在每个步骤动态地从输入中选择相关信息。这就是 Attention 的本质。
2015
Bahdanau Attention (加法注意力)
神经机器翻译的突破
Bahdanau et al. 在 2015 年提出的加法注意力 (Additive Attention) 为每个解码时间步计算一个动态的上下文向量。关键创新是使用一个对齐模型(alignment model)来计算注意力权重。
score(s_t, h_i) = v^T tanh(W_s s_t + W_h h_i)
1
计算对齐分数 (Alignment Scores)
对于解码状态 s_t 和每个编码隐藏状态 h_i,通过非线性变换计算相似度。
2
Softmax 归一化
α_ti = softmax_i(score(s_t, h_i)) 得到注意力权重,和为 1。
3
加权求和得上下文
c_t = Σ_i α_ti h_i — 编码表示的加权组合。
4
拼接与解码
[s_t; c_t] 拼接后送入解码器得到输出。
Bahdanau 注意力机制流程
计算复杂度
O(n²d)
n 是序列长度,d 是隐藏维度。对齐分数计算需要 O(n) 次非线性变换。
内存占用
O(n²)
需存储 n×n 的注意力权重矩阵,长序列时成为瓶颈。
优势: 解决了 RNN 的信息瓶颈,每个解码步骤可动态关注不同输入。
劣势: 非线性变换的计算成本高;无法并行计算整个解码序列。
2015
Luong Attention (乘法注意力)
简化高效的注意力机制
同年,Luong et al. 提出了乘法注意力(Multiplicative Attention),通过一个简单的点积操作替代了 Bahdanau 的非线性变换,大幅降低计算成本。
score(s_t, h_i) = s_t^T h_i
1
点积注意力 (Dot)
score = s_t^T h_i。最快,但要求 s_t 和 h_i 维度相同。
2
一般注意力 (General)
score = s_t^T W h_i。通过矩阵 W 适配维度差异,灵活性最好。
3
拼接注意力 (Concat)
score = v^T tanh(W[s_t; h_i])。相当于 Bahdanau,但拼接而非求和。
关键优势:乘法注意力可以用矩阵乘法高效实现。对于向量化计算(GPU)来说,这比逐元素非线性操作快一个数量级。
全局注意力 vs 本地注意力
| 特征 |
Bahdanau |
Luong |
| 打分函数 |
v^T tanh(W_s s_t + W_h h_i) |
s_t^T W h_i (一般形式) |
| 计算速度 |
慢 (非线性操作) |
快 (矩阵乘法) |
| 参数量 |
多 (d_a × (d_s + d_h) + d_a) |
少 (d_s × d_h) |
| 可并行化 |
困难 |
容易 (内积高度并行) |
| 应用 |
小规模、精度优先 |
大规模、效率优先 |
面试考点
Q: 为什么 Luong 注意力比 Bahdanau 快?
A: Bahdanau 在计算对齐分数时使用 v^T tanh(...),需要多次矩阵乘法和非线性激活。Luong 使用点积 s_t^T h_i,直接可用 BLAS 优化,GPU 上的矩阵乘法操作远快于逐步非线性变换。
2017
Self-Attention (自注意力)
Transformer 的核心机制
Self-attention 是指序列中的每个位置都能与同一序列中的所有其他位置建立直接的关联,而不依赖递归结构。这使得整个序列可以并行处理,成为 Transformer 的基石。
Attention(Q, K, V) = softmax(Q K^T / √d_k) V
Q
查询 (Query)
当前位置"想问什么问题"。来自 x_i W_Q。
K
键 (Key)
"每个位置是什么"。来自 x_j W_K。与 Q 计算相似度。
V
值 (Value)
"每个位置的信息"。来自 x_j W_V。被加权平均。
Self-Attention 计算流程
为什么能并行? Self-attention 中,每个位置的输出只依赖输入矩阵,不依赖前一个隐藏状态。因此整个序列可以一次性处理,无需循环迭代。
Causal Mask (因果掩码) 用于解码器:在 Transformer 解码器中,防止模型看到未来 token,掩码注意力权重使其为 -∞。
Attention(Q, K, V) = softmax(Q K^T / √d_k + M) V
2017
Multi-Head Attention (多头注意力)
捕获多个语义子空间
单个注意力头只能学习一种"注意模式"。多头注意力将 d_model 维度分成 h 个头,每个头独立地学习,最后拼接。这使模型能捕获多个不同的语义关系。
MultiHead(Q, K, V) = Concat(head_1, ..., head_h) W_O
其中 head_i = Attention(Q W_Q^i, K W_K^i, V W_V^i)
1
投影到子空间
Q, K, V 分别投影到 h 个不同的 d_k 维子空间 (d_k = d_model / h)。
2
并行计算头部注意力
h 个注意力计算可完全并行,每个消耗 O(n² d_k) 时间。
3
拼接与输出投影
[head_1; ...; head_h] 维度为 n × d_model,再投影回原维度。
多头注意力的并行结构
参数数量
h × 3 × d_k² + d²
由于 d_k = d/h,总参数与单头相同,但计算可并行。
不同头学到的内容
多样化
头1: 语法关系;头2: 语义相似度;头3: 位置偏移...
为什么多头比单头好? 单个 512-dim 的注意力头必须同时捕捉所有类型的关系。多头让每个头专注于一个特定的语义子空间,增加了表达能力。实验证明 h=8 是一个很好的平衡。
2019-2023
MQA & GQA (减少 KV Cache)
加速大模型推理的关键优化
在多头注意力中,每个查询头都有独立的 K 和 V。这在推理时产生巨大的内存开销:KV Cache 大小为 O(batch × seq_len × num_heads × d_k)。MQA 和 GQA 通过共享键值来大幅降低这个开销。
MHA
多头注意力 (原始)
h 个 Q,h 个 K,h 个 V。每个头都独立计算。
MQA
多查询注意力
h 个 Q,1 个 K,1 个 V。所有查询共享一套键值。
GQA
分组查询注意力
h 个 Q,g 个 K,g 个 V (g ≪ h)。查询分成 h/g 个组,每组共享一对 K、V。
三种注意力机制的 KV 缓存对比
| 方案 |
查询头数 |
KV 头数 |
推理速度 |
质量 |
| MHA |
h |
h |
基准 |
最好 |
| MQA |
h |
1 |
最快 |
稍差 (训练仔细的话可弥补) |
| GQA |
h |
h/g |
快 |
很好 (平衡) |
为什么共享 KV 可行? 经验表明,多个查询可以"重用"同一组键值。虽然理论上损失了一些表达力,但实践中调整学习率和初始化,MQA/GQA 的性能与 MHA 接近,同时推理速度快 2-4 倍。
面试考点
Q: GQA 中的 g 应该设为多少?
A: g 通常为 1 (MQA) 或 8 (GQA)。LLaMA 2 70B 使用 g=8,平衡了内存与质量。如果内存充足优先选择 MHA;如果需要最快推理选 MQA;否则 GQA 是一个很好的折中。
2022-2024
Flash Attention (1/2/3)
GPU 内存的革命性优化
标准注意力实现将整个 n×n 注意力矩阵加载到 HBM (高带宽内存) 中。对于长序列,这会导致 OOM。Flash Attention 通过 tiling 和 online softmax 减少 HBM 访问,实现 O(n) 内存 和显著的性能提升。
GPU 内存层级与 Flash Attention 的 Tiling 策略
Online Softmax:
m_new = max(m_old, max(Q K^T))
l_new = l_old × exp(m_old - m_new) + Σ exp(Q K^T - m_new)
FA1
Flash Attention (2022)
核心 tiling 算法,O(n) 内存,2x 快。关键:online softmax。
FA2
改进 (2023)
减少非矩阵乘法 FLOPs,更好的并行性,4x 快。
FA3
H100 优化 (2024)
利用 WGMMA 张量操作和异步 DMA,极致性能。
内存复杂度
O(n) vs O(n²)
对于 n=4096,从 64 MB 降至 <160 KB。
执行时间
2-4x 加速
FA1: 2x;FA2: 3x;FA3: 4x (理想条件)。
为什么这么快? 标准实现将 n×n 的 S 矩阵、n×n 的 P 矩阵写回 HBM。Flash Attention 只在 SRAM 中处理,将 HBM 访问从 O(n²) 降至 O(n),带宽瓶颈消除。
面试考点
Q: Flash Attention 的瓶颈是什么?
A: 取决于硬件。在 A100 上,主瓶颈是 HBM 带宽。Flash Attention 将其改为计算瓶颈(矩阵乘法),利用率更高。在 H100 上,可进一步优化至接近理论峰值。
2020-2023
稀疏与线性注意力
处理超长序列的方法
虽然 Flash Attention 大幅优化了计算,但 O(n²) 的时间复杂度仍对超长序列 (n > 100K) 成为瓶颈。稀疏注意力和线性注意力提供了不同的折中方案。
1
滑动窗口注意力 (Sliding Window)
每个位置只关注最近的 W 个位置。Mistral-7B 使用 W=4096。复杂度: O(n×W)。
2
全局+窗口 (Longformer)
某些 token (如 [CLS]) 全局关注所有位置,其余只看窗口。
3
随机稀疏 (BigBird)
窗口 + 随机采样 + 全局 token。细粒度控制覆盖率。
4
线性注意力 (Kernel Trick)
用核方法近似 softmax,O(n) 复杂度,但表达力下降。
不同稀疏注意力模式的对比
面试考点
Q: Ring Attention 是什么?
A: 用于分布式环境的长序列注意力。将序列分割到多个 GPU,通过环形通信模式循环计算注意力。每个 GPU 处理一个序列块,通过环形同步传递,实现 O(nlog p) 通信复杂度 (p 为 GPU 数)。适合处理百万级长度。
复习
总结与面试题
注意力机制的演进路线与高频问题
注意力机制演进时间线
| 机制 |
时间复杂度 |
空间复杂度 |
并行化 |
应用场景 |
| Bahdanau |
O(n²d) |
O(n²) |
困难 |
序列 Seq2Seq,小规模 |
| Luong |
O(n²d) |
O(n²) |
容易 |
标准神经翻译 |
| Self-Attention |
O(n²d) |
O(n²) |
完全 |
Transformer 基础 |
| Flash Attention |
O(n²d) |
O(n) |
完全 |
主流生产环境 |
| MQA/GQA |
O(n²d) |
O(n) |
完全 |
KV Cache 受限的推理 |
| 稀疏注意力 |
O(n×W) |
O(n) |
完全 |
超长上下文 (>4K) |
| 线性注意力 |
O(n) |
O(n) |
完全 |
极端长序列,精度可接受 |
10+ 高频面试题
Q1
Self-attention 如何实现并行计算?
A: Self-attention 不依赖递归(如 RNN 需要 h_t-1 来计算 h_t)。整个输入序列可一次性投影到 Q、K、V,然后通过矩阵乘法并行计算所有位置的注意力权重,无需循环迭代。这是 Transformer 相比 RNN 的核心优势。
Q2
缩放因子 √d_k 的数学原因是什么?
A: 点积 Q K^T 的方差约为 d_k(当 Q、K 的元素 iid)。大维度下方差增大,softmax 变尖锐,梯度消失。除以 √d_k 使方差回到 1,梯度稳定。可证明这使梯度范数保持恒定。
Q3
多头注意力为什么比单头好?
A: 不同头可学习不同的"投影子空间"。单头 512-dim 必须同时捕捉语法、语义、位置等所有关系。多头让每个头专注一个方面,总参数不增加(d_k = d/h),但表达能力更强。实验验证 h=8~12 效果最好。
Q4
Bahdanau vs Luong 注意力的核心区别?
A: Bahdanau 用非线性打分 v^T tanh(Ws + Wh),精度高但计算慢。Luong 用点积 s^T h,快但需要维度匹配。现代系统(含 Transformer)多用 Luong 的点积形式,通过 Flash Attention 优化性能。
Q5
Flash Attention 如何解决 O(n²) 内存问题?
A: 标准实现在 HBM 中维护整个 n×n 的 S(分数)和 P(概率)矩阵。Flash Attention 用分块 (tiling) 策略:逐块加载 Q、K、V 到 SRAM(快速),在 SRAM 中计算块级注意力,只保存必要的中间值(最大值、指数和),内存从 O(n²) 降至 O(n)。
Q6
GQA 和 MQA 为什么能减少 KV Cache?
A: MHA 为 h 个查询头维护 h 套 KV。推理时,KV Cache 大小为 batch×seq_len×h×d_k,这往往是最大的内存消耗。MQA 共享 1 套 K、V(只有 1 倍),MQA 共享 h/g 套(g 倍)。虽然牺牲一点精度,但推理吞吐量能提升 2-8 倍。
Q7
如何计算 Self-Attention 的 FLOPs?
A: 对于 n×d 的输入:① Q K^T: O(n² d);② softmax: O(n²);③ (P)V: O(n² d)。总计 O(n²d)。对于 seq_len=4096, d=768,约 12B FLOPs。这是为什么长序列注意力成为瓶颈,Flash Attention 和稀疏注意力很关键。
Q8
Decoder 中的因果掩码如何实现?
A: 在 softmax 前,对注意力分数矩阵的上三角(j > i 位置)设为 -∞。softmax(-∞) = 0,使未来位置的权重为 0。这保证了自回归生成时,token 只能看到已生成的历史,不会"作弊"。
Q9
线性注意力的核心思想是什么?
A: 标准注意力用 softmax(QK^T)V,计算 O(n²)。线性注意力用核方法:将 softmax 近似为可分离的核 φ(q)·φ(k),得到 Σ_i φ(k_i)⊙v_i φ(q) = φ(q)·(Σ_i φ(k_i)v_i),变成 O(n)。代价是精度下降,适合超长序列且精度要求低的场景。
Q10
为什么 Luong 注意力更容易被 GPU 加速?
A: Luong 注意力的核心操作是矩阵乘法(Q K^T),这是 GPU BLAS 库(cuBLAS)高度优化的操作。Bahdanau 需要逐步的非线性变换(tanh、element-wise 乘法),无法充分利用 GPU 的向量化能力。矩阵乘法 FLOPs 与内存访问比高,更能隐藏 HBM 延迟。
面试建议:深入理解注意力的数学原理(维度、复杂度分析)、实现细节(online softmax、causal mask)和工程权衡(精度 vs 速度 vs 内存)。能够从一个方法引申到相关方法,展示系统性的思维。