Flow Matching
一句话总结
Flow Matching 是一种比 Diffusion 更简洁的生成式建模方法——不再需要定义复杂的噪声 schedule 和去噪过程,而是直接学习一个向量场(velocity field),让噪声沿着 ODE 轨迹流向目标分布。训练目标简单到一行公式:
在大模型体系中的位置
┌────────────────────────────────────────────────────────────┐
│ Generative Models for AI │
│ │
│ VAE ──→ GAN ──→ Diffusion ──→ [Flow Matching] ──→ ??? │
│ ↑ ↑ │
│ 加噪+去噪 直接学 ODE 向量场 │
│ (复杂 schedule) (线性插值,更简洁) │
│ │
│ 应用: 图像生成 (Stable Diffusion 3) │
│ 语音合成 (Voicebox) │
│ 蛋白质结构 (FrameFlow) │
│ LLM 文本生成 (MDLM, Flow Matching for Discrete Data) │
└────────────────────────────────────────────────────────────┘Flow Matching 正在成为 Diffusion 的下一代替代方案。Stable Diffusion 3、Meta 的 Voicebox 等都已采用 Flow Matching。更重要的是,它正在被探索用于离散 token 的生成,可能为 LLM 带来全新的生成范式。
1. 从 Diffusion 到 Flow Matching
1.1 Diffusion 回顾
Diffusion Model 的核心思路:
前向过程(加噪): x_0 → x_1 → x_2 → ... → x_T ≈ N(0, I)
数据 逐步加噪 纯噪声
反向过程(去噪): x_T → x_{T-1} → ... → x_0
噪声 逐步去噪 生成数据前向过程(固定,不可学习):
反向过程(需要学习):
训练目标:预测噪声
Diffusion 的不足
- 噪声 schedule(
或 )需要精心设计,不同 schedule 影响巨大 - 前向/反向过程依赖 SDE(随机微分方程),数学推导复杂
- 采样需要数百到上千步,速度慢
- 理论框架和实际训练目标之间有 gap
1.2 Flow Matching 的核心思想
Flow Matching 的思路极其简洁:
直接学习一个向量场
,使得数据点沿着这个向量场从噪声分布"流"到目标分布。
Flow Matching:
t=0 t=1
噪声 x_0 ─────────── 沿向量场流动 ──────────────→ 数据 x_1
∼ N(0,I) dx/dt = v_θ(x_t, t) ∼ p_data用 ODE(常微分方程)描述:
给定
1.3 关键区别对比
| 特性 | Diffusion | Flow Matching |
|---|---|---|
| 核心方程 | SDE(随机微分方程) | ODE(常微分方程) |
| 学习目标 | 预测噪声 | 预测速度场 |
| 路径 | 由噪声 schedule 决定 | 可以自由设计(如线性) |
| 训练目标 | ||
| 采样 | 反向 SDE 求解(需要很多步) | ODE 求解(可以更少步) |
| 噪声 schedule | 需要精心设计 | 不需要 |
2. 数学基础
2.1 常微分方程(ODE)
给定向量场
这条路径定义了一个流(flow)
流的关键性质:
- 每个
有且只有一条轨迹 - 概率密度沿轨迹变化可以用连续性方程精确描述
2.2 概率路径
我们希望构造一个时间连续的概率分布族
中间的
连续性方程将概率路径
如果我们能找到一个
2.3 速度场(Velocity Field)
速度场
对于最简单的线性插值路径:
对应的速度场为:
这就是为什么 Flow Matching 的训练目标如此简洁——目标速度就是
3. Conditional Flow Matching (CFM)
3.1 为什么需要 "Conditional"?
直接优化 Flow Matching 目标需要知道边际向量场
核心技巧:条件化(conditioning)。我们不直接学边际向量场,而是学条件向量场
3.2 CFM 目标函数
Conditional Flow Matching (CFM) 损失:
对于线性条件路径
因此 CFM 损失简化为:
为什么比 Diffusion 训练更简单?
- 无需噪声 schedule:
就是线性插值,没有 这些超参数 - 目标直观:学习速度
(从噪声到数据的位移),比预测噪声 更直观 - 数学等价:Lipman et al. (2023) 证明优化 CFM 损失等价于优化真正的 Flow Matching 目标
3.3 训练算法
算法: Conditional Flow Matching 训练
Repeat:
1. 从数据集采样 x_1 ~ p_data
2. 从噪声分布采样 x_0 ~ N(0, I)
3. 均匀采样时间 t ~ U[0, 1]
4. 计算插值点 x_t = (1-t) * x_0 + t * x_1
5. 计算目标速度 v_target = x_1 - x_0
6. 预测速度 v_pred = v_θ(x_t, t)
7. 计算损失 L = ||v_pred - v_target||²
8. 反向传播更新 θ3.4 采样算法
算法: Flow Matching 采样(欧拉法)
1. 采样 x_0 ~ N(0, I)
2. 设置步长 Δt = 1/N
3. For t = 0, Δt, 2Δt, ..., 1-Δt:
v = v_θ(x_t, t) # 预测速度
x_{t+Δt} = x_t + v * Δt # 欧拉更新
4. 返回 x_1 作为生成样本4. Rectified Flow
4.1 直线路径的优势
Rectified Flow(Liu et al., 2023)进一步强调了直线路径的优越性。
核心观察:如果向量场
- ODE 求解只需要很少的步数(极端情况下一步就够)
- 采样速度大幅提升
一般 Flow Matching: Rectified Flow:
x_0 ~~~曲线~~~→ x_1 x_0 ──直线──→ x_1
需要很多步 可能只需 1 步4.2 Reflow 操作
Rectified Flow 的关键技巧——Reflow:
- 训练一个初始 flow model
- 用
从噪声 生成样本 - 用
配对重新训练一个新的 flow model
每次 Reflow 都会让轨迹更接近直线。经过 2-3 次 Reflow,即使用 1-2 步欧拉法也能得到高质量样本。
4.3 一步生成的可能性
如果轨迹是完美直线,一步欧拉法就是精确解。这使得 Flow Matching + Rectified Flow 成为最快的 diffusion 类生成方法之一。
Stable Diffusion 3 的选择
Stability AI 在 Stable Diffusion 3 中采用了 Rectified Flow 作为采样方法,相比 SD 1/2 的 DDPM/DDIM 采样,在相同步数下获得了更好的图像质量。
5. 代码实现
5.1 基础 Flow Matching(2D 演示)
用一个简单的 2D 例子理解 Flow Matching 的全过程——让噪声点"流向"正弦曲线:
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
# ===== 超参数 =====
dim = 2 # 2D 数据
num_samples = 200
lr = 1e-2
epochs = 100000
num_steps = 50 # ODE 求解步数
# ===== 1. 准备数据 =====
# 目标分布:正弦曲线上的点
x1_samples = torch.rand(num_samples, 1) * 4 * torch.pi # [0, 4π]
y1_samples = torch.sin(x1_samples)
target_data = torch.cat([x1_samples, y1_samples], dim=1) # (200, 2)
# 噪声分布:标准高斯
noise_data = torch.randn(num_samples, dim) * 2 # (200, 2)
# ===== 2. 定义向量场网络 =====
class VectorField(nn.Module):
def __init__(self):
super().__init__()
self.net = nn.Sequential(
nn.Linear(dim + 1, 64), # 输入: x(2) + t(1) = 3
nn.ReLU(),
nn.Linear(64, dim) # 输出: 速度向量(2)
)
def forward(self, x, t):
return self.net(torch.cat([x, t], dim=1))
# ===== 3. 训练 =====
model = VectorField()
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
for epoch in range(epochs):
idx = torch.randperm(num_samples)
x0 = noise_data[idx] # 起点:噪声
x1 = target_data[idx] # 终点:正弦曲线
t = torch.rand(x0.size(0), 1) # 随机时间 t ~ U[0,1]
xt = (1 - t) * x0 + t * x1 # 线性插值
vt_pred = model(xt, t) # 预测速度
vt_target = x1 - x0 # 目标速度
loss = torch.mean((vt_pred - vt_target) ** 2)
optimizer.zero_grad()
loss.backward()
optimizer.step()
# ===== 4. 采样(ODE 欧拉法) =====
x = noise_data[0:1] # 一个噪声起点
trajectory = [x.detach().numpy()]
t = 0
delta_t = 1 / num_steps
with torch.no_grad():
for i in range(num_steps):
vt = model(x, torch.tensor([[t]]))
x = x + vt * delta_t # 欧拉更新
t += delta_t
trajectory.append(x.detach().numpy())
# trajectory 记录了从噪声到正弦曲线的完整轨迹核心只有 4 行:
xt = (1 - t) * x0 + t * x1— 线性插值vt_pred = model(xt, t)— 预测速度vt_target = x1 - x0— 目标速度loss = ||vt_pred - vt_target||^2— MSE 损失
5.2 Conditional Flow Matching(带条件)
加入条件信息 tag,可以控制生成特定区域的样本:
class ConditionalVectorField(nn.Module):
def __init__(self):
super().__init__()
self.net = nn.Sequential(
nn.Linear(dim + 2, 64), # 输入: x(2) + t(1) + tag(1) = 4
nn.ReLU(),
nn.Linear(64, dim)
)
def forward(self, x, t, tag):
return self.net(torch.cat([x, t, tag], dim=1))
# 训练时加入条件
model = ConditionalVectorField()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
# 为每个样本分配条件标签(如按 x 坐标分区)
tags = (x1_samples / (4 * torch.pi / 10.0)).int()
for epoch in range(5000):
idx = torch.randperm(num_samples)
x0 = noise_data[idx]
x1 = target_data[idx]
t = torch.rand(x0.size(0), 1)
xt = (1 - t) * x0 + t * x1
vt_pred = model(xt, t, tags[idx].float()) # 传入条件
vt_target = x1 - x0
loss = torch.mean((vt_pred - vt_target) ** 2)
optimizer.zero_grad()
loss.backward()
optimizer.step()
# 采样时指定条件
tag_num = 3 # 生成 tag=3 区域的样本
x = torch.randn(1, dim)
t = 0
delta_t = 1 / 50
with torch.no_grad():
for i in range(50):
vt = model(x, torch.tensor([[t]]),
torch.tensor([[tag_num]], dtype=torch.float))
x = x + vt * delta_t
t += delta_t
# x 现在应该落在 tag=3 对应的正弦曲线区域Conditional Flow Matching 的意义
条件信息可以是任何东西:类别标签、文本 embedding、图像特征等。这就是为什么 CFM 能用于 text-to-image(文本条件)、text-to-speech(文本条件)等任务。
6. 在 LLM 中的应用
6.1 从连续到离散:离散 Flow Matching
传统 Flow Matching 处理连续数据(图像像素、音频波形)。但 LLM 的输入是离散 token。如何把 Flow Matching 用于离散数据?
几种主要方向:
方向一:嵌入空间 Flow Matching
离散 token → Embedding 空间(连续)→ Flow Matching → 连续向量 → 最近邻解码 → token代表工作:MDLM (Masked Diffusion Language Model)
方向二:离散概率 Flow Matching
在 simplex(概率单纯形)上定义 flow:
方向三:CTMC(连续时间马尔可夫链)
将离散扩散重新表述为连续时间框架下的转移速率矩阵,与 flow matching 的 ODE 形式对偶。
6.2 与自回归 LLM 的对比
| 特性 | 自回归 LLM | Flow Matching LLM |
|---|---|---|
| 生成方式 | 逐 token 生成 | 所有 token 并行精炼 |
| 生成速度 | O(n) 步 | O(T) 步(T 为 flow 步数,与序列长度无关) |
| 左到右依赖 | 强制从左到右 | 无方向限制,可并行 |
| 编辑能力 | 只能 append | 可以修改任意位置 |
| 训练目标 | next-token prediction | flow matching loss |
6.3 未来展望
Flow Matching 在 LLM 领域仍处于早期探索阶段,但有几个值得关注的方向:
- 非自回归文本生成:打破"从左到右"的限制,一次性生成整个句子
- 文本编辑与补全:利用 flow 的可逆性,对现有文本进行局部修改
- 多模态统一:图像和文本在同一个 flow 框架下生成
- 加速采样:Rectified Flow + 蒸馏,用 1-4 步生成完整文本
苏格拉底时刻
在继续之前,尝试回答这些问题:
核心区别:Flow Matching 和 Diffusion 的最本质区别是什么?(提示:SDE vs ODE,随机 vs 确定性)
为什么线性插值 work:
这么简单的路径就够了吗?凭什么不需要 Diffusion 那种精心设计的噪声 schedule? CFM 的巧妙之处:为什么"条件化"能让 intractable 的边际向量场变成 tractable 的训练目标?(提示:类比 VAE 的 ELBO)
一步生成:Rectified Flow 号称可以"一步生成",但实际一步生成的质量如何?瓶颈在哪里?
离散化挑战:将 Flow Matching 应用于离散 token 的核心困难是什么?(提示:梯度、连续空间与离散空间的映射)
参考思路
- Diffusion 的反向过程是 SDE(有随机噪声项),Flow Matching 用的是 ODE(纯确定性)。ODE 的轨迹确定且不交叉,使理论分析和加速采样更容易。
- 线性插值定义了最简单的"条件概率路径"。CFM 的理论保证了:只要条件路径合理,优化条件损失就等价于优化真正的 Flow Matching 目标。噪声 schedule 是 Diffusion 为了保证数学性质而"设计"出来的,Flow Matching 从构造上就绕过了这个需求。
- 边际向量场是对所有可能的
配对求期望,无法直接计算。但 CFM 证明了条件损失的梯度与边际损失的梯度相等(类似 score matching 的 Stein's identity 技巧),所以优化条件版本就等于优化边际版本。 - 一步生成需要轨迹是完美直线。实际中轨迹总有弯曲,Reflow 可以减少弯曲但不能完全消除。实际一步生成的质量低于多步,通常 4-8 步是质量和速度的 sweet spot。
- 离散空间没有连续梯度,不能直接定义 ODE。需要将离散 token 映射到连续空间(embedding/simplex),在连续空间做 flow,再映射回离散空间。这个"连续-离散"转换会引入误差。
推荐资源
| 资源 | 链接 | 说明 |
|---|---|---|
| Flow Matching 原论文 (Lipman et al.) | arxiv.org/abs/2210.02747 | CFM 的理论基础 |
| Rectified Flow 原论文 (Liu et al.) | arxiv.org/abs/2209.03003 | 直线路径 + Reflow 技巧 |
| Flow Matching Guide | arxiv.org/abs/2412.06264 | Meta 的综合教程 |
| Stable Diffusion 3 技术报告 | arxiv.org/abs/2403.03206 | Rectified Flow 在工业界的应用 |
| MDLM 论文 | arxiv.org/abs/2406.07524 | 离散 Flow Matching 用于语言模型 |
| Voicebox (Meta) | arxiv.org/abs/2306.15687 | CFM 用于语音合成 |
| 知乎 Flow Matching 入门 | zhuanlan.zhihu.com/p/28731517852 | 中文入门教程 |