Skip to content

推理模型(O1/R1 范式)

推理模型通过在推理时投入更多计算来提升回答质量,是大模型从"快思考"走向"慢思考"的关键范式

在大模型体系中的位置

预训练 (Pre-training)          → 学习语言知识和世界知识

监督微调 (SFT)                 → 学习指令跟随能力

偏好对齐 (RLHF/DPO/GRPO)      → 学习人类偏好,安全有用

推理增强  ← 你在这里            → 在推理时花更多计算换更好结果
  ├── PRM + Best-of-N          → 过程奖励引导搜索
  ├── MCTS + LLM               → 树搜索式推理
  ├── GRPO for Reasoning        → DeepSeek-R1 方案
  └── 推理蒸馏                  → 大模型 CoT → 小模型

传统大模型在生成每个 token 时花费的计算量是固定的。推理模型的核心洞察是:困难问题需要更多的思考时间。这与人类的 System 2 Thinking 类似——面对复杂数学题,我们不会瞬间给出答案,而是会一步步推导。

从 System 2 Thinking 到 Test-time Compute Scaling

Daniel Kahneman 在《Thinking, Fast and Slow》中将人类思维分为两个系统:

思维系统特点LLM 对应
System 1快速、自动、直觉标准 LLM 的单次前向传播
System 2慢速、刻意、逻辑推理模型的多步推理过程

Test-time Compute Scaling 的核心思想:与其在训练时投入所有计算(更大模型、更多数据),不如在推理时动态分配计算资源——简单问题快速回答,困难问题深入思考。

Performancef(Train-time Compute,Test-time Compute)

OpenAI 的 O1 模型首次验证了这一思路:通过在推理时生成长链式思考(Chain-of-Thought),模型在数学、编程、科学推理等任务上取得了巨大提升。

Reasoning Token 与 Chain-of-Thought

什么是 Reasoning Token?

O1 模型在生成最终答案之前,会先生成一段 Reasoning Token(推理 token),这些 token 不展示给用户,但对提升回答质量至关重要:

用户问题: "求解方程 3x + 5 = 20"

[Reasoning Tokens - 用户不可见]
<think>
我需要求解 3x + 5 = 20。
第一步:两边减去 5,得到 3x = 15。
第二步:两边除以 3,得到 x = 5。
让我验证:3 * 5 + 5 = 15 + 5 = 20,正确。
</think>

[最终回答 - 用户可见]
x = 5

关键区别在于:传统 CoT 需要 prompt 诱导("Let's think step by step"),只是浅层模仿推理格式。而 O1 的推理能力是通过 RL 训练出来的,模型自主决定何时深入思考、何时快速回答,并能发现错误后回溯纠正。

Process Reward Model (PRM)

ORM vs PRM

传统的 Outcome Reward Model (ORM) 只对最终结果打分——答案对了给高分,错了给低分。这存在一个问题:无法区分"方法正确但计算出错"和"方法完全错误"

Process Reward Model (PRM) 对推理的每一步进行打分:

问题: "一个矩形,长是宽的2倍,周长为24,求面积。"

Step 1: 设宽为 x,则长为 2x        → PRM 评分: 0.95 ✓
Step 2: 周长 = 2(x + 2x) = 24      → PRM 评分: 0.92 ✓
Step 3: 6x = 24, x = 4             → PRM 评分: 0.90 ✓
Step 4: 面积 = 4 × 8 = 32          → PRM 评分: 0.93 ✓

对比错误推理:
Step 1: 设宽为 x,则长为 2x        → PRM 评分: 0.95 ✓
Step 2: 周长 = x + 2x = 24         → PRM 评分: 0.15 ✗  ← PRM 在这里就能发现错误!
Step 3: 3x = 24, x = 8             → PRM 评分: 0.20 ✗
Step 4: 面积 = 8 × 16 = 128        → PRM 评分: 0.10 ✗

PRM 的数学形式

给定问题 x 和推理过程 s=(s1,s2,,sn),PRM 为每一步计算一个正确性分数:

PRMθ(x,s1,,sk)=P(step sk is correctx,s1,,sk1)

整个推理过程的分数可以用多种聚合方式:

Scoremin(x,s)=mink=1nPRMθ(x,s1,,sk)Score(x,s)=k=1nPRMθ(x,s1,,sk)Scorelast(x,s)=PRMθ(x,s1,,sn)

实践中 Scoremin 效果最好——因为一个错误步骤足以导致整个推理失败。

PRM 训练数据构建(PRM800K)

OpenAI 发布的 PRM800K 数据集包含约 80 万条人工逐步标注的数学推理步骤,每一步标注为 positive / negative / neutral。由于人工标注成本极高,后续工作(如 Math-Shepherd)提出了Monte Carlo 自动标注——对每一步,多次采样后续推理并检查最终答案,用正确率作为该步的分数:

python
def auto_label_step(model, question, steps_so_far, correct_answer, n_samples=64):
    """Monte Carlo 估计某一步的正确概率"""
    correct_count = 0
    for _ in range(n_samples):
        completion = model.generate(
            prompt=question + "\n".join(steps_so_far),
            temperature=0.8, max_tokens=512)
        if extract_answer(completion) == correct_answer:
            correct_count += 1
    return correct_count / n_samples

PRM 模型架构

