Skip to content

注意力机制

一句话总结: 注意力机制是 Transformer 的灵魂——从 Scaled Dot-Product Attention 的数学本质,到 MHA/MQA/GQA/MLA 的架构演进,再到 Flash Attention 的工程极致优化,每一步都在平衡建模能力计算效率

在大模型体系中的位置

Input Token → Embedding + Positional Encoding

        ┌────────────────────────┐
        │  Attention (this page) │  ← Core: lets each token "see" other tokens
        └────────────────────────┘

           FFN / MoE                ← Per-token nonlinear transform

           LayerNorm                ← Stabilize training

          x N layers

           Output logits

注意力层决定了"信息如何在序列内流动"。模型的上下文理解能力、长距离依赖建模、推理速度和显存消耗,都与注意力机制的设计直接相关。


Scaled Dot-Product Attention

核心公式

Attention(Q,K,V)=softmax(QKTdk)V

其中 QRn×dkKRn×dkVRn×dv,分别由输入 X 经线性投影得到:

Q=XWQ,K=XWK,V=XWV

为什么要除以 dk?——方差证明

以下方差推导来自 Vaswani et al. (2017) "Attention Is All You Need" 原始论文的脚注 4。

假设 qi,ki 均为独立随机变量,服从标准正态分布 N(0,1)。对于点积:

qk=i=1dkqiki

逐元素分析:

E[qiki]=E[qi]E[ki]=0Var(qiki)=(Var(qi)+E[qi]2)(Var(ki)+E[ki]2)E[qi]2E[ki]2=1

对整个向量求和:

E[i=1dkqiki]=0,Var(i=1dkqiki)=dk

点积的方差随维度线性增长!当 dk=1024 时,点积分布在 [100,100] 量级,softmax 输出会极度尖锐(趋近 one-hot),梯度几乎消失

除以 dk 后,利用 Var(cX)=c2Var(X)

Var(qkdk)=1dkdk=1

注意力分数回到标准正态分布,softmax 输出分布温和,梯度稳定。

为什么不除以 dk 除以 dk 会导致方差为 1/dk,分布过于集中,softmax 趋近均匀分布,注意力失去区分能力。

代码实现

python
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class ScaleDotProductAttention(nn.Module):
    def __init__(self, dim_in, dim_out):
        super().__init__()
        self.q_proj = nn.Linear(dim_in, dim_out)
        self.k_proj = nn.Linear(dim_in, dim_out)
        self.v_proj = nn.Linear(dim_in, dim_out)
        self.out_proj = nn.Linear(dim_in, dim_out)

    def forward(self, X, mask=None):
        batch_size, seq_len, dim = X.shape
        Q, K, V = self.q_proj(X), self.k_proj(X), self.v_proj(X)
        # 注意力分数,除以 sqrt(d_k) 稳定梯度
        S = Q @ K.transpose(1, 2) / math.sqrt(dim)
        if mask is not None:
            S = S + mask  # 因果掩码:未来位置设为 -inf
        P = F.softmax(S, dim=-1)
        O = P @ V
        return self.out_proj(O)

Multi-Head Attention (MHA)

核心思想

单头注意力只能在一个子空间中捕捉关系。多头注意力将 Q、K、V 拆分到 h 个头,每个头独立计算注意力,最后拼接:

MultiHead(Q,K,V)=Concat(head1,,headh)WOheadi=Attention(QWiQ,KWiK,VWiV)

其中每个头的维度 dk=dmodel/h,总参数量不变。

完整过程:拆分 → 并行计算 → 拼接

python
class MultiHeadsAttention(nn.Module):
    def __init__(self, dim=512, n_heads=8):
        super().__init__()
        self.dim = dim
        self.n_heads = n_heads
        self.head_dim = dim // n_heads
        self.wq = nn.Linear(dim, dim)
        self.wk = nn.Linear(dim, dim)
        self.wv = nn.Linear(dim, dim)
        self.wo = nn.Linear(dim, dim)

    def forward(self, x, mask=None):
        bsz, seq_len, dim = x.shape
        q, k, v = self.wq(x), self.wk(x), self.wv(x)

        # 拆分多头: (bsz, seq_len, dim) → (bsz, n_heads, seq_len, head_dim)
        q = q.reshape(bsz, seq_len, self.n_heads, self.head_dim).transpose(1, 2)
        k = k.reshape(bsz, seq_len, self.n_heads, self.head_dim).transpose(1, 2)
        v = v.reshape(bsz, seq_len, self.n_heads, self.head_dim).transpose(1, 2)

        # 每个头独立计算注意力
        s = q @ k.transpose(3, 2) / math.sqrt(self.head_dim)
        if mask is not None:
            s = s + mask.unsqueeze(0).unsqueeze(0)
        p = F.softmax(s, dim=-1)
        z = p @ v

        # 拼接所有头: (bsz, n_heads, seq_len, head_dim) → (bsz, seq_len, dim)
        z = z.transpose(1, 2).reshape(bsz, seq_len, self.dim)
        return self.wo(z)

Multi-Query Attention (MQA) 与 Grouped-Query Attention (GQA)

演进动机

推理阶段需要缓存历史 K、V(KV Cache),其大小为 [2, bsz, seq_len, n_heads, head_dim]。当模型有 64/128 个头时,KV Cache 占用巨大,限制了 batch size 和序列长度。

核心问题: 多头的 K、V 是否存在冗余?能否在减少头数的同时保持精度?

MQA:所有 Q 头共享 1 组 KV

