📖 深度解析子页面

想要更深入了解?我们为面试准备了专题深度解析页面:

1. 背景:序列建模的演进与梯度消失

1.1 RNN 基础与梯度消失的数学根源

循环神经网络在每个时间步 t 迭代更新隐藏状态:

RNN 隐状态更新公式 h_t = f(W_h · h_{t-1} + W_x · x_t + b_h)
y_t = W_y · h_t + b_y

其中 f 通常是 tanh 或 ReLU。然而这种递归结构在处理长序列时遭遇根本性困难。

梯度消失的数学原因

对损失函数 L 关于权重 W_h 的梯度,使用链式法则(反向时间传播,BPTT):

时间反向传播梯度 ∂L/∂W_h = Σ_t (∂L/∂y_t · ∂y_t/∂h_t · ∂h_t/∂W_h)

关键是 ∂h_t/∂W_h 的计算。展开到 k 步之前:

长期依赖梯度路径 ∂h_t/∂h_{t-k} = ∏_{i=0}^{k-1} [f'(h_{t-i}) · W_h]
= (f'(...) · W_h)^k (近似为常数乘积)

当梯度经过多个时间步时,它被乘以 k 个相同的雅可比矩阵。若 W_h 的最大特征值 λ_max < 1(对 tanh 约 0.9-0.95),梯度呈指数衰减:

指数衰减规律 ||∂h_t/∂h_{t-k}|| ≈ λ_max^k → 0 当 k → ∞

当 k > 50-100 时,梯度变得极微小(< 10^-10),导致模型无法学习距离较远的依赖。这就是 梯度消失问题(Vanishing Gradient Problem)。

1.2 LSTM:细胞状态与门控机制

LSTM(Hochreiter & Schmidhuber, 1997)通过引入 细胞状态 C_t 和 门机制 来缓解梯度消失:

遗忘门 (Forget Gate) f_t = σ(W_f · [h_{t-1}, x_t] + b_f)
输入门 (Input Gate) i_t = σ(W_i · [h_{t-1}, x_t] + b_i)
候选细胞状态 (Candidate) C̃_t = tanh(W_c · [h_{t-1}, x_t] + b_c)
细胞状态更新(关键!) C_t = f_t ⊙ C_{t-1} + i_t ⊙ C̃_t
输出门 (Output Gate) o_t = σ(W_o · [h_{t-1}, x_t] + b_o)
h_t = o_t ⊙ tanh(C_t)

其中 σ 是 sigmoid 函数,⊙ 表示逐元素乘法(Hadamard product)。

LSTM 为何缓解梯度消失

关键观察:细胞状态通过 加法 更新,而非 RNN 的乘法:

细胞状态梯度流 ∂C_t/∂C_{t-1} = f_t (直接的加法连接)
梯度可以不经过激活函数的非线性操作直接流动

若 f_t ≈ 1(遗忘门完全打开),梯度可以不减衰地通过多个时间步。即使 f_t ≠ 1,梯度也是加法而非乘法,从而避免指数衰减。这使得 LSTM 能学习距离 > 100 的依赖。

1.3 GRU:LSTM 的简化版本

GRU(Cho et al., 2014)用两个门简化 LSTM:

重置门 (Reset Gate) r_t = σ(W_r · [h_{t-1}, x_t] + b_r)
更新门 (Update Gate) z_t = σ(W_z · [h_{t-1}, x_t] + b_z)
候选隐状态 h̃_t = tanh(W_h · [r_t ⊙ h_{t-1}, x_t] + b_h)
隐状态更新 h_t = (1 - z_t) ⊙ h_{t-1} + z_t ⊙ h̃_t

GRU 参数数量约为 LSTM 的 3/4,训练速度快 10-15%,但在某些任务上性能略低。

1.4 Seq2Seq + Attention 的突破

Bahdanau et al. (2015) 提出 Attention 机制来解决固定向量瓶颈。解码器在每一步动态关注编码器的不同部分:

Bahdanau 注意力(加法) score(s_{t-1}, h_i) = v^T · tanh(W_s · s_{t-1} + W_h · h_i)
α_i = softmax(score(s_{t-1}, h_i))
c_t = Σ_i α_i · h_i

其中 s_{t-1} 是解码器隐状态,h_i 是编码器隐状态,c_t 是上下文向量。

Luong et al. (2015) 改进为乘法注意力(点积),计算更高效:

Luong 注意力(乘法) score(s_{t-1}, h_i) = s_{t-1}^T · W · h_i
α_i = softmax(score(s_{t-1}, h_i) / √d_k)
特性 Bahdanau (加法) Luong (乘法)
计算方式 concat + tanh + 投影 直接点积
时间复杂度 O(d²) O(d)
可扩展性 较差 优秀
对齐质量 稍优 实际更好

面试必考:梯度消失的数学原因?梯度是乘积形式 (f'·W)^k,包含 k 个小于 1 的因子,导致指数衰减。LSTM 如何缓解?通过加法更新细胞状态,使梯度可直接流动。为什么仍不够导致需要 Transformer?RNN 仍需逐时步计算,无法并行化,且对超长序列仍困难。

2. Self-Attention 机制:Transformer 的基石

2.1 Self-Attention 的定义与计算

Self-Attention 允许序列中每个位置与所有其他位置交互。给定输入序列 X ∈ ℝ^(n×d)(n 个 token,维度 d),首先投影为三个向量序列:

线性投影矩阵 Q = X · W_Q ∈ ℝ^(n×d_k)
K = X · W_K ∈ ℝ^(n×d_k)
V = X · W_V ∈ ℝ^(n×d_v)
其中 W_Q, W_K ∈ ℝ^(d×d_k), W_V ∈ ℝ^(d×d_v)
Self-Attention 核心公式 Attention(Q, K, V) = softmax(Q · K^T / √d_k) · V

逐步解析:

  1. 相似度计算 (Q·K^T):得到 n×n 矩阵,其中 (i,j) 元素是第 i 个查询与第 j 个键的点积
  2. 缩放 (÷√d_k):归一化分数,防止 softmax 进入饱和区
  3. 权重计算 (softmax):对每一行应用 softmax,得到权重矩阵 A ∈ ℝ^(n×n)
  4. 加权聚合 (A·V):用权重对值进行加权求和,输出维度 n×d_v
Self-Attention 计算流程图
Input X Q = XW_Q K = XW_K V = XW_V Q · K^T (n × n) ÷ √d_k (Scaling) softmax() A ∈ [0,1]^(n×n) A · V (n × d_v) Output (n × d_v) Self-Attention 核心公式: Attention(Q,K,V) = softmax(Q·K^T / √d_k) · V 时间复杂度: O(n²d) | 空间复杂度: O(n²) [注意力矩阵] | d = d_model

2.2 缩放因子 √d_k 的必要性

这个细节对 Self-Attention 成功至关重要。当 d_k 较大时,点积 Q·K^T 会变得很大。假设 Q 和 K 的元素独立同分布(均值 0,方差 1):

点积方差分析 Q_i · K_i 是 d_k 个独立随机变量之和
Var(Q · K^T) = Var(Σ_{j=1}^{d_k} Q_j · K_j)
= Σ_{j=1}^{d_k} Var(Q_j · K_j)
= d_k · E[Q_j²] · E[K_j²]
= d_k · 1 · 1 = d_k

因此标准差为 √d_k。对于 d_k = 64,点积的典型范围是 [-8, 8],使得 softmax 输入位于 sigmoid 函数的平坦区域:

梯度问题 当 x ∈ [-8, 8] 时,σ'(x) = σ(x)(1-σ(x)) ≈ 0
导致 ∂softmax/∂input ≈ 0,梯度流停滞

通过除以 √d_k,我们将方差归一化为 1:

缩放后的方差 Var(Q · K^T / √d_k) = (1/d_k) · Var(Q · K^T) = 1

现在点积范围约为 [-1, 1],softmax 梯度有合理幅度,模型可以有效学习。

2.3 计算复杂度分析

操作 时间复杂度 空间复杂度 备注
投影 (Q,K,V) O(n·d²) O(n·d) 3 个线性变换
点积 Q·K^T O(n²·d_k) O(n²) 主要计算瓶颈
Softmax O(n²) O(n²) 需存储完整注意力矩阵
A·V O(n²·d_v) O(n·d_v) 输出聚合
总计 O(n²·d) O(n²) n=序列长度,d=维度

实际问题:当 n = 4096(标准 LLM 窗口)时,注意力矩阵有 16M 个元素。若用 float32,需要 64 MB 内存。当 n = 32768 时,暴增至 4 GB!这成为超长上下文的硬瓶颈。

面试必考:Self-Attention 和传统 Attention 的区别?Self-Attention 中 Q、K、V 都来自同一序列,允许序列内部的每个位置与其他所有位置交互。传统 Attention(Seq2Seq)的 Q 来自解码器,K/V 来自编码器。

3. Multi-Head Attention (MHA):多头注意力

3.1 为什么需要多头

单个注意力头的表达能力可能不足。一个头被迫在一个 d_k 维的投影空间中学习所有特征交互(语法、语义、长程依赖等)。多头注意力将问题分解为多个并行的子问题:

单个头的计算 head_i = Attention(Q·W_i^Q, K·W_i^K, V·W_i^V)
其中 W_i^Q, W_i^K, W_i^V ∈ ℝ^(d_model × d_k)
d_k = d_model / h (h 是头数)
多头拼接与投影 MultiHead(Q,K,V) = Concat(head_1, head_2, ..., head_h) · W^O
其中 W^O ∈ ℝ^(h·d_k × d_model)

3.2 参数数量与计算复杂度

标准配置:d_model = 768, h = 12, d_k = 64

MHA 参数数量 参数 = 3 × d_model × d_k × h + d_model × (h·d_k)
= 3 × 768 × 64 × 12 + 768 × 768
= 1,769,472 + 589,824
= 2,359,296 ≈ 2.36M

计算复杂度为什么不增加:每个头处理 d_k 维而非完整 d 维:

单头复杂度 × h = 总复杂度 O(n² × d_k) × h = O(n² × d_model)
与单头相同!

3.3 多头的优势

为什么不用单头(h=1)?

如果 h=1,模型在单一 d_model 维的投影空间中学习所有交互。实验表明,h=1 模型性能下降约 5-10%。多头的好处:(1) 捕捉多种相关性,(2) 更好的梯度流,(3) 正则化效应。

3.4 头数的选择

标准配置为 h = 12 (BERT, GPT-2) 或 h = 32 (GPT-3)。理由:

面试必考:多头注意力的优势?在表达能力不变的情况下增加多样性。为什么 h=12 而不是 h=1 或 h=24?

4. 多查询与分组查询注意力 (MQA & GQA)

4.1 推理时的 KV 缓存瓶颈

在自回归生成中,每个新 token 需要重新计算注意力。为避免重复计算所有历史 token 对,我们缓存它们的 K 和 V:

KV 缓存大小 (MHA) Size = 2 × num_layers × num_heads × d_k × seq_len
= 2 × 32层 × 32头 × 64维 × 2048长度
= 268 MB (仅 KV!)
对于 70B 参数模型,总权重 ~140GB,但 KV 缓存可达 4GB!

这成为推理的主要瓶颈,限制了批处理大小和上下文长度。

4.2 Multi-Query Attention (MQA)

Facebook 的 Falcon-40B/180B 采用的设计:所有注意力头 共享相同的 K 和 V,只有 Q 不同:

MQA 计算 head_i = Attention(Q_i, K_shared, V_shared)
Q_i = X · W_i^Q (每个头不同)
K_shared = X · W^K (所有头共享)
V_shared = X · W^V (所有头共享)
MQA KV 缓存大小 Size = 2 × 32层 × 1头 × 64维 × 2048长度
= 8.4 MB (相比 MHA 减少 97%!)
内存节省惊人,但表达能力可能下降

权衡:MQA 在某些任务上性能略低(约 2-3%),但推理加速 5-10 倍。

4.3 Grouped-Query Attention (GQA):折中方案

Ainslie et al. (2023) 的创新:将 h 个查询头分成 g 个组,每组共享 K 和 V(g < h < hMHA):

GQA 结构 h = 32 查询头,g = 8 组
每个组有 32/8 = 4 个查询头共享 1 组 KV
head_{i,j} = Attention(Q_{i,j}, K_i, V_i)
i ∈ [1, g], j ∈ [1, h/g]
GQA KV 缓存大小 Size = 2 × 32层 × 8组 × 64维 × 2048长度
= 67 MB (相比 MHA 的 268 MB,减少 75%)
相比 MQA 的 8MB,略增但仍显著优化

4.4 三种方案对比

方案 KV 缓存 推理速度 表达能力 典型采用
MHA 268 MB 基准 100% BERT, GPT-2, T5
GQA (g=8) 67 MB (-75%) +25-30% 98-99% LLaMA 2 70B, Mistral 7B
MQA 8 MB (-97%) +50-100% 95-97% Falcon-40B/180B, PaLM

LLaMA 2 在 7B 和 70B 版本中都采用 GQA (8 groups),证明了这个折中的价值。它既保持了 MHA 的性能,又显著降低了推理成本。

面试必考:为什么需要 MQA/GQA?答:推理时 KV 缓存成为主要瓶颈(特别是长上下文和大批量场景)。为什么 K/V 可以共享而 Q 不能?因为 Q 用于匹配不同位置的注意力模式,需要多样化;而 K/V 是被查询的内容,共享不影响多头各自的关注模式。

5. 位置编码 (Positional Encoding)

5.1 Self-Attention 的位置不变性问题

Self-Attention 的计算 softmax(Q·K^T/√d_k)·V 对位置顺序完全不敏感。输入 [x_1, x_2, ..., x_n] 和任意排列得到的输出不同,但模型参数相同。这意味着需要 显式 编码位置信息。

5.2 绝对位置编码:正弦/余弦编码

Vaswani et al. (2017) 提出的标准方案,后来被大量论文采用:

正弦位置编码 PE(pos, 2i) = sin(pos / 10000^(2i/d_model))
PE(pos, 2i+1) = cos(pos / 10000^(2i/d_model))
pos ∈ [0, max_seq_len), i ∈ [0, d_model/2)

具体例子(d_model=512, pos=10):

PE[10, 0] = sin(10 / 10000^0) = sin(10) ≈ -0.544 PE[10, 1] = cos(10 / 10000^0) = cos(10) ≈ -0.839 PE[10, 2] = sin(10 / 10000^(2/512)) = sin(10 / 1.0091) ≈ -0.532 PE[10, 3] = cos(10 / 10000^(2/512)) ≈ -0.847 ... PE[10, 510] = sin(10 / 10000) = sin(0.001) ≈ 0.001 PE[10, 511] = cos(10 / 10000) ≈ 1.0

设计特性

相对位置线性性(关键性质) PE(pos+k) = PE(pos) · R(k) + ...
其中 R(k) 依赖于 k,这允许注意力学习相对距离而非绝对位置

5.3 可学习位置编码

BERT 和 GPT-2 使用可学习的位置嵌入:

可学习位置编码 PE[pos] ∈ ℝ^(d_model), pos = 0, 1, ..., max_seq_len-1
每个 PE[pos] 都是可训练参数(初始化随机)

限制

5.4 RoPE:旋转位置编码(现代标准)

Su et al. (2021) 提出的 RoPE (Rotary Position Embedding) 现已成为 LLaMA、Qwen、Mistral 等所有现代 LLM 的标准:

RoPE 核心思想 对 Q 和 K 的每对维度 (q_{2i}, q_{2i+1}) 应用旋转矩阵:
[q_{2i}', q_{2i+1}'] = [cos(mθ_i), -sin(mθ_i); sin(mθ_i), cos(mθ_i)] · [q_{2i}; q_{2i+1}]
其中 θ_i = 10000^(-2i/d_model), m 是位置
RoPE 矩阵形式 f(x, m) = x ⊗ Rotation(mθ)
即对向量的每个 2D 分块应用旋转

关键性质:两个不同位置 m 和 n 的 Q 和 K 的内积只依赖相对位置 m-n:

RoPE 内积的相对位置性质(最重要!) 其中 θ(q,k) 是 q 和 k 的夹角
只依赖于 (m-n),不依赖于 m 或 n 的绝对值!

这个性质使 RoPE 天然地在 attention 计算中编码相对位置。所有现代 LLM 都采用它,因为:

5.5 ALiBi:替代方案

Ofir Press et al. (2022) 提出更简单的方案:不编码位置信息,而是在 attention 分数中加入线性衰减:

Attention with Linear Biases (ALiBi) Attention(Q,K,V) = softmax((Q·K^T + LinearBias) / √d_k) · V
LinearBias[i,j] = -α_h · |i - j|
α_h 是每个头学习的衰减速率

优势:无位置编码参数,外推能力更好,速度略快。不足:性能通常略低于 RoPE,大多数新模型仍用 RoPE。

面试必考:为什么需要位置编码?Self-Attention 本身完全对位置不敏感,若输入排列改变,注意力分数不变。必须显式编码位置。

6. 完整 Encoder-Decoder 架构

6.1 编码器层堆栈

Transformer 编码器由 N 个相同的层堆叠(标准 N=12 或 24):

编码器单层(Pre-LN 形式,现代标准) z_l = LayerNorm(x_{l-1})
x_l = x_{l-1} + MultiHeadSelfAttn(z_l)
z_l' = LayerNorm(x_l)
x_l = x_l + FFN(z_l')

关键:Pre-LN (先归一化,再子层,再残差),而非原论文的 Post-LN (子层,加,归一化)。

6.2 解码器层堆栈

解码器与编码器的区别:多一个交叉注意力,且自注意力被 因果掩码 限制:

解码器单层 z_l = LayerNorm(x_{l-1})
x_l = x_{l-1} + MaskedMultiHeadSelfAttn(z_l) [仅关注 ≤ 当前位置]
z_l' = LayerNorm(x_l)
x_l = x_l + CrossAttn(z_l', encoder_output) [Q 来自解码器,K,V 来自编码器]
z_l'' = LayerNorm(x_l)
x_l = x_l + FFN(z_l'')

6.3 因果掩码实现

防止解码器看到未来 token,通过在 softmax 前加掩码:

因果掩码定义 mask[i, j] = 0 if j ≤ i (允许关注当前及过去)
= -∞ if j > i (掩盖未来)
scores_masked = Q·K^T + mask
attention_weights = softmax(scores_masked / √d_k)

实现上,通常用上三角矩阵加到分数上。e^(-∞) = 0,所以被掩码的位置的权重为 0。

6.4 完整编码器-解码器数据流

Transformer 编码器-解码器架构
输入文本 Embedding + PosEnc Encoder Stack × N 层 - MultiHeadSelfAttn - FFN - LayerNorm (Pre-LN) Encoder Output 输出文本 (Shifted) Embedding + PosEnc Decoder Stack × N 层 - Masked SelfAttn - CrossAttn - FFN Linear + Softmax Cross-Attention (K,V from Encoder)

7. Feed-Forward Network (FFN):位置式前馈网络

7.1 标准 FFN(ReLU)

标准 FFN 结构 FFN(x) = ReLU(x · W_1 + b_1) · W_2 + b_2
其中 W_1 ∈ ℝ^(d_model × d_ff), W_2 ∈ ℝ^(d_ff × d_model)
d_ff = 4 × d_model (标准配置)

这是一个两层的全连接网络,中间层维度是输入的 4 倍。原始 Transformer 采用此设计。

7.2 GLU 变体(现代标准)

现代 LLM(LLaMA, PaLM, Mistral)用 Gated Linear Unit 替代,效果更好:

SwiGLU (LLaMA 采用) FFN(x) = (Swish(x · W_1) ⊙ (x · W_3)) · W_2
其中 Swish(x) = x · σ(x), σ 是 sigmoid
d_ff = (8/3) × d_model ≈ 2.67 × d_model
GeGLU (替代选择) FFN(x) = (GELU(x · W_1) ⊙ (x · W_3)) · W_2
其中 GELU(x) = x · Φ(x), Φ 是标准正态分布的 CDF

为什么更好

变体 公式 d_ff 性能 采用者
ReLU ReLU(xW_1)W_2 4d 基准 原 Transformer
SwiGLU Swish(xW_1)⊙xW_3)W_2 8d/3 +1-2% LLaMA, Qwen
GeGLU GELU(xW_1)⊙xW_3)W_2 8d/3 +1-2% T5, PaLM

8. 层归一化与残差连接

8.1 Post-LN vs Pre-LN

Post-LN(原始 Transformer)

Post-LN 形式 x_l = LayerNorm(x_{l-1} + Sublayer(x_{l-1}))

问题:深层网络梯度流不稳定,需要 warmup(学习率预热)才能训练。

Pre-LN(现代标准,如 GPT-2+)

Pre-LN 形式(推荐) z_l = LayerNorm(x_{l-1})
x_l = x_{l-1} + Sublayer(z_l)

优势

8.2 RMSNorm 优化

LLaMA 采用 RMSNorm 替代 LayerNorm,计算更快:

LayerNorm y = (x - mean(x)) / √(var(x) + ε) · γ + β
包含中心化和缩放两步
RMSNorm (Root Mean Square Norm) y = x / RMS(x) · γ
其中 RMS(x) = √(1/n · Σ x_i²)
只做缩放,不中心化

优势

8.3 残差连接的数学意义

残差流 x_l = x_{l-1} + f_l(x_{l-1})
梯度: ∂L/∂x_{l-1} = ∂L/∂x_l · (1 + ∂f_l/∂x_{l-1})

关键:梯度中的 "1" 项使得信号直接流向浅层,即使 ∂f_l/∂x_{l-1} 很小也不会消失。这是深层网络能训练的原因。

面试必考:Pre-LN vs Post-LN 有什么区别?为什么 Pre-LN 训练更稳定?Pre-LN 中梯度不经过子层参数就能直接流向前一层,避免了梯度衰减。RMSNorm 相比 LayerNorm 的优势?计算更快,无需中心化,性能相当。

9. Flash Attention:内存高效的注意力计算

9.1 标准 Attention 的内存瓶颈

标准实现必须在内存中存储完整的 n×n 注意力矩阵:

内存使用 Attention_matrix = Q · K^T (n × n 矩阵)
内存 = n² × 4 bytes (float32)
n=4096: 64 MB (可接受)
n=32768: 4 GB (严重瓶颈)

高 GPU 间的 memory bandwidth 是计算速度的主要限制因素,而不是 FLOP。

9.2 Flash Attention 的核心思想

Dao et al. (2022) 的创新:通过 Tiling(分块)和 Kernel Fusion(核函数融合) 避免显式存储注意力矩阵:

Flash Attention 核心步骤 1. 将 Q, K, V 分成块(tiles) 2. 逐块计算注意力并累积结果 3. 只在 SRAM 中保留当前块,不写回 HBM (高带宽内存)

算法概要

  1. 初始化输出 O = 0, 最大值 M_prev = -∞, 归一化因子 Z = 0
  2. 对 K 的每个块 K_j:
    • 计算 S = Q · K_j^T(Q 对应 K_j 的块)
    • 更新最大值 M_curr = max(M_prev, max_row(S))
    • 计算 P = exp(S - M_curr) (安全的 softmax)
    • 累积输出和归一化因子,同时补偿前一块的指数衰减

9.3 内存与速度的改进

IO 复杂度改进 标准: O(n²) 内存,O(n²) I/O
Flash Attn: O(N) 内存 (M 块大小), O(n²d/M) I/O
其中 M ≈ SRAM 大小 / d

实际加速(A100 GPU):

9.4 Flash Attention 的进一步优化

Flash Attention 2 (Dao et al. 2023)

Flash Attention 3 (2024)

Flash Attention 的实际意义

它使得长序列(8K-32K)成为可能。无需 Flash Attention,推理 4K 上下文需 tens of GB 内存;有了它,消耗显著降低,使批处理成为可能。

面试必考:Flash Attention 为什么快?它不是减少了计算(FLOP 相同),而是减少了 IO 操作。通过 tiling 和 kernel fusion,避免了在 HBM 和 SRAM 间的反复移动。

10. 注意力的变体与长序列优化

10.1 稀疏注意力(Sparse Attention)

问题:Dense attention 的 O(n²) 复杂度限制了序列长度。稀疏注意力只计算重要的注意力权重。

Longformer (Beltagy et al. 2020)

BigBird (Zaheer et al. 2021):组合本地、全局和随机注意力。

10.2 线性注意力(Linear Attention)

用核函数近似 softmax,使复杂度从 O(n²) 降到 O(n):

核函数近似 softmax(Q·K^T) ≈ φ(Q) · φ(K)^T
其中 φ 是核函数(如 ELU+1, sigmoid)

问题:性能通常低于标准注意力 5-10%,但对超长序列有用。

10.3 滑动窗口注意力(Sliding Window)

Mistral 7B 采用:每个 token 只关注最近 W 个 token(如 4096)。

因果滑动窗口 attention_mask[i, j] = 0 if j ≤ i and i - j < W
= -∞ 否则

优势:O(n × W) 复杂度,性能只略低于 dense(~1-2%)。

10.4 分布式注意力(Ring Attention、Page Attention)

Ring Attention(Liu et al. 2023):在多 GPU 间分布式计算长序列注意力,支持超长上下文(32K+)。

Page Attention(vLLM):将 KV 缓存分页管理,减少内存碎片,提高 GPU 利用率。

11. Transformer 架构变体与应用

11.1 三种主要范式

架构 结构 预训练任务 典型模型 适用任务
Encoder-only 只有编码器 MLM (Masked LM) + NSP BERT, RoBERTa 分类, NER, 相似度
Decoder-only 只有解码器 CLM (Causal LM) GPT-2, GPT-3, LLaMA 生成, 对话, 推理
Encoder-Decoder 两者都有 多种(MLM, CLM, UAR) T5, BART, mT5 翻译, 摘要, QA

11.2 Encoder-only:BERT 与变体

BERT (Devlin et al. 2019)

RoBERTa (Liu et al. 2019):BERT 的改进版,只用 MLM,移除 NSP,更大数据,更长训练。性能提升明显。

11.3 Decoder-only:GPT 系列为主

GPT-2/3

为什么 Decoder-only 成为主流

11.4 Encoder-Decoder:T5 与应用

T5 (Raffel et al. 2020)

BART (Lewis et al. 2020)

面试必考:Encoder-only vs Decoder-only 的选择?Encoder-only 专注理解(分类),Decoder-only 专注生成。为什么现在主流是 Decoder-only?因为它更简单、更易 scaling,且通过 prompt 可实现所有任务。Encoder-Decoder 架构的优势和劣势?优:灵活,编码和解码独立。劣:参数多,训练复杂。

12. Transformer 关键创新与面试高频题

核心创新点

12.1 Transformer 的面试高频题

题 1:Self-Attention 为什么比 RNN 好?

RNN 逐步计算,无法并行,序列长度 n 需要 O(n) 时间步。Self-Attention 在一步内计算所有位置的交互,可完全并行,时间复杂度只与 attention 矩阵大小 O(n²) 有关,但常数因子使得总体快得多。

题 2:缩放因子 √d_k 的作用?

当 d_k=64 时,Q·K^T 的标准差为 8,导致 softmax 输入范围太大,处于 sigmoid 的平坦区,梯度接近零。除以 √d_k 将标准差归一化为 1,softmax 梯度有合理幅度,模型能有效学习。

题 3:为什么用多头而不是单头?

单头必须在一个 d 维空间学习所有特征交互。多头将其分解为 h 个 d_k 维的子空间,每个头可专注不同特征(语法、语义、长程等),同时总计算量不增加,表达能力大幅增强。

题 4:Pre-LN vs Post-LN,为什么 Pre-LN 更好?

Post-LN 中梯度路径需经过子层(MultiHeadAttn 或 FFN 的参数),深层时容易衰减。Pre-LN 中梯度可直接通过残差连接流向前一层,不被子层参数扰乱,深层网络梯度流更稳定,无需 warmup。

题 5:RoPE 相比可学习位置编码的优势?

RoPE 通过旋转矩阵将位置信息编码为相对位置关系,完全无参数,具有完美的外推能力(可处理 2-4 倍超长序列)。可学习编码最长固定为训练长度,超出需插值,性能下降。

题 6:Flash Attention 的核心思想?

标准 Attention 需在内存中存储 n×n 的注意力矩阵。Flash Attention 通过 tiling 和 kernel fusion,逐块计算注意力并累积结果,避免显式存储矩阵,大幅减少 IO 操作,实现 2-8 倍加速。

题 7:KV 缓存为什么是瓶颈?如何优化?

推理时每个新 token 需缓存其 K 和 V。n=4096 的序列在 MHA 中需 268MB 内存。通过 GQA(分组共享 K/V)可减 75%,MQA 减 97%。这是长上下文和批处理的主要限制。

题 8:GLU (SwiGLU、GeGLU) vs ReLU FFN 的改进?

GLU 用门控机制(第二个线性层作"门")和平滑激活替代简单 ReLU。性能提升 1-2%,同时通过减小 d_ff 可减少参数。LLaMA 用 d_ff=2.67d(而非 4d),性能反而更好。

题 9:Decoder-only 为何成为 LLM 的主流?

Decoder-only 架构更简单(无复杂的编码器-解码器交互),参数更少。通过 prompt 和 in-context learning,可实现理解、生成、推理等所有任务,性能随规模提升更快。Encoder-Decoder 则需特殊微调每种任务。

题 10:Transformer 中哪些成分对训练稳定性最关键?

前三个最关键:(1) 残差连接,避免梯度消失;(2) 层归一化(Pre-LN),使梯度流不被参数扰乱;(3) 缩放的多头注意力(√d_k),梯度幅度合理。这三者结合使得深层网络(100+ 层)可训练。

上一章 ← 首页 深度解析 🔬 Attention 机制演进深度解析