Skip to content

GPT 实现挑战 (Level 4)

难度: 困难 | 前置知识: Transformer 全部内容、PyTorch 熟练使用 | 预计时间: 3-5 小时

挑战目标

从零实现一个可训练、可推理的 mini-GPT 模型。不使用 nn.TransformerDecoderLayer 等高层封装,所有核心组件手写。

完成后,你的模型应该能在一个小数据集(如莎士比亚全集 / 唐诗三百首)上训练,并生成连贯的文本。


热身练习

在挑战完整模型之前,先完成以下三个小练习,确保你掌握了核心组件。

热身 1:手写 GELU 激活函数

GELU 的数学定义为 GELU(x)=xΦ(x),其中 Φ(x) 是标准正态分布的 CDF。

GPT-2 使用的是 tanh 近似版本:

GELU(x)x12[1+tanh(2π(x+0.044715x3))]
python
import torch
import math

def gelu(x):
    """
    实现 GELU 激活函数(tanh 近似版本)
    提示: 使用 torch.tanh, torch.pow, math.sqrt, torch.pi
    """
    # TODO: 你的实现
    pass

# 测试
x = torch.tensor([-2.0, -1.0, -0.5, 0.0, 0.5, 1.0, 2.0])
result = gelu(x)
expected = torch.tensor([-0.0454, -0.1588, -0.1543, 0.0000, 0.3457, 0.8412, 1.9546])
assert torch.allclose(result, expected, atol=1e-4), \
    f"GELU 输出不正确!\n得到: {result}\n期望: {expected}"
print("GELU 测试通过!")

热身 2:构造因果掩码

因果掩码是 Decoder-Only 模型的核心——确保每个 token 只能看到自己和之前的 token。

python
import torch
import torch.nn.functional as F

def create_causal_mask(seq_len, neg_inf=-1e5):
    """
    构造 additive 因果掩码

    返回: (seq_len, seq_len) 的矩阵
          下三角(含对角线)= 0,上三角 = neg_inf

    提示: 使用 torch.tril 和 torch.ones
    """
    # TODO: 你的实现
    pass

# 测试
mask = create_causal_mask(4)
print(mask)
# 期望输出:
# tensor([[     0., -100000., -100000., -100000.],
#         [     0.,      0., -100000., -100000.],
#         [     0.,      0.,      0., -100000.],
#         [     0.,      0.,      0.,      0.]])

# 验证 softmax 后的效果
scores = torch.zeros(4, 4)  # 假设注意力分数全为 0
masked_scores = scores + mask
probs = F.softmax(masked_scores, dim=-1)
print(probs)
# 期望: 位置 0 只看自己(概率 1.0),位置 3 均匀看所有(各 0.25)
assert torch.allclose(probs[0], torch.tensor([1.0, 0.0, 0.0, 0.0]), atol=1e-4)
assert torch.allclose(probs[3], torch.tensor([0.25, 0.25, 0.25, 0.25]), atol=1e-4)
print("因果掩码测试通过!")

热身 3:实现 KV Cache 更新

KV Cache 的核心操作是:每生成一个新 token 时,将其 K、V 追加到缓存中。

python
import torch

def update_kv_cache(kv_cache, new_k, new_v):
    """
    更新 KV Cache

    参数:
        kv_cache: None(首次调用)或 [cached_k, cached_v]
                  cached_k/v shape: (batch, cached_len, dim)
        new_k: 新 token 的 Key (batch, new_len, dim)
        new_v: 新 token 的 Value (batch, new_len, dim)

    返回:
        updated_cache: [full_k, full_v]
        full_k: (batch, total_len, dim)
        full_v: (batch, total_len, dim)

    提示: 首次调用直接存储,后续用 torch.cat 拼接
    """
    # TODO: 你的实现
    pass

# 测试
batch, dim = 2, 4

# 第一步:prefill,输入 3 个 token
k1 = torch.randn(batch, 3, dim)
v1 = torch.randn(batch, 3, dim)
cache = update_kv_cache(None, k1, v1)
assert cache[0].shape == (2, 3, 4), f"Prefill 后 K shape 错误: {cache[0].shape}"