python
class MultiQueryAttention(nn.Module):
    """所有 Query 头共享同一组 K 和 V"""
    def __init__(self, dim=512, n_heads=8):
        super().__init__()
        self.dim = dim
        self.n_heads = n_heads
        self.head_dim = dim // n_heads
        self.wq = nn.Linear(dim, dim)          # 多头 Q
        self.wk = nn.Linear(dim, self.head_dim) # 单头 K
        self.wv = nn.Linear(dim, self.head_dim) # 单头 V
        self.wo = nn.Linear(dim, dim)

    def forward(self, x, mask=None):
        bsz, seq_len, dim = x.shape
        q, k, v = self.wq(x), self.wk(x), self.wv(x)

        q = q.reshape(bsz, seq_len, self.n_heads, self.head_dim).transpose(1, 2)
        k = k[:, None, :, :]  # (bsz, 1, seq_len, head_dim) — 广播到所有头
        v = v[:, None, :, :]

        s = q @ k.transpose(3, 2) / math.sqrt(self.head_dim)
        if mask is not None:
            s = s + mask.unsqueeze(0).unsqueeze(0)
        p = F.softmax(s, dim=-1)
        z = (p @ v).transpose(1, 2).reshape(bsz, seq_len, self.dim)
        return self.wo(z)

GQA:分组共享 KV(Llama 2/3 采用)

GQA 是 MHA 与 MQA 的折中——将 h 个 Q 头分成 g 组,每组共享一套 KV:

python
class GroupQueryAttention(nn.Module):
    """分组共享 KV,n_kv_heads 组,每组 repeats = n_heads // n_kv_heads 个 Q 头"""
    def __init__(self, dim=512, n_heads=8, n_kv_heads=2):
        super().__init__()
        self.dim = dim
        self.n_heads = n_heads
        self.n_kv_heads = n_kv_heads
        self.head_dim = dim // n_heads
        self.repeats = n_heads // n_kv_heads
        self.wq = nn.Linear(dim, dim)
        self.wk = nn.Linear(dim, self.head_dim * n_kv_heads)  # 分组 K
        self.wv = nn.Linear(dim, self.head_dim * n_kv_heads)  # 分组 V
        self.wo = nn.Linear(dim, dim)

    def forward(self, x, mask=None):
        bsz, seq_len, dim = x.shape
        q, k, v = self.wq(x), self.wk(x), self.wv(x)

        q = q.reshape(bsz, seq_len, self.n_heads, self.head_dim).transpose(1, 2)
        k = k.reshape(bsz, seq_len, self.n_kv_heads, self.head_dim).transpose(1, 2)
        v = v.reshape(bsz, seq_len, self.n_kv_heads, self.head_dim).transpose(1, 2)
        # 每组 KV 重复 repeats 次,对齐 Q 的头数
        k = torch.repeat_interleave(k, self.repeats, dim=1)
        v = torch.repeat_interleave(v, self.repeats, dim=1)

        s = q @ k.transpose(3, 2) / math.sqrt(self.head_dim)
        if mask is not None:
            s = s + mask.unsqueeze(0).unsqueeze(0)
        p = F.softmax(s, dim=-1)
        z = (p @ v).transpose(1, 2).reshape(bsz, seq_len, self.dim)
        return self.wo(z)

GQA 的本质:n_kv_heads = n_heads 时退化为 MHA;当 n_kv_heads = 1 时退化为 MQA。Llama 2 70B 使用 n_kv_heads = 8,在质量和效率间取得了极佳平衡。


Multi-Latent Attention (MLA)

DeepSeek 的创新思路

MQA/GQA 通过减少 KV 头数来压缩 KV Cache,但这本质上是一种"特征丢弃"。DeepSeek-V2 提出 MLA,换了一个思路:用低秩压缩代替头数削减

核心思想: 先将输入压缩到一个低维 latent 向量 c,再通过 up-projection 恢复完整的多头 KV。KV Cache 只需存储低维的 c

cKV=WKVdownX,WKVdownRd×dcK=WKupcKV,V=WVupcKV

传统 MHA 的 KV Cache 大小为 2×nh×dh×l,MLA 只需存储 dc×ldcd),压缩比可达 16 倍以上

代码实现

python
class MultiHeadsLatentAttention(nn.Module):
    """MLA: 低秩压缩 KV,KV Cache 只存 latent 向量 c"""
    def __init__(self, dim=64, n_heads=8, d_c=4, q_compress_dim=4):
        super().__init__()
        self.n_heads = n_heads
        self.dim = dim
        self.head_dim = dim // n_heads

        # Q 的低秩分解: X → latent_q → Q
        self.q_down_proj = nn.Linear(dim, q_compress_dim, bias=False)
        self.q_up_proj = nn.Linear(q_compress_dim, dim, bias=False)

        # KV 共享一个 latent 压缩: X → c_kv → K, V
        self.kv_down_proj = nn.Linear(dim, d_c, bias=False)  # 压缩
        self.k_up_proj = nn.Linear(d_c, dim, bias=False)     # 恢复 K
        self.v_up_proj = nn.Linear(d_c, dim, bias=False)     # 恢复 V

        self.wo = nn.Linear(dim, dim, bias=False)

    def forward(self, x):
        B, seq_len, _ = x.shape

        # 低秩压缩后再升维
        latent_q = self.q_down_proj(x)      # (B, seq_len, q_compress_dim)
        Q = self.q_up_proj(latent_q)        # (B, seq_len, dim)

        C_KV = self.kv_down_proj(x)    # (B, seq_len, d_c) ← KV Cache 只存这个!
        K = self.k_up_proj(C_KV)       # (B, seq_len, dim)
        V = self.v_up_proj(C_KV)       # (B, seq_len, dim)

        # 后续与标准多头注意力一致
        Q = Q.reshape(B, seq_len, self.n_heads, self.head_dim).transpose(1, 2)
        K = K.reshape(B, seq_len, self.n_heads, self.head_dim).transpose(1, 2)
        V = V.reshape(B, seq_len, self.n_heads, self.head_dim).transpose(1, 2)

        S = Q @ K.transpose(2, 3) / math.sqrt(self.head_dim)
        P = F.softmax(S.float(), dim=-1)
        Z = (P @ V).transpose(1, 2).contiguous().view(B, seq_len, -1)
        return self.wo(Z)

