Skip to content

预训练

预训练是让模型从海量文本中学会语言的过程,是 LLM 所有能力的基础。模型在这一阶段通过预测下一个 token,从万亿级语料中压缩出语言结构、世界知识和推理模式。

在大模型体系中的位置

预训练(本章)──> SFT 微调 ──> RLHF/DPO 对齐 ──> 部署推理
   │                │              │
   │                │              └─ 偏好数据(chosen/rejected)
   │                └─ 指令数据(instruction-response)
   └─ 海量无标注文本(万亿 token)

预训练是整个流水线中计算量最大、耗时最长、成本最高的阶段。以 Llama 3 70B 为例,预训练在 15T token 上进行,消耗约 6.4M GPU-hours(H100)。但正是这一阶段赋予模型所有的基础能力——后续的 SFT 和 RLHF 只是在预训练的基础上"激活"和"对齐"。


预训练的本质

Next Token Prediction 目标

预训练的核心目标极其简单:给定前面的所有 token,预测下一个 token

给定序列 x1,x2,,xT,模型学习条件概率分布:

P(xt|x1,x2,,xt1;θ)

训练目标是最大化整个序列的对数似然(等价于最小化交叉熵损失):

L(θ)=1Tt=1TlogP(xt|x<t;θ)

交叉熵损失的直觉:当模型预测正确类别的概率越高,logP 越小(损失越低)。

python
# 交叉熵的本质
import torch
import torch.nn.functional as F

# 给定两个概率分布,计算交叉熵
q = torch.tensor([0.05, 0.9, 0.05])  # 模型预测(预测较准确)
p = torch.tensor([0.0, 1.0, 0.0])    # 目标分布(真实标签是第2类)

entropy = -p * torch.log(q)
print(entropy.sum())  # tensor(0.1054) —— 预测准确时损失低

q_bad = torch.tensor([0.3, 0.5, 0.2])  # 模型预测(不够准确)
entropy_bad = -p * torch.log(q_bad)
print(entropy_bad.sum())  # tensor(0.6931) —— 预测不准时损失高

在实际分类问题中,由于目标分布是 one-hot 的(只有正确类别概率为 1),交叉熵可以简化为:

L=logqy

其中 y 是正确类别的索引。这就是 PyTorch 中 nn.CrossEntropyLoss 的实现方式:

python
# 手动实现批量交叉熵
def manual_ce_loss(targets, logits):
    """
    targets: [batch]       — 每个样本的正确类别索引
    logits:  [batch, num_cls] — 模型输出的未归一化分数
    """
    batch, _ = logits.shape
    probs = F.softmax(logits, dim=-1)
    row_idx = torch.arange(batch)
    # 只取正确类别的 log 概率
    log_probs = probs[row_idx, targets].log()
    loss = -log_probs.mean()
    return loss

# 验证与 PyTorch 实现一致
import torch.nn as nn
loss_fn = nn.CrossEntropyLoss()
logits = torch.randn(6, 8)
targets = torch.randint(high=8, size=(6,))
print(manual_ce_loss(targets, logits))  # tensor(2.xxxx)
print(loss_fn(logits, targets))          # tensor(2.xxxx) —— 结果一致

交叉熵梯度的优美性质:CE + Softmax 对 logits 的梯度有一个极简的形式:

Lzi=qipi

预测概率减去目标概率。这意味着:正确类别的梯度方向是增大其概率(负梯度),错误类别的梯度方向是减小其概率。

python
# 验证梯度公式:CE + Softmax 对 logits 的梯度
z = torch.tensor([1.0, 2.0, 3.0])
p = torch.tensor([0.0, 1.0, 0.0])  # 目标分布(one-hot)
q = F.softmax(z, dim=0)             # softmax 输出

# 通过链式法则手动计算 dL/dz
dL_dp = -(p / q)                              # dL/dq
J_softmax = torch.diag(q) - torch.outer(q, q) # softmax 雅可比矩阵
dL_dz = dL_dp @ J_softmax                     # 链式法则
print(dL_dz)    # tensor([ 0.0900, -0.7553,  0.6652])
print(q - p)     # tensor([ 0.0900, -0.7553,  0.6652]) —— 完全一致

Causal Language Modeling vs Masked Language Modeling

