Deep Dive · 上下文扩展技术

从 512 到 100 万 Token
上下文窗口扩展技术全解析

为什么标准 Transformer 的上下文这么难扩展?每一项技术突破的核心思想是什么?本文逐一拆解,配原理图讲清楚。

00根本问题

为什么上下文这么难扩展?

一切问题的根源:注意力机制的平方复杂度

标准 Transformer 里,自注意力机制要让序列中的每个 Token,都和其他所有 Token 计算关系。如果序列长度是 n,那就是 n × n 次配对运算。

Attention(Q, K, V) = softmax( QKᵀ / √d ) · V Q、K、V 矩阵形状均为 [n × d],QKᵀ 的形状是 [n × n]

这个 n × n 的注意力矩阵就是一切麻烦的根源:

注意力矩阵可视化 — 序列越长,矩阵越大
n=8 8×8=64 计算量: 64 显存: 1× n=32 32×32 =1,024 计算量: 1024 显存: 16× n=128 128×128 =16,384 计算量: 16K 显存: 256× n=1024 1024×1024 =1,048,576 计算量: 1M 显存: 16,384× ×4 ×4 ×8 序列长度翻倍 → 注意力矩阵面积变 4 倍(平方增长)

这带来两个严重问题:算力(FLOPs)显存(HBM)都是 O(n²) 增长。实际跑 128K token 的注意力,光是这个矩阵就要占用约 32GB 显存——一块 A100 都不够。

核心矛盾:Transformer 的能力来自"每个词看所有词",但这个"看所有词"在长序列下算不起、存不下。后面所有技术,都是在不同层面解决这个矛盾。
时间复杂度
O(n²·d)
n 个 Token,d 维向量,两两点积
空间复杂度(注意力矩阵)
O(n²)
需要把整个 n×n 矩阵存入显存
128K token 注意力矩阵大小
≈ 32 GB
仅注意力矩阵,还不含权重参数
1M token 注意力矩阵大小
≈ 2 TB
完全不可能硬塞——需要根本性的算法改进
01Flash Attention

Flash Attention

不改数学,只改顺序——用 IO 感知的分块计算绕过显存瓶颈

Flash Attention(Tri Dao et al., 2022)是近年来最重要的工程突破之一。它不改变任何数学结果,只重新安排计算顺序,使得注意力计算几乎不需要把那个巨大的 n×n 矩阵存到显存里。

理解 Flash Attention,先要理解 GPU 的存储层次:

GPU 存储层次结构
SRAM (on-chip cache) 20–40 MB | 19 TB/s 带宽 | 极快 搬运数据(慢) HBM (GPU 显存) 40–80 GB | 2 TB/s 带宽 | 快但慢于 SRAM 搬运数据(更慢) CPU DRAM | TB 级 | 但 PCIe 带宽仅 ~32 GB/s

标准注意力计算的瓶颈不是"算得慢",而是数据搬运太多:它要把整个 n×n 注意力矩阵反复写入/读出 HBM,而 HBM 的带宽远低于 SRAM。Flash Attention 的核心思想是:把数据分成小块,让计算尽量在 SRAM 里完成,减少 HBM 读写次数

Flash Attention 分块计算 vs 标准注意力
标准注意力 n×n 完整注意力矩阵 存在 HBM 中 ① 写入 HBM: QKᵀ → S ② 读出 HBM: softmax(S) ③ 写入 HBM: softmax(S) ④ 读出 HBM: · V → 输出 HBM 读写: O(n²) 次 Flash Attention 在SRAM计算 在SRAM计算 在SRAM计算 在SRAM ① 分块加载 Q/K/V 到 SRAM ② 在 SRAM 内完成该块的计算 ③ 只把结果写回 HBM ④ 用 online softmax 合并各块 HBM 读写: O(n) 次 ↓ 大幅减少

Flash Attention 用了一个叫 online softmax 的技巧,可以在不存下完整矩阵的情况下,用分块的方式得到完全相同的计算结果。这是一个数值等价的重新推导,不是近似。

效果:Flash Attention 2 让长序列训练速度提升 2~4×,显存占用从 O(n²) 降到 O(n)。这是使得 128K+ 上下文从"理论可行"变成"实际可跑"的最关键一步。Flash Attention 3(2024)针对 H100 架构进一步优化,速度再提升约 75%。
02RoPE 位置编码