PRM 的架构与奖励模型类似(参见 alignment.md),但在步骤分隔符位置输出分数:

python
import torch
import torch.nn as nn

class ProcessRewardModel(nn.Module):
    """过程奖励模型:在步骤分隔符位置输出正确性分数"""
    def __init__(self, base_model, step_token_id):
        super().__init__()
        self.base_model = base_model
        self.reward_head = nn.Linear(base_model.config.hidden_size, 1)
        self.step_token_id = step_token_id

    def forward(self, input_ids, attention_mask=None):
        hidden = self.base_model(
            input_ids=input_ids, attention_mask=attention_mask,
            output_hidden_states=True).hidden_states[-1]
        # 提取步骤分隔符位置的隐藏状态并打分
        positions = (input_ids == self.step_token_id).nonzero(as_tuple=True)
        step_hidden = hidden[positions[0], positions[1]]
        return torch.sigmoid(self.reward_head(step_hidden).squeeze(-1))

训练使用 BCE Loss,只在步骤位置计算损失。完整训练代码见下方代码实战。

Monte Carlo Tree Search (MCTS)

MCTS 核心思想

MCTS 是一种在大搜索空间中寻找最优策略的树搜索算法。在 LLM 推理中,我们将生成过程看作树搜索

Root = Question

    ┌─────┼─────────┐
    ↓     ↓         ↓
  Step1a  Step1b  Step1c     ← Different first reasoning steps
    │       │       │
  ┌─┼─┐   ┌┼┐     ┌┼┐
  ↓ ↓ ↓   ↓ ↓     ↓ ↓
 ...       ...     ...       ← Different follow-up reasoning

MCTS 四步循环

MCTS 通过四步循环不断改进搜索策略:

1. Selection(选择):从根节点出发,用 UCB 公式选择最有潜力的子节点:

UCB(s)=Q(s)N(s)+clnN(parent)N(s)

其中 Q(s) 是累计奖励,N(s) 是访问次数,c 是探索系数。

2. Expansion(扩展):到达叶节点后,用 LLM 生成新的推理步骤作为子节点。

3. Simulation(模拟):从新节点出发,用 LLM 快速生成完整解答,检查结果是否正确。

4. Backpropagation(回传):将模拟结果沿路径回传,更新所有祖先节点的统计信息。

代码示例:简化版 MCTS for LLM Reasoning

MCTS 搜索实现参考了 AlphaGo 论文 (Silver et al., 2016) 的 UCB 选择策略,并适配了 LLM 推理场景。

python
import math, random

class MCTSNode:
    """MCTS 节点:代表推理过程中的一个状态"""
    def __init__(self, state, parent=None, step_text=""):
        self.state = state            # 到当前步的完整推理文本
        self.parent = parent
        self.step_text = step_text
        self.children = []
        self.visits = 0               # 访问次数 N(s)
        self.total_reward = 0.0       # 累计奖励 Q(s)

    def ucb_score(self, c=1.414):
        if self.visits == 0:
            return float('inf')
        return (self.total_reward / self.visits +
                c * math.sqrt(math.log(self.parent.visits) / self.visits))


class MCTSReasoner:
    """将 MCTS 应用于 LLM 推理"""
    def __init__(self, llm, prm=None, max_steps=6, n_candidates=3):
        self.llm = llm
        self.prm = prm
        self.max_steps = max_steps
        self.n_candidates = n_candidates

    def search(self, question, n_iterations=50):
        root = MCTSNode(state=question)
        for _ in range(n_iterations):
            node = self._select(root)                      # 1. Selection
            if node.visits > 0 and len(node.children) == 0:
                node = self._expand(node)                  # 2. Expansion
            reward = self._simulate(node)                  # 3. Simulation
            self._backpropagate(node, reward)               # 4. Backpropagation
        return self._best_path(root)

    def _select(self, node):
        while node.children:
            node = max(node.children, key=lambda n: n.ucb_score())
        return node

    def _expand(self, node):
        for _ in range(self.n_candidates):
            next_step = self.llm.generate_step(node.state, temperature=0.7)
            child = MCTSNode(node.state + "\n" + next_step, node, next_step)
            node.children.append(child)
        return random.choice(node.children)

    def _simulate(self, node):
        completion = self.llm.complete_reasoning(node.state, self.max_steps)
        return self.prm.score(completion) if self.prm else self._check_answer(completion)

    def _backpropagate(self, node, reward):
        while node is not None:
            node.visits += 1
            node.total_reward += reward
            node = node.parent

    def _best_path(self, root):
        path, node = [], root
        while node.children:
            node = max(node.children, key=lambda n: n.visits)
            path.append(node.step_text)
        return path

    def _check_answer(self, completion):
        return 1.0 if "正确答案" in completion else 0.0

MCTS 在推理模型中的应用

MCTS + LLM 的核心优势在于可以系统地探索推理空间:

方法搜索策略优点缺点
Greedy Decoding每步选概率最高的速度快可能错过更好的路径
Beam Search保留 top-k 路径比贪心好路径之间缺乏多样性
Best-of-N独立采样 N 个简单有效无法利用部分推理的信息
MCTS动态树搜索智能分配计算资源计算成本高

GRPO for Reasoning(DeepSeek-R1 方案)

无需 Reward Model 的 RL 训练