特性Causal LM (GPT 系列)Masked LM (BERT 系列)
预测方向从左到右,预测下一个 token双向,预测被 mask 的 token
注意力模式Causal mask(下三角)全注意力
生成能力天然支持自回归生成不擅长开放式生成
代表模型GPT, Llama, QwenBERT, RoBERTa
当前主流几乎所有 LLM 都采用主要用于理解任务

现代 LLM 几乎全部采用 Causal LM。原因:生成能力是 LLM 最核心的能力,而 Causal LM 的训练目标天然与自回归生成对齐。

预训练赋予模型什么能力?

预训练并非简单的"记忆",而是通过预测下一个 token 这一目标,迫使模型学会:

  1. 语法和语言结构:主谓宾搭配、时态一致性、代词指代
  2. 世界知识:"北京是中国的首都"、"水在 100 度沸腾"
  3. 推理模式:因果推理、类比推理、数学逻辑
  4. 代码理解:编程语言的语法、API 使用模式、算法结构
  5. 多语言能力:如果训练数据包含多种语言,模型会学到跨语言的表示

Scaling Laws

Kaplan et al. (2020) 的发现

OpenAI 在 2020 年发现了语言模型性能与三个关键变量之间的幂律关系:

L(N)=(NcN)αN,αN0.076L(D)=(DcD)αD,αD0.095L(C)=(CcC)αC,αC0.050

其中:

  • N = 模型参数量(非嵌入层参数)
  • D = 训练数据量(token 数)
  • C = 计算量(FLOPs)
  • L = 测试集交叉熵损失

核心发现:性能与这三个变量分别呈幂律关系,且在很大范围内保持平滑。这意味着可以用小实验预测大模型的性能。

Chinchilla 法则

DeepMind 的 Chinchilla 论文(Hoffmann et al., 2022)修正了 Kaplan 的结论,提出了计算最优的训练策略:

给定固定的计算预算 C,模型参数量 N 和训练数据量 D 应当等比例增长。

具体而言:

NoptC0.50,DoptC0.50

经验法则:每个参数对应约 20 个训练 token。即:

模型参数量最优训练 token 数计算量 (FLOPs)
400M8B1.9e19
1B20B1.2e20
7B140B5.8e21
13B260B2.0e22
70B1.4T5.9e23
175B3.5T3.7e24

对实践的指导意义

  1. 不要只加大模型:Chinchilla (70B, 1.4T tokens) 打败了 Gopher (280B, 300B tokens),因为后者严重"过大欠训练"
  2. 数据是瓶颈:对于大模型,高质量数据的需求量巨大。Llama 3 在 15T token 上训练 8B 和 70B 模型,远超 Chinchilla 最优比例,说明过度训练(over-training)在推理成本敏感场景下是合理的
  3. 预测训练成本C6ND(近似公式),其中 C 是 FLOPs,N 是参数量,D 是 token 数

实际案例:训练 Llama 3 8B 在 15T tokens 上,计算量约 6×8×109×15×1012=7.2×1023 FLOPs。以 H100 (989 TFLOPS BF16) 和 40% MFU 计算,需要约 7.2×1023/(989×1012×0.4×3600)505,000 GPU-hours。

Scaling Laws 深度:从理论到实践

Kaplan Scaling Laws(OpenAI 2020)详解

Kaplan 等人在 2020 年系统地研究了语言模型的 loss 与三个核心变量之间的幂律关系。这些关系在数个数量级上保持惊人的平滑:

单变量幂律:固定其中两个变量,只增长第三个:

L(N)=(NcN)αN(8.8×1013N)0.076L(D)=(DcD)αD(5.4×1013D)0.095L(C)=(CcC)αC(3.1×108C)0.050

联合幂律:当同时优化 ND 时:

L(N,D)=[(NcN)αN/αD+DcD]αD

Kaplan 的关键结论

  1. 模型大小比数据量重要αN=0.076>αD=0.095(注意:α 越小意味着对 loss 的边际收益越大)。Kaplan 认为在计算预算有限时,应该优先加大模型
  2. 最优分配:给定计算预算 CNoptC0.73DoptC0.27——这意味着大部分预算应该给模型参数
  3. 架构细节次要:在合理范围内,层数/宽度/头数的具体选择对 scaling 影响不大

Chinchilla Scaling Laws(DeepMind 2022)的修正

Hoffmann 等人(DeepMind)在 2022 年发表了 Chinchilla 论文,通过更严格的实验设计修正了 Kaplan 的结论。

