Skip to content

Llama 实现挑战 (Level 3-4)

从零构建一个 mini-Llama 模型。先通过热身练习掌握各子模块,再在主挑战中将它们组装为完整的 Decoder 模型。


热身练习

练习 1: RMSNorm 实现(Level 2)

背景

RMSNorm 是 LayerNorm 的简化版,去掉了 re-center(减均值)操作,仅使用 RMS(Root Mean Square)统计量进行缩放。其公式为:

RMS(x)=1Nixi2RMSNorm(x)=γxRMS(x)

RMSNorm 有输入尺度不变性:xRMS(x)=sxRMS(sx),且计算比 LayerNorm 更简单(不需要计算均值和方差)。

任务

python
import torch
import torch.nn as nn

class RMSNorm(nn.Module):
    def __init__(self, dim, eps=1e-12):
        super().__init__()
        # ===== 填空 1: 定义可学习参数 gamma =====
        self.gamma = _____  # 提示: nn.Parameter, 初始化为全 1
        self.eps = eps

    def forward(self, x):
        # x: [bsz, seq_len, dim]

        # ===== 填空 2: 计算 x^2 在最后一维的均值 =====
        mean_sq = _____  # 提示: (x**2).mean(...),注意 keepdim

        # ===== 填空 3: 计算 RMS 归一化 =====
        x_normed = _____  # 提示: x / sqrt(mean_sq + eps)

        # ===== 填空 4: 乘以可学习参数 =====
        return _____

提示

  • gammadim 维的向量,初始化为 torch.ones(dim)
  • 计算均值时 dim=-1, keepdim=True,保持维度以便广播
  • eps 防止除零
参考答案
python
class RMSNorm(nn.Module):
    def __init__(self, dim, eps=1e-12):
        super().__init__()
        self.gamma = nn.Parameter(torch.ones(dim))
        self.eps = eps

    def forward(self, x):
        mean_sq = (x ** 2).mean(-1, keepdim=True)
        x_normed = x / torch.sqrt(mean_sq + self.eps)
        return self.gamma * x_normed

验证:

python
norm = RMSNorm(dim=512)
x = torch.randn(2, 6, 512)
y = norm(x)
print(y.shape)       # torch.Size([2, 6, 512])
# 验证 RMS 归一化后统计量为 1
rms = torch.sqrt((y**2).mean(-1))
print(rms.mean())    # 接近 1.0

练习 2: SwiGLU 实现(Level 2-3)

背景

SwiGLU 将 Swish 激活函数与 GLU(门控线性单元)结合:

h=Swish(Wgate(x))W1(x)y=W2(h)

其中 Swish 即 F.silu(即 x * sigmoid(x)), 为逐元素乘法。GLU 门控机制让模型对特征做 token-wise 的细粒度选择。

为保持与标准 FFN 相近的参数量,hidden_dim 设为 8d/3(标准 FFN 用 4d)。

任务

python
import torch
import torch.nn as nn
import torch.nn.functional as F

class SwiGLU(nn.Module):
    def __init__(self, dim):
        super().__init__()
        # ===== 填空 1: 计算 hidden_dim (保持与 FFN 同等参数量) =====
        hidden_dim = _____  # 提示: 8 * dim // 3

        # ===== 填空 2: 定义三个线性层 (无 bias) =====
        self.w1 = _____      # dim -> hidden_dim
        self.w_gate = _____  # dim -> hidden_dim (门控)
        self.w2 = _____      # hidden_dim -> dim

    def forward(self, x):
        # ===== 填空 3: 门控特征 =====
        gate = _____  # 提示: 对 w_gate(x) 做 silu 激活

        # ===== 填空 4: 上投影特征 =====
        x_up = _____  # 提示: w1(x)

        # ===== 填空 5: GLU 门控乘法 + 下投影 =====
        h = _____  # 提示: gate * x_up
        return _____  # 提示: w2(h)

