Llama 实现挑战 (Level 3-4)
从零构建一个 mini-Llama 模型。先通过热身练习掌握各子模块,再在主挑战中将它们组装为完整的 Decoder 模型。
热身练习
练习 1: RMSNorm 实现(Level 2)
背景
RMSNorm 是 LayerNorm 的简化版,去掉了 re-center(减均值)操作,仅使用 RMS(Root Mean Square)统计量进行缩放。其公式为:
RMSNorm 有输入尺度不变性:
任务
python
import torch
import torch.nn as nn
class RMSNorm(nn.Module):
def __init__(self, dim, eps=1e-12):
super().__init__()
# ===== 填空 1: 定义可学习参数 gamma =====
self.gamma = _____ # 提示: nn.Parameter, 初始化为全 1
self.eps = eps
def forward(self, x):
# x: [bsz, seq_len, dim]
# ===== 填空 2: 计算 x^2 在最后一维的均值 =====
mean_sq = _____ # 提示: (x**2).mean(...),注意 keepdim
# ===== 填空 3: 计算 RMS 归一化 =====
x_normed = _____ # 提示: x / sqrt(mean_sq + eps)
# ===== 填空 4: 乘以可学习参数 =====
return _____提示
gamma是dim维的向量,初始化为torch.ones(dim)- 计算均值时
dim=-1, keepdim=True,保持维度以便广播 - eps 防止除零
参考答案
python
class RMSNorm(nn.Module):
def __init__(self, dim, eps=1e-12):
super().__init__()
self.gamma = nn.Parameter(torch.ones(dim))
self.eps = eps
def forward(self, x):
mean_sq = (x ** 2).mean(-1, keepdim=True)
x_normed = x / torch.sqrt(mean_sq + self.eps)
return self.gamma * x_normed验证:
python
norm = RMSNorm(dim=512)
x = torch.randn(2, 6, 512)
y = norm(x)
print(y.shape) # torch.Size([2, 6, 512])
# 验证 RMS 归一化后统计量为 1
rms = torch.sqrt((y**2).mean(-1))
print(rms.mean()) # 接近 1.0练习 2: SwiGLU 实现(Level 2-3)
背景
SwiGLU 将 Swish 激活函数与 GLU(门控线性单元)结合:
其中 Swish 即 F.silu(即 x * sigmoid(x)),
为保持与标准 FFN 相近的参数量,hidden_dim 设为 8d/3(标准 FFN 用 4d)。
任务
python
import torch
import torch.nn as nn
import torch.nn.functional as F
class SwiGLU(nn.Module):
def __init__(self, dim):
super().__init__()
# ===== 填空 1: 计算 hidden_dim (保持与 FFN 同等参数量) =====
hidden_dim = _____ # 提示: 8 * dim // 3
# ===== 填空 2: 定义三个线性层 (无 bias) =====
self.w1 = _____ # dim -> hidden_dim
self.w_gate = _____ # dim -> hidden_dim (门控)
self.w2 = _____ # hidden_dim -> dim
def forward(self, x):
# ===== 填空 3: 门控特征 =====
gate = _____ # 提示: 对 w_gate(x) 做 silu 激活
# ===== 填空 4: 上投影特征 =====
x_up = _____ # 提示: w1(x)
# ===== 填空 5: GLU 门控乘法 + 下投影 =====
h = _____ # 提示: gate * x_up
return _____ # 提示: w2(h)提示
F.silu(x)等价于x * torch.sigmoid(x),即 Swish 函数- 三个 Linear 层均不带 bias(
bias=False)
参考答案
python
class SwiGLU(nn.Module):
def __init__(self, dim):
super().__init__()
hidden_dim = dim * 8 // 3
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
self.w_gate = nn.Linear(dim, hidden_dim, bias=False)
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
def forward(self, x):
gate = F.silu(self.w_gate(x))
x_up = self.w1(x)
h = gate * x_up
return self.w2(h)验证:
python
ffn = SwiGLU(dim=512)
x = torch.randn(2, 3, 512)
print(ffn(x).shape) # torch.Size([2, 3, 512])
# 对比参数量: FFN(4d) vs SwiGLU(8d/3)
ffn_params = 2 * 512 * 2048 # w1 + w2
swiglu_params = 3 * 512 * 1365 # w1 + w_gate + w2
print(f"FFN: {ffn_params}, SwiGLU: {swiglu_params}")
# 两者参数量接近练习 3: GQA 的 KV Repeat(Level 2-3)
背景
GQA(Grouped Query Attention)让多个 Q head 共享同一组 KV head,减少 KV Cache 存储。例如 8 个 Q head 共享 4 个 KV head,每 2 个 Q head 共用 1 个 KV head。
实现时需要将 KV 沿 head 维度复制(repeat),使其与 Q 的 head 数匹配。
任务
python
import torch
def repeat_kv(
k: torch.Tensor, # [bsz, n_kv_heads, seq_len, head_dim]
v: torch.Tensor, # [bsz, n_kv_heads, seq_len, head_dim]
n_rep: int, # 每个 kv head 需要复制的次数
):
"""
将 KV 的 head 维度复制 n_rep 次
例: [bsz, 4, seq_len, head_dim] -> [bsz, 8, seq_len, head_dim] (n_rep=2)
"""
if n_rep == 1:
return k, v
# ===== 填空 1: 使用 repeat_interleave 复制 =====
k = _____ # 提示: torch.repeat_interleave(k, n_rep, dim=?)
v = _____ # 提示: 同上,注意 dim 参数
return k, v
# 验证
bsz, n_kv_heads, seq_len, head_dim = 2, 4, 16, 64
k = torch.randn(bsz, n_kv_heads, seq_len, head_dim)
v = torch.randn(bsz, n_kv_heads, seq_len, head_dim)
k_rep, v_rep = repeat_kv(k, v, n_rep=2)
print(k_rep.shape) # 应为 torch.Size([2, 8, 16, 64])
# ===== 填空 2: 验证复制正确性 =====
# 第 0 个 kv head 应该等于第 0 和第 1 个 q head 的 kv
assert torch.equal(k_rep[:, 0], _____), "复制不正确" # 提示: k 的哪个 head?
assert torch.equal(k_rep[:, 1], _____), "复制不正确" # 提示: 应与 head 0 相同参考答案
python
# 填空 1
k = torch.repeat_interleave(k, n_rep, dim=1)
v = torch.repeat_interleave(v, n_rep, dim=1)
# 填空 2
assert torch.equal(k_rep[:, 0], k[:, 0])
assert torch.equal(k_rep[:, 1], k[:, 0]) # head 1 也来自原始 kv head 0主挑战: 构建完整 Llama Decoder(Level 4)
目标
从零实现一个完整的 mini-Llama 模型,包含以下组件:
- RMSNorm -- Pre-Norm 结构
- RoPE -- 旋转位置编码
- GQA -- 分组查询注意力
- SwiGLU -- 门控前馈网络
- LlamaDecoderBlock -- 单个 Decoder 层
- LlamaModel -- 完整模型(Embedding + N 层 Block + LM Head)
配置
python
from dataclasses import dataclass
@dataclass
class LlamaConfig:
vocab_size: int = 200
max_len: int = 512
dim: int = 512
n_heads: int = 8
n_kv_heads: int = 4 # GQA: 每 2 个 Q head 共享 1 个 KV head
num_layers: int = 6
position_encoding_base: float = 10000.0
attention_bias: bool = False # Linear 层不带 bias要求
Part 1: RoPE(旋转位置编码)
实现两个函数:
python
def create_rope(max_len, dim, base=10000.0):
"""
创建 RoPE 的 sin/cos 缓存
返回: (sin, cos),形状均为 [max_len, dim]
关键步骤:
1. m = [0, 1, ..., max_len-1]
2. theta_i = base^(-2i/dim), i = [0, 1, ..., dim/2-1]
3. m_theta = outer(m, theta)
4. sin[:, 0::2] = sin[:, 1::2] = sin(m_theta)
cos[:, 0::2] = cos[:, 1::2] = cos(m_theta)
"""
pass
def apply_rope(X, sin, cos):
"""
对 X 施加 RoPE
X: [bsz, n_heads, seq_len, head_dim]
关键步骤:
1. 构造 X_shift: 偶数位取负奇数位,奇数位取偶数位
X_shift[..., 0::2] = -X[..., 1::2]
X_shift[..., 1::2] = X[..., 0::2]
2. Y = cos * X + sin * X_shift
"""
passPart 2: GroupQueryAttention
python
class GroupQueryAttention(nn.Module):
"""
分组查询注意力
__init__ 参数:
- wq: Linear(dim, dim) -- Q 投影
- wk: Linear(dim, head_dim * n_kv_heads) -- K 投影 (KV head 数少于 Q)
- wv: Linear(dim, head_dim * n_kv_heads) -- V 投影
- wo: Linear(dim, dim) -- 输出投影
- 所有 Linear 不带 bias
forward(x, mask, sin, cos):
1. Q, K, V 投影并 reshape 为 multi-head 形式
2. K, V 用 repeat_interleave 复制到与 Q 相同 head 数
3. 对 Q, K 施加 RoPE
4. Scaled Dot-Product Attention + causal mask
5. 拼接 + 输出投影
"""
passPart 3: LlamaDecoderBlock
python
class LlamaDecoderBlock(nn.Module):
"""
Pre-Norm 残差结构:
X = GQA(RMSNorm(X)) + X
X = SwiGLU(RMSNorm(X)) + X
"""
passPart 4: LlamaModel
python
class LlamaModel(nn.Module):
"""
完整模型:
1. Embedding (无位置编码,由 RoPE 在 attention 中引入)
2. N 个 LlamaDecoderBlock
3. 最终 RMSNorm
4. LM Head (Linear, 无 bias)
5. 缓存 causal mask 和 RoPE sin/cos
"""
pass评估标准
| 项目 | 标准 | 分值 |
|---|---|---|
| RMSNorm | 正确实现,仅有 gamma 参数 | 10 |
| RoPE | create_rope 和 apply_rope 正确 | 15 |
| GQA | KV repeat + RoPE + Scaled Attention | 25 |
| SwiGLU | Swish 门控 + 参数量对齐 | 10 |
| LlamaDecoderBlock | Pre-Norm 残差结构正确 | 15 |
| LlamaModel | 整体组装、mask、forward 正确 | 15 |
| 代码规范 | 无 bias、head_dim 计算正确 | 10 |
| 总分 | 100 |
测试用例
python
config = LlamaConfig()
model = LlamaModel(config)
# 测试 1: 基本 forward
input_ids = torch.randint(config.vocab_size, (2, 32))
logits = model(input_ids)
assert logits.shape == (2, 32, config.vocab_size), \
f"输出形状错误: {logits.shape}"
# 测试 2: 不同序列长度
for seq_len in [1, 16, 64, 128]:
x = torch.randint(config.vocab_size, (1, seq_len))
y = model(x)
assert y.shape == (1, seq_len, config.vocab_size)
# 测试 3: 参数检查 (无 bias)
for name, param in model.named_parameters():
if 'bias' in name:
print(f"警告: 发现 bias 参数 {name}")
# 测试 4: RoPE 正确性
sin, cos = create_rope(512, 64)
assert sin.shape == (512, 64)
x = torch.randn(1, 8, 10, 64)
x_rope = apply_rope(x, sin, cos)
assert x_rope.shape == x.shape
# 测试 5: 可训练
loss = logits.mean()
loss.backward()
print("反向传播成功")
print("所有测试通过!")参考实现
点击展开完整参考实现
python
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
# ============ Config ============
from dataclasses import dataclass
@dataclass
class LlamaConfig:
vocab_size: int = 200
max_len: int = 512
dim: int = 512
n_heads: int = 8
n_kv_heads: int = 4
num_layers: int = 6
position_encoding_base: float = 10000.0
attention_bias: bool = False
# ============ RMSNorm ============
class RMSNorm(nn.Module):
def __init__(self, dim, eps=1e-12):
super().__init__()
self.gamma = nn.Parameter(torch.ones(dim))
self.eps = eps
def forward(self, x):
mean = (x ** 2).mean(-1, keepdim=True)
out_mean = x / torch.sqrt(mean + self.eps)
out = self.gamma * out_mean
return out
# ============ SwiGLU ============
class SwiGLU(nn.Module):
def __init__(self, dim):
super().__init__()
hidden_dim = dim * 8 // 3
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
self.gate_proj = nn.Linear(dim, hidden_dim, bias=False)
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
def forward(self, x):
h = F.silu(self.gate_proj(x)) * self.w1(x)
return self.w2(h)
# ============ RoPE ============
def create_rope(max_len=1024, dim=512, base=10000.0):
m = torch.arange(0, max_len, 1)
i = torch.arange(0, dim // 2, 1)
theta = base ** (-2 * i / dim)
m_theta = torch.outer(m, theta)
cos = torch.zeros(max_len, dim)
sin = torch.zeros(max_len, dim)
cos[:, 0::2] = cos[:, 1::2] = torch.cos(m_theta)
sin[:, 0::2] = sin[:, 1::2] = torch.sin(m_theta)
return sin, cos
def apply_rope(X, sin, cos):
bs, n_heads, seq_len, d = X.shape
X_shift = torch.zeros_like(X)
X_shift[..., 0::2] = -X[..., 1::2]
X_shift[..., 1::2] = X[..., 0::2]
Y = cos[None, None, :seq_len, :] * X + \
sin[None, None, :seq_len, :] * X_shift
return Y
# ============ GQA ============
class GroupQueryAttention(nn.Module):
def __init__(self, config):
super().__init__()
self.dim = config.dim
self.n_heads = config.n_heads
self.n_kv_heads = config.n_kv_heads
self.head_dim = self.dim // self.n_heads
self.repeats = self.n_heads // self.n_kv_heads
self.wq = nn.Linear(self.dim, self.dim,
bias=config.attention_bias)
self.wk = nn.Linear(self.dim, self.head_dim * self.n_kv_heads,
bias=config.attention_bias)
self.wv = nn.Linear(self.dim, self.head_dim * self.n_kv_heads,
bias=config.attention_bias)
self.wo = nn.Linear(self.dim, self.dim,
bias=config.attention_bias)
def forward(self, x, mask=None, sin=None, cos=None):
bsz, seq_len, dim = x.shape
q, k, v = self.wq(x), self.wk(x), self.wv(x)
# reshape to multi-head
q = q.reshape(bsz, seq_len, self.n_heads,
self.head_dim).transpose(1, 2)
k = k.reshape(bsz, seq_len, self.n_kv_heads,
self.head_dim).transpose(1, 2)
v = v.reshape(bsz, seq_len, self.n_kv_heads,
self.head_dim).transpose(1, 2)
# GQA: repeat KV heads
k = torch.repeat_interleave(k, self.repeats, dim=1)
v = torch.repeat_interleave(v, self.repeats, dim=1)
# apply RoPE
q = apply_rope(q, sin, cos)
k = apply_rope(k, sin, cos)
# scaled dot-product attention
s = q @ k.transpose(3, 2) / math.sqrt(self.dim)
if mask is not None:
s = s + mask.unsqueeze(0).unsqueeze(0)
p = F.softmax(s, dim=-1)
z = p @ v
# concat heads
z = z.transpose(1, 2).reshape(bsz, seq_len, self.dim)
return self.wo(z)
# ============ Decoder Block ============
class LlamaDecoderBlock(nn.Module):
def __init__(self, config):
super().__init__()
self.dim = config.dim
self.norm1 = RMSNorm(dim=self.dim)
self.attn = GroupQueryAttention(config)
self.norm2 = RMSNorm(dim=self.dim)
self.ffn = SwiGLU(dim=self.dim)
def forward(self, X, mask=None, sin=None, cos=None):
# Pre-Norm + Residual
X_norm = self.norm1(X)
X = self.attn(X_norm, mask, sin, cos) + X
X_norm = self.norm2(X)
h = self.ffn(X_norm) + X
return h
# ============ Llama Model ============
class LlamaModel(nn.Module):
def __init__(self, config=None):
super().__init__()
self.config = config
self.embd = nn.Embedding(config.vocab_size, config.dim)
self.decoder = nn.ModuleList(
[LlamaDecoderBlock(config)
for _ in range(config.num_layers)]
)
self.ln = RMSNorm(config.dim)
self.lm_head = nn.Linear(config.dim, config.vocab_size,
bias=False)
# 缓存 mask 和 RoPE
self.cache_mask = torch.tril(
torch.ones(config.max_len, config.max_len))
self.rope_sin, self.rope_cos = create_rope(
config.max_len,
config.dim // config.n_heads)
def forward(self, x):
bs, seq_len = x.shape
# causal mask: 将 0 位置设为 -inf
add_mask = self.cache_mask[:seq_len, :seq_len]
X = self.embd(x)
for block in self.decoder:
X = block(X, mask=add_mask,
sin=self.rope_sin,
cos=self.rope_cos)
X = self.ln(X)
logits = self.lm_head(X)
return logits运行验证:
python
config = LlamaConfig()
model = LlamaModel(config)
input_ids = torch.randint(config.vocab_size, (2, 32))
y = model(input_ids)
print(y.shape) # torch.Size([2, 32, 200])
# 参数量统计
total = sum(p.numel() for p in model.parameters())
print(f"总参数量: {total:,}") # 约 25M进阶思考
完成主挑战后,可以思考以下问题:
- KV Cache 集成: 如何在 GQA 中加入 KV Cache 支持,使其可用于自回归推理?需要修改哪些部分?
- RoPE 与 KV Cache 的顺序: 为什么 RoPE 要在 KV Cache 拼接之前施加?如果在拼接之后施加会怎样?
- 参数量分析: 计算模型中 Attention 和 FFN 各占多少参数,理解为什么 MoE 选择扩展 FFN 而非 Attention。
- Pre-Norm vs Post-Norm: Llama 使用 Pre-Norm(先 Norm 再 Attention/FFN),GPT-2 使用 Post-Norm。分析两者在梯度传播上的差异。
- Scaling 因子: 注意力中除以
的作用是什么?如果改为除以 结果会有什么不同?
MLM 代码训练模式
完成上面的固定填空后,试试随机挖空模式 -- 每次点击「刷新」会随机遮盖不同的代码片段,帮你彻底记住每一行。
RMSNorm 实现
共 39 个可挖空位 | 已挖 0 个
class RMSNorm(nn.Module):
def __init__(self, dim, eps=1e-12):
super().__init__()
self.gamma = nn.Parameter(torch.ones(dim))
self.eps = eps
def forward(self, x):
mean_sq = (x ** 2).mean(-1, keepdim=True)
x_normed = x / torch.sqrt(mean_sq + self.eps)
return self.gamma * x_normedSwiGLU 门控前馈网络
共 47 个可挖空位 | 已挖 0 个
class SwiGLU(nn.Module):
def __init__(self, dim):
super().__init__()
hidden_dim = dim * 8 // 3
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
self.gate_proj = nn.Linear(dim, hidden_dim, bias=False)
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
def forward(self, x):
h = F.silu(self.gate_proj(x)) * self.w1(x)
return self.w2(h)GQA 分组查询注意力(含 RoPE)
共 83 个可挖空位 | 已挖 0 个
def forward(self, x, mask=None, sin=None, cos=None):
bsz, seq_len, dim = x.shape
q, k, v = self.wq(x), self.wk(x), self.wv(x)
q = q.reshape(bsz, seq_len, self.n_heads, self.head_dim).transpose(1, 2)
k = k.reshape(bsz, seq_len, self.n_kv_heads, self.head_dim).transpose(1, 2)
v = v.reshape(bsz, seq_len, self.n_kv_heads, self.head_dim).transpose(1, 2)
k = torch.repeat_interleave(k, self.repeats, dim=1)
v = torch.repeat_interleave(v, self.repeats, dim=1)
q = apply_rope(q, sin, cos)
k = apply_rope(k, sin, cos)
s = q @ k.transpose(3, 2) / math.sqrt(self.dim)
if mask is not None:
s = s + mask.unsqueeze(0).unsqueeze(0)
p = F.softmax(s, dim=-1)
z = p @ v
z = z.transpose(1, 2).reshape(bsz, seq_len, self.dim)
return self.wo(z)Llama 模型组装与前向
共 99 个可挖空位 | 已挖 0 个
class LlamaModel(nn.Module):
def __init__(self, config):
super().__init__()
self.embd = nn.Embedding(config.vocab_size, config.dim)
self.decoder = nn.ModuleList(
[LlamaDecoderBlock(config) for _ in range(config.num_layers)]
)
self.ln = RMSNorm(config.dim)
self.lm_head = nn.Linear(config.dim, config.vocab_size, bias=False)
self.cache_mask = torch.tril(torch.ones(config.max_len, config.max_len))
self.rope_sin, self.rope_cos = create_rope(
config.max_len, config.dim // config.n_heads)
def forward(self, x):
bs, seq_len = x.shape
add_mask = self.cache_mask[:seq_len, :seq_len]
X = self.embd(x)
for block in self.decoder:
X = block(X, mask=add_mask, sin=self.rope_sin, cos=self.rope_cos)
X = self.ln(X)
logits = self.lm_head(X)
return logits