实验方法:他们用三种不同方法估计最优分配:

  1. 固定计算预算,训练大量不同 ND 的模型
  2. 固定模型大小,变化训练 token 数
  3. 参数化拟合 L(N,D)

Chinchilla 的幂律

L(N,D)=ANα+BDβ+E

其中 α0.34β0.28E1.69(不可约损失)。

compute-optimal 训练的核心结论

NoptC0.50,DoptC0.50

这与 Kaplan 的结论(NC0.73截然不同!Chinchilla 认为模型参数和数据量应该等比例增长

为什么 Kaplan 和 Chinchilla 结论不同?

  1. Kaplan 的实验中,大模型没有训练到收敛,导致高估了"增大模型"的收益
  2. Kaplan 对计算量 C6ND 的近似中忽略了 embedding 层
  3. Chinchilla 使用了更大范围的模型-数据组合(400 多个实验点)

Chinchilla 对工业界的深远影响

在 Chinchilla 之前,业界的信条是"bigger is better"——GPT-3 175B 只用了 300B tokens 训练。Chinchilla 70B 用 1.4T tokens 训练后,性能超过了 Gopher 280B(300B tokens),仅用约 1/4 的推理成本。

这直接改变了后续所有模型的训练策略:

模型参数量训练 Token 数Token/参数比策略
GPT-3175B300B1.7欠训练
Gopher280B300B1.1严重欠训练
Chinchilla70B1.4T20compute-optimal
Llama 2 70B70B2T29略超 Chinchilla
Llama 3 8B8B15T1875大幅过训练
Llama 3 70B70B15T214大幅过训练

为什么 Llama 3 大幅超过 Chinchilla 比例? 因为 Chinchilla 优化的是训练计算量,而实际部署时还要考虑推理成本。小模型训练更久,推理更快、更省钱。对于部署量大的模型,多花训练成本换来的推理节省远超投入。

推理 Scaling Laws(OpenAI 2024)

2024 年,随着 OpenAI o1 等"推理模型"的出现,一种新的 Scaling Law 被提出——test-time compute scaling

L(ctest)ctestγ

其中 ctest 是推理时使用的计算量(如 Chain-of-Thought 的 token 数、搜索树的大小)。

核心发现

  1. 推理时的计算量也遵循幂律:让模型"思考更久"(生成更多推理步骤),可以持续提升准确率
  2. 训练计算和推理计算可以互换:在某些任务上,一个小模型 + 大量推理计算,可以匹敌大模型 + 少量推理计算
  3. 最优分配:给定总预算(训练 + 推理),应该在两者之间寻找最优分配
传统 Scaling:       推理 Scaling:
  更大模型 → 更好      更多推理步骤 → 更好
  1B → 7B → 70B       1 step → 10 steps → 100 steps
  训练时确定能力        推理时释放能力

这意味着 Scaling Laws 的维度从"预训练三角"(N,D,Ctrain)扩展到了四维:N,D,Ctrain,Ctest

Scaling Laws 的实用价值

1. 训练成本预估

在启动大规模训练前,先用小模型验证 scaling trend:

python
import numpy as np
from scipy.optimize import curve_fit

def power_law(x, a, b, c):
    """幂律函数: L = a / x^b + c"""
    return a / np.power(x, b) + c

# 用小规模实验数据拟合 Scaling Law
# 模型参数量(单位:百万)
N_values = np.array([25, 50, 100, 200, 400, 800])
# 对应的验证集 loss
loss_values = np.array([3.85, 3.52, 3.25, 3.05, 2.88, 2.75])

# 拟合幂律参数
params, _ = curve_fit(power_law, N_values, loss_values,
                      p0=[10, 0.1, 2.0], maxfev=10000)

a, b, c = params
print(f"拟合结果: L(N) = {a:.2f} / N^{b:.4f} + {c:.4f}")
print(f"不可约损失 (irreducible loss): {c:.4f}")

# 预测更大模型的性能
for target_N in [7000, 13000, 70000]:  # 7B, 13B, 70B
    predicted_loss = power_law(target_N, *params)
    print(f"预测 {target_N}M 参数模型 loss: {predicted_loss:.4f}")

2. 最优模型大小选择

给定计算预算 C(FLOPs),按 Chinchilla 法则:

python
def chinchilla_optimal(compute_budget_flops):
    """
    给定计算预算,计算 Chinchilla-optimal 的模型大小和数据量
    基于 C ≈ 6ND 和 N_opt ∝ C^0.5, D_opt ∝ C^0.5
    经验关系: D_opt ≈ 20 * N_opt
    """
    # 从 C = 6ND 和 D = 20N 得: C = 120 * N^2
    N_opt = np.sqrt(compute_budget_flops / 120)
    D_opt = 20 * N_opt

    return {
        'params': N_opt,
        'tokens': D_opt,
        'params_B': N_opt / 1e9,
        'tokens_T': D_opt / 1e12,
        'compute_flops': compute_budget_flops,
    }

# 不同计算预算下的最优配置
budgets = {
    '小型实验': 1e19,
    '7B 级别': 6e21,
    '70B 级别': 6e23,
    '175B 级别': 4e24,
}

for name, budget in budgets.items():
    result = chinchilla_optimal(budget)
    print(f"{name} (C={budget:.0e}):")
    print(f"  最优参数量: {result['params_B']:.1f}B")
    print(f"  最优数据量: {result['tokens_T']:.2f}T tokens")
    print()

# 输出:
# 小型实验 (C=1e+19): 最优参数量: 0.3B, 最优数据量: 0.01T tokens
# 7B 级别 (C=6e+21): 最优参数量: 7.1B, 最优数据量: 0.14T tokens
# 70B 级别 (C=6e+23): 最优参数量: 70.7B, 最优数据量: 1.41T tokens
# 175B 级别 (C=4e+24): 最优参数量: 182.6B, 最优数据量: 3.65T tokens

3. 数据需求规划

python
def estimate_data_needs(target_params_B, strategy='chinchilla'):
    """估算不同训练策略下的数据需求"""
    N = target_params_B * 1e9

    if strategy == 'chinchilla':
        # Chinchilla-optimal: 20 tokens/param
        tokens = 20 * N
    elif strategy == 'llama3':
        # Llama 3 风格: 过度训练,优化推理成本
        # 经验值: 小模型 ~2000 tokens/param, 大模型 ~200 tokens/param
        tokens_per_param = 2000 if target_params_B < 20 else 200
        tokens = tokens_per_param * N
    elif strategy == 'balanced':
        # 折中: 100 tokens/param
        tokens = 100 * N

    compute = 6 * N * tokens  # FLOPs

    return {
        'tokens_T': tokens / 1e12,
        'compute_flops': compute,
        'h100_hours': compute / (989e12 * 0.4 * 3600),
    }

for model_size in [1, 7, 13, 70]:
    print(f"\n=== {model_size}B 模型 ===")
    for strategy in ['chinchilla', 'llama3', 'balanced']:
        result = estimate_data_needs(model_size, strategy)
        print(f"  {strategy:12s}: {result['tokens_T']:8.1f}T tokens, "
              f"{result['h100_hours']:10,.0f} H100-hours")

# 输出示例:
# === 7B 模型 ===
#   chinchilla  :      0.1T tokens,      2,960 H100-hours
#   llama3      :     14.0T tokens,    414,444 H100-hours
#   balanced    :      0.7T tokens,     20,723 H100-hours

4. 拟合自己的 Scaling Law 曲线

python
import matplotlib.pyplot as plt

def fit_and_plot_scaling_law(N_values, loss_values, title="Scaling Law"):
    """
    拟合并可视化 Scaling Law 曲线
    N_values: 模型参数量(列表)
    loss_values: 对应的验证 loss(列表)
    """
    N = np.array(N_values, dtype=np.float64)
    L = np.array(loss_values, dtype=np.float64)

    # 拟合 L(N) = a / N^b + c
    params, cov = curve_fit(power_law, N, L,
                            p0=[10, 0.1, 2.0], maxfev=10000)
    a, b, c = params

    # 绘制拟合曲线
    N_range = np.logspace(np.log10(N.min() * 0.5),
                          np.log10(N.max() * 100), 200)
    L_predicted = power_law(N_range, *params)

    plt.figure(figsize=(10, 6))
    plt.scatter(N, L, s=100, c='red', zorder=5, label='实验数据')
    plt.plot(N_range, L_predicted, 'b--', label=f'拟合: L = {a:.2f}/N^{b:.4f} + {c:.4f}')
    plt.axhline(y=c, color='gray', linestyle=':', label=f'不可约损失 = {c:.4f}')
    plt.xscale('log')
    plt.xlabel('模型参数量 (M)')
    plt.ylabel('验证集 Loss')
    plt.title(title)
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.savefig('scaling_law.png', dpi=150)
    print(f"Scaling Law 拟合完成, 图表已保存")

    return params

# 使用示例
N_experiment = [25, 50, 100, 200, 400, 800, 1500]
loss_experiment = [3.85, 3.52, 3.25, 3.05, 2.88, 2.75, 2.65]

params = fit_and_plot_scaling_law(N_experiment, loss_experiment,
                                  title="Validation Loss Scaling Law")

面试考点:Scaling Laws 的意义是什么? Scaling Laws 使得大模型训练从"黑盒炼丹"变为"工程可预测"。通过小规模实验拟合幂律参数,可以在花费数百万美元训练前,预估目标模型的大致性能、所需数据量和训练时间。这是大模型训练从艺术走向科学的关键一步。


数据准备

数据来源

现代 LLM 的预训练数据通常来自以下几个大类:

数据源规模特点
Common Crawl数万亿 token覆盖面广但质量参差不齐
Wikipedia~40 亿 token(英文)高质量、结构化
GitHub 代码数千亿 token提升模型代码能力
书籍语料数百亿 token长文本、高质量
学术论文 (ArXiv)数百亿 token数学和科学推理
StackOverflow/论坛数百亿 token问答格式、实践知识
python
# GPT-2 预训练数据加载示例
import json

def load_jsonl(file_path):
    data = []
    with open(file_path, 'r', encoding='utf-8') as f:
        for line in f:
            if line.strip():
                data.append(json.loads(line)['text'])
    return data

data = load_jsonl('./data.jsonl')
print(len(data))    # 5000 篇文档
print(data[0][:100]) # "The Technology Report empowers or enlightens..."

# 预训练数据处理:文档拼接 -> tokenize -> 切分成固定长度的 block
data_total = '\n'.join(data)  # 将所有文档拼接
print(len(data_total))        # 15,075,691 字符

数据清洗流水线

原始网页数据极其嘈杂,需要多层清洗:

原始 Common Crawl 快照

    ├── 1. URL 过滤:移除成人站点、广告站点、已知低质量域名

    ├── 2. 语言识别:使用 fastText 分类器,保留目标语言(如英文 > 0.65)

    ├── 3. 质量过滤:
    │       ├── 基于规则:行长度、特殊字符比例、重复行比例
    │       ├── 基于困惑度:用 KenLM 计算 perplexity,过滤高困惑度文档
    │       └── 基于分类器:训练质量分类器(以 Wikipedia 为正例)

    ├── 4. 去重:
    │       ├── 精确去重:文档级 SHA-256 哈希
    │       └── 模糊去重:MinHash + LSH(下文详述)

    └── 5. 去污染:移除与评测集(MMLU, HumanEval 等)重叠的内容

    最终保留约 10-15% 的原始数据

FineWeb 与 RedPajama

FineWeb(HuggingFace, 2024)是目前最大的开源英文网页数据集之一(15T token)。其关键创新在于:

  • 使用更严格的质量过滤器,基于教育内容评分(educational score)
  • FineWeb-Edu 子集在教育类 benchmark 上显著优于全量数据
  • 证明了数据质量筛选比简单扩大规模更有效

RedPajama(Together AI)旨在复现 Llama 的训练数据:

  • RedPajama v1: 1.2T token,复现 Llama 1 的数据配比
  • RedPajama v2: 30T token 的原始数据 + 质量信号标注,支持用户自定义过滤

数据配比

Llama 3 的数据配比策略(Meta, 2024):

数据类型占比说明
网页文本~50%经过严格质量过滤的 Common Crawl
代码~17%GitHub 代码
数学~4.5%数学推理相关内容
书籍~4.5%长文本、连贯叙述
学术论文~4.5%科学推理
多语言~20%非英语数据

关键策略:

  • 数据上采样(upsampling):对高质量数据(代码、数学)进行多次重复使用
  • 动态调整:训练后期增加高质量数据比例(即 Data Curriculum)
  • 去重比重复更重要:Llama 3 对数据进行了 4 轮去重

优化器

AdamW 详解

几乎所有现代 LLM 的预训练都使用 AdamW 优化器。它是 Adam 的改进版本,核心修复了 weight decay 的实现方式。

Adam 算法推导

给定参数 θ 和梯度 gt=θLt,Adam 维护两个指数移动平均:

  1. 一阶矩估计(梯度的均值):mt=β1mt1+(1β1)gt
  2. 二阶矩估计(梯度平方的均值):vt=β2vt1+(1β2)gt2

由于初始化为零,需要偏差校正:

m^t=mt1β1t,v^t=vt1β2t

参数更新:

θt+1=θtηm^tv^t+ϵ

Weight Decay vs L2 正则化

在标准 SGD 中,L2 正则化与 weight decay 是等价的。但在 Adam 中两者不等价:

  • L2 正则化将正则项加入损失函数,梯度会经过 Adam 的自适应缩放
  • Weight Decay 直接在参数更新时减去 λθt,不经过自适应缩放

AdamW 选择后者(解耦 weight decay),在实践中表现更好。

python
# AdamW 手动实现
class Adam:
    def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8):
        self.w = params
        self.lr = lr
        self.beta1, self.beta2 = betas
        self.eps = eps
        self.m = torch.zeros_like(params)
        self.v = torch.zeros_like(params)
        self.t = 0

    def step(self, w, grad, weight_decay=1e-2):
        self.t += 1

        # 一阶矩估计(动量方向)
        self.m = self.beta1 * self.m + (1 - self.beta1) * grad
        # 二阶矩估计(自适应学习率)
        self.v = self.beta2 * self.v + (1 - self.beta2) * grad.pow(2)

        # 偏差校正
        m_hat = self.m / (1 - self.beta1 ** self.t)
        v_hat = self.v / (1 - self.beta2 ** self.t)

        if weight_decay is not None:  # AdamW: 解耦的 weight decay
            return w - self.lr * (m_hat / (v_hat.sqrt() + self.eps) + weight_decay * w)

        return w - self.lr * m_hat / (v_hat.sqrt() + self.eps)  # 标准 Adam

