注意力机制
一句话总结: 注意力机制是 Transformer 的灵魂——从 Scaled Dot-Product Attention 的数学本质,到 MHA/MQA/GQA/MLA 的架构演进,再到 Flash Attention 的工程极致优化,每一步都在平衡建模能力与计算效率。
在大模型体系中的位置
Input Token → Embedding + Positional Encoding
↓
┌────────────────────────┐
│ Attention (this page) │ ← Core: lets each token "see" other tokens
└────────────────────────┘
↓
FFN / MoE ← Per-token nonlinear transform
↓
LayerNorm ← Stabilize training
↓
x N layers
↓
Output logits注意力层决定了"信息如何在序列内流动"。模型的上下文理解能力、长距离依赖建模、推理速度和显存消耗,都与注意力机制的设计直接相关。
Scaled Dot-Product Attention
核心公式
其中
为什么要除以 ?——方差证明
以下方差推导来自 Vaswani et al. (2017) "Attention Is All You Need" 原始论文的脚注 4。
假设
逐元素分析:
对整个向量求和:
点积的方差随维度线性增长!当
除以
注意力分数回到标准正态分布,softmax 输出分布温和,梯度稳定。
为什么不除以
? 除以 会导致方差为 ,分布过于集中,softmax 趋近均匀分布,注意力失去区分能力。
代码实现
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class ScaleDotProductAttention(nn.Module):
def __init__(self, dim_in, dim_out):
super().__init__()
self.q_proj = nn.Linear(dim_in, dim_out)
self.k_proj = nn.Linear(dim_in, dim_out)
self.v_proj = nn.Linear(dim_in, dim_out)
self.out_proj = nn.Linear(dim_in, dim_out)
def forward(self, X, mask=None):
batch_size, seq_len, dim = X.shape
Q, K, V = self.q_proj(X), self.k_proj(X), self.v_proj(X)
# 注意力分数,除以 sqrt(d_k) 稳定梯度
S = Q @ K.transpose(1, 2) / math.sqrt(dim)
if mask is not None:
S = S + mask # 因果掩码:未来位置设为 -inf
P = F.softmax(S, dim=-1)
O = P @ V
return self.out_proj(O)Multi-Head Attention (MHA)
核心思想
单头注意力只能在一个子空间中捕捉关系。多头注意力将 Q、K、V 拆分到
其中每个头的维度
完整过程:拆分 → 并行计算 → 拼接
class MultiHeadsAttention(nn.Module):
def __init__(self, dim=512, n_heads=8):
super().__init__()
self.dim = dim
self.n_heads = n_heads
self.head_dim = dim // n_heads
self.wq = nn.Linear(dim, dim)
self.wk = nn.Linear(dim, dim)
self.wv = nn.Linear(dim, dim)
self.wo = nn.Linear(dim, dim)
def forward(self, x, mask=None):
bsz, seq_len, dim = x.shape
q, k, v = self.wq(x), self.wk(x), self.wv(x)
# 拆分多头: (bsz, seq_len, dim) → (bsz, n_heads, seq_len, head_dim)
q = q.reshape(bsz, seq_len, self.n_heads, self.head_dim).transpose(1, 2)
k = k.reshape(bsz, seq_len, self.n_heads, self.head_dim).transpose(1, 2)
v = v.reshape(bsz, seq_len, self.n_heads, self.head_dim).transpose(1, 2)
# 每个头独立计算注意力
s = q @ k.transpose(3, 2) / math.sqrt(self.head_dim)
if mask is not None:
s = s + mask.unsqueeze(0).unsqueeze(0)
p = F.softmax(s, dim=-1)
z = p @ v
# 拼接所有头: (bsz, n_heads, seq_len, head_dim) → (bsz, seq_len, dim)
z = z.transpose(1, 2).reshape(bsz, seq_len, self.dim)
return self.wo(z)Multi-Query Attention (MQA) 与 Grouped-Query Attention (GQA)
演进动机
推理阶段需要缓存历史 K、V(KV Cache),其大小为 [2, bsz, seq_len, n_heads, head_dim]。当模型有 64/128 个头时,KV Cache 占用巨大,限制了 batch size 和序列长度。
核心问题: 多头的 K、V 是否存在冗余?能否在减少头数的同时保持精度?
MQA:所有 Q 头共享 1 组 KV
class MultiQueryAttention(nn.Module):
"""所有 Query 头共享同一组 K 和 V"""
def __init__(self, dim=512, n_heads=8):
super().__init__()
self.dim = dim
self.n_heads = n_heads
self.head_dim = dim // n_heads
self.wq = nn.Linear(dim, dim) # 多头 Q
self.wk = nn.Linear(dim, self.head_dim) # 单头 K
self.wv = nn.Linear(dim, self.head_dim) # 单头 V
self.wo = nn.Linear(dim, dim)
def forward(self, x, mask=None):
bsz, seq_len, dim = x.shape
q, k, v = self.wq(x), self.wk(x), self.wv(x)
q = q.reshape(bsz, seq_len, self.n_heads, self.head_dim).transpose(1, 2)
k = k[:, None, :, :] # (bsz, 1, seq_len, head_dim) — 广播到所有头
v = v[:, None, :, :]
s = q @ k.transpose(3, 2) / math.sqrt(self.head_dim)
if mask is not None:
s = s + mask.unsqueeze(0).unsqueeze(0)
p = F.softmax(s, dim=-1)
z = (p @ v).transpose(1, 2).reshape(bsz, seq_len, self.dim)
return self.wo(z)GQA:分组共享 KV(Llama 2/3 采用)
GQA 是 MHA 与 MQA 的折中——将
class GroupQueryAttention(nn.Module):
"""分组共享 KV,n_kv_heads 组,每组 repeats = n_heads // n_kv_heads 个 Q 头"""
def __init__(self, dim=512, n_heads=8, n_kv_heads=2):
super().__init__()
self.dim = dim
self.n_heads = n_heads
self.n_kv_heads = n_kv_heads
self.head_dim = dim // n_heads
self.repeats = n_heads // n_kv_heads
self.wq = nn.Linear(dim, dim)
self.wk = nn.Linear(dim, self.head_dim * n_kv_heads) # 分组 K
self.wv = nn.Linear(dim, self.head_dim * n_kv_heads) # 分组 V
self.wo = nn.Linear(dim, dim)
def forward(self, x, mask=None):
bsz, seq_len, dim = x.shape
q, k, v = self.wq(x), self.wk(x), self.wv(x)
q = q.reshape(bsz, seq_len, self.n_heads, self.head_dim).transpose(1, 2)
k = k.reshape(bsz, seq_len, self.n_kv_heads, self.head_dim).transpose(1, 2)
v = v.reshape(bsz, seq_len, self.n_kv_heads, self.head_dim).transpose(1, 2)
# 每组 KV 重复 repeats 次,对齐 Q 的头数
k = torch.repeat_interleave(k, self.repeats, dim=1)
v = torch.repeat_interleave(v, self.repeats, dim=1)
s = q @ k.transpose(3, 2) / math.sqrt(self.head_dim)
if mask is not None:
s = s + mask.unsqueeze(0).unsqueeze(0)
p = F.softmax(s, dim=-1)
z = (p @ v).transpose(1, 2).reshape(bsz, seq_len, self.dim)
return self.wo(z)GQA 的本质: 当
n_kv_heads = n_heads时退化为 MHA;当n_kv_heads = 1时退化为 MQA。Llama 2 70B 使用n_kv_heads = 8,在质量和效率间取得了极佳平衡。
Multi-Latent Attention (MLA)
DeepSeek 的创新思路
MQA/GQA 通过减少 KV 头数来压缩 KV Cache,但这本质上是一种"特征丢弃"。DeepSeek-V2 提出 MLA,换了一个思路:用低秩压缩代替头数削减。
核心思想: 先将输入压缩到一个低维 latent 向量
传统 MHA 的 KV Cache 大小为
代码实现
class MultiHeadsLatentAttention(nn.Module):
"""MLA: 低秩压缩 KV,KV Cache 只存 latent 向量 c"""
def __init__(self, dim=64, n_heads=8, d_c=4, q_compress_dim=4):
super().__init__()
self.n_heads = n_heads
self.dim = dim
self.head_dim = dim // n_heads
# Q 的低秩分解: X → latent_q → Q
self.q_down_proj = nn.Linear(dim, q_compress_dim, bias=False)
self.q_up_proj = nn.Linear(q_compress_dim, dim, bias=False)
# KV 共享一个 latent 压缩: X → c_kv → K, V
self.kv_down_proj = nn.Linear(dim, d_c, bias=False) # 压缩
self.k_up_proj = nn.Linear(d_c, dim, bias=False) # 恢复 K
self.v_up_proj = nn.Linear(d_c, dim, bias=False) # 恢复 V
self.wo = nn.Linear(dim, dim, bias=False)
def forward(self, x):
B, seq_len, _ = x.shape
# 低秩压缩后再升维
latent_q = self.q_down_proj(x) # (B, seq_len, q_compress_dim)
Q = self.q_up_proj(latent_q) # (B, seq_len, dim)
C_KV = self.kv_down_proj(x) # (B, seq_len, d_c) ← KV Cache 只存这个!
K = self.k_up_proj(C_KV) # (B, seq_len, dim)
V = self.v_up_proj(C_KV) # (B, seq_len, dim)
# 后续与标准多头注意力一致
Q = Q.reshape(B, seq_len, self.n_heads, self.head_dim).transpose(1, 2)
K = K.reshape(B, seq_len, self.n_heads, self.head_dim).transpose(1, 2)
V = V.reshape(B, seq_len, self.n_heads, self.head_dim).transpose(1, 2)
S = Q @ K.transpose(2, 3) / math.sqrt(self.head_dim)
P = F.softmax(S.float(), dim=-1)
Z = (P @ V).transpose(1, 2).contiguous().view(B, seq_len, -1)
return self.wo(Z)矩阵吸收技巧(推理优化)
训练完成后,低秩矩阵可以合并,避免推理时的额外计算:
# 推理时合并权重:训练时省显存,推理时保精度
WQ = (mla.q_up_proj.weight.data @ mla.q_down_proj.weight.data).t()
Q = X @ WQ # 等价于 mla.q_up_proj(mla.q_down_proj(X)),但只需一次矩阵乘同理,
Flash Attention
GPU 内存层次:SRAM vs HBM
| 存储层级 | 容量 | 带宽 | 特点 |
|---|---|---|---|
| SRAM(片上缓存) | ~20 MB | ~19 TB/s | 极快,但容量很小 |
| HBM(显存) | 40-80 GB | ~1.5 TB/s | 容量大,但带宽是瓶颈 |
标准 Attention 的 IO 瓶颈
标准 attention 的计算流程:
- 从 HBM 读取 Q、K,计算
,写回 HBM( 中间矩阵!) - 从 HBM 读取
,计算 ,写回 HBM - 从 HBM 读取
、 ,计算 ,写回 HBM
Flash Attention 的分块策略 + Online Softmax
核心思想: 将 Q、K、V 分成小块,每块放进 SRAM 中完成全部计算,避免将
难点在于:softmax 需要全局 max 和 sum,分块后怎么办?答案是 Online Softmax。
Online Softmax 原理
对于向量
分块版本: 每块内部独立算 max 和 sum,块间通过上述递推公式合并。
Flash Attention v1 实现(先 KV 后 Q)
# Flash Attention v1: 外层遍历 KV 块,内层遍历 Q 块
NEG_INF = -1e10
EPSILON = 1e-10
O = torch.zeros_like(Q)
l = torch.zeros(Q.shape[:-1])[..., None] # 累计 softmax 分母
m = torch.ones(Q.shape[:-1])[..., None] * NEG_INF # 累计 max
Q_BLOCKS = torch.split(Q, Q_BLOCK_SIZE, dim=2)
K_BLOCKS = torch.split(K, KV_BLOCK_SIZE, dim=2)
V_BLOCKS = torch.split(V, KV_BLOCK_SIZE, dim=2)
O_BLOCKS = list(torch.split(O, Q_BLOCK_SIZE, dim=2))
l_BLOCKS = list(torch.split(l, Q_BLOCK_SIZE, dim=2))
m_BLOCKS = list(torch.split(m, Q_BLOCK_SIZE, dim=2))
for j in range(Tc): # 外层遍历 KV 块
Kj, Vj = K_BLOCKS[j], V_BLOCKS[j]
for i in range(Tr): # 内层遍历 Q 块
Qi, Oi, li, mi = Q_BLOCKS[i], O_BLOCKS[i], l_BLOCKS[i], m_BLOCKS[i]
S_ij = Qi @ Kj.transpose(2, 3) # 块内注意力分数
m_block_ij, _ = torch.max(S_ij, dim=-1, keepdims=True)
P_ij = torch.exp(S_ij - m_block_ij) # 数值稳定的 exp
l_block_ij = torch.sum(P_ij, dim=-1, keepdims=True) + EPSILON
mi_new = torch.maximum(m_block_ij, mi) # 更新全局 max
li_new = torch.exp(mi - mi_new) * li \
+ torch.exp(m_block_ij - mi_new) * l_block_ij # 更新全局 sum
# 在线更新输出(关键!无需存储完整 n×n 矩阵)
O_BLOCKS[i] = (li / li_new) * torch.exp(mi - mi_new) * Oi \
+ (torch.exp(m_block_ij - mi_new) / li_new) * (P_ij @ Vj)
l_BLOCKS[i] = li_new
m_BLOCKS[i] = mi_new
O = torch.cat(O_BLOCKS, dim=2) # 数学上与标准 attention 完全等价!Flash Attention v2 改进(先 Q 后 KV)
v2 将外层改为遍历 Q 块、内层遍历 KV 块,减少 O 的读写次数,并将 scale 操作推迟到最后:
# Flash Attention v2: 外层遍历 Q 块,内层遍历 KV 块
for i in range(Tr): # 外层遍历 Q 块(O 只在最后写一次)
Qi, Oi, li, mi = Q_BLOCKS[i], O_BLOCKS[i], l_BLOCKS[i], m_BLOCKS[i]
for j in range(Tc): # 内层遍历 KV 块
Kj, Vj = K_BLOCKS[j], V_BLOCKS[j]
S_ij = Qi @ Kj.transpose(2, 3)
m_block_ij, _ = torch.max(S_ij, dim=-1, keepdims=True)
mi_new = torch.maximum(m_block_ij, mi)
P_ij_hat = torch.exp(S_ij - mi_new)
l_block_ij = torch.sum(P_ij_hat, dim=-1, keepdims=True) + EPSILON
li_new = torch.exp(mi - mi_new) * li + l_block_ij
Oi = torch.exp(mi - mi_new) * Oi + P_ij_hat @ Vj # 不除以 l,延迟 scale
li, mi = li_new, mi_new
O_BLOCKS[i] = Oi / li_new # 最后一次性做 scaleFlash Attention 深度实现
核心算法:Online Softmax 与分块计算
Flash Attention 的关键挑战在于:softmax 是一个全局操作,需要知道整个序列的 max 和 sum。分块计算时,每个 block 只能看到部分数据,如何保证结果的精确性?
答案是 Online Softmax 的分块递推公式。假设我们已经处理了前
最终输出为
前向传播伪代码
算法: Flash Attention 前向传播
输入: Q, K, V ∈ R^{N×d}, 块大小 B_r, B_c
输出: O ∈ R^{N×d}
1. 将 Q 分成 T_r = ⌈N/B_r⌉ 块, K/V 分成 T_c = ⌈N/B_c⌉ 块
2. 初始化 O = 0, l = 0, m = -∞ (均为 R^{N} 向量)
3. for j = 1 to T_c: # 外层遍历 KV 块
4. 从 HBM 加载 K_j, V_j 到 SRAM
5. for i = 1 to T_r: # 内层遍历 Q 块
6. 从 HBM 加载 Q_i, O_i, l_i, m_i 到 SRAM
7. 计算 S_ij = Q_i @ K_j^T ∈ R^{B_r × B_c} (在 SRAM 中)
8. 计算 m_ij = rowmax(S_ij)
9. 计算 P_ij = exp(S_ij - m_ij)
10. 计算 l_ij = rowsum(P_ij)
11. 更新 m_new = max(m_i, m_ij)
12. 更新 l_new = exp(m_i - m_new) * l_i + exp(m_ij - m_new) * l_ij
13. 更新 O_i = exp(m_i - m_new) * O_i + exp(m_ij - m_new) * P_ij @ V_j
14. 将 O_i, l_new, m_new 写回 HBM
15. 返回 O = O / l (逐行 scale)关键点:
PyTorch 实现
import torch
import math
def flash_attention_forward(Q, K, V, block_size=64):
"""
Flash Attention 前向传播的纯 PyTorch 模拟实现。
实际 Flash Attention 使用 CUDA kernel 在 SRAM 中执行,
这里用 Python 展示算法逻辑。
Args:
Q, K, V: (batch, heads, seq_len, head_dim)
block_size: SRAM 块大小
Returns:
O: (batch, heads, seq_len, head_dim)
"""
B, H, N, d = Q.shape
Br = min(block_size, N) # Q 块大小
Bc = min(block_size, N) # KV 块大小
Tr = math.ceil(N / Br) # Q 块数量
Tc = math.ceil(N / Bc) # KV 块数量
O = torch.zeros_like(Q)
l = torch.zeros(B, H, N, 1, device=Q.device)
m = torch.full((B, H, N, 1), float('-inf'), device=Q.device)
# 将 Q, K, V 分块
Q_blocks = list(Q.split(Br, dim=2))
K_blocks = list(K.split(Bc, dim=2))
V_blocks = list(V.split(Bc, dim=2))
O_blocks = list(O.split(Br, dim=2))
l_blocks = list(l.split(Br, dim=2))
m_blocks = list(m.split(Br, dim=2))
scale = 1.0 / math.sqrt(d)
for j in range(Tc):
Kj = K_blocks[j] # (B, H, Bc, d)
Vj = V_blocks[j] # (B, H, Bc, d)
for i in range(Tr):
Qi = Q_blocks[i]
Oi = O_blocks[i]
li = l_blocks[i]
mi = m_blocks[i]
# Step 1: 计算注意力分数(在 SRAM 中)
Sij = (Qi @ Kj.transpose(-2, -1)) * scale # (B, H, Br, Bc)
# Step 2: 块内 softmax 统计量
mij = Sij.max(dim=-1, keepdim=True).values # (B, H, Br, 1)
Pij = torch.exp(Sij - mij) # (B, H, Br, Bc)
lij = Pij.sum(dim=-1, keepdim=True) # (B, H, Br, 1)
# Step 3: 合并新旧统计量
mi_new = torch.maximum(mi, mij)
li_new = (torch.exp(mi - mi_new) * li
+ torch.exp(mij - mi_new) * lij)
# Step 4: 更新输出(在线累积)
O_blocks[i] = (torch.exp(mi - mi_new) * Oi
+ torch.exp(mij - mi_new) * (Pij @ Vj))
l_blocks[i] = li_new
m_blocks[i] = mi_new
# 最终 scale
O = torch.cat([O_blocks[i] / l_blocks[i] for i in range(Tr)], dim=2)
return O
# 验证与标准 Attention 等价
torch.manual_seed(42)
B, H, N, d = 2, 4, 128, 64
Q = torch.randn(B, H, N, d)
K = torch.randn(B, H, N, d)
V = torch.randn(B, H, N, d)
# 标准 Attention
scale = 1.0 / math.sqrt(d)
S = (Q @ K.transpose(-2, -1)) * scale
P = torch.softmax(S, dim=-1)
O_standard = P @ V
# Flash Attention
O_flash = flash_attention_forward(Q, K, V, block_size=32)
print(f"最大误差: {(O_standard - O_flash).abs().max():.2e}") # ~1e-6,浮点精度误差反向传播:重计算 vs 存储
标准 Attention 的反向传播需要注意力矩阵
- 前向传播:只保存
(输出和 softmax 统计量),不保存 和 - 反向传播:利用保存的
,在 SRAM 中重新计算 和 的每个块 - 额外计算量:反向传播多做了一次分块矩阵乘法,但由于 IO 大幅减少,总体仍然更快
反向传播关键步骤:
1. 从 HBM 加载 Q_i, K_j, V_j, O_i, l_i, m_i, dO_i
2. 在 SRAM 中重计算: S_ij = Q_i @ K_j^T, P_ij = softmax(S_ij) ← 利用 l_i, m_i
3. 计算 dV_j += P_ij^T @ dO_i
4. 计算 dP_ij = dO_i @ V_j^T
5. 计算 dS_ij = P_ij ⊙ (dP_ij - rowsum(dP_ij ⊙ P_ij)) ← softmax 反向
6. 计算 dQ_i += dS_ij @ K_j, dK_j += dS_ij^T @ Q_i
7. 写回 dQ_i, dK_j, dV_j 到 HBM为什么重计算反而更快? 因为重计算的代价是
Flash Attention 2 的优化
Flash Attention 2 在 v1 基础上做了三项关键优化,将速度进一步提升 ~2x:
优化 1:减少非矩阵乘法 FLOPs
v1 中大量时间花在 rescaling 操作(乘以
# v1: 每个块都做完整 rescaling
O_i = (l_old / l_new) * exp(m_old - m_new) * O_i + (exp(m_block - m_new) / l_new) * PV
# v2: 延迟 rescaling,最后一步才除以 l
O_i = exp(m_old - m_new) * O_i + P_tilde @ Vj # 不除以 l
# ... 内层循环结束后 ...
O_i = O_i / l_final # 一次性 scale在 GPU 上,矩阵乘法(GEMM)由 Tensor Core 加速,而逐元素操作(rescaling)只能用普通 CUDA Core。减少非 matmul FLOPs 能显著提升 Tensor Core 利用率。
优化 2:更好的并行——外层遍历 Q
v1 外层遍历 KV、内层遍历 Q,导致每个 Q 块的输出需要反复读写 HBM。v2 反转循环顺序:
v1: for KV_block → for Q_block # O 被反复读写
v2: for Q_block → for KV_block # 每个 Q 块的 O 只写一次 HBM这使得每个 thread block 独立处理一个 Q 块,不同 Q 块之间无需通信,在 GPU SM 之间实现了完美并行。
优化 3:序列长度维度的并行
v1 只在 batch 和 head 维度做并行。当 batch size 较小时(如推理),SM 利用率不高。v2 额外在序列长度维度做并行(将 Q 的不同块分配到不同 SM),大幅提升了小 batch 场景的效率。
IO 复杂度分析
| 算法 | FLOPs | HBM 读写量 | IO 复杂度 |
|---|---|---|---|
| 标准 Attention | 受 | ||
| Flash Attention |
推导:Flash Attention 的外层有
当
数值示例:
- 标准 Attention IO:
读写(注意力矩阵 和 ) - Flash Attention IO:
看似差不多,但 Flash Attention 避免了写
面试考点:为什么 Flash Attention 更快但 FLOPs 相同?
这是一个非常经典的问题,核心答案是Roofline Model:
- FLOPs 不变:Flash Attention 计算的数学结果与标准 Attention 完全相同,矩阵乘法的次数一样
- IO 大幅减少:标准 Attention 需要将
大小的中间矩阵 和 写入 HBM 再读回,而 Flash Attention 将这些中间结果保持在 SRAM 中 - 瓶颈转移:标准 Attention 是 IO-bound(显存带宽是瓶颈),Flash Attention 通过减少 IO 将瓶颈转移到 compute-bound,从而真正利用上 GPU 的算力
- 重计算的"免费午餐":反向传播多做的那次前向重计算,其 FLOPs 增加约 33%,但 IO 节省远大于此——在 A100 上净加速 2-4 倍
一句话总结: Flash Attention 不是"算得更快",而是"搬数据搬得更少"。在 GPU 上,SRAM 带宽是 HBM 的 ~10 倍,减少 HBM 访问就是最大的加速。
Online Softmax:从两遍扫描到一遍扫描
标准 Softmax 需要两遍扫描才能完成计算:
- 第一遍:扫描所有元素,求全局最大值
(数值稳定性所需)和求和 - 第二遍:再扫描一次,计算每个元素的 softmax 值
这意味着整个向量必须在内存中被访问两次。对于 Flash Attention 的分块计算来说,我们无法一次看到完整的行——每次只能看到一个 block。
Online Softmax 的核心思想
利用指数函数的性质,当新 block 带来更大的 max 值
这样只需一遍扫描就能得到全局正确的 softmax 分母。
逐元素 Online Softmax 实现:
import torch
X = torch.tensor([1.0, 1.5, 1.8, 2.0, 1.4, 2.1])
# ---- 标准 safe softmax(两遍扫描)----
X_max = X.max()
X_safe_softmax = torch.exp(X - X_max) / torch.exp(X - X_max).sum()
# ---- Online Softmax(一遍扫描)----
m_cur = torch.tensor(float('-inf'))
l_cur = torch.tensor(0.0)
for i in range(len(X)):
m_new = torch.max(m_cur, X[i])
# 修正历史 sum + 加入新元素
l_cur = l_cur * torch.exp(m_cur - m_new) + torch.exp(X[i] - m_new)
m_cur = m_new
X_online_softmax = torch.exp(X - m_cur) / l_cur
print(torch.allclose(X_safe_softmax, X_online_softmax)) # True分块 Online Softmax(Flash Attention 实际使用的形式):
BLOCK = 3
X_blocks = X.split(BLOCK)
m_cur = torch.tensor(float('-inf'))
l_cur = torch.tensor(0.0)
for blk in X_blocks:
m_blk = blk.max()
m_new = torch.max(m_cur, m_blk)
l_cur = l_cur * torch.exp(m_cur - m_new) \
+ torch.exp(blk - m_new).sum()
m_cur = m_new
X_block_online_softmax = torch.exp(X - m_cur) / l_cur
print(torch.allclose(X_safe_softmax, X_block_online_softmax)) # True分块 Online Softmax 是 Flash Attention 能在 SRAM 中分块完成 Softmax 的数学基础——看到新 block 时修正旧统计量,而非回头重算。
Flash Attention 反向传播
Flash Attention 反向传播的关键挑战是:前向传播没有保存
核心策略:重计算 (Recomputation)
前向时只保存
梯度推导
给定上游梯度
其中
为什么 ?
因为
分块反向传播实现:
以下实现参考了 Flash Attention 2 论文 (Dao, 2023) 的算法描述,代码经过教学化改写。
import torch
import math
torch.manual_seed(42)
n, dim, nb = 12, 8, 4 # 序列长度, 头维度, block 大小
block = n // nb
Q = torch.randn(n, dim, requires_grad=True)
K = torch.randn(n, dim, requires_grad=True)
V = torch.randn(n, dim, requires_grad=True)
# ---- Flash Attention 前向(保存 O 和 L)----
def flash_attention_forward(Q, K, V):
O = torch.zeros_like(Q)
L = torch.zeros(n, 1)
for tq in range(block):
q = Q[tq*nb:(tq+1)*nb, :]
o_old = torch.zeros_like(q)
l_old = m_old = torch.zeros(nb, 1)
for tk in range(block):
k = K[tk*nb:(tk+1)*nb, :]
v = V[tk*nb:(tk+1)*nb, :]
s = q @ k.t() / math.sqrt(dim)
m = s.max(dim=1, keepdim=True).values
m_new = torch.maximum(m, m_old)
l = torch.exp(s - m_new).sum(dim=1, keepdim=True)
l_new = l_old * torch.exp(m_old - m_new) + l
o_old = l_old * o_old * torch.exp(m_old - m_new) \
+ torch.exp(s - m_new) @ v
o_old = o_old / l_new
l_old, m_old = l_new, m_new
O[tq*nb:(tq+1)*nb, :] = o_old
L[tq*nb:(tq+1)*nb, :] = m_old + l_old.log()
return O, L
O, L = flash_attention_forward(Q, K, V)
# ---- Flash Attention 反向(分块重计算)----
dO = torch.randn_like(O) # 模拟上游梯度
def flash_attention_backward(Q, K, V, O, dO, L):
dQ = torch.zeros_like(Q)
dK = torch.zeros_like(K)
dV = torch.zeros_like(V)
# D_i = rowsum(O * dO),不需要 P
D = (O * dO).sum(dim=1, keepdim=True)
for tk in range(block): # 外层遍历 KV block
k = K[tk*nb:(tk+1)*nb, :]
v = V[tk*nb:(tk+1)*nb, :]
for tq in range(block): # 内层遍历 Q block
q = Q[tq*nb:(tq+1)*nb, :]
o = O[tq*nb:(tq+1)*nb, :]
do = dO[tq*nb:(tq+1)*nb, :]
l = L[tq*nb:(tq+1)*nb, :]
d = D[tq*nb:(tq+1)*nb, :]
# ---- 重计算 attention(无需从 HBM 读 P)----
s = q @ k.t() / math.sqrt(dim)
p = torch.exp(s - l) # 利用 L = m + log(l) 还原 softmax
# ---- 梯度计算 ----
dv = p.t() @ do
dp = do @ v.t()
ds = p * (dp - d) # softmax 反向的紧凑形式
dq = ds @ k / math.sqrt(dim)
dk = ds.t() @ q / math.sqrt(dim)
# ---- 累加到全局梯度 ----
dV[tk*nb:(tk+1)*nb, :] += dv
dQ[tq*nb:(tq+1)*nb, :] += dq
dK[tk*nb:(tk+1)*nb, :] += dk
return dQ, dK, dV
dQ_flash, dK_flash, dV_flash = flash_attention_backward(
Q, K, V, O, dO, L
)用 PyTorch Autograd 验证正确性
我们实现的反向传播是否正确?我们用 PyTorch 自动微分作为 ground truth 进行对比:
# ---- PyTorch 标准 attention + autograd ----
S = Q @ K.t() / math.sqrt(dim)
P = torch.softmax(S, dim=-1)
O_ref = P @ V
# 用相同的 dO 计算 autograd 梯度
O_ref.backward(dO)
print("dQ allclose:", torch.allclose(dQ_flash, Q.grad, atol=1e-5))
print("dK allclose:", torch.allclose(dK_flash, K.grad, atol=1e-5))
print("dV allclose:", torch.allclose(dV_flash, V.grad, atol=1e-5))
# 输出:全部 True反向传播的 IO 分析
标准反向传播需要从 HBM 读取
Tensor Product Attention (TPA)
核心思想
标准注意力中,K 和 V 通过单个线性投影
数学公式
TPA 将 K、V 的计算分解为两个低秩分量
V 的计算方式完全类似。其中
关键洞察: 这本质上是对 KV 投影矩阵做了 CP 分解(Canonical Polyadic Decomposition),每个 token 的 K/V 由两个低秩因子按位相乘得到,兼顾表达能力与压缩效率。
简化版代码实现
以下是简化版 TPA 实现:
import torch
import torch.nn as nn
class TPAProjection(nn.Module):
"""Tensor Product Attention 的 QKV 投影"""
def __init__(self, d_model=512, n_head=8, head_dim=64, rank=4):
super().__init__()
self.n_head = n_head
self.head_dim = head_dim
self.rank = rank
# Q 使用标准投影
self.W_q = nn.Linear(d_model, n_head * head_dim, bias=False)
# K, V 各用两个低秩投影 (CP 分解)
self.W_A_k = nn.Linear(d_model, n_head * rank, bias=False)
self.W_B_k = nn.Linear(d_model, rank * head_dim, bias=False)
self.W_A_v = nn.Linear(d_model, n_head * rank, bias=False)
self.W_B_v = nn.Linear(d_model, rank * head_dim, bias=False)
def forward(self, x):
bs, seq_len, _ = x.size()
q = self.W_q(x).view(bs, seq_len, self.n_head, self.head_dim)
# K = (1/r) * A_k @ B_k
A_k = self.W_A_k(x).view(bs * seq_len, self.n_head, self.rank)
B_k = self.W_B_k(x).view(bs * seq_len, self.rank, self.head_dim)
k = torch.bmm(A_k, B_k).div_(self.rank)
k = k.view(bs, seq_len, self.n_head, self.head_dim)
# V = (1/r) * A_v @ B_v
A_v = self.W_A_v(x).view(bs * seq_len, self.n_head, self.rank)
B_v = self.W_B_v(x).view(bs * seq_len, self.rank, self.head_dim)
v = torch.bmm(A_v, B_v).div_(self.rank)
v = v.view(bs, seq_len, self.n_head, self.head_dim)
return q, k, v
# 使用示例
tpa = TPAProjection(d_model=512, n_head=8, head_dim=64, rank=4)
x = torch.randn(2, 16, 512)
q, k, v = tpa(x)
print(q.shape, k.shape, v.shape)
# torch.Size([2, 16, 8, 64]) torch.Size([2, 16, 8, 64]) torch.Size([2, 16, 8, 64])TPA vs 标准 Attention 对比
| 特性 | 标准 MHA | TPA |
|---|---|---|
| KV 投影参数 | ||
| 参数压缩比 | 基准 | 当 |
| KV Cache | 标准 | 可只缓存 |
| 表达能力 | 基准 | rank 足够时接近 MHA |
| 适用场景 | 通用 | KV Cache 受限的长序列推理 |
注意力变体对比
| 特性 | MHA | MQA | GQA | MLA | TPA |
|---|---|---|---|---|---|
| Q 头数 | |||||
| KV 头数 | 1 | 全头(从 latent 恢复) | |||
| KV Cache 大小 | |||||
| KV 参数量 | |||||
| 精度保持 | 基准 | 略有下降 | 接近 MHA | 接近 MHA | 接近 MHA |
| 代表模型 | GPT-3, BERT | PaLM | Llama 2/3 | DeepSeek-V2/V3 | Tensor Product Attention |
| 核心思想 | 多头并行 | KV 共享 | 分组 KV 共享 | 低秩 KV 压缩 | KV 张量积分解 |
关键洞察: MQA/GQA 是在"头数维度"上压缩;MLA 是在"特征维度"上压缩(类似 LoRA 思想);TPA 是在"投影矩阵"上做 CP 分解——三者从不同角度减少 KV 开销。
苏格拉底时刻
为什么 MHA 不用一个大头? 多头让模型在不同子空间并行捕捉不同类型的关系(语法、语义、位置等)。单头只能在一个空间学习,表达能力受限。
Flash Attention 在数学上完全等价,加速从何而来? 减少了 HBM 读写次数。标准 attention 的 IO 复杂度为
(存储中间矩阵 ),Flash Attention 通过分块将 IO 降至 ,其中 为 SRAM 大小。 GQA 的
repeat_interleave是否增加了计算量? 注意力分数的计算量不变(仍为)。减少的是投影参数和 KV Cache 的存储/传输开销。在 GPU SRAM 中 repeat 操作几乎免费。 MLA 的"矩阵吸收"为什么能 work? 基于矩阵乘法结合律:
。训练时分两步省显存;推理时合并为一步保精度。 Online Softmax 为什么能单遍扫描? 利用指数函数的性质:
,当发现新的 max 时,只需对历史累积量乘一个修正因子。
常见问题 & 面试考点
Q1: 注意力的计算复杂度是多少? 时间复杂度
Q2: Causal Mask 在训练和推理中的作用? 训练时:对注意力矩阵的上三角填充
Q3: KV Cache 为什么只缓存 K 和 V,不缓存 Q? 自回归推理时,每步只有一个新 token 生成新的 Q(长度为 1),无需缓存。而 K、V 需要保留所有历史 token 的结果。
Q4: MQA 为什么效果只略有下降? 高维特征存在冗余,多个头的 KV 投影高度相关。实验表明 KV 的多样性对模型质量影响远小于 Q 的多样性。
Q5: Flash Attention 能用于训练吗? 能。Flash Attention 同时优化了前向和反向传播的 IO,训练速度提升 2-4 倍,显存减少 5-20 倍。
推荐资源
- Attention Is All You Need — 原始 Transformer 论文
- FlashAttention: Fast and Memory-Efficient Exact Attention — Flash Attention 论文
- FlashAttention-2: Faster Attention with Better Parallelism — v2 改进
- GQA: Training Generalized Multi-Query Transformer Models — GQA 论文
- DeepSeek-V2: A Strong, Economical, and Efficient Mixture-of-Experts Language Model — MLA 原始论文
- The Illustrated Transformer — 注意力可视化详解
- Online Normalizer Calculation for Softmax — Online Softmax 原始论文