提示

  • F.silu(x) 等价于 x * torch.sigmoid(x),即 Swish 函数
  • 三个 Linear 层均不带 bias(bias=False
参考答案
python
class SwiGLU(nn.Module):
    def __init__(self, dim):
        super().__init__()
        hidden_dim = dim * 8 // 3
        self.w1 = nn.Linear(dim, hidden_dim, bias=False)
        self.w_gate = nn.Linear(dim, hidden_dim, bias=False)
        self.w2 = nn.Linear(hidden_dim, dim, bias=False)

    def forward(self, x):
        gate = F.silu(self.w_gate(x))
        x_up = self.w1(x)
        h = gate * x_up
        return self.w2(h)

验证:

python
ffn = SwiGLU(dim=512)
x = torch.randn(2, 3, 512)
print(ffn(x).shape)  # torch.Size([2, 3, 512])

# 对比参数量: FFN(4d) vs SwiGLU(8d/3)
ffn_params = 2 * 512 * 2048  # w1 + w2
swiglu_params = 3 * 512 * 1365  # w1 + w_gate + w2
print(f"FFN: {ffn_params}, SwiGLU: {swiglu_params}")
# 两者参数量接近

练习 3: GQA 的 KV Repeat(Level 2-3)

背景

GQA(Grouped Query Attention)让多个 Q head 共享同一组 KV head,减少 KV Cache 存储。例如 8 个 Q head 共享 4 个 KV head,每 2 个 Q head 共用 1 个 KV head。

实现时需要将 KV 沿 head 维度复制(repeat),使其与 Q 的 head 数匹配。

任务

python
import torch

def repeat_kv(
    k: torch.Tensor,   # [bsz, n_kv_heads, seq_len, head_dim]
    v: torch.Tensor,    # [bsz, n_kv_heads, seq_len, head_dim]
    n_rep: int,         # 每个 kv head 需要复制的次数
):
    """
    将 KV 的 head 维度复制 n_rep 次
    例: [bsz, 4, seq_len, head_dim] -> [bsz, 8, seq_len, head_dim] (n_rep=2)
    """
    if n_rep == 1:
        return k, v

    # ===== 填空 1: 使用 repeat_interleave 复制 =====
    k = _____  # 提示: torch.repeat_interleave(k, n_rep, dim=?)
    v = _____  # 提示: 同上,注意 dim 参数

    return k, v


# 验证
bsz, n_kv_heads, seq_len, head_dim = 2, 4, 16, 64
k = torch.randn(bsz, n_kv_heads, seq_len, head_dim)
v = torch.randn(bsz, n_kv_heads, seq_len, head_dim)

k_rep, v_rep = repeat_kv(k, v, n_rep=2)
print(k_rep.shape)  # 应为 torch.Size([2, 8, 16, 64])

# ===== 填空 2: 验证复制正确性 =====
# 第 0 个 kv head 应该等于第 0 和第 1 个 q head 的 kv
assert torch.equal(k_rep[:, 0], _____), "复制不正确"  # 提示: k 的哪个 head?
assert torch.equal(k_rep[:, 1], _____), "复制不正确"  # 提示: 应与 head 0 相同
参考答案
python
# 填空 1
k = torch.repeat_interleave(k, n_rep, dim=1)
v = torch.repeat_interleave(v, n_rep, dim=1)

# 填空 2
assert torch.equal(k_rep[:, 0], k[:, 0])
assert torch.equal(k_rep[:, 1], k[:, 0])  # head 1 也来自原始 kv head 0

主挑战: 构建完整 Llama Decoder(Level 4)

目标

从零实现一个完整的 mini-Llama 模型,包含以下组件:

  1. RMSNorm -- Pre-Norm 结构
  2. RoPE -- 旋转位置编码
  3. GQA -- 分组查询注意力
  4. SwiGLU -- 门控前馈网络
  5. LlamaDecoderBlock -- 单个 Decoder 层
  6. LlamaModel -- 完整模型(Embedding + N 层 Block + LM Head)

配置

python
from dataclasses import dataclass

@dataclass
class LlamaConfig:
    vocab_size: int = 200
    max_len: int = 512
    dim: int = 512
    n_heads: int = 8
    n_kv_heads: int = 4     # GQA: 每 2 个 Q head 共享 1 个 KV head
    num_layers: int = 6
    position_encoding_base: float = 10000.0
    attention_bias: bool = False  # Linear 层不带 bias

要求

Part 1: RoPE(旋转位置编码)

实现两个函数:

python
def create_rope(max_len, dim, base=10000.0):
    """
    创建 RoPE 的 sin/cos 缓存
    返回: (sin, cos),形状均为 [max_len, dim]

    关键步骤:
    1. m = [0, 1, ..., max_len-1]
    2. theta_i = base^(-2i/dim), i = [0, 1, ..., dim/2-1]
    3. m_theta = outer(m, theta)
    4. sin[:, 0::2] = sin[:, 1::2] = sin(m_theta)
       cos[:, 0::2] = cos[:, 1::2] = cos(m_theta)
    """
    pass

def apply_rope(X, sin, cos):
    """
    对 X 施加 RoPE
    X: [bsz, n_heads, seq_len, head_dim]

    关键步骤:
    1. 构造 X_shift: 偶数位取负奇数位,奇数位取偶数位
       X_shift[..., 0::2] = -X[..., 1::2]
       X_shift[..., 1::2] = X[..., 0::2]
    2. Y = cos * X + sin * X_shift
    """
    pass

Part 2: GroupQueryAttention

python
class GroupQueryAttention(nn.Module):
    """
    分组查询注意力

    __init__ 参数:
    - wq: Linear(dim, dim)                       -- Q 投影
    - wk: Linear(dim, head_dim * n_kv_heads)     -- K 投影 (KV head 数少于 Q)
    - wv: Linear(dim, head_dim * n_kv_heads)     -- V 投影
    - wo: Linear(dim, dim)                        -- 输出投影
    - 所有 Linear 不带 bias

    forward(x, mask, sin, cos):
    1. Q, K, V 投影并 reshape 为 multi-head 形式
    2. K, V 用 repeat_interleave 复制到与 Q 相同 head 数
    3. 对 Q, K 施加 RoPE
    4. Scaled Dot-Product Attention + causal mask
    5. 拼接 + 输出投影
    """
    pass

Part 3: LlamaDecoderBlock

python
class LlamaDecoderBlock(nn.Module):
    """
    Pre-Norm 残差结构:
      X = GQA(RMSNorm(X)) + X
      X = SwiGLU(RMSNorm(X)) + X
    """
    pass

Part 4: LlamaModel

python
class LlamaModel(nn.Module):
    """
    完整模型:
    1. Embedding (无位置编码,由 RoPE 在 attention 中引入)
    2. N 个 LlamaDecoderBlock
    3. 最终 RMSNorm
    4. LM Head (Linear, 无 bias)
    5. 缓存 causal mask 和 RoPE sin/cos
    """
    pass

评估标准

项目标准分值
RMSNorm正确实现,仅有 gamma 参数10
RoPEcreate_rope 和 apply_rope 正确15
GQAKV repeat + RoPE + Scaled Attention25
SwiGLUSwish 门控 + 参数量对齐10
LlamaDecoderBlockPre-Norm 残差结构正确15
LlamaModel整体组装、mask、forward 正确15
代码规范无 bias、head_dim 计算正确10
总分100

测试用例

python
config = LlamaConfig()
model = LlamaModel(config)

# 测试 1: 基本 forward
input_ids = torch.randint(config.vocab_size, (2, 32))
logits = model(input_ids)
assert logits.shape == (2, 32, config.vocab_size), \
    f"输出形状错误: {logits.shape}"

# 测试 2: 不同序列长度
for seq_len in [1, 16, 64, 128]:
    x = torch.randint(config.vocab_size, (1, seq_len))
    y = model(x)
    assert y.shape == (1, seq_len, config.vocab_size)

# 测试 3: 参数检查 (无 bias)
for name, param in model.named_parameters():
    if 'bias' in name:
        print(f"警告: 发现 bias 参数 {name}")

# 测试 4: RoPE 正确性
sin, cos = create_rope(512, 64)
assert sin.shape == (512, 64)
x = torch.randn(1, 8, 10, 64)
x_rope = apply_rope(x, sin, cos)
assert x_rope.shape == x.shape

# 测试 5: 可训练
loss = logits.mean()
loss.backward()
print("反向传播成功")

print("所有测试通过!")

参考实现

点击展开完整参考实现
python
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

# ============ Config ============

from dataclasses import dataclass

@dataclass
class LlamaConfig:
    vocab_size: int = 200
    max_len: int = 512
    dim: int = 512
    n_heads: int = 8
    n_kv_heads: int = 4
    num_layers: int = 6
    position_encoding_base: float = 10000.0
    attention_bias: bool = False

# ============ RMSNorm ============

class RMSNorm(nn.Module):
    def __init__(self, dim, eps=1e-12):
        super().__init__()
        self.gamma = nn.Parameter(torch.ones(dim))
        self.eps = eps

    def forward(self, x):
        mean = (x ** 2).mean(-1, keepdim=True)
        out_mean = x / torch.sqrt(mean + self.eps)
        out = self.gamma * out_mean
        return out

# ============ SwiGLU ============

class SwiGLU(nn.Module):
    def __init__(self, dim):
        super().__init__()
        hidden_dim = dim * 8 // 3
        self.w1 = nn.Linear(dim, hidden_dim, bias=False)
        self.gate_proj = nn.Linear(dim, hidden_dim, bias=False)
        self.w2 = nn.Linear(hidden_dim, dim, bias=False)

    def forward(self, x):
        h = F.silu(self.gate_proj(x)) * self.w1(x)
        return self.w2(h)

# ============ RoPE ============

def create_rope(max_len=1024, dim=512, base=10000.0):
    m = torch.arange(0, max_len, 1)
    i = torch.arange(0, dim // 2, 1)
    theta = base ** (-2 * i / dim)
    m_theta = torch.outer(m, theta)
    cos = torch.zeros(max_len, dim)
    sin = torch.zeros(max_len, dim)
    cos[:, 0::2] = cos[:, 1::2] = torch.cos(m_theta)
    sin[:, 0::2] = sin[:, 1::2] = torch.sin(m_theta)
    return sin, cos

def apply_rope(X, sin, cos):
    bs, n_heads, seq_len, d = X.shape
    X_shift = torch.zeros_like(X)
    X_shift[..., 0::2] = -X[..., 1::2]
    X_shift[..., 1::2] = X[..., 0::2]
    Y = cos[None, None, :seq_len, :] * X + \
        sin[None, None, :seq_len, :] * X_shift
    return Y

# ============ GQA ============

class GroupQueryAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.dim = config.dim
        self.n_heads = config.n_heads
        self.n_kv_heads = config.n_kv_heads
        self.head_dim = self.dim // self.n_heads
        self.repeats = self.n_heads // self.n_kv_heads
        self.wq = nn.Linear(self.dim, self.dim,
                            bias=config.attention_bias)
        self.wk = nn.Linear(self.dim, self.head_dim * self.n_kv_heads,
                            bias=config.attention_bias)
        self.wv = nn.Linear(self.dim, self.head_dim * self.n_kv_heads,
                            bias=config.attention_bias)
        self.wo = nn.Linear(self.dim, self.dim,
                            bias=config.attention_bias)

    def forward(self, x, mask=None, sin=None, cos=None):
        bsz, seq_len, dim = x.shape
        q, k, v = self.wq(x), self.wk(x), self.wv(x)

        # reshape to multi-head
        q = q.reshape(bsz, seq_len, self.n_heads,
                      self.head_dim).transpose(1, 2)
        k = k.reshape(bsz, seq_len, self.n_kv_heads,
                      self.head_dim).transpose(1, 2)
        v = v.reshape(bsz, seq_len, self.n_kv_heads,
                      self.head_dim).transpose(1, 2)

        # GQA: repeat KV heads
        k = torch.repeat_interleave(k, self.repeats, dim=1)
        v = torch.repeat_interleave(v, self.repeats, dim=1)

        # apply RoPE
        q = apply_rope(q, sin, cos)
        k = apply_rope(k, sin, cos)

        # scaled dot-product attention
        s = q @ k.transpose(3, 2) / math.sqrt(self.dim)
        if mask is not None:
            s = s + mask.unsqueeze(0).unsqueeze(0)
        p = F.softmax(s, dim=-1)
        z = p @ v

        # concat heads
        z = z.transpose(1, 2).reshape(bsz, seq_len, self.dim)
        return self.wo(z)

# ============ Decoder Block ============

class LlamaDecoderBlock(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.dim = config.dim
        self.norm1 = RMSNorm(dim=self.dim)
        self.attn = GroupQueryAttention(config)
        self.norm2 = RMSNorm(dim=self.dim)
        self.ffn = SwiGLU(dim=self.dim)

    def forward(self, X, mask=None, sin=None, cos=None):
        # Pre-Norm + Residual
        X_norm = self.norm1(X)
        X = self.attn(X_norm, mask, sin, cos) + X
        X_norm = self.norm2(X)
        h = self.ffn(X_norm) + X
        return h

# ============ Llama Model ============

class LlamaModel(nn.Module):
    def __init__(self, config=None):
        super().__init__()
        self.config = config
        self.embd = nn.Embedding(config.vocab_size, config.dim)
        self.decoder = nn.ModuleList(
            [LlamaDecoderBlock(config)
             for _ in range(config.num_layers)]
        )
        self.ln = RMSNorm(config.dim)
        self.lm_head = nn.Linear(config.dim, config.vocab_size,
                                 bias=False)

        # 缓存 mask 和 RoPE
        self.cache_mask = torch.tril(
            torch.ones(config.max_len, config.max_len))
        self.rope_sin, self.rope_cos = create_rope(
            config.max_len,
            config.dim // config.n_heads)

    def forward(self, x):
        bs, seq_len = x.shape

        # causal mask: 将 0 位置设为 -inf
        add_mask = self.cache_mask[:seq_len, :seq_len]

        X = self.embd(x)
        for block in self.decoder:
            X = block(X, mask=add_mask,
                      sin=self.rope_sin,
                      cos=self.rope_cos)
        X = self.ln(X)
        logits = self.lm_head(X)
        return logits

运行验证:

python
config = LlamaConfig()
model = LlamaModel(config)
input_ids = torch.randint(config.vocab_size, (2, 32))
y = model(input_ids)
print(y.shape)  # torch.Size([2, 32, 200])

# 参数量统计
total = sum(p.numel() for p in model.parameters())
print(f"总参数量: {total:,}")  # 约 25M

进阶思考

完成主挑战后,可以思考以下问题:

  1. KV Cache 集成: 如何在 GQA 中加入 KV Cache 支持,使其可用于自回归推理?需要修改哪些部分?
  2. RoPE 与 KV Cache 的顺序: 为什么 RoPE 要在 KV Cache 拼接之前施加?如果在拼接之后施加会怎样?
  3. 参数量分析: 计算模型中 Attention 和 FFN 各占多少参数,理解为什么 MoE 选择扩展 FFN 而非 Attention。
  4. Pre-Norm vs Post-Norm: Llama 使用 Pre-Norm(先 Norm 再 Attention/FFN),GPT-2 使用 Post-Norm。分析两者在梯度传播上的差异。
  5. Scaling 因子: 注意力中除以 d 的作用是什么?如果改为除以 dhead 结果会有什么不同?

MLM 代码训练模式

完成上面的固定填空后,试试随机挖空模式 -- 每次点击「刷新」会随机遮盖不同的代码片段,帮你彻底记住每一行。

RMSNorm 实现

RMSNorm 完整实现
共 39 个可挖空位 | 已挖 0 个
class RMSNorm(nn.Module):
    def __init__(self, dim, eps=1e-12):
        super().__init__()
        self.gamma = nn.Parameter(torch.ones(dim))
        self.eps = eps

    def forward(self, x):
        mean_sq = (x ** 2).mean(-1, keepdim=True)
        x_normed = x / torch.sqrt(mean_sq + self.eps)
        return self.gamma * x_normed

SwiGLU 门控前馈网络

SwiGLU 门控 FFN 实现
共 47 个可挖空位 | 已挖 0 个
class SwiGLU(nn.Module):
    def __init__(self, dim):
        super().__init__()
        hidden_dim = dim * 8 // 3
        self.w1 = nn.Linear(dim, hidden_dim, bias=False)
        self.gate_proj = nn.Linear(dim, hidden_dim, bias=False)
        self.w2 = nn.Linear(hidden_dim, dim, bias=False)

    def forward(self, x):
        h = F.silu(self.gate_proj(x)) * self.w1(x)
        return self.w2(h)

GQA 分组查询注意力(含 RoPE)

GroupQueryAttention 核心前向
共 83 个可挖空位 | 已挖 0 个
def forward(self, x, mask=None, sin=None, cos=None):
    bsz, seq_len, dim = x.shape
    q, k, v = self.wq(x), self.wk(x), self.wv(x)

    q = q.reshape(bsz, seq_len, self.n_heads, self.head_dim).transpose(1, 2)
    k = k.reshape(bsz, seq_len, self.n_kv_heads, self.head_dim).transpose(1, 2)
    v = v.reshape(bsz, seq_len, self.n_kv_heads, self.head_dim).transpose(1, 2)

    k = torch.repeat_interleave(k, self.repeats, dim=1)
    v = torch.repeat_interleave(v, self.repeats, dim=1)

    q = apply_rope(q, sin, cos)
    k = apply_rope(k, sin, cos)

    s = q @ k.transpose(3, 2) / math.sqrt(self.dim)
    if mask is not None:
        s = s + mask.unsqueeze(0).unsqueeze(0)
    p = F.softmax(s, dim=-1)
    z = p @ v

    z = z.transpose(1, 2).reshape(bsz, seq_len, self.dim)
    return self.wo(z)

Llama 模型组装与前向

LlamaModel 前向传播
共 99 个可挖空位 | 已挖 0 个
class LlamaModel(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.embd = nn.Embedding(config.vocab_size, config.dim)
        self.decoder = nn.ModuleList(
            [LlamaDecoderBlock(config) for _ in range(config.num_layers)]
        )
        self.ln = RMSNorm(config.dim)
        self.lm_head = nn.Linear(config.dim, config.vocab_size, bias=False)
        self.cache_mask = torch.tril(torch.ones(config.max_len, config.max_len))
        self.rope_sin, self.rope_cos = create_rope(
            config.max_len, config.dim // config.n_heads)

    def forward(self, x):
        bs, seq_len = x.shape
        add_mask = self.cache_mask[:seq_len, :seq_len]
        X = self.embd(x)
        for block in self.decoder:
            X = block(X, mask=add_mask, sin=self.rope_sin, cos=self.rope_cos)
        X = self.ln(X)
        logits = self.lm_head(X)
        return logits