# 验证收敛
w = torch.randn(10, 1)
optimizer = Adam(w)
input_data = torch.randn(8, 10)
target = torch.randn(8, 1)

for epoch in range(1000):
    output = input_data @ w
    grad = input_data.transpose(1, 0) @ (output - target)
    if epoch % 200 == 0:
        loss = (0.5 / 8) * ((output - target) ** 2).sum()
        print(f"Epoch {epoch}: loss = {loss:.4f}")
    w = optimizer.step(w, grad, weight_decay=1e-2)
# 输出: loss 从 0.30 持续下降到 0.05

典型超参数(Llama 3):β1=0.9,β2=0.95,ϵ=108,λ=0.1

学习率调度

学习率调度对训练稳定性至关重要。

Warmup 的必要性

  • 训练初期,Adam 的二阶矩估计 vt 尚不准确(接近 0),导致更新步长过大
  • Warmup 阶段线性增加学习率,给 vt 积累时间
  • 典型 warmup 步数:2000 步

Cosine Annealing(余弦退火)

ηt=ηmin+12(ηmaxηmin)(1+cos(tTwTTwπ))

其中 Tw 是 warmup 步数,T 是总训练步数。直觉:学习率从峰值平滑下降到接近 0,前期下降慢(探索),后期下降快(收敛)。