DeepSeek-R1 提出了一个令人兴奋的发现:不需要复杂的 Reward Model,仅用简单的规则奖励 + GRPO 就能训练出强大的推理模型

训练流程:

基座模型 (DeepSeek-V3-Base)

冷启动 SFT(少量高质量 CoT 数据)

GRPO + 规则奖励(大规模 RL 训练)        ← 核心阶段

拒绝采样 + SFT(格式规范化)

再次 GRPO(进一步提升 + 对齐)

DeepSeek-R1

规则奖励函数设计

DeepSeek-R1 的奖励函数由两部分组成:

python
def deepseek_r1_reward(response, ground_truth):
    """DeepSeek-R1 的规则奖励函数"""
    reward = 0.0

    # 1. 格式奖励:检查是否使用了 <think>...</think> 格式
    has_think_start = "<think>" in response
    has_think_end = "</think>" in response
    if has_think_start and has_think_end:
        think_content = response.split("<think>")[1].split("</think>")[0]
        if len(think_content.strip()) > 0:
            reward += 1.0  # 格式正确 +1

    # 2. 正确性奖励:检查最终答案是否正确
    predicted_answer = extract_boxed_answer(response)
    if predicted_answer == ground_truth:
        reward += 1.0  # 答案正确 +1

    return reward


def extract_boxed_answer(response):
    """提取 \\boxed{...} 中的答案"""
    import re
    match = re.search(r'\\boxed\{(.+?)\}', response)
    return match.group(1).strip() if match else ""

"Aha Moment" 涌现

DeepSeek-R1 在训练过程中观察到了一个惊人的现象——模型自发学会了反思和自我纠错,被称为 "Aha Moment":

问题: "计算 28 × 15"

[训练早期的输出]
<think>
28 × 15 = 28 × 10 + 28 × 5 = 280 + 140 = 410
</think>
答案是 410。    ← 错误,没有反思

[训练后期的输出 - Aha Moment]
<think>
28 × 15 = 28 × 10 + 28 × 5 = 280 + 140 = 410
等一下,让我重新验证一下。
28 × 15 = 30 × 15 - 2 × 15 = 450 - 30 = 420     ← 自发重新验证!
两种方法结果不同,我再算一次:
28 × 5 = 140,这是对的。
28 × 10 = 280,这也是对的。
280 + 140 = 420,不是 410!之前算错了。        ← 发现并纠正错误!
</think>
答案是 \boxed{420}。

这种自我反思能力不是通过标注数据教会的,而是在 RL 训练过程中自然涌现的——模型发现"检查自己的推理"能提高获得正确答案的概率,从而被奖励强化。

GRPO 推理训练核心逻辑

python
import torch

def grpo_reasoning_step(model, ref_model, tokenizer, question, answer,
                         group_size=8, beta=0.04, epsilon=0.2):
    """GRPO 推理训练的核心步骤(简化版,完整实现见 alignment.md)"""
    # 1. 采样 G 个回答
    responses, log_probs_old = [], []
    for _ in range(group_size):
        with torch.no_grad():
            output = model.generate(
                tokenizer(question, return_tensors='pt').input_ids,
                max_new_tokens=512, temperature=0.7, do_sample=True,
                return_dict_in_generate=True, output_scores=True)
            responses.append(tokenizer.decode(output.sequences[0]))
            log_probs_old.append(compute_log_prob(output))

    # 2. 规则奖励
    rewards = torch.tensor([deepseek_r1_reward(r, answer) for r in responses])

    # 3. 组内相对优势(全对/全错时跳过)
    if rewards.std() < 1e-6:
        return torch.tensor(0.0)
    advantages = (rewards - rewards.mean()) / (rewards.std() + 1e-8)

    # 4. PPO-clip + KL 正则化
    total_loss = 0
    for i, resp in enumerate(responses):
        ids = tokenizer(resp, return_tensors='pt').input_ids
        log_p_new = compute_token_log_probs(model(ids).logits, ids)
        with torch.no_grad():
            log_p_ref = compute_token_log_probs(ref_model(ids).logits, ids)
        ratio = torch.exp(log_p_new - log_probs_old[i])
        clipped = torch.clamp(ratio, 1 - epsilon, 1 + epsilon)
        pg = -torch.min(ratio * advantages[i], clipped * advantages[i]).mean()
        kl = (log_p_ref.exp() / log_p_new.exp() - (log_p_ref - log_p_new) - 1).mean()
        total_loss += pg + beta * kl
    return total_loss / group_size

推理蒸馏

用大推理模型的 CoT 教小模型

推理蒸馏的核心思路:用大推理模型(如 DeepSeek-R1、O1)生成包含详细推理过程的数据,然后用这些数据对小模型进行 SFT

大推理模型(Teacher)                  小模型(Student)
         │                                 │
    生成高质量 CoT 推理数据 ──────→  用 CoT 数据进行 SFT
         │                                 │
  "设宽为x,长为2x,              "设宽为x,长为2x,
   周长=2(x+2x)=6x=24,            周长=2(x+2x)=6x=24,
   x=4,面积=4×8=32"               x=4,面积=4×8=32"

DeepSeek-R1 的实验表明:用 R1 生成的 80 万条推理数据微调 Qwen-2.5-32B,效果接近甚至超过用 RL 直接训练