# 第二步:decode,输入 1 个新 token
k2 = torch.randn(batch, 1, dim)
v2 = torch.randn(batch, 1, dim)
cache = update_kv_cache(cache, k2, v2)
assert cache[0].shape == (2, 4, 4), f"Decode 后 K shape 错误: {cache[0].shape}"

# 第三步:再生成一个 token
k3 = torch.randn(batch, 1, dim)
v3 = torch.randn(batch, 1, dim)
cache = update_kv_cache(cache, k3, v3)
assert cache[0].shape == (2, 5, 4), f"第二次 Decode 后 K shape 错误: {cache[0].shape}"

# 验证拼接正确性
assert torch.equal(cache[0][:, :3, :], k1)
assert torch.equal(cache[0][:, 3:4, :], k2)
assert torch.equal(cache[0][:, 4:5, :], k3)
print("KV Cache 测试通过!")

需求规格

模型配置

python
config = {
    "vocab_size": 50257,       # GPT-2 词表大小(或自定义)
    "max_seq_len": 256,        # 最大序列长度
    "d_model": 384,            # 隐藏维度
    "n_heads": 6,              # 注意力头数
    "n_layers": 6,             # Transformer 层数
    "d_ff": 1536,              # FFN 中间维度 (4 * d_model)
    "dropout": 0.1,            # Dropout 概率
}

必须实现的组件

  1. Token Embedding + Position Embedding
  2. Causal Self-Attention(带因果掩码)
  3. Feed-Forward Network(带 GELU 激活)
  4. Layer Normalization(Pre-Norm 架构)
  5. 残差连接
  6. 语言模型头(输出 logits over vocabulary)
  7. 文本生成(Top-k / Top-p 采样)

类骨架

以下是建议的类结构。你需要补全所有 TODO 部分。

python
import torch
import torch.nn as nn
import torch.nn.functional as F
import math


class GPTConfig:
    """模型配置"""
    def __init__(self, **kwargs):
        self.vocab_size = kwargs.get("vocab_size", 50257)
        self.max_seq_len = kwargs.get("max_seq_len", 256)
        self.d_model = kwargs.get("d_model", 384)
        self.n_heads = kwargs.get("n_heads", 6)
        self.n_layers = kwargs.get("n_layers", 6)
        self.d_ff = kwargs.get("d_ff", 1536)
        self.dropout = kwargs.get("dropout", 0.1)


class CausalSelfAttention(nn.Module):
    """因果自注意力层"""

    def __init__(self, config: GPTConfig):
        super().__init__()
        assert config.d_model % config.n_heads == 0

        # TODO: 定义 Q、K、V 投影层(可以用一个合并的线性层)
        # 提示: 参考 nn.Linear(d_model, d_model * 3),一次性投影 QKV
        # TODO: 定义输出投影层
        # TODO: 注册因果掩码 buffer(下三角矩阵)
        # 提示: self.register_buffer('mask', torch.tril(...))
        # TODO: 定义 dropout

        self.n_heads = config.n_heads
        self.d_k = config.d_model // config.n_heads

    def forward(self, x):
        """
        参数: x (batch_size, seq_len, d_model)
        返回: (batch_size, seq_len, d_model)
        """
        B, T, C = x.size()

        # TODO: 计算 Q, K, V(用合并投影后 split)
        # TODO: reshape 为多头: (B, T, C) → (B, n_heads, T, d_k)
        # TODO: 计算注意力分数 S = Q @ K^T / sqrt(d_k)
        # TODO: 应用因果掩码(只截取 [:T, :T] 部分)
        # TODO: softmax + dropout
        # TODO: 加权求和 Z = P @ V
        # TODO: 合并多头 (B, n_heads, T, d_k) → (B, T, C),输出投影

        pass


class FeedForward(nn.Module):
    """前馈网络 (FFN)"""

    def __init__(self, config: GPTConfig):
        super().__init__()

        # TODO: 两层线性变换 + GELU 激活 + Dropout
        # 结构: d_model → d_ff → d_model
        # 提示: nn.Linear(d_model, d_ff), nn.GELU(), nn.Linear(d_ff, d_model), nn.Dropout

        pass

    def forward(self, x):
        # TODO
        pass


