Skip to content

RoPE 代码填空 (Level 2-3)

难度: 中高 | 前置知识: Transformer 架构、旋转位置编码原理 | 预计时间: 45-60 分钟

RoPE(Rotary Position Embedding)是 Llama 等现代大模型的标准位置编码方案。它通过对 Q、K 向量施加位置相关的旋转,将相对位置信息编码到注意力分数中。

本练习包含 3 个代码填空,分别覆盖 2D 旋转基础、频率计算和旋转应用。

背景知识速览

为什么需要 RoPE?

传统绝对位置编码(PE)的问题:两个 token 的注意力分数展开后包含绝对位置信息:

(Em+Pm)(En+Pn)T=EmEnT+EmPnT+PmEnT+PmPnT

其中 PmEnTEmPnT 项让注意力依赖于绝对位置。例如"大模型"在位置 0-2 和位置 100-102 的注意力分数不同,导致在训练范围外的位置(如 10000+)表示失效。

RoPE 的目标: 找到一种变换 f,使得 f(q,m)f(k,n)T=g(q,k,mn),即注意力分数只依赖相对位置 mn

核心思想:2D 旋转

RoPE 将特征向量每两个维度分为一组,对每组施加 2D 旋转矩阵:

(x2ix2i+1)=(cos(mθi)sin(mθi)sin(mθi)cos(mθi))(x2ix2i+1)

关键性质: 旋转矩阵满足 R(mθ)R(nθ)T=R((mn)θ),因此:

Smn,i=Qm,(i)R((mn)θi)Kn,(i)T

这正是我们想要的——注意力分数只依赖相对位置 mn

频率公式(控制每组维度旋转的速度):

θi=1100002i/d,i=0,1,,d/21

低维度的 θi 大(旋转快,捕捉短程依赖),高维度的 θi 小(旋转慢,捕捉长程依赖)。


练习 0(热身):2D 旋转矩阵

理解 2D 旋转是理解 RoPE 的基础。

python
import torch
import math

def rotate_2d(theta):
    """
    构造 2D 旋转矩阵

    参数: theta - 旋转角度(弧度)
    返回: (2, 2) 旋转矩阵
    """
    # 空白: 填写 2x2 旋转矩阵
    # 提示: [[cos, -sin], [sin, cos]]
    mat = _____
    return mat

# 测试:将 [1, 0] 旋转 45 度
vec = torch.tensor([[1.0], [0.0]])
mat = rotate_2d(math.radians(45))
v_rot = mat @ vec
print(f"旋转前: {vec[:, 0].tolist()}")
print(f"旋转后: {v_rot[:, 0].tolist()}")
# 期望: [0.7071, 0.7071]
assert torch.allclose(v_rot, torch.tensor([[0.7071], [0.7071]]), atol=1e-3)
print("2D 旋转测试通过!")
查看答案
python
mat = torch.tensor(
    [[math.cos(theta), -math.sin(theta)],
     [math.sin(theta),  math.cos(theta)]]
)

解析: 这就是标准的 2D 旋转矩阵。旋转矩阵是正交矩阵,满足 RT=R1,所以旋转不改变向量的范数。


练习 1:频率计算

python
import torch

def precompute_rope_frequencies(dim, max_seq_len, base=10000.0):
    """
    预计算 RoPE 所需的 cos 和 sin 值

    参数:
        dim: 每个注意力头的维度 d_k (必须是偶数)
        max_seq_len: 支持的最大序列长度
        base: 频率基数 (默认 10000)

    返回:
        cos_cached: (max_seq_len, dim) 预计算的 cos 值(已 repeat 到全维度)
        sin_cached: (max_seq_len, dim) 预计算的 sin 值(已 repeat 到全维度)
    """
    assert dim % 2 == 0, "维度必须是偶数"

    # 空白1: 计算频率向量 theta
    # theta_i = 1 / (base^(2i/dim)), i = 0, 1, ..., dim/2 - 1
    # 提示: 先构造 i = [0, 2, 4, ..., dim-2] / dim,再用 base 的负指数
    theta = _____
    # theta shape: (dim//2,)

    # 空白2: 构造位置索引并计算 m * theta
    # positions = [0, 1, 2, ..., max_seq_len-1]
    # 使用外积 (outer product) 计算每个位置和每个频率的乘积
    positions = torch.arange(max_seq_len, dtype=torch.float32)
    m_theta = _____
    # m_theta shape: (max_seq_len, dim//2)

    # 空白3: 将 cos/sin 扩展到全维度
    # 每个频率对应两个维度(偶数和奇数),需要 repeat
    # 提示: 构造 (max_seq_len, dim) 的矩阵,偶数列和奇数列填相同的值
    cos_cached = torch.zeros(max_seq_len, dim)
    sin_cached = torch.zeros(max_seq_len, dim)
    cos_cached[:, 0::2] = _____
    cos_cached[:, 1::2] = _____
    sin_cached[:, 0::2] = _____
    sin_cached[:, 1::2] = _____

    return cos_cached, sin_cached

