Deep Dive · 上下文扩展技术
从 512 到 100 万 Token
上下文窗口扩展技术全解析
为什么标准 Transformer 的上下文这么难扩展?每一项技术突破的核心思想是什么?本文逐一拆解,配原理图讲清楚。
00根本问题
为什么上下文这么难扩展?
一切问题的根源:注意力机制的平方复杂度
标准 Transformer 里,自注意力机制要让序列中的每个 Token,都和其他所有 Token 计算关系。如果序列长度是 n,那就是 n × n 次配对运算。
Attention(Q, K, V) = softmax( QKᵀ / √d ) · V
这个 n × n 的注意力矩阵就是一切麻烦的根源:
注意力矩阵可视化 — 序列越长,矩阵越大
这带来两个严重问题:算力(FLOPs)和显存(HBM)都是 O(n²) 增长。实际跑 128K token 的注意力,光是这个矩阵就要占用约 32GB 显存——一块 A100 都不够。
核心矛盾:Transformer 的能力来自"每个词看所有词",但这个"看所有词"在长序列下算不起、存不下。后面所有技术,都是在不同层面解决这个矛盾。
时间复杂度
O(n²·d)
n 个 Token,d 维向量,两两点积
空间复杂度(注意力矩阵)
O(n²)
需要把整个 n×n 矩阵存入显存
128K token 注意力矩阵大小
≈ 32 GB
仅注意力矩阵,还不含权重参数
1M token 注意力矩阵大小
≈ 2 TB
完全不可能硬塞——需要根本性的算法改进
01Flash Attention
Flash Attention
不改数学,只改顺序——用 IO 感知的分块计算绕过显存瓶颈
Flash Attention(Tri Dao et al., 2022)是近年来最重要的工程突破之一。它不改变任何数学结果,只重新安排计算顺序,使得注意力计算几乎不需要把那个巨大的 n×n 矩阵存到显存里。
理解 Flash Attention,先要理解 GPU 的存储层次:
GPU 存储层次结构
标准注意力计算的瓶颈不是"算得慢",而是数据搬运太多:它要把整个 n×n 注意力矩阵反复写入/读出 HBM,而 HBM 的带宽远低于 SRAM。Flash Attention 的核心思想是:把数据分成小块,让计算尽量在 SRAM 里完成,减少 HBM 读写次数。
Flash Attention 分块计算 vs 标准注意力
Flash Attention 用了一个叫 online softmax 的技巧,可以在不存下完整矩阵的情况下,用分块的方式得到完全相同的计算结果。这是一个数值等价的重新推导,不是近似。
效果:Flash Attention 2 让长序列训练速度提升 2~4×,显存占用从 O(n²) 降到 O(n)。这是使得 128K+ 上下文从"理论可行"变成"实际可跑"的最关键一步。Flash Attention 3(2024)针对 H100 架构进一步优化,速度再提升约 75%。
02RoPE 位置编码
RoPE:旋转式位置编码
把位置信息编进旋转角度,让模型自然感知"相对距离"
Transformer 的自注意力本身不区分顺序——"猫吃鱼"和"鱼吃猫"的词向量输入是一样的,必须额外注入位置信息。不同的位置编码方案直接影响模型能否扩展到更长序列。
传统方案对比:
| 方案 | 做法 | 长度外推能力 | 代表模型 |
| 绝对正弦位置编码 |
预先计算每个位置的固定向量,直接加到词嵌入上 |
差,训练长度即上限 |
原始 BERT, GPT-2 |
| 可学习绝对位置 |
训练一组可学习的位置 embedding |
差,位置数量在训练时固定 |
GPT-3 |
| ALiBi |
不加位置向量,在注意力分数上直接加线性惩罚项 |
一般,可外推但质量下降 |
BLOOM |
| RoPE |
把位置编进 Q/K 向量的旋转角度,通过相对角度差来感知位置 |
好,相对位置自然编码 |
LLaMA, Mistral, Qwen, GPT-NeoX |
RoPE 的核心思想:如果我们把每两个维度视作一个二维平面上的点,那么"位置 m 处的向量"就是"位置 0 处的向量"旋转了 mθ 角度。这样,位置 m 的 Query 和位置 n 的 Key 做点积时,结果只取决于它们的相对距离 (m−n),和绝对位置无关。
RoPE 旋转原理(以 2D 为例)
q_m = R(mθ) · W_q · x_m
k_n = R(nθ) · W_k · x_n
q_m · k_n = xᵀ_q W_q R((m−n)θ) W_k x_n
为什么外推好? 绝对位置编码需要在训练时"见过"第 n 个位置的 embedding;而 RoPE 只关心相对距离,即使推理时遇到训练时没见过的超长位置,相对关系的计算方式不变——只是角度更大而已。但超出训练长度后,角度可能进入没有训练过的范围,精度仍会下降,这是 YaRN 要解决的问题。
03YaRN / LongRoPE
YaRN & LongRoPE:让 RoPE 优雅地外推
训练 4K,推理用 128K——让旋转角度平滑地"缩放"到更长范围
RoPE 在超出训练长度后会退化,原因是:高频维度(小 θ)的旋转角在长距离时变化太快,进入了训练时从未见过的角度范围,模型不知道如何处理。
最朴素的解决方案是位置插值(PI, Position Interpolation):把所有位置编号等比例压缩,让 0~max_len 的位置映射回 0~train_len。但这会压缩近距离的精度。
YaRN:按频率分区处理(NTK-aware interpolation)
①
确定缩放比例 s
s = 目标长度 ÷ 训练长度。例如训练用 4K、目标 128K,则 s = 32
②
按维度的频率特性分区
高频维度(θ 小)保持不变,低频维度(θ 大)整体除以 s,中间做线性插值
③
少量微调
用少量长文本数据(约几百步)微调,让模型适应新的角度分布。不需要从头训练
④
推理时的长度扩展
LongRoPE(2024)更进一步:在推理时动态调整每一层的缩放比例,实现无需微调的 2M token 外推
实际效果:LLaMA 3.1 用 YaRN 变种将上下文从 8K 扩展到 128K,仅用了约 0.1% 的训练计算量做长上下文微调。Phi-3 用类似方法从 4K 扩展到 128K。这使得"长上下文"不再需要昂贵的从头训练。
04GQA / MQA
GQA 与 MQA:精简注意力头
多个 Query 头共享一组 K/V,大幅削减推理时的 KV Cache 显存
标准多头注意力(MHA)中,每个注意力头都有独立的 Q、K、V 投影矩阵。假设有 32 个头,那就有 32 组 K 和 32 组 V,推理时每个 Token 都要缓存这 64 组向量——这是 KV Cache 显存爆炸的直接原因。
三种注意力结构对比
为什么对长上下文特别重要? 长上下文推理时,KV Cache 的显存占用随序列长度线性增长。128K token + 70B 参数模型的 KV Cache 可能需要 100GB+。GQA 将 KV heads 从 64 减到 8,显存占用直接砍掉 8×,这是长上下文可以在单机运行的关键。
05稀疏 & 滑动窗口注意力
稀疏注意力 & 滑动窗口
不是每个词都需要看所有词——有选择地计算注意力
前面的技术都是在"同样计算全量注意力"的前提下优化。稀疏注意力走了另一条路:根本就不计算所有的 n×n 注意力对,只计算"有意义"的那些,把复杂度从 O(n²) 降到 O(n·k)。
四种注意力模式(黑色 = 计算,白色 = 跳过)
滑动窗口注意力的关键局限:位置 1000 的 Token 无法直接"看到"位置 1 的信息,必须依靠中间层的激活值逐层传递。在实践中,Mistral 通过叠加 32 层,每层窗口 4K,理论感受野约 4K×32 = 128K——但实验表明远距离信息的质量不如全量注意力。
最新趋势:2024 年后,大部分前沿模型(GPT-4o, Claude 3, Gemini 1.5)转向了"全量注意力 + Flash Attention + GQA"的组合,而不是用稀疏注意力——因为全量注意力质量更好,而 Flash Attention + GQA 已经把开销控制在可接受范围内。
06KV Cache 原理与优化
KV Cache:推理时最大的显存杀手
自回归生成为什么必须缓存?又如何在长上下文下把它压下去?
大语言模型生成文本时,每生成一个新 Token,就要把整个上下文重新跑一遍注意力——这代价太大。KV Cache 的思想是:把历史 Token 的 Key 和 Value 向量缓存起来,生成新 Token 时只需计算新 Token 的 Q,然后和缓存的 K/V 做注意力即可。
KV Cache 工作原理:第 t 步生成新 token
KV Cache 的显存优化策略:
①
KV Cache 量化(INT8 / INT4)
把缓存的 K/V 向量从 FP16(2 字节/元素)压缩到 INT8(1 字节)或 INT4(0.5 字节)。精度略微损失,但显存减半/减四倍。GPTQ-KV、KVQuant 等方案可以做到几乎无损的 INT4。
②
PagedAttention(vLLM)
借鉴操作系统"内存分页"思想。把 KV Cache 切成固定大小的"页",不连续存储,按需分配,避免内存碎片。让服务器同时处理数百并发请求成为可能。
③
Token 淘汰 / 稀疏 KV Cache
动态丢弃"不重要"的历史 Token 的 KV。StreamingLLM 保留所有"Sink Tokens"(开头)+ 滑动窗口尾部,实现无限长上下文(但真正理解受限)。H2O 算法根据注意力得分保留重要 Token。
④
Prefix Caching / Prompt Caching
多次请求共享同一个"系统提示"时,只计算一次其 KV Cache,后续请求复用。Anthropic 的 Claude、OpenAI API 均支持此功能。可将重复长提示的成本降低 90%。
07长文本训练策略
有窗口,不等于会用
工程解决了容量,训练解决了能力——让模型真正"善用"长上下文
一个常见误区:认为只要技术上支持了 128K 上下文,模型就能"真正理解"128K 的内容。实际上,草堆中找针(Needle-in-a-Haystack)测试表明,很多号称支持长上下文的模型,在上下文超过 32K 后,提取中间段信息的能力就急剧下降——这不是窗口大小的问题,而是训练问题。
"Lost in the Middle" 现象:模型对上下文位置的注意力分布
针对这个问题,现代训练策略包括:
长文本数据配比
专项数据混合
在预训练后期或微调阶段,专门加入长文档数据(书籍、代码仓库、长对话),占比约 5–20%
渐进式长度课程
Curriculum Learning
先用短序列训练,逐步增大序列长度。让模型先学会短距离推理,再学长距离,避免训练初期损失爆炸
RLHF 长上下文对齐
针对性强化学习
专门构造需要从长文本中提取、综合信息的问题,通过人类反馈或 AI 反馈强化模型的"深度阅读"能力
长文本合成数据
数据增强
自动生成"把答案藏在文档中间"的 QA 对,强制模型练习检索中段信息,对抗 Lost in the Middle
08技术对比总结
各技术综合对比
解决的问题、代价、现状
| 技术 |
解决什么问题 |
复杂度改善 |
代价 / 局限 |
主流应用 |
| Flash Attention 2/3 |
HBM 带宽瓶颈,显存 O(n²) |
IO: O(n²)→O(n) 速度 2–4× |
算法复杂,需与硬件强绑定 |
几乎所有现代模型训练 |
| RoPE + YaRN |
位置编码无法外推 |
无计算量变化,仅改变位置编码 |
极远距离仍有精度下降;需少量微调 |
LLaMA 3, Qwen2, Mistral, Phi-3 |
| GQA / MQA |
KV Cache 显存占用 |
KV Cache:减少 4–8× |
质量略低于 MHA;需重新训练 |
LLaMA 3, Gemma, Mistral, Falcon |
| 滑动窗口注意力 |
注意力计算 O(n²) |
O(n²)→O(n·w) |
远距离信息靠层间传递,质量下降 |
Mistral 7B(早期版本) |
| KV Cache 量化 |
推理显存峰值 |
显存:减少 2–4× |
量化误差;需专门的量化感知训练 |
vLLM, TensorRT-LLM, llama.cpp |
| PagedAttention |
显存碎片,并发吞吐 |
并发量提升 2–4× |
实现复杂,主要是推理服务层优化 |
vLLM, SGLang |
| 长文本训练数据 |
有窗口但不会用 |
无复杂度改变,提升质量 |
数据稀缺;训练成本高 |
Claude 3, GPT-4, Gemini 1.5 |
现代长上下文模型的标准"配方"(以 Gemini 1.5 Pro 1M / Claude 3 200K 为例):
Flash Attention 3(IO 优化)+ RoPE + YaRN 外推 + GQA(减少 KV 显存)+ 专项长文本训练数据 + 长上下文 RLHF 对齐 + 推理端 KV Cache 量化 + PagedAttention 服务部署
没有单一银弹——是多个技术协同叠加的结果。
未来方向
Mamba / SSM 线性复杂度架构
Linear Attention 近似
外部记忆 (Memory Augmented)
RAG + 长上下文混合
Speculative Decoding 加速
MoE × 长上下文