WSD 调度器(Warmup-Stable-Decay)

MiniCPM 和部分新模型采用的三阶段策略:

Learning Rate
  │     ┌──────────────┐
  │    /│   Stable      │\
  │   / │               │ \
  │  /  │               │  \
  │ /   │               │   \
  │/    │               │    \
  └─────┴───────────────┴─────── Training Steps
  Warmup     Stable        Decay
  (~2000)   (most steps)   (~last 10%)

WSD 的优势:Stable 阶段可以随时决定何时开始 Decay,方便灵活调整训练长度。

python
# 学习率调度实现
import math

def cosine_schedule(step, total_steps, warmup_steps, max_lr, min_lr=0):
    """Cosine Annealing with Warmup"""
    if step < warmup_steps:
        return max_lr * step / warmup_steps
    progress = (step - warmup_steps) / (total_steps - warmup_steps)
    return min_lr + 0.5 * (max_lr - min_lr) * (1 + math.cos(math.pi * progress))

def wsd_schedule(step, total_steps, warmup_steps, decay_steps, max_lr, min_lr=0):
    """Warmup-Stable-Decay 调度器"""
    if step < warmup_steps:
        return max_lr * step / warmup_steps
    elif step < total_steps - decay_steps:
        return max_lr
    else:
        decay_progress = (step - (total_steps - decay_steps)) / decay_steps
        return min_lr + (max_lr - min_lr) * (1 - decay_progress)