矩阵吸收技巧(推理优化)

训练完成后,低秩矩阵可以合并,避免推理时的额外计算:

Q=WQup(WQdownX)=(WQupWQdown)WQabsorbedX
python
# 推理时合并权重:训练时省显存,推理时保精度
WQ = (mla.q_up_proj.weight.data @ mla.q_down_proj.weight.data).t()
Q = X @ WQ  # 等价于 mla.q_up_proj(mla.q_down_proj(X)),但只需一次矩阵乘

同理,WUV 可以被 WO 吸收,减少推理时的参数量和计算量。


Flash Attention

GPU 内存层次:SRAM vs HBM

存储层级容量带宽特点
SRAM(片上缓存)~20 MB~19 TB/s极快,但容量很小
HBM(显存)40-80 GB~1.5 TB/s容量大,但带宽是瓶颈

标准 Attention 的 IO 瓶颈

标准 attention 的计算流程:

  1. 从 HBM 读取 Q、K,计算 S=QKT写回 HBMO(n2) 中间矩阵!)
  2. 从 HBM 读取 S,计算 P=softmax(S)写回 HBM
  3. 从 HBM 读取 PV,计算 O=PV写回 HBM

n2 大小的中间矩阵反复在 HBM 上读写,IO 成为瓶颈,而非计算本身。

Flash Attention 的分块策略 + Online Softmax

核心思想: 将 Q、K、V 分成小块,每块放进 SRAM 中完成全部计算,避免将 n2 中间矩阵写回 HBM。

难点在于:softmax 需要全局 max 和 sum,分块后怎么办?答案是 Online Softmax

Online Softmax 原理

对于向量 X=[x1,,xn],标准 softmax 需要两遍扫描(求 max + 求 sum)。Online Softmax 可以单遍扫描,通过递推更新:

m(t)=max(m(t1),xt)l(t)=l(t1)em(t1)m(t)+extm(t)

分块版本: 每块内部独立算 max 和 sum,块间通过上述递推公式合并。

Flash Attention v1 实现(先 KV 后 Q)

python
# Flash Attention v1: 外层遍历 KV 块,内层遍历 Q 块
NEG_INF = -1e10
EPSILON = 1e-10

O = torch.zeros_like(Q)
l = torch.zeros(Q.shape[:-1])[..., None]        # 累计 softmax 分母
m = torch.ones(Q.shape[:-1])[..., None] * NEG_INF  # 累计 max

Q_BLOCKS = torch.split(Q, Q_BLOCK_SIZE, dim=2)
K_BLOCKS = torch.split(K, KV_BLOCK_SIZE, dim=2)
V_BLOCKS = torch.split(V, KV_BLOCK_SIZE, dim=2)
O_BLOCKS = list(torch.split(O, Q_BLOCK_SIZE, dim=2))
l_BLOCKS = list(torch.split(l, Q_BLOCK_SIZE, dim=2))
m_BLOCKS = list(torch.split(m, Q_BLOCK_SIZE, dim=2))

for j in range(Tc):           # 外层遍历 KV 块
    Kj, Vj = K_BLOCKS[j], V_BLOCKS[j]
    for i in range(Tr):       # 内层遍历 Q 块
        Qi, Oi, li, mi = Q_BLOCKS[i], O_BLOCKS[i], l_BLOCKS[i], m_BLOCKS[i]

        S_ij = Qi @ Kj.transpose(2, 3)                       # 块内注意力分数
        m_block_ij, _ = torch.max(S_ij, dim=-1, keepdims=True)
        P_ij = torch.exp(S_ij - m_block_ij)                  # 数值稳定的 exp
        l_block_ij = torch.sum(P_ij, dim=-1, keepdims=True) + EPSILON

        mi_new = torch.maximum(m_block_ij, mi)               # 更新全局 max
        li_new = torch.exp(mi - mi_new) * li \
               + torch.exp(m_block_ij - mi_new) * l_block_ij  # 更新全局 sum

        # 在线更新输出(关键!无需存储完整 n×n 矩阵)
        O_BLOCKS[i] = (li / li_new) * torch.exp(mi - mi_new) * Oi \
                     + (torch.exp(m_block_ij - mi_new) / li_new) * (P_ij @ Vj)

        l_BLOCKS[i] = li_new
        m_BLOCKS[i] = mi_new

O = torch.cat(O_BLOCKS, dim=2)  # 数学上与标准 attention 完全等价!

Flash Attention v2 改进(先 Q 后 KV)

v2 将外层改为遍历 Q 块、内层遍历 KV 块,减少 O 的读写次数,并将 scale 操作推迟到最后:

python
# Flash Attention v2: 外层遍历 Q 块,内层遍历 KV 块
for i in range(Tr):           # 外层遍历 Q 块(O 只在最后写一次)
    Qi, Oi, li, mi = Q_BLOCKS[i], O_BLOCKS[i], l_BLOCKS[i], m_BLOCKS[i]
    for j in range(Tc):       # 内层遍历 KV 块
        Kj, Vj = K_BLOCKS[j], V_BLOCKS[j]

        S_ij = Qi @ Kj.transpose(2, 3)
        m_block_ij, _ = torch.max(S_ij, dim=-1, keepdims=True)
        mi_new = torch.maximum(m_block_ij, mi)
        P_ij_hat = torch.exp(S_ij - mi_new)
        l_block_ij = torch.sum(P_ij_hat, dim=-1, keepdims=True) + EPSILON

        li_new = torch.exp(mi - mi_new) * li + l_block_ij
        Oi = torch.exp(mi - mi_new) * Oi + P_ij_hat @ Vj  # 不除以 l,延迟 scale

        li, mi = li_new, mi_new

    O_BLOCKS[i] = Oi / li_new   # 最后一次性做 scale

Flash Attention 深度实现

核心算法:Online Softmax 与分块计算

Flash Attention 的关键挑战在于:softmax 是一个全局操作,需要知道整个序列的 max 和 sum。分块计算时,每个 block 只能看到部分数据,如何保证结果的精确性?

答案是 Online Softmax 的分块递推公式。假设我们已经处理了前 j1 个 KV 块,当前要合并第 j 个块:

m(j)=max(m(j1),max(Sij))l(j)=em(j1)m(j)l(j1)+rowsum(eSijm(j))O(j)=em(j1)m(j)O(j1)+eSijm(j)Vj

最终输出为 O=O(Tc)/l(Tc)。这个递推保证了数学上与标准 Attention 完全等价

前向传播伪代码

算法: Flash Attention 前向传播
输入: Q, K, V ∈ R^{N×d}, 块大小 B_r, B_c
输出: O ∈ R^{N×d}

1. 将 Q 分成 T_r = ⌈N/B_r⌉ 块, K/V 分成 T_c = ⌈N/B_c⌉ 块
2. 初始化 O = 0, l = 0, m = -∞  (均为 R^{N} 向量)
3. for j = 1 to T_c:                    # 外层遍历 KV 块
4.     从 HBM 加载 K_j, V_j 到 SRAM
5.     for i = 1 to T_r:                # 内层遍历 Q 块
6.         从 HBM 加载 Q_i, O_i, l_i, m_i 到 SRAM
7.         计算 S_ij = Q_i @ K_j^T ∈ R^{B_r × B_c}    (在 SRAM 中)
8.         计算 m_ij = rowmax(S_ij)
9.         计算 P_ij = exp(S_ij - m_ij)
10.        计算 l_ij = rowsum(P_ij)
11.        更新 m_new = max(m_i, m_ij)
12.        更新 l_new = exp(m_i - m_new) * l_i + exp(m_ij - m_new) * l_ij
13.        更新 O_i = exp(m_i - m_new) * O_i + exp(m_ij - m_new) * P_ij @ V_j
14.        将 O_i, l_new, m_new 写回 HBM
15. 返回 O = O / l   (逐行 scale)

关键点:Sij 矩阵从未写入 HBM,它在 SRAM 中计算、使用、然后丢弃。这就是为什么显存从 O(N2) 降到了 O(N)

PyTorch 实现

python
import torch
import math

def flash_attention_forward(Q, K, V, block_size=64):
    """
    Flash Attention 前向传播的纯 PyTorch 模拟实现。
    实际 Flash Attention 使用 CUDA kernel 在 SRAM 中执行,
    这里用 Python 展示算法逻辑。

    Args:
        Q, K, V: (batch, heads, seq_len, head_dim)
        block_size: SRAM 块大小
    Returns:
        O: (batch, heads, seq_len, head_dim)
    """
    B, H, N, d = Q.shape
    Br = min(block_size, N)  # Q 块大小
    Bc = min(block_size, N)  # KV 块大小
    Tr = math.ceil(N / Br)   # Q 块数量
    Tc = math.ceil(N / Bc)   # KV 块数量

    O = torch.zeros_like(Q)
    l = torch.zeros(B, H, N, 1, device=Q.device)
    m = torch.full((B, H, N, 1), float('-inf'), device=Q.device)

    # 将 Q, K, V 分块
    Q_blocks = list(Q.split(Br, dim=2))
    K_blocks = list(K.split(Bc, dim=2))
    V_blocks = list(V.split(Bc, dim=2))
    O_blocks = list(O.split(Br, dim=2))
    l_blocks = list(l.split(Br, dim=2))
    m_blocks = list(m.split(Br, dim=2))

    scale = 1.0 / math.sqrt(d)

    for j in range(Tc):
        Kj = K_blocks[j]     # (B, H, Bc, d)
        Vj = V_blocks[j]     # (B, H, Bc, d)
        for i in range(Tr):
            Qi = Q_blocks[i]
            Oi = O_blocks[i]
            li = l_blocks[i]
            mi = m_blocks[i]

            # Step 1: 计算注意力分数(在 SRAM 中)
            Sij = (Qi @ Kj.transpose(-2, -1)) * scale  # (B, H, Br, Bc)

            # Step 2: 块内 softmax 统计量
            mij = Sij.max(dim=-1, keepdim=True).values  # (B, H, Br, 1)
            Pij = torch.exp(Sij - mij)                   # (B, H, Br, Bc)
            lij = Pij.sum(dim=-1, keepdim=True)           # (B, H, Br, 1)

            # Step 3: 合并新旧统计量
            mi_new = torch.maximum(mi, mij)
            li_new = (torch.exp(mi - mi_new) * li
                      + torch.exp(mij - mi_new) * lij)

            # Step 4: 更新输出(在线累积)
            O_blocks[i] = (torch.exp(mi - mi_new) * Oi
                           + torch.exp(mij - mi_new) * (Pij @ Vj))

            l_blocks[i] = li_new
            m_blocks[i] = mi_new

    # 最终 scale
    O = torch.cat([O_blocks[i] / l_blocks[i] for i in range(Tr)], dim=2)
    return O