流程:(1) Teacher 对每个问题多次采样,选出答案正确的 CoT 推理;(2) 用这些 (问题, <think>推理</think>答案) 数据对 Student 做 SFT。详见 distillation.md

蒸馏 vs RL 的效果对比

方法适合场景优点缺点
RL 训练有算力、追求极致效果能涌现新能力训练不稳定,成本极高
推理蒸馏快速部署、成本敏感简单稳定,效果好受限于 Teacher 的能力上限

DeepSeek-R1 论文指出:蒸馏是"站在巨人的肩膀上",而 RL 是"自己成为巨人"。对于大多数实际场景,蒸馏是更务实的选择。

Scaling Law for Test-time Compute

OpenAI 和 DeepSeek 的实验都表明,推理时计算量与性能之间存在类似 Scaling Law 的关系:

Performance(Ctest)alog(Ctest)+b

增加 test-time compute 的方式:

方式具体做法效率
生成更多 token允许模型思考更长边际递减
Best-of-N 采样采样 N 个回答,选最好的O(N)
树搜索(MCTS)系统性搜索推理空间最高效
集成(Majority Vote)多次采样取多数投票简单有效

关键发现:在某些困难推理任务上,小模型 + 大量 test-time compute 可以匹敌大模型 + 少量 test-time compute。这意味着我们可以用更小的模型实现同等的推理能力,只要在推理时给予足够的计算资源。

代码实战:PRM 训练 + MCTS 搜索 + Best-of-N 采样

实战目标

将前面的 PRM 理论和 MCTS 搜索算法落地为可运行的代码。包含四部分:

  1. PRM 训练 — 步级 Reward Model 的数据格式与训练流程
  2. MCTS 搜索实现 — 完整的 Selection → Expansion → Simulation → Backpropagation 循环
  3. PRM 引导的逐步搜索 — 生成一步、验证一步、拒绝重采样
  4. 端到端示例 — 用 MCTS + PRM 解一道 GSM8K 数学题

Part 1:PRM 训练代码实战

PRM 的核心思路:复用 LLM 的 next-token prediction head,在步骤分隔符位置预测 positive / negative 两个特殊 token 的概率。这比额外加一个 reward head 更高效,因为可以直接用 LoRA 微调。

python
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset

# ============================================
# 1. 数据格式:PRM 训练数据长什么样
# ============================================
# 每条数据 = prompt + 若干 step,每个 step 以 SEP_TOKEN 结尾
# label 标注该 step 是 positive 还是 negative

SEP_TOKEN = "<|reserved_special_token_1|>"   # 步骤分隔符
POSITIVE_TOKEN = "<|reserved_special_token_2|>"  # 正确标签 token
NEGATIVE_TOKEN = "<|reserved_special_token_3|>"  # 错误标签 token

prm_train_example = {
    "prompt": "计算 (7+3) × (12-8) 的值",
    "steps": [
        "第一步:计算括号内,7+3=10",      # ← positive
        "第二步:计算括号内,12-8=4",       # ← positive
        "第三步:计算乘法,10×4=40",        # ← positive
    ],
    "labels": [True, True, True],  # 每步是否正确
}

prm_train_negative = {
    "prompt": "计算 (7+3) × (12-8) 的值",
    "steps": [
        "第一步:计算括号内,7+3=10",      # ← positive
        "第二步:直接算 10×12=120",         # ← negative(跳过了减法)
        "第三步:120-8=112",               # ← negative(后续全错)
    ],
    "labels": [True, False, False],
}

def format_prm_training_text(example, tokenizer):
    """将一条 PRM 数据转换为训练文本 + label 序列"""
    text = example["prompt"]
    label_ids = []  # 只在 SEP 位置有 label,其余位置 = -100(忽略)
    for step, is_correct in zip(example["steps"], example["labels"]):
        text += "\n" + step + SEP_TOKEN
        target_token = POSITIVE_TOKEN if is_correct else NEGATIVE_TOKEN
        label_ids.append(tokenizer.convert_tokens_to_ids(target_token))
    return text, label_ids

# ============================================
# 2. PRM 模型:复用 LM head,在 SEP 位置做二分类
# ============================================

class ProcessRewardTrainer:
    """PRM 训练器:在步骤分隔符位置,用 LM head 预测 positive/negative"""
    def __init__(self, model, tokenizer, lr=1e-5):
        self.model = model
        self.tokenizer = tokenizer
        self.sep_id = tokenizer.convert_tokens_to_ids(SEP_TOKEN)
        self.pos_id = tokenizer.convert_tokens_to_ids(POSITIVE_TOKEN)
        self.neg_id = tokenizer.convert_tokens_to_ids(NEGATIVE_TOKEN)
        self.optimizer = torch.optim.AdamW(model.parameters(), lr=lr)

    def train_step(self, examples):
        """一个 batch 的训练步骤"""
        self.model.train()
        total_loss = 0.0
        for ex in examples:
            self.optimizer.zero_grad()
            text, step_labels = format_prm_training_text(ex, self.tokenizer)
            enc = self.tokenizer(text, return_tensors="pt")
            input_ids = enc["input_ids"]

            # 前向传播,拿到所有位置的 logits
            logits = self.model(input_ids).logits  # (1, seq_len, vocab_size)

            # 找到所有 SEP_TOKEN 的位置
            sep_positions = (input_ids[0] == self.sep_id).nonzero(as_tuple=True)[0]

            # 只取 SEP 位置前一个 token 的 logits(next-token prediction)
            # 因为 model 在位置 t 预测的是 t+1 的 token
            pred_positions = sep_positions - 1
            sep_logits = logits[0, pred_positions, :]  # (n_steps, vocab_size)

            # 只看 positive / negative 两个 token 的 logits
            binary_logits = sep_logits[:, [self.neg_id, self.pos_id]]  # (n_steps, 2)
            targets = torch.tensor(
                [1 if l == self.pos_id else 0 for l in step_labels],
                dtype=torch.long
            )

            loss = F.cross_entropy(binary_logits, targets)
            loss.backward()
            self.optimizer.step()
            total_loss += loss.item()
        return total_loss / len(examples)