Mixed Precision Training

FP32, FP16, BF16 的精度对比

格式符号位指数位尾数位数值范围精度
FP321823~3.4e38
FP161510~6.5e4
BF16187~3.4e38低(但范围大)

Loss Scaling 技巧

FP16 的数值范围有限(最大 65504),训练中容易出现梯度下溢(很小的梯度变为 0)。Loss Scaling 的做法:

  1. 将 loss 乘以一个大常数(如 1024)
  2. 反向传播得到放大后的梯度
  3. 更新参数前将梯度除以同一个常数

动态 Loss Scaling 会自动调整缩放因子:如果出现 NaN/Inf,减半缩放因子并跳过当前步。

为什么 BF16 比 FP16 更适合训练

  • BF16 的指数位与 FP32 相同(8位),数值范围一致,几乎不会溢出
  • 不需要 Loss Scaling,训练流程更简单
  • 代价是精度略低(7 位尾数 vs FP16 的 10 位),但实践证明对训练影响很小
  • Llama 3, Qwen 2.5 等主流模型全部使用 BF16 训练

混合精度训练的典型配置:

  • 模型参数和梯度:BF16
  • 优化器状态(mt,vt):FP32(需要高精度累加)
  • 损失计算:FP32

Gradient Checkpointing

内存 vs 计算的 trade-off