提示:

  • 空白1:用 torch.arange 生成 [0, 2, 4, ..., dim-2],除以 dim,然后作为 base 的负指数
  • 空白2:torch.outer(a, b) 计算两个一维向量的外积
  • 空白3:cos/sin 的偶数列和奇数列填相同的值,因为每对维度共享一个频率
查看答案
python
# 空白1: 频率向量
i = torch.arange(0, dim // 2, dtype=torch.float32)
theta = base ** (-2 * i / dim)
# 等价写法: theta = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim))

# 空白2: 外积计算 m * theta
m_theta = torch.outer(positions, theta)

# 空白3: repeat 到全维度
cos_cached[:, 0::2] = torch.cos(m_theta)
cos_cached[:, 1::2] = torch.cos(m_theta)
sin_cached[:, 0::2] = torch.sin(m_theta)
sin_cached[:, 1::2] = torch.sin(m_theta)

解析:

空白1 实现的是 θi=1100002i/d

  • torch.arange(0, dim//2) 生成 [0, 1, 2, ..., dim/2-1],对应公式中的 i
  • base ** (-2 * i / dim) 直接计算 100002i/d=1100002i/d

结果是一个长度为 dim//2 的向量,每个元素对应一个频率。低维度的频率高(变化快),高维度的频率低(变化慢)。

空白2 用外积计算所有位置-频率的组合:

  • positions: shape (max_seq_len,),值为 [0, 1, 2, ..., max_seq_len-1]
  • theta: shape (dim//2,)
  • torch.outer(positions, theta): shape (max_seq_len, dim//2)
  • (m, i) 个元素就是 mθi

空白3dim//2 扩展到 dim:每对相邻维度 (2i, 2i+1) 共享同一个频率 θi,因此偶数列和奇数列填相同的 cos/sin 值。


练习 2:旋转应用

python
def apply_rope(x, cos, sin):
    """
    对输入张量应用旋转位置编码

    参数:
        x: 输入张量 (batch_size, n_heads, seq_len, d_k)
        cos: 预计算的 cos 值 (max_seq_len, d_k)
        sin: 预计算的 sin 值 (max_seq_len, d_k)

    返回:
        x_rotated: 旋转后的张量 (batch_size, n_heads, seq_len, d_k)
    """
    bs, n_heads, seq_len, d = x.shape

    # 截取当前序列长度需要的 cos/sin
    cos = cos[:seq_len]  # (seq_len, d)
    sin = sin[:seq_len]  # (seq_len, d)

    # 空白1: 构造"旋转交换"向量 X_shift
    # 旋转公式的线性化版本:
    #   x_new = cos * x + sin * x_shift
    # 其中 x_shift 将每对相邻维度交换并取负:
    #   x_shift[..., 0::2] = -x[..., 1::2]  (偶数位 = 负的奇数位)
    #   x_shift[..., 1::2] =  x[..., 0::2]  (奇数位 = 偶数位)
    X_shift = torch.zeros_like(x)
    X_shift[..., 0::2] = _____
    X_shift[..., 1::2] = _____

    # 空白2: 应用旋转
    # 提示: cos/sin 的 shape 是 (seq_len, d)
    # 需要广播到 (batch, heads, seq_len, d)
    # 使用 None 或 unsqueeze 添加 batch 和 heads 维度
    Y = _____

    return Y

提示:

  • 空白1:这是将矩阵乘法 R(θ)x 转化为逐元素运算的技巧。将 [x0, x1, x2, x3, ...] 变为 [-x1, x0, -x3, x2, ...]
  • 空白2:cos[None, None, :seq_len, :] 将 shape 从 (seq_len, d) 变为 (1, 1, seq_len, d),支持广播
查看答案
python
# 空白1: 构造旋转交换向量
X_shift[..., 0::2] = -x[..., 1::2]  # 偶数位 = 负的奇数位
X_shift[..., 1::2] = x[..., 0::2]   # 奇数位 = 偶数位

# 空白2: 应用旋转
Y = cos[None, None, :seq_len, :] * x + sin[None, None, :seq_len, :] * X_shift

解析:

空白1 实现了旋转矩阵乘法的线性化。对于每对维度 (x2i,x2i+1),旋转公式为:

x2i=cosθx2isinθx2i+1x2i+1=sinθx2i+cosθx2i+1

将其统一写成向量形式:Y=cosX+sinXshift

其中 Xshift 将每对维度交换并给偶数位取负:[-x1, x0, -x3, x2, ...]

这样就避免了显式构造旋转矩阵再做矩阵乘法,只需要逐元素运算。

空白2 使用 None 索引(等价于 unsqueeze)将 cos/sin 从 (seq_len, d) 广播到 (batch, heads, seq_len, d),然后做逐元素乘加。


验证代码

python
# 完整测试
dim = 64
max_seq_len = 128
batch_size, n_heads, seq_len = 2, 8, 32

# 预计算频率
cos_cached, sin_cached = precompute_rope_frequencies(dim, max_seq_len)
print(f"cos shape: {cos_cached.shape}")  # (128, 64)
print(f"sin shape: {sin_cached.shape}")  # (128, 64)

# 应用旋转
Q = torch.randn(batch_size, n_heads, seq_len, dim)
Q_rotated = apply_rope(Q, cos_cached, sin_cached)
print(f"Q shape: {Q.shape}")
print(f"Q_rotated shape: {Q_rotated.shape}")
assert Q_rotated.shape == Q.shape, "Shape 不匹配!"

# 验证关键性质: 旋转后向量的范数应该不变(旋转矩阵是正交矩阵)
norm_before = torch.norm(Q, dim=-1)
norm_after = torch.norm(Q_rotated, dim=-1)
print(f"范数差异: {(norm_before - norm_after).abs().max().item():.6f}")  # 应接近 0
assert torch.allclose(norm_before, norm_after, atol=1e-5), "范数应该不变!"

# 验证多头共享: 不同 head 使用相同的 RoPE 参数
K = torch.randn(batch_size, n_heads, seq_len, dim)
K_rotated = apply_rope(K, cos_cached, sin_cached)
print(f"K_rotated shape: {K_rotated.shape}")

print("所有测试通过!")

RoPE 在模型中的使用

完成填空后,了解 RoPE 如何集成到完整模型中:

python
# 伪代码:RoPE 在 GPT 模型中的位置
class Attention:
    def __init__(self):
        self.wq, self.wk, self.wv, self.wo = ...

    def forward(self, X, mask, sin, cos):
        q, k, v = self.wq(X), self.wk(X), self.wv(X)
        q, k, v = split_heads(q), split_heads(k), split_heads(v)

        # RoPE 只作用于 Q 和 K(不作用于 V!)
        q = apply_rope(q, cos, sin)
        k = apply_rope(k, cos, sin)

        # KV Cache 存储的是 apply_rope 后的 K, V
        # ... 后续注意力计算 ...

class Model:
    def __init__(self):
        # RoPE 参数在 model 层面预计算,所有层共享
        sin, cos = precompute_rope_frequencies(head_dim, max_len)
        self.register_buffer('rope_sin', sin)  # 不需要梯度
        self.register_buffer('rope_cos', cos)

    def forward(self, x):
        for block in self.decoder:
            # 只传当前序列长度的 sin/cos
            x = block(x, self.rope_sin[:seq_len], self.rope_cos[:seq_len])

注意几个要点:

  1. RoPE 的维度是 head_dim(不是 dim),多头共享一份 RoPE 参数
  2. RoPE 只作用于 Q 和 K,不作用于 V
  3. KV Cache 中存储的 K 是已经应用过 RoPE 的
  4. sin/cos 使用 register_buffer 存储,不参与梯度计算

思考延伸

完成填空后,尝试回答:

  1. 为什么 base=10000 是一个好的默认值?如果改成 100 或 1000000,会发生什么?
  2. RoPE 的频率从高到低覆盖了不同的"波长" -- 低维度的频率高,高维度的频率低。这与傅里叶变换有什么联系?
  3. 如果想让模型支持比训练时更长的序列(如训练时 4K,推理时 32K),可以如何修改频率?(提示:搜索 NTK-aware scaling / YaRN)
  4. RoPE 的本质是什么?它与传统绝对位置编码的根本区别在哪里?
  5. 为什么 RoPE 只作用于 Q 和 K 而不作用于 V?

MLM 代码训练模式

完成上面的固定填空后,试试随机挖空模式 -- 每次点击「刷新」会随机遮盖不同的代码片段,帮你彻底记住每一行。

RoPE 频率预计算

频率向量与 cos/sin 预计算
共 56 个可挖空位 | 已挖 0 个
def precompute_rope_frequencies(dim, max_seq_len, base=10000.0):
    i = torch.arange(0, dim // 2, dtype=torch.float32)
    theta = base ** (-2 * i / dim)

    positions = torch.arange(max_seq_len, dtype=torch.float32)
    m_theta = torch.outer(positions, theta)

    cos_cached = torch.zeros(max_seq_len, dim)
    sin_cached = torch.zeros(max_seq_len, dim)
    cos_cached[:, 0::2] = torch.cos(m_theta)
    cos_cached[:, 1::2] = torch.cos(m_theta)
    sin_cached[:, 0::2] = torch.sin(m_theta)
    sin_cached[:, 1::2] = torch.sin(m_theta)
    return cos_cached, sin_cached

RoPE 旋转应用

apply_rope 旋转交换与广播
共 29 个可挖空位 | 已挖 0 个
def apply_rope(x, cos, sin):
    bs, n_heads, seq_len, d = x.shape
    cos = cos[:seq_len]
    sin = sin[:seq_len]

    X_shift = torch.zeros_like(x)
    X_shift[..., 0::2] = -x[..., 1::2]
    X_shift[..., 1::2] = x[..., 0::2]

    Y = cos[None, None, :seq_len, :] * x + sin[None, None, :seq_len, :] * X_shift
    return Y