# 使用示例
# trainer = ProcessRewardTrainer(model, tokenizer)
# for epoch in range(3):
#     loss = trainer.train_step(train_data)
#     print(f"Epoch {epoch+1}, PRM Loss: {loss:.4f}")
PRM 训练的工程细节
  • LoRA 微调:实践中通常用 QLoRA 4-bit 量化 + LoRA 微调 LM head 附近的层,显存占用约为全量微调的 1/4
  • 数据来源:可以用 PRM800K 人工标注数据,也可以用 Monte Carlo 自动标注(见上文 auto_label_step 函数)
  • SEP token 选择:不同模型有不同的 reserved token,Llama 3 用 <|reserved_special_token_N|>,Qwen 可以自定义添加

Part 2:MCTS 搜索实现——完整四步循环

将前面的简化版 MCTS 扩展为可与真实 LLM 配合的完整实现,核心改进包括:PRM 打分替代随机 rollout、KV Cache 复用、步级粒度搜索。

python
import math
import random
import copy
from dataclasses import dataclass, field
from typing import List, Optional, Tuple

# ============================================
# MCTS 节点定义
# ============================================

@dataclass
class ReasoningNode:
    """MCTS 节点:代表推理过程中一个步骤的状态"""
    state_text: str              # 从根到当前节点的完整推理文本
    step_text: str = ""          # 当前节点对应的推理步骤文本
    parent: Optional['ReasoningNode'] = None
    children: List['ReasoningNode'] = field(default_factory=list)
    visits: int = 0              # N(s) — 访问次数
    total_reward: float = 0.0    # Q(s) — 累计奖励
    prm_score: float = 0.0       # PRM 对当前步骤的打分
    is_terminal: bool = False    # 是否到达终止状态(得出最终答案)

    @property
    def avg_reward(self) -> float:
        return self.total_reward / self.visits if self.visits > 0 else 0.0

    def ucb_score(self, c: float = 1.414) -> float:
        """UCB1 公式:平衡探索与利用"""
        if self.visits == 0:
            return float('inf')  # 未访问的节点优先探索
        exploitation = self.total_reward / self.visits
        exploration = c * math.sqrt(math.log(self.parent.visits) / self.visits)
        return exploitation + exploration


# ============================================
# MCTS 搜索引擎
# ============================================