标准训练需要保存所有中间激活值用于反向传播。对于一个 L 层的 Transformer:

  • 不用 checkpointing:内存 O(L),计算 1×
  • 全量 checkpointing:内存 O(L),计算约 1.33×(多做一次前向)

实现原理

标准训练:
  前向: 保存 layer1_out, layer2_out, ..., layerL_out  (内存占用大)
  反向: 用保存的激活值计算梯度

Gradient Checkpointing:
  前向: 只保存 layer1_out, layer_k_out, layer_2k_out...  (每隔 k 层保存)
  反向: 遇到未保存的激活值时,从最近的 checkpoint 重新前向计算

典型配置:对每个 Transformer block 做 checkpointing,即每层保存输入,层内的中间激活不保存。代价是约 33% 的额外计算时间,但可以显著减少显存占用。


训练监控

关键指标:loss, grad norm, learning rate

指标正常范围异常信号
Training loss持续平滑下降突然飙升(loss spike)
Gradient norm稳定在 0.1-10突然变大(梯度爆炸)或趋近 0
Learning rate按照 schedule 变化应当与 loss 曲线配合
Tokens/sec基本恒定突然下降说明有硬件问题
python
# 训练循环中的监控
self.model.train()
self.optimizer.zero_grad()

logits = self.model(input_tensor)
loss = self.criterion(
    logits.view(-1, logits.size(-1)),
    label_tensor.view(-1)
)
loss.backward()

# 梯度裁剪:防止梯度爆炸
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)

self.optimizer.step()
self.loss_history.append(loss.item())

# 定期打印统计信息
avg_loss = np.mean(self.loss_history[-100:])

Loss Spike 的处理

Loss spike(损失突然飙升)在大规模训练中很常见。处理策略:

  1. 梯度裁剪(gradient clipping):限制梯度范数,通常 max_norm=1.0
  2. 回滚到之前的 checkpoint:如果 spike 持续不恢复
  3. 降低学习率:从 spike 前的 checkpoint 恢复,学习率减半
  4. 数据检查:某些低质量数据 batch 可能导致 spike
  5. 跳过异常步:如果梯度出现 NaN/Inf,跳过当前更新

MFU (Model FLOPs Utilization) 的计算

MFU 衡量 GPU 的实际利用率:

MFU=实际每秒 FLOPsGPU 理论峰值 FLOPs

对于 Transformer 模型,每个 token 的前向 + 反向 FLOPs 约为 6NN 为参数量):

实际 FLOPs/s=6×N×batch_size×seq_len每步耗时(秒)

典型 MFU:30-50%。达到 50% 以上说明系统优化较好。


预训练实战

完整训练配置示例

以训练一个 7B 参数模型为例:

python
# 模型配置
model_config = {
    "hidden_size": 4096,
    "num_layers": 32,
    "num_heads": 32,
    "vocab_size": 32000,
    "max_seq_len": 4096,
    "intermediate_size": 11008,  # FFN 中间层
}

# 训练配置
train_config = {
    "total_tokens": 2_000_000_000_000,  # 2T tokens
    "batch_size": 4_000_000,             # ~4M tokens per batch (global)
    "micro_batch_size": 4,               # 每个 GPU 的 micro batch
    "seq_len": 4096,
    "learning_rate": 3e-4,
    "min_lr": 3e-5,
    "warmup_steps": 2000,
    "weight_decay": 0.1,
    "grad_clip": 1.0,
    "optimizer": "AdamW",
    "betas": (0.9, 0.95),
    "precision": "bf16",
    "gradient_checkpointing": True,
}

# 数据处理配置
from torch.utils.data import Dataset, DataLoader

class PretrainedLanguageModelDataset(Dataset):
    def __init__(self, data, max_len=4096, pad_token_id=0):
        self.data = data
        self.max_len = max_len
        self.pad_token_id = pad_token_id

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        if len(item) < self.max_len:
            item = item + (self.max_len - len(item)) * [self.pad_token_id]
        return torch.tensor(item, dtype=torch.long)

训练成本估算

问题:训练一个 7B 模型,2T tokens,需要多少 GPU-hours?

计算步骤:

  1. 总 FLOPs = 6×N×D=6×7×109×2×1012=8.4×1022
  2. H100 BF16 理论峰值 = 989 TFLOPS
  3. 假设 MFU = 40%,有效算力 = 989×1012×0.4=3.96×1014 FLOPS
  4. 总 GPU 秒数 = 8.4×1022/3.96×10142.12×108
  5. 总 GPU-hours 59,000 H100-hours

如果使用 128 张 H100:

  • 训练时间 59000 / 128 461 小时 19 天
  • 按 H100 云价格 $3/GPU-hour 计算:约 $177,000

苏格拉底时刻

  1. 预训练的 loss 从 10+ 降到 2-3 的过程中,模型分别在学什么? Loss 从 10 到 5 的阶段,模型主要在学习词频分布和简单的语法模式;从 5 到 3 的阶段,开始学习语义关联和世界知识;从 3 到 2.5 以下,开始涌现推理能力。

  2. Chinchilla 法则说每个参数需要 20 个 token,但 Llama 3 用了远超此比例的数据(8B 模型训练 15T token),为什么? 因为 Chinchilla 法则优化的是训练计算量,但推理成本也很重要。小模型训练更久,推理时更省算力——这在部署阶段的收益远大于额外的训练成本。

  3. 交叉熵损失的梯度为 qp,这个简洁的形式意味着什么? 意味着每次更新,模型在"向正确答案靠近"的同时"远离错误答案",而且调整的幅度与当前预测的误差成正比。这是一个"自我校正"的过程。

  4. 为什么预训练数据中包含大量低质量网页内容,模型仍能学到有用知识? 因为有用的语言模式在高质量和低质量文本中都存在(语法、常见搭配)。但数据质量过低会导致模型学到错误知识和有毒内容。这就是为什么数据清洗如此重要。

  5. 如果将所有预训练数据重复训练多个 epoch,效果会变差吗? 会。实验表明数据重复 4 次以上,模型开始"记忆"而非"泛化"。这也是为什么高质量去重数据如此珍贵。


常见问题 & 面试考点

问题要点
预训练的损失函数是什么?交叉熵损失,等价于最大化下一个 token 的对数似然
解释 Chinchilla Scaling Law给定计算预算,模型参数和数据量应等比增长,每参数约 20 tokens
AdamW 和 Adam 的区别AdamW 将 weight decay 与自适应学习率解耦
为什么用 BF16 而不是 FP16?BF16 指数位与 FP32 相同,不易溢出,无需 loss scaling
Gradient checkpointing 的代价约 33% 额外计算时间,换取显著内存节省
什么是 MFU?好的 MFU 是多少?实际 FLOPs / 理论峰值 FLOPs,40-50% 是较好的水平
Warmup 的作用让 Adam 的二阶矩估计充分积累,避免初始阶段步长过大
预训练数据去重为什么重要?重复数据导致过拟合、降低多样性、浪费计算资源

推荐资源