class TransformerBlock(nn.Module):
    """一个 Transformer Decoder 块 (Pre-Norm 架构)"""

    def __init__(self, config: GPTConfig):
        super().__init__()

        # TODO: 定义 ln1, attn, ln2, ffn
        # 提示: 用 nn.LayerNorm(d_model) 或手写 LayerNorm

        pass

    def forward(self, x):
        """
        Pre-Norm 架构:
            x = x + Attention(LayerNorm(x))
            x = x + FFN(LayerNorm(x))
        注意: 残差连接不经过 LayerNorm!
        """
        # TODO
        pass


class MiniGPT(nn.Module):
    """完整的 GPT 模型"""

    def __init__(self, config: GPTConfig):
        super().__init__()
        self.config = config

        # TODO: Token Embedding (nn.Embedding)
        # TODO: Position Embedding (nn.Embedding,可学习的)
        # TODO: Dropout
        # TODO: N 个 TransformerBlock (nn.ModuleList)
        # TODO: 最终的 LayerNorm(Pre-Norm 架构必须有!)
        # TODO: 语言模型头 (nn.Linear,映射到 vocab_size)
        # 进阶: LM head 可以与 Token Embedding 共享权重 (weight tying)

        pass

    def forward(self, idx, targets=None):
        """
        参数:
            idx: token indices (batch_size, seq_len)
            targets: 目标 token indices (batch_size, seq_len),训练时提供

        返回:
            logits: (batch_size, seq_len, vocab_size)
            loss: 交叉熵损失(仅在提供 targets 时返回)
        """
        B, T = idx.size()
        assert T <= self.config.max_seq_len

        # TODO: Token Embedding + Position Embedding + Dropout
        # 提示: pos = torch.arange(T, device=idx.device)
        # TODO: 依次通过所有 TransformerBlock
        # TODO: 最终 LayerNorm
        # TODO: 语言模型头 → logits

        # TODO: 如果提供了 targets,计算交叉熵损失
        # 提示: logits 和 targets 需要 reshape
        # logits: (B*T, vocab_size), targets: (B*T,)
        loss = None
        if targets is not None:
            loss = F.cross_entropy(
                logits.view(-1, logits.size(-1)),
                targets.view(-1)
            )

        return logits, loss

    @torch.no_grad()
    def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):
        """
        自回归文本生成

        参数:
            idx: 初始 token indices (batch_size, seq_len)
            max_new_tokens: 生成的最大 token 数
            temperature: 采样温度 (>1 更随机, <1 更确定)
            top_k: Top-k 采样的 k 值 (None 表示不使用)
        """
        for _ in range(max_new_tokens):
            # TODO: 截取最后 max_seq_len 个 token(防止超长)
            # TODO: 前向传播获取 logits
            # TODO: 取最后一个位置的 logits: logits[:, -1, :]
            # TODO: 除以 temperature
            # TODO: 可选 Top-k 过滤(将 top-k 之外的 logits 设为 -inf)
            # TODO: softmax → 概率 → torch.multinomial 采样
            # TODO: 拼接新 token
            pass

        return idx

训练脚本骨架

模型实现完成后,你还需要编写训练循环:

python
# 伪代码框架
def train():
    # 1. 准备数据
    #    - 加载文本,用 tokenizer 编码
    #    - 切分为固定长度的训练样本
    #    - 构造 DataLoader

    # 2. 初始化模型和优化器
    #    - model = MiniGPT(config)
    #    - optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)
    #    AdamW 核心: w = w - lr * (m_hat / (sqrt(v_hat) + eps) + wd * w)

    # 3. 训练循环
    #    for epoch in range(n_epochs):
    #        for batch in dataloader:
    #            logits, loss = model(batch_x, batch_y)
    #            loss.backward()
    #            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
    #            optimizer.step()
    #            optimizer.zero_grad()

    # 4. 生成测试
    #    model.eval()
    #    prompt = tokenizer.encode("从前有座山")
    #    generated = model.generate(prompt, max_new_tokens=200)
    #    print(tokenizer.decode(generated))