class LLMReasoningMCTS:
    """
    MCTS + LLM + PRM 的推理搜索引擎
    - Selection:  UCB 选择最有潜力的叶节点
    - Expansion:  用 LLM 生成多个候选推理步骤
    - Simulation: 用 PRM 打分(替代传统的随机 rollout)
    - Backprop:   回传奖励,更新路径上所有节点的统计量
    """
    def __init__(self, llm, prm, tokenizer,
                 n_candidates=3, max_depth=8, c_explore=1.414):
        self.llm = llm
        self.prm = prm
        self.tokenizer = tokenizer
        self.n_candidates = n_candidates  # Expansion 时每个节点生成几个子节点
        self.max_depth = max_depth
        self.c_explore = c_explore

    # ---------- 1. Selection ----------
    def _select(self, node: ReasoningNode) -> ReasoningNode:
        """从根递归选择 UCB 最高的子节点,直到叶节点"""
        while node.children and not node.is_terminal:
            node = max(node.children, key=lambda n: n.ucb_score(self.c_explore))
        return node

    # ---------- 2. Expansion ----------
    def _expand(self, node: ReasoningNode) -> ReasoningNode:
        """用 LLM 采样生成多个候选下一步,作为子节点"""
        if node.is_terminal:
            return node

        for _ in range(self.n_candidates):
            # 用 temperature sampling 生成不同的下一步推理
            next_step = self._generate_one_step(node.state_text, temperature=0.7)
            child_state = node.state_text + "\n" + next_step
            child = ReasoningNode(
                state_text=child_state,
                step_text=next_step,
                parent=node,
                is_terminal=self._is_answer_complete(next_step),
            )
            node.children.append(child)

        # 返回一个随机子节点进行 simulation
        return random.choice(node.children)

    # ---------- 3. Simulation ----------
    def _simulate(self, node: ReasoningNode) -> float:
        """
        用 PRM 给当前推理路径打分(替代传统的随机 rollout)
        如果路径未完成,先用 LLM 快速补全到终止,再整体打分
        """
        if node.is_terminal:
            return self._prm_score_path(node.state_text)

        # 快速补全:greedy decoding 到终止
        completion = self._fast_complete(node.state_text)
        full_path = node.state_text + "\n" + completion
        return self._prm_score_path(full_path)

    # ---------- 4. Backpropagation ----------
    def _backpropagate(self, node: ReasoningNode, reward: float):
        """将 reward 沿路径回传到根节点"""
        current = node
        while current is not None:
            current.visits += 1
            current.total_reward += reward
            current = current.parent

    # ---------- 主搜索循环 ----------
    def search(self, question: str, n_iterations: int = 50) -> Tuple[str, float]:
        """执行 MCTS 搜索,返回最佳推理路径和得分"""
        root = ReasoningNode(state_text=question)

        for i in range(n_iterations):
            # 1. Selection — 找到最有潜力的叶节点
            leaf = self._select(root)

            # 2. Expansion — 如果叶节点已被访问过,展开生成子节点
            if leaf.visits > 0 and not leaf.is_terminal:
                leaf = self._expand(leaf)

            # 3. Simulation — 用 PRM 评估当前路径
            reward = self._simulate(leaf)

            # 4. Backpropagation — 回传奖励
            self._backpropagate(leaf, reward)

        # 搜索结束,沿 visits 最多的路径提取最终答案
        return self._extract_best_path(root)

    # ---------- 辅助方法 ----------
    def _generate_one_step(self, context: str, temperature: float = 0.7) -> str:
        """用 LLM 生成一个推理步骤(到换行或 SEP token 为止)"""
        ids = self.tokenizer(context, return_tensors="pt").input_ids
        with torch.no_grad():
            out = self.llm.generate(
                ids, max_new_tokens=64, temperature=temperature,
                do_sample=True, eos_token_id=self.tokenizer.encode("\n")[0])
        new_tokens = out[0, ids.shape[1]:]
        return self.tokenizer.decode(new_tokens, skip_special_tokens=True).strip()

    def _fast_complete(self, context: str) -> str:
        """Greedy decoding 快速补全推理到终止"""
        ids = self.tokenizer(context, return_tensors="pt").input_ids
        with torch.no_grad():
            out = self.llm.generate(ids, max_new_tokens=256, do_sample=False)
        return self.tokenizer.decode(out[0, ids.shape[1]:], skip_special_tokens=True)

    def _prm_score_path(self, full_text: str) -> float:
        """用 PRM 对整条推理路径打分,返回 min-step-score"""
        ids = self.tokenizer(full_text, return_tensors="pt").input_ids
        sep_id = self.tokenizer.convert_tokens_to_ids(SEP_TOKEN)
        with torch.no_grad():
            logits = self.prm(ids).logits[0]
        sep_positions = (ids[0] == sep_id).nonzero(as_tuple=True)[0]
        if len(sep_positions) == 0:
            return 0.0
        pos_id = self.tokenizer.convert_tokens_to_ids(POSITIVE_TOKEN)
        neg_id = self.tokenizer.convert_tokens_to_ids(NEGATIVE_TOKEN)
        scores = []
        for pos in sep_positions:
            step_logits = logits[pos, [neg_id, pos_id]]
            prob_positive = torch.softmax(step_logits, dim=0)[1].item()
            scores.append(prob_positive)
        return min(scores)  # min-score 聚合

    def _is_answer_complete(self, step_text: str) -> bool:
        """判断是否已得出最终答案"""
        return "答案" in step_text or "\\boxed" in step_text or "因此" in step_text

    def _extract_best_path(self, root: ReasoningNode) -> Tuple[str, float]:
        """沿访问次数最多的路径,提取最终推理结果"""
        path_steps = []
        node = root
        while node.children:
            node = max(node.children, key=lambda n: n.visits)
            path_steps.append(node.step_text)
        score = node.avg_reward
        return "\n".join(path_steps), score

MCTS 搜索的计算开销

假设 n_iterations=50n_candidates=3,每次 iteration 需要 1 次 LLM 推理(expansion)+ 1 次 LLM 补全(simulation)+ 1 次 PRM 前向传播。总计约 100-150 次模型调用,因此 MCTS 主要用于对准确率要求极高的场景(数学竞赛、代码生成)。对于一般问答任务,Best-of-N 更实用。

Part 3:PRM 引导的逐步验证搜索

与 MCTS 的树搜索不同,这是一种更轻量的逐步生成 + 逐步验证策略:每生成一步就用 PRM 打分,如果当前步骤被判定为错误,则拒绝并重新采样。这个思路来自 O1 风格的 PRM Search。