RoPE:旋转式位置编码

把位置信息编进旋转角度,让模型自然感知"相对距离"

Transformer 的自注意力本身不区分顺序——"猫吃鱼"和"鱼吃猫"的词向量输入是一样的,必须额外注入位置信息。不同的位置编码方案直接影响模型能否扩展到更长序列。

传统方案对比:

方案做法长度外推能力代表模型
绝对正弦位置编码 预先计算每个位置的固定向量,直接加到词嵌入上 差,训练长度即上限 原始 BERT, GPT-2
可学习绝对位置 训练一组可学习的位置 embedding 差,位置数量在训练时固定 GPT-3
ALiBi 不加位置向量,在注意力分数上直接加线性惩罚项 一般,可外推但质量下降 BLOOM
RoPE 把位置编进 Q/K 向量的旋转角度,通过相对角度差来感知位置 好,相对位置自然编码 LLaMA, Mistral, Qwen, GPT-NeoX

RoPE 的核心思想:如果我们把每两个维度视作一个二维平面上的点,那么"位置 m 处的向量"就是"位置 0 处的向量"旋转了 mθ 角度。这样,位置 m 的 Query 和位置 n 的 Key 做点积时,结果只取决于它们的相对距离 (m−n),和绝对位置无关。

RoPE 旋转原理(以 2D 为例)
位置 m=0 x 原始向量 位置 m(旋转 mθ) R(mθ)x 旋转后的 Query/Key 点积结果 Q_m · K_n = f(x_q, x_k, m − n ) 只依赖相对距离! 不依赖绝对位置 → 外推时相对关系依然稳定
q_m = R(mθ) · W_q · x_m
k_n = R(nθ) · W_k · x_n
q_m · k_n = xᵀ_q W_q R((m−n)θ) W_k x_n R 是旋转矩阵;每对维度 (2i, 2i+1) 用不同的基础频率 θᵢ = 10000^(−2i/d)

为什么外推好? 绝对位置编码需要在训练时"见过"第 n 个位置的 embedding;而 RoPE 只关心相对距离,即使推理时遇到训练时没见过的超长位置,相对关系的计算方式不变——只是角度更大而已。但超出训练长度后,角度可能进入没有训练过的范围,精度仍会下降,这是 YaRN 要解决的问题。

03YaRN / LongRoPE

YaRN & LongRoPE:让 RoPE 优雅地外推

训练 4K,推理用 128K——让旋转角度平滑地"缩放"到更长范围

RoPE 在超出训练长度后会退化,原因是:高频维度(小 θ)的旋转角在长距离时变化太快,进入了训练时从未见过的角度范围,模型不知道如何处理。

最朴素的解决方案是位置插值(PI, Position Interpolation):把所有位置编号等比例压缩,让 0~max_len 的位置映射回 0~train_len。但这会压缩近距离的精度。

YaRN:按频率分区处理(NTK-aware interpolation)
维度频率 高频(近距离敏感) 低频(长距离感知) 不插值 维持原始旋转频率 保留近距离精度 θ' = θ 高频维度 线性混合插值 平滑过渡,兼顾 近距离和远距离 θ' = lerp(θ, θ/s) 中频维度 全量插值 大幅拉伸旋转频率 覆盖超长距离 θ' = θ / s 低频维度 s = 目标长度 / 训练长度(如 128K / 4K = 32) 只需少量微调(约 400 steps),模型即可适应新的上下文长度
确定缩放比例 s
s = 目标长度 ÷ 训练长度。例如训练用 4K、目标 128K,则 s = 32
按维度的频率特性分区
高频维度(θ 小)保持不变,低频维度(θ 大)整体除以 s,中间做线性插值
少量微调
用少量长文本数据(约几百步)微调,让模型适应新的角度分布。不需要从头训练
推理时的长度扩展
LongRoPE(2024)更进一步:在推理时动态调整每一层的缩放比例,实现无需微调的 2M token 外推
实际效果:LLaMA 3.1 用 YaRN 变种将上下文从 8K 扩展到 128K,仅用了约 0.1% 的训练计算量做长上下文微调。Phi-3 用类似方法从 4K 扩展到 128K。这使得"长上下文"不再需要昂贵的从头训练。
04GQA / MQA

GQA 与 MQA:精简注意力头

多个 Query 头共享一组 K/V,大幅削减推理时的 KV Cache 显存