评估标准

基础要求(必须达成)

检查项要求
模型可实例化MiniGPT(config) 无报错
前向传播输出 shape 正确 (B, T, vocab_size)
损失计算交叉熵损失可正常计算和反向传播
参数量合理~25M 参数(给定默认 config)
训练可收敛loss 在训练过程中持续下降
可生成文本generate() 能输出文本(哪怕质量不高)

进阶要求(挑战自我)

检查项要求
Weight TyingToken Embedding 和 LM Head 共享权重
权重初始化使用合理的初始化方案(如 Xavier / 正态分布缩放)
学习率调度使用 Cosine Annealing 或 Warmup + Decay
梯度裁剪使用 clip_grad_norm_
Top-k + Top-pgenerate 方法同时支持 Top-k 和 Top-p 采样
KV Cache推理时缓存 K、V 避免重复计算
生成质量训练 1-2 小时后能生成基本通顺的文本

高阶挑战(选做)

  • 替换位置编码为 RoPE(复用 RoPE 填空练习 的代码)
  • 替换 LayerNorm 为 RMSNorm
  • 替换 GELU 为 SwiGLU
  • 实现 GQA(Grouped-Query Attention)
  • 用 Flash Attention 替代手写注意力

常见陷阱

在实现过程中,以下问题最容易踩坑:

  1. 因果掩码的 shape 不对:掩码应该是 (1, 1, T, T) 的下三角矩阵,注意在前向传播时只截取当前序列长度的部分
  2. 损失计算时忘记 shift:GPT 的训练目标是"给定前 n 个 token 预测第 n+1 个",所以 logits 和 targets 需要错位一个位置
  3. Position Embedding 的索引:应该是 torch.arange(T),不要写成 token indices
  4. dropout 在推理时要关闭model.eval() 会自动处理,但要确认 nn.Dropout 而非手动 dropout
  5. 数值溢出:注意力分数在 softmax 前可能很大,确保用了缩放(除以 dk

参考时间分配

阶段内容建议时间
0完成三个热身练习30 分钟
1实现 CausalSelfAttention45 分钟
2实现 FeedForward + TransformerBlock30 分钟
3实现 MiniGPT(forward + loss)45 分钟
4实现 generate 方法30 分钟
5训练脚本 + 调试60-90 分钟
6进阶功能(可选)60+ 分钟

参考实现

完成挑战后点击查看参考实现(请先独立完成!)
python
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from dataclasses import dataclass


@dataclass
class GPT2Config:
    """GPT-2 模型配置"""
    vocab_size: int = 100
    max_len: int = 512
    dim: int = 512
    heads: int = 8
    num_layers: int = 6
    initializer_range: float = 0.02


def GELU(x):
    """GPT-2 使用的 GELU 激活函数(tanh 近似)"""
    cdf = 0.5 * (1.0 + torch.tanh(
        math.sqrt(2.0 / torch.pi) * (x + 0.044715 * torch.pow(x, 3))
    ))
    return x * cdf


class LayerNorm(nn.Module):
    """手写 Layer Normalization"""
    def __init__(self, dim):
        super().__init__()
        self.gamma = nn.Parameter(torch.ones(dim))
        self.beta = nn.Parameter(torch.zeros(dim))
        self.epsilon = 1e-8

    def forward(self, X):
        mu = X.mean(dim=-1, keepdim=True)
        var = X.var(dim=-1, keepdim=True)
        X_hat = (X - mu) / torch.sqrt(var + self.epsilon)
        return X_hat * self.gamma + self.beta


class MultiHeadAttention(nn.Module):
    """多头缩放点积注意力(合并 QKV 投影)"""
    def __init__(self, dim, heads=8):
        super().__init__()
        self.dim = dim
        self.heads = heads
        self.head_dim = dim // heads
        # 合并 Q, K, V 为一个线性层,减少 kernel launch
        self.WQKV = nn.Linear(dim, dim * 3)
        self.WO = nn.Linear(dim, dim)

    def forward(self, X, mask=None):
        bs, seq_len, dim = X.shape
        # 一次投影,然后拆分为 Q, K, V
        QKV = self.WQKV(X)
        Q, K, V = QKV.split(split_size=self.dim, dim=2)

        # 拆分多头: (bs, seq_len, dim) → (bs, heads, seq_len, head_dim)
        Q = Q.reshape(bs, seq_len, self.heads, self.head_dim).transpose(1, 2)
        K = K.reshape(bs, seq_len, self.heads, self.head_dim).transpose(1, 2)
        V = V.reshape(bs, seq_len, self.heads, self.head_dim).transpose(1, 2)

        # 缩放点积注意力
        S = Q @ K.transpose(2, 3) / math.sqrt(self.head_dim)
        if mask is not None:
            S = S + mask[:, None, :, :]  # 广播到 heads 维度
        P = torch.softmax(S, dim=-1)
        Z = P @ V

        # 合并多头
        Z = Z.transpose(1, 2).reshape(bs, seq_len, dim)
        return self.WO(Z)


class FeedForwardNetwork(nn.Module):
    """前馈网络: dim → 4*dim (GELU) → dim"""
    def __init__(self, dim):
        super().__init__()
        self.fc1 = nn.Linear(dim, 4 * dim)
        self.fc2 = nn.Linear(4 * dim, dim)

    def forward(self, X):
        return self.fc2(GELU(self.fc1(X)))


class GPTBlock(nn.Module):
    """GPT-2 Decoder Block (Pre-Norm)"""
    def __init__(self, dim=512, heads=8):
        super().__init__()
        self.attn = MultiHeadAttention(dim, heads)
        self.ln1 = LayerNorm(dim)
        self.ffn = FeedForwardNetwork(dim)
        self.ln2 = LayerNorm(dim)

    def forward(self, X, mask=None):
        # Pre-Norm: 先 Norm 再计算,残差连接不经过 Norm
        X_ln = self.ln1(X)
        X_attn = self.attn(X_ln, mask=mask)
        X = X + X_attn                # 残差连接 1

        X_ln = self.ln2(X)
        X_ffn = self.ffn(X_ln)
        X = X + X_ffn                 # 残差连接 2
        return X


class GPT2Model(nn.Module):
    """完整 GPT-2 模型"""
    def __init__(self, config: GPT2Config):
        super().__init__()
        self.config = config
        self.embedding = nn.Embedding(config.vocab_size, config.dim)

        self.decoder = nn.ModuleList(
            [GPTBlock(config.dim, config.heads)
             for _ in range(config.num_layers)]
        )
        self.ln = LayerNorm(config.dim)  # 最终 LayerNorm
        self.lm_head = nn.Linear(config.dim, config.vocab_size, bias=False)

        # 预计算因果掩码
        self.register_buffer(
            'cache_mask',
            torch.tril(torch.ones(config.max_len, config.max_len))
        )

        # 权重初始化
        self.apply(self._init_weights)

    def _init_weights(self, module):
        """GPT-2 风格的权重初始化"""
        if isinstance(module, nn.Linear):
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
            if module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.Embedding):
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)

    def forward(self, x, targets=None):
        bs, seq_len = x.shape
        # 构造 additive 因果掩码
        mask = (1 - self.cache_mask[:seq_len, :seq_len]) * (-1e5)
        mask = mask.unsqueeze(0).expand(bs, -1, -1)

        X = self.embedding(x)
        for block in self.decoder:
            X = block(X, mask=mask)
        X = self.ln(X)
        logits = self.lm_head(X)

        loss = None
        if targets is not None:
            loss = F.cross_entropy(
                logits.view(-1, logits.size(-1)),
                targets.view(-1)
            )
        return logits, loss

    @torch.no_grad()
    def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):
        for _ in range(max_new_tokens):
            idx_cond = idx[:, -self.config.max_len:]
            logits, _ = self.forward(idx_cond)
            logits = logits[:, -1, :] / temperature
            if top_k is not None:
                v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
                logits[logits < v[:, [-1]]] = float('-inf')
            probs = F.softmax(logits, dim=-1)
            idx_next = torch.multinomial(probs, num_samples=1)
            idx = torch.cat([idx, idx_next], dim=1)
        return idx