python
def prm_guided_step_search(
    llm, prm_model, tokenizer, question,
    max_steps=10, max_retries_per_step=5, temperature=0.9
):
    """
    逐步生成 + PRM 验证的推理搜索
    - 每一步:LLM 生成候选 step → PRM 打分 → 通过则继续,否则重采样
    - 如果某一步多次重采样都不通过,则强制接受(避免死循环)
    """
    sep_id = tokenizer.convert_tokens_to_ids(SEP_TOKEN)
    pos_id = tokenizer.convert_tokens_to_ids(POSITIVE_TOKEN)
    neg_id = tokenizer.convert_tokens_to_ids(NEGATIVE_TOKEN)

    current_ids = tokenizer(question, return_tensors="pt").input_ids
    accepted_steps = []

    for step_i in range(max_steps):
        # 检查是否已生成 EOS
        if current_ids[0, -1] == tokenizer.eos_token_id:
            break

        best_step, best_prob = None, -1.0
        for retry in range(max_retries_per_step):
            # 1. 生成一个推理步骤(do_sample 保证多样性)
            with torch.no_grad():
                step_out = llm.generate(
                    current_ids, max_new_tokens=64,
                    temperature=temperature, do_sample=True,
                    eos_token_id=sep_id)
            step_ids = step_out[:, current_ids.shape[1]:]

            # 2. 拼接后用 PRM 验证
            candidate_ids = torch.cat([current_ids, step_ids], dim=1)
            with torch.no_grad():
                logits = prm_model(candidate_ids).logits[0]

            # 在最后一个 SEP 位置取 positive 概率
            sep_positions = (candidate_ids[0] == sep_id).nonzero(as_tuple=True)[0]
            if len(sep_positions) > 0:
                last_sep = sep_positions[-1]
                step_logits = logits[last_sep, [neg_id, pos_id]]
                p_positive = torch.softmax(step_logits, dim=0)[1].item()
            else:
                p_positive = 0.5

            # 记录最好的候选
            if p_positive > best_prob:
                best_prob = p_positive
                best_step = step_ids

            # 3. 如果 PRM 判定为正确,接受该步骤
            if p_positive > 0.5:
                break

        # 接受最好的步骤(即使没通过阈值也选概率最高的)
        current_ids = torch.cat([current_ids, best_step], dim=1)
        step_text = tokenizer.decode(best_step[0], skip_special_tokens=False)
        accepted_steps.append({
            "step": step_text.strip(),
            "prm_positive_prob": best_prob,
            "retries": retry + 1
        })
        print(f"Step {step_i+1}: P(correct)={best_prob:.3f}, "
              f"retries={retry+1}, text={step_text.strip()[:60]}")

    full_text = tokenizer.decode(current_ids[0], skip_special_tokens=True)
    return full_text, accepted_steps
PRM Search vs MCTS vs Best-of-N 对比
维度Best-of-NPRM Step SearchMCTS
搜索策略独立采样 N 条完整路径逐步生成+验证树搜索+回溯
LLM 调用次数N 次max_steps × avg_retriesiterations × 2
能否利用部分正确推理不能能(逐步保留)能(树结构复用)
实现复杂度
推荐场景N ≤ 64,通用任务步骤可验证的数学/逻辑题竞赛级别难题

Part 4:端到端示例——用 MCTS + PRM 解 GSM8K 数学题

将上面三个组件串联起来,展示一个完整的推理搜索流程:

python
import torch
import re

# ============================================
# GSM8K 端到端示例
# ============================================

def solve_gsm8k_with_mcts(question, ground_truth_answer,
                           llm, prm_model, tokenizer):
    """
    用 MCTS + PRM 解一道 GSM8K 数学题
    1. MCTS 搜索最佳推理路径
    2. Best-of-N 做 baseline 对比
    3. PRM 逐步搜索做对比
    """
    print("=" * 60)
    print(f"问题: {question}")
    print(f"标准答案: {ground_truth_answer}")
    print("=" * 60)

    # --- 方法 1:直接 greedy decoding ---
    ids = tokenizer(question, return_tensors="pt").input_ids
    with torch.no_grad():
        greedy_out = llm.generate(ids, max_new_tokens=256, do_sample=False)
    greedy_answer = tokenizer.decode(greedy_out[0], skip_special_tokens=True)
    greedy_result = extract_number(greedy_answer)
    print(f"\n[Greedy] 答案: {greedy_result}")

    # --- 方法 2:Best-of-N + PRM 打分 ---
    best_answer, best_score = best_of_n_with_prm(
        llm, prm_model, tokenizer, question, n=8, temperature=0.7)
    bon_result = extract_number(best_answer)
    print(f"[Best-of-8] 答案: {bon_result}, PRM score: {best_score:.3f}")

    # --- 方法 3:PRM 逐步搜索 ---
    step_answer, steps = prm_guided_step_search(
        llm, prm_model, tokenizer, question,
        max_steps=6, max_retries_per_step=3)
    step_result = extract_number(step_answer)
    print(f"[PRM Search] 答案: {step_result}")

    # --- 方法 4:MCTS 搜索 ---
    mcts = LLMReasoningMCTS(llm, prm_model, tokenizer,
                             n_candidates=3, max_depth=6)
    mcts_path, mcts_score = mcts.search(question, n_iterations=30)
    mcts_result = extract_number(mcts_path)
    print(f"[MCTS] 答案: {mcts_result}, score: {mcts_score:.3f}")
    print(f"  推理路径:\n  {mcts_path.replace(chr(10), chr(10) + '  ')}")

    # --- 对比结果 ---
    print("\n" + "-" * 40)
    for name, result in [("Greedy", greedy_result), ("Best-of-8", bon_result),
                          ("PRM Search", step_result), ("MCTS", mcts_result)]:
        correct = "correct" if str(result) == str(ground_truth_answer) else "wrong"
        print(f"  {name:12s}: {result:>8s}  [{correct}]")