标准多头注意力(MHA)中,每个注意力头都有独立的 Q、K、V 投影矩阵。假设有 32 个头,那就有 32 组 K 和 32 组 V,推理时每个 Token 都要缓存这 64 组向量——这是 KV Cache 显存爆炸的直接原因。

三种注意力结构对比
MHA(标准多头) Q₁ K₁ V₁ Q₂ K₂ V₂ Q₃ K₃ V₃ Q₄ K₄ V₄ KV heads = 4 × Q heads KV Cache: 100% MQA(多查询注意力) Q₁ Q₂ Q₃ Q₄ K V KV heads = 1,速度最快 KV Cache: 25% ▼ 质量略降 GQA(分组查询) Q₁ Q₂ K₁, V₁ Q₃ Q₄ K₂, V₂ KV heads = 2,平衡方案 KV Cache: 50% ▼ 质量接近MHA LLaMA 3 使用 GQA:Q heads=32, KV heads=8,显存节省 4× Mistral 7B 使用 GQA:Q heads=32, KV heads=8
为什么对长上下文特别重要? 长上下文推理时,KV Cache 的显存占用随序列长度线性增长。128K token + 70B 参数模型的 KV Cache 可能需要 100GB+。GQA 将 KV heads 从 64 减到 8,显存占用直接砍掉 8×,这是长上下文可以在单机运行的关键。
05稀疏 & 滑动窗口注意力

稀疏注意力 & 滑动窗口

不是每个词都需要看所有词——有选择地计算注意力

前面的技术都是在"同样计算全量注意力"的前提下优化。稀疏注意力走了另一条路:根本就不计算所有的 n×n 注意力对,只计算"有意义"的那些,把复杂度从 O(n²) 降到 O(n·k)。

四种注意力模式(黑色 = 计算,白色 = 跳过)
全量注意力 ■■■ ■■■ ■■■ O(n²) 每个看所有 滑动窗口(Mistral) O(n·w) 只看附近 w 个 全局+局部(Longformer) O(n·(g+w)) 特殊 token 看全局 分块因果(GPT类) O(n²/2) 因果遮蔽(标准) ■ 深色 = 计算注意力 · 空白 = 跳过(直接置为 −∞) Mistral 7B 的滑动窗口大小 w = 4096,配合滚动 KV Cache 使用 代价:远距离信息只能逐层"接力传递",可能丢失精度

滑动窗口注意力的关键局限:位置 1000 的 Token 无法直接"看到"位置 1 的信息,必须依靠中间层的激活值逐层传递。在实践中,Mistral 通过叠加 32 层,每层窗口 4K,理论感受野约 4K×32 = 128K——但实验表明远距离信息的质量不如全量注意力。

最新趋势:2024 年后,大部分前沿模型(GPT-4o, Claude 3, Gemini 1.5)转向了"全量注意力 + Flash Attention + GQA"的组合,而不是用稀疏注意力——因为全量注意力质量更好,而 Flash Attention + GQA 已经把开销控制在可接受范围内。
06KV Cache 原理与优化

KV Cache:推理时最大的显存杀手

自回归生成为什么必须缓存?又如何在长上下文下把它压下去?

大语言模型生成文本时,每生成一个新 Token,就要把整个上下文重新跑一遍注意力——这代价太大。KV Cache 的思想是:把历史 Token 的 Key 和 Value 向量缓存起来,生成新 Token 时只需计算新 Token 的 Q,然后和缓存的 K/V 做注意力即可。

KV Cache 工作原理:第 t 步生成新 token
KV Cache(已缓存,不重复计算) Token₁ K₁ [d×1] V₁ [d×1] Token₂ K₂ V₂ · · · Token_{t-1} K_{t-1} V_{t-1} 新 Token_t(只计算这个) Token_t Q_t 只有 Q K_t → 存入Cache V_t → 存入Cache Q_t 与所有缓存 K 做注意力 KV Cache 显存 = 2 × n_layers × n_kv_heads × d_head × seq_len × bytes_per_elem 例:70B 模型,128K context → KV Cache ≈ 160 GB(FP16);INT4 量化后 ≈ 40 GB

KV Cache 的显存优化策略:

KV Cache 量化(INT8 / INT4)
把缓存的 K/V 向量从 FP16(2 字节/元素)压缩到 INT8(1 字节)或 INT4(0.5 字节)。精度略微损失,但显存减半/减四倍。GPTQ-KV、KVQuant 等方案可以做到几乎无损的 INT4。
PagedAttention(vLLM)
借鉴操作系统"内存分页"思想。把 KV Cache 切成固定大小的"页",不连续存储,按需分配,避免内存碎片。让服务器同时处理数百并发请求成为可能。
Token 淘汰 / 稀疏 KV Cache
动态丢弃"不重要"的历史 Token 的 KV。StreamingLLM 保留所有"Sink Tokens"(开头)+ 滑动窗口尾部,实现无限长上下文(但真正理解受限)。H2O 算法根据注意力得分保留重要 Token。
Prefix Caching / Prompt Caching
多次请求共享同一个"系统提示"时,只计算一次其 KV Cache,后续请求复用。Anthropic 的 Claude、OpenAI API 均支持此功能。可将重复长提示的成本降低 90%。
07长文本训练策略

有窗口,不等于会用

工程解决了容量,训练解决了能力——让模型真正"善用"长上下文

一个常见误区:认为只要技术上支持了 128K 上下文,模型就能"真正理解"128K 的内容。实际上,草堆中找针(Needle-in-a-Haystack)测试表明,很多号称支持长上下文的模型,在上下文超过 32K 后,提取中间段信息的能力就急剧下降——这不是窗口大小的问题,而是训练问题。

"Lost in the Middle" 现象:模型对上下文位置的注意力分布
信息在上下文中的位置(开始 → 结尾) (Liu et al. 2023, "Lost in the Middle") 检索准确率 开头信息记得好 中间段容易"遗忘" 结尾信息还行 ---- 长文本专项训练后(更均匀)

针对这个问题,现代训练策略包括:

长文本数据配比
专项数据混合
在预训练后期或微调阶段,专门加入长文档数据(书籍、代码仓库、长对话),占比约 5–20%
渐进式长度课程
Curriculum Learning
先用短序列训练,逐步增大序列长度。让模型先学会短距离推理,再学长距离,避免训练初期损失爆炸
RLHF 长上下文对齐
针对性强化学习
专门构造需要从长文本中提取、综合信息的问题,通过人类反馈或 AI 反馈强化模型的"深度阅读"能力
长文本合成数据
数据增强
自动生成"把答案藏在文档中间"的 QA 对,强制模型练习检索中段信息,对抗 Lost in the Middle
08技术对比总结

各技术综合对比

解决的问题、代价、现状

技术 解决什么问题 复杂度改善 代价 / 局限 主流应用
Flash Attention 2/3 HBM 带宽瓶颈,显存 O(n²) IO: O(n²)→O(n)
速度 2–4×
算法复杂,需与硬件强绑定 几乎所有现代模型训练
RoPE + YaRN 位置编码无法外推 无计算量变化,仅改变位置编码 极远距离仍有精度下降;需少量微调 LLaMA 3, Qwen2, Mistral, Phi-3
GQA / MQA KV Cache 显存占用 KV Cache:减少 4–8× 质量略低于 MHA;需重新训练 LLaMA 3, Gemma, Mistral, Falcon
滑动窗口注意力 注意力计算 O(n²) O(n²)→O(n·w) 远距离信息靠层间传递,质量下降 Mistral 7B(早期版本)
KV Cache 量化 推理显存峰值 显存:减少 2–4× 量化误差;需专门的量化感知训练 vLLM, TensorRT-LLM, llama.cpp
PagedAttention 显存碎片,并发吞吐 并发量提升 2–4× 实现复杂,主要是推理服务层优化 vLLM, SGLang
长文本训练数据 有窗口但不会用 无复杂度改变,提升质量 数据稀缺;训练成本高 Claude 3, GPT-4, Gemini 1.5
现代长上下文模型的标准"配方"(以 Gemini 1.5 Pro 1M / Claude 3 200K 为例):
Flash Attention 3(IO 优化)+ RoPE + YaRN 外推 + GQA(减少 KV 显存)+ 专项长文本训练数据 + 长上下文 RLHF 对齐 + 推理端 KV Cache 量化 + PagedAttention 服务部署

没有单一银弹——是多个技术协同叠加的结果。
未来方向
Mamba / SSM 线性复杂度架构 Linear Attention 近似 外部记忆 (Memory Augmented) RAG + 长上下文混合 Speculative Decoding 加速 MoE × 长上下文