# 验证
config = GPT2Config(vocab_size=100, dim=512, heads=8, num_layers=6)
model = GPT2Model(config)
x = torch.randint(100, [2, 16])
logits, _ = model(x)
print(f"logits shape: {logits.shape}")  # (2, 16, 100)

# 参数量统计
n_params = sum(p.numel() for p in model.parameters())
print(f"参数量: {n_params / 1e6:.1f}M")

祝你实现顺利! 遇到困难时,回顾 Transformer 架构注意力机制 的内容会很有帮助。


MLM 代码训练模式

完成上面的固定填空后,试试随机挖空模式 -- 每次点击「刷新」会随机遮盖不同的代码片段,帮你彻底记住每一行。

因果自注意力(合并 QKV 投影 + 多头拆分 + 缩放点积)

CausalSelfAttention 核心实现
共 98 个可挖空位 | 已挖 0 个
class MultiHeadAttention(nn.Module):
    def __init__(self, dim, heads=8):
        super().__init__()
        self.dim = dim
        self.heads = heads
        self.head_dim = dim // heads
        self.WQKV = nn.Linear(dim, dim * 3)
        self.WO = nn.Linear(dim, dim)

    def forward(self, X, mask=None):
        bs, seq_len, dim = X.shape
        QKV = self.WQKV(X)
        Q, K, V = QKV.split(split_size=self.dim, dim=2)

        Q = Q.reshape(bs, seq_len, self.heads, self.head_dim).transpose(1, 2)
        K = K.reshape(bs, seq_len, self.heads, self.head_dim).transpose(1, 2)
        V = V.reshape(bs, seq_len, self.heads, self.head_dim).transpose(1, 2)

        S = Q @ K.transpose(2, 3) / math.sqrt(self.head_dim)
        if mask is not None:
            S = S + mask[:, None, :, :]
        P = torch.softmax(S, dim=-1)
        Z = P @ V

        Z = Z.transpose(1, 2).reshape(bs, seq_len, dim)
        return self.WO(Z)