def best_of_n_with_prm(llm, prm_model, tokenizer, question,
                        n=8, temperature=0.7):
    """生成 N 条完整推理,用 PRM min-score 选最佳"""
    sep_id = tokenizer.convert_tokens_to_ids(SEP_TOKEN)
    pos_id = tokenizer.convert_tokens_to_ids(POSITIVE_TOKEN)
    neg_id = tokenizer.convert_tokens_to_ids(NEGATIVE_TOKEN)
    candidates, scores = [], []

    for _ in range(n):
        ids = tokenizer(question, return_tensors="pt").input_ids
        with torch.no_grad():
            out = llm.generate(ids, max_new_tokens=256,
                               temperature=temperature, do_sample=True)
        candidates.append(tokenizer.decode(out[0], skip_special_tokens=True))

    for cand in candidates:
        cand_ids = tokenizer(cand, return_tensors="pt").input_ids
        with torch.no_grad():
            logits = prm_model(cand_ids).logits[0]
        sep_positions = (cand_ids[0] == sep_id).nonzero(as_tuple=True)[0]
        if len(sep_positions) == 0:
            scores.append(0.0)
            continue
        step_scores = []
        for pos in sep_positions:
            sl = logits[pos, [neg_id, pos_id]]
            step_scores.append(torch.softmax(sl, dim=0)[1].item())
        scores.append(min(step_scores))  # min-score 聚合

    best_idx = max(range(n), key=lambda i: scores[i])
    return candidates[best_idx], scores[best_idx]


def extract_number(text):
    """从推理文本中提取最终数值答案"""
    # 优先从 \\boxed{} 中提取
    match = re.search(r'\\boxed\{([^}]+)\}', text)
    if match:
        return match.group(1).strip()
    # 否则取最后出现的数字
    numbers = re.findall(r'[-+]?\d*\.?\d+', text)
    return numbers[-1] if numbers else ""


# ============================================
# 运行示例(需要加载实际模型和 PRM)
# ============================================

# gsm8k_question = (
#     "Janet's ducks lay 16 eggs per day. She eats three for breakfast "
#     "every morning and bakes muffins for her friends every day with four. "
#     "She sells the remainder at the farmers' market daily for $2 per "
#     "fresh duck egg. How much in dollars does she make every day at "
#     "the farmers' market?"
# )
# solve_gsm8k_with_mcts(gsm8k_question, "18", llm, prm_model, tokenizer)

预期输出效果

在 GSM8K 基准测试上,不同搜索策略的典型效果(以 Llama-3-8B 为例):

  • Greedy: ~50% 准确率
  • Best-of-8: ~62% 准确率(+12%)
  • PRM Step Search: ~66% 准确率(+16%)
  • MCTS (30 iterations): ~70% 准确率(+20%)

随着搜索预算增加(更多 iterations / 更大 N),准确率持续提升,这就是 Test-time Compute Scaling 的直观体现。

苏格拉底时刻

  1. 为什么 PRM 比 ORM 更有效? ORM 只提供"对/错"的稀疏信号,而 PRM 在每一步都给出反馈。这类似于老师批改数学作业——只看最终答案对错(ORM)远不如逐步批改(PRM)有指导价值。
  2. MCTS 的 exploration-exploitation tradeoff 如何理解? UCB 公式中,第一项是利用(选择历史奖励高的路径),第二项是探索(选择访问少的路径)。c 值越大越倾向探索。推理任务中,适度探索能避免陷入"看似正确实则错误"的推理路径。
  3. DeepSeek-R1 的 "Aha Moment" 为什么能涌现? 因为自我纠错能提高最终答案的正确率,而正确率直接对应奖励。RL 的优化压力使模型自然学会了这种策略。这说明足够简单的奖励信号也能催生复杂的行为。
  4. 推理蒸馏的上限在哪里? 蒸馏模型最多只能达到 Teacher 的水平,因为它学习的是 Teacher 的推理模式。RL 训练则有可能超越 Teacher(因为 RL 是自我探索)。但在实践中,蒸馏的效率优势往往使其成为更好的选择。
  5. 小模型 + 大量 test-time compute 能否替代大模型? 理论上可以在某些任务上实现,但前提是任务有明确的验证信号(如数学题可以验算)。对于开放式生成任务(创意写作、综合分析),test-time compute scaling 的收益有限。

常见问题 & 面试考点

  • Q: O1 和 R1 的核心区别是什么? O1 是闭源产品,具体技术未公开。R1 公开了技术细节:使用 GRPO + 规则奖励进行 RL 训练,并通过推理蒸馏扩展到不同规模的模型。
  • Q: PRM 训练数据如何获取? 三种方式:(1) 人工标注(成本高但质量好,如 PRM800K);(2) Monte Carlo 自动标注(用采样估计每步正确率);(3) 模型自动生成标注(成本最低但质量不稳定)。
  • Q: Best-of-N 和 MCTS 如何选择? Best-of-N 简单高效,适合 N 较小的场景(8~64)。MCTS 更智能但计算开销大,适合需要深入推理的难题。实践中建议先用 Best-of-N,不够再上 MCTS。
  • Q: GRPO 在推理任务上为什么比 PPO 更受欢迎? (1) 不需要训练额外的 Critic 网络,节省一半显存;(2) 规则奖励消除了 Reward Model 的噪声;(3) 组内相对优势在推理任务上信号更清晰。
  • Q: Reasoning Token 增加了推理成本,如何平衡? 可以通过训练"think budget"——让模型学习根据问题难度决定思考的深度。简单问题少思考,难题多思考,实现计算资源的动态分配。

推荐资源