# 验证与标准 Attention 等价
torch.manual_seed(42)
B, H, N, d = 2, 4, 128, 64
Q = torch.randn(B, H, N, d)
K = torch.randn(B, H, N, d)
V = torch.randn(B, H, N, d)

# 标准 Attention
scale = 1.0 / math.sqrt(d)
S = (Q @ K.transpose(-2, -1)) * scale
P = torch.softmax(S, dim=-1)
O_standard = P @ V

# Flash Attention
O_flash = flash_attention_forward(Q, K, V, block_size=32)

print(f"最大误差: {(O_standard - O_flash).abs().max():.2e}")  # ~1e-6,浮点精度误差

反向传播:重计算 vs 存储

标准 Attention 的反向传播需要注意力矩阵 PRN×N,这正是我们想避免存储的。Flash Attention 的解决方案是重计算(recomputation)

  1. 前向传播:只保存 O,l,m(输出和 softmax 统计量),不保存 SP
  2. 反向传播:利用保存的 l,m,在 SRAM 中重新计算 SP 的每个块
  3. 额外计算量:反向传播多做了一次分块矩阵乘法,但由于 IO 大幅减少,总体仍然更快
反向传播关键步骤:
1. 从 HBM 加载 Q_i, K_j, V_j, O_i, l_i, m_i, dO_i
2. 在 SRAM 中重计算: S_ij = Q_i @ K_j^T, P_ij = softmax(S_ij)  ← 利用 l_i, m_i
3. 计算 dV_j += P_ij^T @ dO_i
4. 计算 dP_ij = dO_i @ V_j^T
5. 计算 dS_ij = P_ij ⊙ (dP_ij - rowsum(dP_ij ⊙ P_ij))    ← softmax 反向
6. 计算 dQ_i += dS_ij @ K_j, dK_j += dS_ij^T @ Q_i
7. 写回 dQ_i, dK_j, dV_j 到 HBM

为什么重计算反而更快? 因为重计算的代价是 O(N2d) 的额外 FLOPs(与前向相同),但节省了 O(N2) 的 HBM 读写。在现代 GPU 上,IO 是瓶颈(HBM 带宽远低于计算吞吐),用计算换 IO 是划算的。

Flash Attention 2 的优化

Flash Attention 2 在 v1 基础上做了三项关键优化,将速度进一步提升 ~2x:

优化 1:减少非矩阵乘法 FLOPs

v1 中大量时间花在 rescaling 操作(乘以 emoldmnew)上。v2 将 rescaling 推迟到内层循环结束后一次性做:

python
# v1: 每个块都做完整 rescaling
O_i = (l_old / l_new) * exp(m_old - m_new) * O_i + (exp(m_block - m_new) / l_new) * PV

# v2: 延迟 rescaling,最后一步才除以 l
O_i = exp(m_old - m_new) * O_i + P_tilde @ Vj  # 不除以 l
# ... 内层循环结束后 ...
O_i = O_i / l_final  # 一次性 scale

在 GPU 上,矩阵乘法(GEMM)由 Tensor Core 加速,而逐元素操作(rescaling)只能用普通 CUDA Core。减少非 matmul FLOPs 能显著提升 Tensor Core 利用率。

优化 2:更好的并行——外层遍历 Q

v1 外层遍历 KV、内层遍历 Q,导致每个 Q 块的输出需要反复读写 HBM。v2 反转循环顺序:

v1: for KV_block → for Q_block   # O 被反复读写
v2: for Q_block → for KV_block   # 每个 Q 块的 O 只写一次 HBM

这使得每个 thread block 独立处理一个 Q 块,不同 Q 块之间无需通信,在 GPU SM 之间实现了完美并行

优化 3:序列长度维度的并行

v1 只在 batch 和 head 维度做并行。当 batch size 较小时(如推理),SM 利用率不高。v2 额外在序列长度维度做并行(将 Q 的不同块分配到不同 SM),大幅提升了小 batch 场景的效率。

IO 复杂度分析

算法FLOPsHBM 读写量IO 复杂度
标准 AttentionO(N2d)O(N2+Nd)N2 中间矩阵 IO 限制
Flash AttentionO(N2d)O(N2d2/M)M = SRAM 大小

推导:Flash Attention 的外层有 Tc=N/Bc 次迭代,内层有 Tr=N/Br 次迭代。每次内层迭代从 HBM 读取 QiBr×d)、KjVjBc×d)。选择最优块大小 Bc=Θ(M/d)Br=Θ(min(M/d,d))(确保块能放入 SRAM),总 HBM 访问量为:

IO=Θ(N2d2M)