GPT Block(Pre-Norm 残差结构)

GPTBlock Pre-Norm 前向传播
共 45 个可挖空位 | 已挖 0 个
class GPTBlock(nn.Module):
    def __init__(self, dim=512, heads=8):
        super().__init__()
        self.attn = MultiHeadAttention(dim, heads)
        self.ln1 = LayerNorm(dim)
        self.ffn = FeedForwardNetwork(dim)
        self.ln2 = LayerNorm(dim)

    def forward(self, X, mask=None):
        X = X + self.attn(self.ln1(X), mask=mask)
        X = X + self.ffn(self.ln2(X))
        return X

GPT Forward + Generate

GPT 前向传播与自回归生成
共 112 个可挖空位 | 已挖 0 个
def forward(self, x, targets=None):
    bs, seq_len = x.shape
    mask = (1 - self.cache_mask[:seq_len, :seq_len]) * (-1e5)
    mask = mask.unsqueeze(0).expand(bs, -1, -1)

    X = self.embedding(x)
    for block in self.decoder:
        X = block(X, mask=mask)
    X = self.ln(X)
    logits = self.lm_head(X)

    loss = None
    if targets is not None:
        loss = F.cross_entropy(
            logits.view(-1, logits.size(-1)),
            targets.view(-1)
        )
    return logits, loss

@torch.no_grad()
def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):
    for _ in range(max_new_tokens):
        idx_cond = idx[:, -self.config.max_len:]
        logits, _ = self.forward(idx_cond)
        logits = logits[:, -1, :] / temperature
        if top_k is not None:
            v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
            logits[logits < v[:, [-1]]] = float('-inf')
        probs = F.softmax(logits, dim=-1)
        idx_next = torch.multinomial(probs, num_samples=1)
        idx = torch.cat([idx, idx_next], dim=1)
    return idx