解码策略(Decoding Strategies)
一句话总结: 解码策略决定了模型"如何从概率分布中选择下一个 Token"——同一个模型搭配不同的解码策略,可以生成从严谨精确到天马行空的截然不同的文本。
在大模型体系中的位置
语言模型的前向传播输出的是下一个 Token 在整个词表上的概率分布(logits → softmax → 概率)。解码策略作用于这个概率分布之上,决定最终选哪个 Token。它不影响模型参数,却深刻影响生成质量——选择合适的解码策略是 LLM 应用落地的关键环节。
输入 Token 序列 → [Transformer 前向传播] → logits (词表大小的向量)
→ [Temperature 缩放] → [Top-k / Top-p 过滤] → [采样或取 argmax] → 下一个 Token核心概念
Greedy Search(贪心搜索)
最简单的解码策略:每一步都选择概率最高的 Token。
优点: 速度快,确定性输出,适合有唯一正确答案的任务(如分类、提取)。
缺点: 容易陷入局部最优。因为每一步都只看当前最优,可能错过整体更优的序列。例如,当前概率最高的词可能把后续生成引入一条"死胡同",导致整体序列质量下降。
# Greedy Search 实现
def greedy_decode(model, input_ids, max_length):
for _ in range(max_length):
logits = model(input_ids) # [batch, seq_len, vocab_size]
next_token = logits[:, -1, :].argmax(dim=-1) # 取概率最高的
input_ids = torch.cat([input_ids, next_token.unsqueeze(-1)], dim=-1)
if next_token == eos_token_id:
break
return input_idsBeam Search(束搜索)
Beam Search 是 Greedy Search 的扩展:不只保留一条路径,而是同时维护
工作流程:
- 从起始 Token 出发,生成 Top-k 个候选
- 对每条候选路径,分别扩展下一步的 Top-k 个 Token
- 在所有
个候选中,保留总概率最高的 条 - 重复直到所有路径都生成了结束符或达到最大长度
核心公式: 每条路径的得分是各步 log 概率之和:
通常还需要做长度归一化(除以序列长度),否则 Beam Search 会倾向于生成更短的序列(因为每多一步,log 概率之和就更小)。
优点: 比 Greedy 能找到更优的序列,适合翻译、摘要等需要精确输出的任务。
缺点: 生成的文本缺乏多样性,容易重复;计算量随 beam width 线性增长。在开放式生成(如对话、创意写作)中效果不佳。
Temperature(温度参数)
Temperature 通过缩放 logits 来控制概率分布的"尖锐程度":
其中
| Temperature | 效果 | 适用场景 |
|---|---|---|
| 分布趋近 one-hot(等价于 Greedy) | 事实性问答、代码生成 | |
| 保持模型原始分布 | 通用场景 | |
| 分布更平坦,低概率 Token 获得更多机会 | 创意写作、头脑风暴 |
直觉理解: Temperature 低 = 模型更"自信"更保守;Temperature 高 = 模型更"随机"更有创意。
Top-k Sampling(Top-k 采样)
在采样前,只保留概率最高的
# Top-k Sampling
def top_k_sampling(logits, k):
top_k_logits, top_k_indices = torch.topk(logits, k)
probs = F.softmax(top_k_logits, dim=-1)
next_token_index = torch.multinomial(probs, num_samples=1)
return top_k_indices[next_token_index]问题:
Top-p / Nucleus Sampling(核采样)
Top-p 采样解决了 Top-k 中
其中 Token 按概率从高到低排序后依次加入,直到累积概率
# Top-p (Nucleus) Sampling
def top_p_sampling(logits, p):
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
sorted_probs = F.softmax(sorted_logits, dim=-1)
cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
# 找到累积概率超过 p 的位置,移除之后的 Token
mask = cumulative_probs - sorted_probs > p
sorted_logits[mask] = -float('inf')
probs = F.softmax(sorted_logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
return sorted_indices.gather(-1, next_token)优点: 自适应候选集大小。当模型很确定时,可能只保留 2-3 个 Token;当模型不确定时,可能保留数百个 Token。
实践中的组合策略: 现代 LLM 推理通常同时使用 Temperature + Top-p(有时再加 Top-k)。例如 OpenAI API 的默认配置是 temperature=1.0, top_p=1.0,用户可根据任务调整。
各策略对比总结
| 策略 | 确定性 | 多样性 | 典型应用 |
|---|---|---|---|
| Greedy | 完全确定 | 无 | 分类、提取、简单问答 |
| Beam Search | 近似确定 | 低 | 机器翻译、摘要 |
| Top-k Sampling | 随机 | 中 | 通用生成 |
| Top-p Sampling | 随机 | 中高 | 通用生成(更自适应) |
| Temperature < 1 | 偏确定 | 低 | 事实性任务、代码 |
| Temperature > 1 | 偏随机 | 高 | 创意写作、多样化输出 |
进阶话题:重复惩罚与停止条件
实际部署中还需要考虑:
- 重复惩罚(Repetition Penalty): 降低已生成 Token 的概率,避免模型陷入重复循环
- 频率惩罚 / 存在惩罚(Frequency / Presence Penalty): OpenAI API 提供的两种不同粒度的重复控制
- 停止条件: 生成 EOS Token、达到最大长度、或匹配到指定的停止字符串
代码实战
本节从零实现四种解码策略,并用同一个 prompt 对比生成效果。所有代码基于 PyTorch,可直接在 HuggingFace 模型上运行。
统一的采样框架
import torch
import torch.nn.functional as F
def sample_next_token(logits, strategy="greedy", temperature=1.0, top_k=50, top_p=0.9):
"""
统一的 Token 采样函数,支持多种策略组合
Args:
logits: 模型输出的 logits,shape [vocab_size]
strategy: "greedy" | "sample"(采样时可叠加 temperature/top_k/top_p)
temperature: 温度参数,越低越确定
top_k: Top-k 截断,0 表示不使用
top_p: Top-p 核采样阈值,1.0 表示不使用
Returns:
选中的 token id(标量)
"""
if strategy == "greedy":
return logits.argmax(dim=-1)
# Step 1: Temperature 缩放
if temperature != 1.0:
logits = logits / temperature
# Step 2: Top-k 过滤
if top_k > 0:
top_k = min(top_k, logits.size(-1))
kth_value = logits.topk(top_k).values[-1]
logits = logits.where(logits >= kth_value, torch.tensor(float('-inf')))
# Step 3: Top-p 过滤
if top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
sorted_probs = F.softmax(sorted_logits, dim=-1)
cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
# 移除累积概率超过 p 的 Token(保留第一个超过的)
mask = cumulative_probs - sorted_probs > top_p
sorted_logits[mask] = float('-inf')
# 还原到原始顺序
logits = sorted_logits.scatter(-1, sorted_indices.argsort(-1), sorted_logits)
# Step 4: 采样
probs = F.softmax(logits, dim=-1)
return torch.multinomial(probs, num_samples=1).squeeze(-1)Beam Search 从零实现
def beam_search(model, input_ids, max_length, beam_width=4, length_penalty=0.6):
"""
Beam Search 实现
Args:
model: 语言模型(输入 token ids,输出 logits)
input_ids: 初始 token 序列,shape [1, seq_len]
max_length: 最大生成长度
beam_width: 束宽度
length_penalty: 长度惩罚系数(>0 鼓励更长序列)
Returns:
最优序列的 token ids
"""
device = input_ids.device
# 每条 beam: (累积 log 概率, token 序列)
beams = [(0.0, input_ids.squeeze(0).tolist())]
completed = []
for step in range(max_length):
all_candidates = []
for score, seq in beams:
# 如果这条 beam 已结束,直接保留
if seq[-1] == model.config.eos_token_id:
completed.append((score, seq))
continue
# 前向传播获取下一步 logits
ids = torch.tensor([seq], device=device)
with torch.no_grad():
logits = model(ids).logits[0, -1, :] # [vocab_size]
log_probs = F.log_softmax(logits, dim=-1)
# 取 Top-k 个候选
topk_log_probs, topk_ids = log_probs.topk(beam_width)
for i in range(beam_width):
new_score = score + topk_log_probs[i].item()
new_seq = seq + [topk_ids[i].item()]
all_candidates.append((new_score, new_seq))
if not all_candidates:
break
# 按长度归一化得分排序,保留 Top beam_width 条
def normalized_score(item):
score, seq = item
return score / (len(seq) ** length_penalty)
all_candidates.sort(key=normalized_score, reverse=True)
beams = all_candidates[:beam_width]
# 从 completed + beams 中选最优
all_results = completed + beams
best = max(all_results, key=normalized_score)
return best[1]完整运行示例:对比四种策略
from transformers import AutoModelForCausalLM, AutoTokenizer
# 加载模型(用小模型便于本地实验)
model_name = "gpt2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)
model.eval()
prompt = "The future of artificial intelligence is"
input_ids = tokenizer.encode(prompt, return_tensors="pt")
def generate(input_ids, max_new_tokens=30, **kwargs):
"""通用生成函数"""
ids = input_ids.clone()
for _ in range(max_new_tokens):
with torch.no_grad():
logits = model(ids).logits[0, -1, :]
next_id = sample_next_token(logits, **kwargs)
ids = torch.cat([ids, next_id.unsqueeze(0).unsqueeze(0)], dim=-1)
if next_id == tokenizer.eos_token_id:
break
return tokenizer.decode(ids[0], skip_special_tokens=True)
# 对比不同策略
strategies = {
"Greedy": dict(strategy="greedy"),
"Sample (T=0.7)": dict(strategy="sample", temperature=0.7, top_p=1.0),
"Top-k (k=50)": dict(strategy="sample", top_k=50, top_p=1.0),
"Top-p (p=0.9)": dict(strategy="sample", top_k=0, top_p=0.9),
"Top-p + Low T (T=0.3)":dict(strategy="sample", temperature=0.3, top_p=0.9),
}
for name, params in strategies.items():
result = generate(input_ids, max_new_tokens=30, **params)
print(f"[{name}]")
print(f" {result}\n")动手实验
- 把
temperature分别设为 0.1、0.5、1.0、2.0,观察输出的确定性和多样性变化 - 把
top_p从 0.1 逐步增加到 1.0,对比生成文本的质量 - 对同一 prompt 多次运行采样策略,观察输出的方差——这就是为什么 ChatGPT 每次回答不同
- 尝试用 Beam Search 生成,对比它与采样策略在"事实性问题"vs"创意写作"上的表现差异
苏格拉底时刻
请停下来思考以下问题,不急于查看答案:
- Beam Search 的 beam width 越大结果越好吗?为什么在开放式对话中 Beam Search 表现不如采样方法?
- Temperature 和 Top-p 都能控制生成的"随机性"——它们的作用机制有什么本质区别?能否只用其中一个?
- 如果你需要让 LLM 生成 JSON 格式的结构化输出,应该如何选择解码策略?为什么?
- 为什么 Greedy Decoding 不能保证找到全局最优序列?能否构造一个具体的反例?
- "采样导致的随机性"和"模型能力的不确定性"是一回事吗?如何区分一个错误答案是因为解码策略不好还是因为模型本身能力不足?
常见问题 & 面试考点
- Q: Top-k 和 Top-p 可以同时使用吗? 可以,先做 Top-k 截断,再在 Top-k 结果中做 Top-p 过滤。很多推理框架支持同时设置。
- Q: Temperature = 0 和 Greedy 完全等价吗? 数学极限上等价(
时 softmax 退化为 argmax)。实现中 temperature=0通常直接走 argmax 逻辑。 - Q: 为什么 ChatGPT 同一个问题每次回答不同? 因为使用了采样策略(Temperature > 0),每次从概率分布中随机采样的结果不同。
- Q: Speculative Decoding 是什么? 用一个小模型快速生成候选 Token,再用大模型并行验证,从而加速推理。这不改变生成分布,只加速推理过程。
推荐资源
- HuggingFace "How to generate text" — 各种解码策略的交互式教程和代码示例
- Holtzman et al.《The Curious Case of Neural Text Degeneration》 — 提出 Top-p/Nucleus Sampling 的论文
- Leviathan et al.《Fast Inference from Transformers via Speculative Decoding》 — Speculative Decoding 原始论文
- Welleck et al.《Neural Text Generation with Unlikelihood Training》 — 分析和缓解重复生成问题
- OpenAI API 文档中的参数说明 — 理解 temperature、top_p、frequency_penalty 等参数的实际效果