M=Θ(Nd) 时(SRAM 足够大),IO 降为 O(Nd),等价于只读写 Q、K、V 各一次。

数值示例N=4096, d=128, M=20MB(A100 SRAM)

  • 标准 Attention IO: N216M 读写(注意力矩阵 SP
  • Flash Attention IO: N2d2/M16M×16K/20M13M

看似差不多,但 Flash Attention 避免了写 S,P 到 HBM 再读回的两次完整 IO,且块大小经过优化后实际加速更显著。

面试考点:为什么 Flash Attention 更快但 FLOPs 相同?

这是一个非常经典的问题,核心答案是Roofline Model

  1. FLOPs 不变:Flash Attention 计算的数学结果与标准 Attention 完全相同,矩阵乘法的次数一样
  2. IO 大幅减少:标准 Attention 需要将 N2 大小的中间矩阵 SP 写入 HBM 再读回,而 Flash Attention 将这些中间结果保持在 SRAM 中
  3. 瓶颈转移:标准 Attention 是 IO-bound(显存带宽是瓶颈),Flash Attention 通过减少 IO 将瓶颈转移到 compute-bound,从而真正利用上 GPU 的算力
  4. 重计算的"免费午餐":反向传播多做的那次前向重计算,其 FLOPs 增加约 33%,但 IO 节省远大于此——在 A100 上净加速 2-4 倍

一句话总结: Flash Attention 不是"算得更快",而是"搬数据搬得更少"。在 GPU 上,SRAM 带宽是 HBM 的 ~10 倍,减少 HBM 访问就是最大的加速。

Online Softmax:从两遍扫描到一遍扫描

标准 Softmax 需要两遍扫描才能完成计算:

softmax(xi)=eximjexjm,m=maxjxj
  • 第一遍:扫描所有元素,求全局最大值 m(数值稳定性所需)和求和 =jexjm
  • 第二遍:再扫描一次,计算每个元素的 softmax 值

这意味着整个向量必须在内存中被访问两次。对于 Flash Attention 的分块计算来说,我们无法一次看到完整的行——每次只能看到一个 block。

Online Softmax 的核心思想

利用指数函数的性质,当新 block 带来更大的 max 值 mnew 时,只需对历史累积量乘一个修正因子 emoldmnew 即可:

new=oldemoldmnew+jblockexjmnew

这样只需一遍扫描就能得到全局正确的 softmax 分母。

逐元素 Online Softmax 实现:

python
import torch

X = torch.tensor([1.0, 1.5, 1.8, 2.0, 1.4, 2.1])

# ---- 标准 safe softmax(两遍扫描)----
X_max = X.max()
X_safe_softmax = torch.exp(X - X_max) / torch.exp(X - X_max).sum()

# ---- Online Softmax(一遍扫描)----
m_cur = torch.tensor(float('-inf'))
l_cur = torch.tensor(0.0)

for i in range(len(X)):
    m_new = torch.max(m_cur, X[i])
    # 修正历史 sum + 加入新元素
    l_cur = l_cur * torch.exp(m_cur - m_new) + torch.exp(X[i] - m_new)
    m_cur = m_new

X_online_softmax = torch.exp(X - m_cur) / l_cur
print(torch.allclose(X_safe_softmax, X_online_softmax))  # True

分块 Online Softmax(Flash Attention 实际使用的形式):

python
BLOCK = 3
X_blocks = X.split(BLOCK)

m_cur = torch.tensor(float('-inf'))
l_cur = torch.tensor(0.0)

for blk in X_blocks:
    m_blk = blk.max()
    m_new = torch.max(m_cur, m_blk)
    l_cur = l_cur * torch.exp(m_cur - m_new) \
          + torch.exp(blk - m_new).sum()
    m_cur = m_new

X_block_online_softmax = torch.exp(X - m_cur) / l_cur
print(torch.allclose(X_safe_softmax, X_block_online_softmax))  # True

分块 Online Softmax 是 Flash Attention 能在 SRAM 中分块完成 Softmax 的数学基础——看到新 block 时修正旧统计量,而非回头重算

Flash Attention 反向传播

Flash Attention 反向传播的关键挑战是:前向传播没有保存 N×N 的注意力矩阵 P(否则就失去了省显存的意义),但反向传播需要 P 来计算 dQ,dK,dV

核心策略:重计算 (Recomputation)

前向时只保存 Q,K,V,O 和每行的 logsumexp Li=mi+log(i)。反向时在 SRAM 中重新计算每个 block 的 SijPij,避免从 HBM 读取 N2 大小的矩阵。

梯度推导

给定上游梯度 dO,标准 Attention 的梯度为:

dV=PdOdP=dOVdS=P(dPD),Di=jPijdPij=rowsum(OdO)dQ=dSK/ddK=dSQ/d

其中 Di=jPijdPij 可以用 OdO 直接算出(D=rowsum(OdO)),不需要存储 P

为什么 Di=rowsum(OdO)

因为 Oi=jPijVj,所以 jPij(dOiVj)=dOijPijVj=dOiOi。这正是逐行求和 rowsum(OdO) 的第 i 个元素。

分块反向传播实现:

以下实现参考了 Flash Attention 2 论文 (Dao, 2023) 的算法描述,代码经过教学化改写。

python
import torch
import math

torch.manual_seed(42)
n, dim, nb = 12, 8, 4  # 序列长度, 头维度, block 大小
block = n // nb

Q = torch.randn(n, dim, requires_grad=True)
K = torch.randn(n, dim, requires_grad=True)
V = torch.randn(n, dim, requires_grad=True)

# ---- Flash Attention 前向(保存 O 和 L)----
def flash_attention_forward(Q, K, V):
    O = torch.zeros_like(Q)
    L = torch.zeros(n, 1)
    for tq in range(block):
        q = Q[tq*nb:(tq+1)*nb, :]
        o_old = torch.zeros_like(q)
        l_old = m_old = torch.zeros(nb, 1)
        for tk in range(block):
            k = K[tk*nb:(tk+1)*nb, :]
            v = V[tk*nb:(tk+1)*nb, :]
            s = q @ k.t() / math.sqrt(dim)
            m = s.max(dim=1, keepdim=True).values
            m_new = torch.maximum(m, m_old)
            l = torch.exp(s - m_new).sum(dim=1, keepdim=True)
            l_new = l_old * torch.exp(m_old - m_new) + l
            o_old = l_old * o_old * torch.exp(m_old - m_new) \
                  + torch.exp(s - m_new) @ v
            o_old = o_old / l_new
            l_old, m_old = l_new, m_new
        O[tq*nb:(tq+1)*nb, :] = o_old
        L[tq*nb:(tq+1)*nb, :] = m_old + l_old.log()
    return O, L

O, L = flash_attention_forward(Q, K, V)

# ---- Flash Attention 反向(分块重计算)----
dO = torch.randn_like(O)  # 模拟上游梯度

def flash_attention_backward(Q, K, V, O, dO, L):
    dQ = torch.zeros_like(Q)
    dK = torch.zeros_like(K)
    dV = torch.zeros_like(V)
    # D_i = rowsum(O * dO),不需要 P
    D = (O * dO).sum(dim=1, keepdim=True)

    for tk in range(block):          # 外层遍历 KV block
        k = K[tk*nb:(tk+1)*nb, :]
        v = V[tk*nb:(tk+1)*nb, :]
        for tq in range(block):      # 内层遍历 Q block
            q  = Q[tq*nb:(tq+1)*nb, :]
            o  = O[tq*nb:(tq+1)*nb, :]
            do = dO[tq*nb:(tq+1)*nb, :]
            l  = L[tq*nb:(tq+1)*nb, :]
            d  = D[tq*nb:(tq+1)*nb, :]

            # ---- 重计算 attention(无需从 HBM 读 P)----
            s = q @ k.t() / math.sqrt(dim)
            p = torch.exp(s - l)     # 利用 L = m + log(l) 还原 softmax

            # ---- 梯度计算 ----
            dv = p.t() @ do
            dp = do @ v.t()
            ds = p * (dp - d)        # softmax 反向的紧凑形式
            dq = ds @ k / math.sqrt(dim)
            dk = ds.t() @ q / math.sqrt(dim)

            # ---- 累加到全局梯度 ----
            dV[tk*nb:(tk+1)*nb, :] += dv
            dQ[tq*nb:(tq+1)*nb, :] += dq
            dK[tk*nb:(tk+1)*nb, :] += dk

    return dQ, dK, dV

dQ_flash, dK_flash, dV_flash = flash_attention_backward(
    Q, K, V, O, dO, L
)

用 PyTorch Autograd 验证正确性

我们实现的反向传播是否正确?我们用 PyTorch 自动微分作为 ground truth 进行对比:

python
# ---- PyTorch 标准 attention + autograd ----
S = Q @ K.t() / math.sqrt(dim)
P = torch.softmax(S, dim=-1)
O_ref = P @ V

# 用相同的 dO 计算 autograd 梯度
O_ref.backward(dO)

print("dQ allclose:", torch.allclose(dQ_flash, Q.grad, atol=1e-5))
print("dK allclose:", torch.allclose(dK_flash, K.grad, atol=1e-5))
print("dV allclose:", torch.allclose(dV_flash, V.grad, atol=1e-5))
# 输出:全部 True

反向传播的 IO 分析

标准反向传播需要从 HBM 读取 P(大小 N2),IO 复杂度 O(N2)。Flash Attention 反向通过重计算 P,将 IO 降至 O(N2d/M)(与前向相同),代价是多做了一次 O(N2d) 的矩阵乘法——但在 GPU 上这是 compute-bound 操作,远快于 HBM 读写。


Tensor Product Attention (TPA)

核心思想

标准注意力中,K 和 V 通过单个线性投影 K=XWK 得到,每个 token 独立生成 KV。TPA(Tensor Product Attention) 引入了一种新的分解方式:将 K、V 的投影分解为两个低秩矩阵的张量积,让模型在更低的参数量和 KV Cache 开销下保持表达能力。

数学公式

TPA 将 K、V 的计算分解为两个低秩分量 A(token 级)和 B(token 级)的乘积:

Ak=XWAkRn×h×r,Bk=XWBkRn×r×dhK=1rAkBkRn×h×dh

V 的计算方式完全类似。其中 r 是分解的秩(rank),远小于 dh。Q 仍然使用标准的线性投影。

关键洞察: 这本质上是对 KV 投影矩阵做了 CP 分解(Canonical Polyadic Decomposition),每个 token 的 K/V 由两个低秩因子按位相乘得到,兼顾表达能力与压缩效率。

简化版代码实现

以下是简化版 TPA 实现:

python
import torch
import torch.nn as nn

class TPAProjection(nn.Module):
    """Tensor Product Attention 的 QKV 投影"""
    def __init__(self, d_model=512, n_head=8, head_dim=64, rank=4):
        super().__init__()
        self.n_head = n_head
        self.head_dim = head_dim
        self.rank = rank

        # Q 使用标准投影
        self.W_q = nn.Linear(d_model, n_head * head_dim, bias=False)
        # K, V 各用两个低秩投影 (CP 分解)
        self.W_A_k = nn.Linear(d_model, n_head * rank, bias=False)
        self.W_B_k = nn.Linear(d_model, rank * head_dim, bias=False)
        self.W_A_v = nn.Linear(d_model, n_head * rank, bias=False)
        self.W_B_v = nn.Linear(d_model, rank * head_dim, bias=False)

    def forward(self, x):
        bs, seq_len, _ = x.size()
        q = self.W_q(x).view(bs, seq_len, self.n_head, self.head_dim)

        # K = (1/r) * A_k @ B_k
        A_k = self.W_A_k(x).view(bs * seq_len, self.n_head, self.rank)
        B_k = self.W_B_k(x).view(bs * seq_len, self.rank, self.head_dim)
        k = torch.bmm(A_k, B_k).div_(self.rank)
        k = k.view(bs, seq_len, self.n_head, self.head_dim)

        # V = (1/r) * A_v @ B_v
        A_v = self.W_A_v(x).view(bs * seq_len, self.n_head, self.rank)
        B_v = self.W_B_v(x).view(bs * seq_len, self.rank, self.head_dim)
        v = torch.bmm(A_v, B_v).div_(self.rank)
        v = v.view(bs, seq_len, self.n_head, self.head_dim)

        return q, k, v

# 使用示例
tpa = TPAProjection(d_model=512, n_head=8, head_dim=64, rank=4)
x = torch.randn(2, 16, 512)
q, k, v = tpa(x)
print(q.shape, k.shape, v.shape)
# torch.Size([2, 16, 8, 64]) torch.Size([2, 16, 8, 64]) torch.Size([2, 16, 8, 64])

TPA vs 标准 Attention 对比

特性标准 MHATPA
KV 投影参数2×d×hdh2×d×(hr+rdh)
参数压缩比基准rdh 时显著减少
KV Cache标准可只缓存 AB 因子
表达能力基准rank 足够时接近 MHA
适用场景通用KV Cache 受限的长序列推理

注意力变体对比

特性MHAMQAGQAMLATPA
Q 头数hhhh(低秩分解)h
KV 头数h1g1<g<h全头(从 latent 恢复)h(低秩因子)
KV Cache 大小2hdhl2dhl2gdhldcldchdh2h(r+rdh/h)l
KV 参数量2×d×d2×d×dh2×d×gdhd×dc+2×dc×d2×d×(hr+rdh)
精度保持基准略有下降接近 MHA接近 MHA接近 MHA
代表模型GPT-3, BERTPaLMLlama 2/3DeepSeek-V2/V3Tensor Product Attention
核心思想多头并行KV 共享分组 KV 共享低秩 KV 压缩KV 张量积分解

关键洞察: MQA/GQA 是在"头数维度"上压缩;MLA 是在"特征维度"上压缩(类似 LoRA 思想);TPA 是在"投影矩阵"上做 CP 分解——三者从不同角度减少 KV 开销。


苏格拉底时刻

  1. 为什么 MHA 不用一个大头? 多头让模型在不同子空间并行捕捉不同类型的关系(语法、语义、位置等)。单头只能在一个空间学习,表达能力受限。

  2. Flash Attention 在数学上完全等价,加速从何而来? 减少了 HBM 读写次数。标准 attention 的 IO 复杂度为 O(n2)(存储中间矩阵 S,P),Flash Attention 通过分块将 IO 降至 O(n2d/M),其中 M 为 SRAM 大小。

  3. GQA 的 repeat_interleave 是否增加了计算量? 注意力分数的计算量不变(仍为 n2hdh)。减少的是投影参数和 KV Cache 的存储/传输开销。在 GPU SRAM 中 repeat 操作几乎免费。

  4. MLA 的"矩阵吸收"为什么能 work? 基于矩阵乘法结合律:Wup(WdownX)=(WupWdown)X。训练时分两步省显存;推理时合并为一步保精度。

  5. Online Softmax 为什么能单遍扫描? 利用指数函数的性质:exmnew=exmoldemoldmnew,当发现新的 max 时,只需对历史累积量乘一个修正因子。


常见问题 & 面试考点

Q1: 注意力的计算复杂度是多少? 时间复杂度 O(n2d),空间复杂度 O(n2)(存储注意力矩阵)。这是长序列建模的主要瓶颈。

Q2: Causal Mask 在训练和推理中的作用? 训练时:对注意力矩阵的上三角填充 ,防止 token 看到未来信息,保证自回归训练的正确性。推理时(KV Cache 模式):每步只计算新 token 对所有历史 token 的注意力,mask 隐式生效。

Q3: KV Cache 为什么只缓存 K 和 V,不缓存 Q? 自回归推理时,每步只有一个新 token 生成新的 Q(长度为 1),无需缓存。而 K、V 需要保留所有历史 token 的结果。

Q4: MQA 为什么效果只略有下降? 高维特征存在冗余,多个头的 KV 投影高度相关。实验表明 KV 的多样性对模型质量影响远小于 Q 的多样性。

Q5: Flash Attention 能用于训练吗? 能。Flash Attention 同时优化了前向和反向传播的 IO,训练速度提升 2-4 倍,显存减少 5-20 倍。


推荐资源