【推理与部署篇07】投机采样(Speculative Decoding)深度解析:推理加速的终极武器
前言:大模型推理的核心瓶颈是什么?自回归解码——一次只生成一个token,生成100个token就要串行调用100次模型,GPU利用率不到10%。投机采样(Speculative Decoding)在2024-2026年异军突起,成为突破这一瓶颈的最有效手段之一。它用一个"小模型"快速草拟多个token,再用"大模型"并行验证——既保留了原始模型的质量,又能实现2-4倍的解码加速。本文从原理到实战,完整覆盖Greedy SD、Stochastic SD、Medusa、Eagle、Self-SD、Lookahead六大方案。
📋 目录
一、自回归解码的瓶颈与投机采样的直觉
二、投机采样核心原理:小模型草稿 + 大模型验证
三、Greedy Speculative Decoding:确定性加速
四、Stochastic Speculative Decoding:分布一致性加速
五、Medusa:多头投机解码
六、Eagle:特征级投机解码
七、Self-Speculative Decoding:自草稿技术
八、Lookahead Decoding与Blockwise Parallel Decoding
九、各方案性能对比与选型指南
十、生产部署最佳实践
十一、面试高频问答
一、自回归解码的瓶颈与投机采样的直觉
1.1 自回归解码为什么这么慢?
自回归解码(Autoregressive Decoding)的本质:
LLM是一个函数 F(token_i) → logits_{i+1}
生成过程:
Step 1: 输入 [BOS] → F → logits_1 → token_1
Step 2: 输入 [BOS, t1] → F → logits_2 → token_2
Step 3: 输入 [BOS, t1, t2] → F → logits_3 → token_3
…
Step N: 输入 [BOS, t1…tN-1] → F → logits_N → token_N
每个token只能串行生成,无法并行!
核心瓶颈分析:
┌─────────────────────────────────────────────────────────────────┐
│ 假设模型 = 70B LLaMA,GPU = H100 │
│ │
│ 理论算力:1979 TFLOPS(FP16) │
│ 实际利用率:5-15%(自回归解码时) │
│ │
│ 为什么利用率这么低? │
│ 1. 每个step只生成1个token → 计算量极小 │
│ 2. 矩阵计算中batch_size=1 → 无法充分利用Tensor Core │
│ 3. KV Cache加载频繁 → 内存带宽瓶颈 │
│ │
│ 一句话:自回归解码 = 算力浪费 │
│ 生成100个token = 做了100次小矩阵乘法 │
│ 如果能一次做4次 → 利用率立刻翻4倍 │
└─────────────────────────────────────────────────────────────────┘
1.2 直觉:什么是投机采样?
投机采样(Speculative Decoding)的非技术类比:
想象你是要写一篇1000字的文章(生成1000个token)。
普通写法(自回归):
写一个字,检查一遍,再写下一个字 → 非常慢
投机写法(投机解码):
先让实习生(小模型)写20个字,
然后你(大模型)一次性审阅这20个字
→ 如果全部正确 → 一次性通过(20个token一次完成)
→ 如果有错 → 改错,错误位置之后的重写
关键洞察:
- 小模型计算量小,可以快速生成多个候选
- 大模型一次前向可以验证多个候选
- 验证的计算量 ≈ 生成1个token的计算量
- 只要正确率 > 30%,就有加速收益
投机采样的直觉公式:
GPU利用率 ↗ = tokens_步长 / 时间_步长
原来:1 token / 步
投机:平均 2-4 tokens / 步 →
利用率从 5% → 15-25%
1.3 投机采样的三个核心问题
任何投机采样方案都需要回答三个问题:
Q1: 谁来草稿? (Draft Model)
- 外部小模型(如 1B-7B 参数)
- 同一模型的小层版本
- 额外的预测头(Medusa Head)
- 无额外模型的LM Head复用
Q2: 一次草稿几个? (Speculation Length / Lookahead Window) - 固定长度:gamma=3, 5, 10
- 自适应长度:根据接受率动态调整
- 经验值:gamma=3-5 时性价比最高
Q3: 如何保证最终质量一致? (Rejection / Verification) - Greedy模式:草稿token必须等于大模型argmax
- Stochastic模式:基于概率的拒绝采样
- 关键:投机采样不改变原始模型的输出分布
二、投机采样核心原理:小模型草稿 + 大模型验证
2.1 通用流程
投机采样的标准流程(以gamma=3为例):
┌─────────────────────────────────────────────────────────┐
│ │
│ Step 1: Draft(草稿阶段) │
│ 小模型 M_q 基于当前context生成3个token │
│ context: [t1, t2, …, tn] │
│ 草稿: [q_{n+1}, q_{n+2}, q_{n+3}] │
│ │
│ Step 2: Verify(验证阶段) │
│ 大模型 M_p 一次前向计算: │
│ 输入: [t1, t2, …, tn, q_{n+1}, q_{n+2}, q_{n+3}] │
│ 输出: logits_1, logits_2, logits_3, logits_4 │
│ 每个logits位置对应大模型对下一个token的预测分布 │
│ │
│ Step 3: Accept(接受阶段) │
│ 逐位置检查草稿token是否被大模型接受: │
│ pos1: q_{n+1} == argmax(logits_1) → 接受 ✅ │
│ pos2: q_{n+2} == argmax(logits_2) → 接受 ✅ │
│ pos3: q_{n+3} == argmax(logits_3) → 拒绝 ❌ │
│ │
│ Step 4: Output(输出 + 回退) │
│ 输出: [q_{n+1}, q_{n+2}] │
│ 回退: 在pos3从大模型的logits_3中采样新的token │
│ │
│ 本轮生成了 2 个token,而不是串行的 3 步 │
│ 大模型只做了 1 次前向(验证3个位置) │
│ 串行需要 3 次前向 → 投机只用 1 次 → ~3x 加速 │
│ │
└──────────────────────────────────────────────────────────┘
2.2 接受率(Acceptance Rate)——加速比的决定因素
接受率 = 草稿token被大模型接受的比例
这是投机采样加速比的核心变量。
加速比的理论公式:
Speedup = (1 + gamma * acceptance_rate + acceptance_rate^2 + …)
/ (gamma * cost_ratio + 1)
其中:
- gamma = 草稿长度(一次草稿几个token)
- acceptance_rate = 单token被接受的概率
- cost_ratio = 小模型计算成本 / 大模型计算成本
通常 cost_ratio = 0.1-0.3(小模型小几倍到十几倍)
简化近似(当 acceptance_rate 接近1时):
Speedup ≈ 1 / (1/gamma + acceptance_rate × cost_ratio)
实际经验值:
┌─────────────────┬──────────────┬─────────────┬────────────────┐
│ 场景 │ 草稿模型 │ 接受率 │ 实测加速比 │
├─────────────────┼──────────────┼─────────────┼────────────────┤
│ 小模型草稿 │ 7B → 70B │ 60-80% │ 2.0-2.5x │
│ Medusa │ 无(head) │ 70-85% │ 2.5-3.0x │
│ Self-SD │ 前几层 │ 40-60% │ 1.3-1.8x │
│ Eagle │ 特征对齐 │ 80-95% │ 3.0-4.0x │
│ Lookahead │ LM Head复用 │ 30-50% │ 1.2-1.5x │
└─────────────────┴──────────────┴─────────────┴────────────────┘
2.3 为什么投机采样不改变输出分布?
这是投机采样最重要的理论保证。
对于Stochastic(随机采样)模式:
草稿token q 从小模型分布 M_q(·|context) 采样
验证:以概率 min(1, M_p(q)/M_q(q)) 接受q
如果拒绝:从调整后的分布 M_p(·) - M_q(·) 中重新采样
数学证明:
接受概率 = min(1, M_p(q) / M_q(q))
这保证了最终采样分布 = M_p(大模型的原始分布)
因为拒绝采样的校正机制恰好抵消了小模型的偏差
直观理解:
- 如果小模型对某个token的预测概率比大模型还高
→ 说明小模型对这个token比大模型还自信
→ 应该接受(概率加权) - 如果小模型预测概率很低但大模型预测高
→ 很可能被拒绝
→ 从大模型分布重新采样
关键结论:
投机采样 = 无损加速
生成的文本质量与直接使用大模型解码完全等价
2.4 实现投机采样的完整伪代码
import torch
import torch.nn.functional as F
def speculative_decoding_step(
target_model, # 大模型 M_p
draft_model, # 小模型 M_q
input_ids, # 当前context,shape [1, seq_len]
gamma: int = 5, # 草稿长度
temperature: float = 1.0,
):
“”"
单步投机采样
Args:
target_model: 目标大模型
draft_model: 草稿小模型
input_ids: 当前输入
gamma: 草稿长度
temperature: 采样温度
Returns:
accepted_tokens: 本轮接受的token列表
new_input_ids: 更新后的输入
"""
# ===== Phase 1: Draft — 用小模型草稿gamma个token =====
draft_ids = input_ids.clone()
draft_logprobs = [] # 记录每个草稿token在小模型下的log概率
with torch.no_grad():
for _ in range(gamma):
logits_q = draft_model(draft_ids) # [1, vocab]
logits_q = logits_q[:, -1, :] # 取最后一个位置
probs_q = F.softmax(logits_q / temperature, dim=-1)
dist_q = torch.distributions.Categorical(probs_q)
token_q = dist_q.sample() # [1, 1]
logprob_q = dist_q.log_prob(token_q)
draft_ids = torch.cat([draft_ids, token_q], dim=-1)
draft_logprobs.append(logprob_q)
draft_tokens = draft_ids[:, input_ids.shape[-1]:] # 草稿token序列
# ===== Phase 2: Verify — 大模型并行验证 =====
with torch.no_grad():
logits_p = target_model(draft_ids) # 一次前向,计算所有位置
logits_p = logits_p[:, -(gamma+1):] # 取草稿相关的gamma+1个位置
# 计算每个位置大模型下取草稿token的概率
probs_p = F.softmax(logits_p / temperature, dim=-1) # [1, gamma+1, vocab]
# ===== Phase 3: Accept — 逐位置判断是否接受 =====
accepted = []
rejection_point = gamma
# 最后一个位置没用草稿,直接采样
# 位置0~gamma-1对应草稿token的验证
for i in range(gamma):
token_draft = draft_tokens[0, i].item()
if temperature == 0: # Greedy模式
token_greedy = probs_p[0, i].argmax().item()
if token_draft == token_greedy:
accepted.append(token_draft)
else:
rejection_point = i
break
else: # Stochastic模式 — 拒绝采样
prob_p = probs_p[0, i, token_draft].item()
prob_q = torch.exp(draft_logprobs[i]).item()
accept_prob = min(1.0, prob_p / max(prob_q, 1e-10))
if torch.rand(1).item() < accept_prob:
accepted.append(token_draft)
else:
# 从调整分布中采样新token
adjusted_probs = torch.clamp(
probs_p[0, i] - probs_q[0, i],
min=0
)
adjusted_probs /= adjusted_probs.sum()
new_token = torch.multinomial(adjusted_probs, 1).item()
accepted.append(new_token)
rejection_point = i
break
# ===== Phase 4: Final — 最后一个位置从大模型采样 =====
if rejection_point == gamma:
# 所有草稿都被接受 → 额外采样一个token
last_logits = logits_p[0, -1] # 位置gamma
last_probs = F.softmax(last_logits / temperature, dim=-1)
final_token = torch.multinomial(last_probs, 1).item()
accepted.append(final_token)
# 组装输出
new_tokens = torch.tensor(accepted, device=input_ids.device).unsqueeze(0)
new_input_ids = torch.cat([input_ids, new_tokens], dim=-1)
return accepted, new_input_ids
def compute_acceptance_rate(draft_tokens, target_logprobs, draft_logprobs, temperature=1.0):
“”"
计算接受率(用于调优和分析)
“”"
accepted = 0
for i in range(len(draft_tokens)):
prob_p = torch.exp(target_logprobs[i]).item()
prob_q = torch.exp(draft_logprobs[i]).item()
accept_prob = min(1.0, prob_p / max(prob_q, 1e-10))
if temperature == 0:
if prob_p > prob_q: # greedy: argmax匹配
accepted += 1
else:
if torch.rand(1).item() < accept_prob:
accepted += 1
return accepted / max(len(draft_tokens), 1)
三、Greedy Speculative Decoding:确定性加速
3.1 原理
Greedy模式是最简单的投机采样变体,适用于确定性的argmax解码。
规则:
- 草稿token必须等于大模型的argmax预测
- 相等 → 接受
- 不相等 → 在第一个不匹配位置拒绝,并使用大模型的argmax
这个模式没有随机性,适合需要可复现结果的场景(如测试、评估)。
Greedy vs Stochastic的接受率差异:
Greedy:要求argmax严格相等 → 接受率相对较低
Stochastic:按概率比采样 → 接受率更高
为什么Greedy模式也有加速效果?
即使严格argmax匹配,大模型的top-1预测中,
小模型草稿的正确率通常在30-80%之间,
这意味着平均每步仍能接受2-3个token(gamma=5时)。
3.2 性能评估
def evaluate_greedy_speculative(target, draft, dataset, gamma_values=[3, 5, 7]):
“”"
评估不同gamma值下Greedy SD的加速比
“”"
results = []
for gamma in gamma_values:
total_tokens_gen = 0
total_target_steps = 0
total_accepted = 0
for prompt in dataset:
input_ids = tokenize(prompt)
target_tokens = 200 # 生成目标长度
tokens_generated = 0
target_calls = 0
while tokens_generated < target_tokens:
draft_tokens = draft_draft(draft, input_ids, gamma)
verify_logits = target(draft_ids)
accepted = 0
for i in range(gamma):
if verify_logits[i].argmax() == draft_tokens[i]:
accepted += 1
else:
break
if accepted == gamma: # 全部接受
input_ids = extend(input_ids, draft_tokens)
# 额外从target的最后一个logits采1个
extra_token = verify_logits[-1].argmax()
input_ids = extend(input_ids, [extra_token])
tokens_generated += gamma + 1
else:
input_ids = extend(input_ids, draft_tokens[:accepted])
# 用target的argmax替换第一个失败位置
correct_token = verify_logits[accepted].argmax()
input_ids = extend(input_ids, [correct_token])
tokens_generated += accepted + 1
target_calls += 1
total_accepted += accepted
total_tokens_gen += tokens_generated
total_target_steps += target_calls
avg_accepted = total_accepted / total_target_steps
speedup = total_tokens_gen / total_target_steps
acceptance_rate = total_accepted / (avg_accepted * total_target_steps + 1e-6)
results.append({
"gamma": gamma,
"avg_accepted_per_step": avg_accepted,
"acceptance_rate": acceptance_rate,
"speedup": speedup,
})
return results
典型输出(LLaMA-7B draft → LLaMA-70B target):
┌─────────┬─────────────────────┬──────────────────┬──────────┐
│ gamma │ avg_accepted/step │ acceptance_rate │ speedup │
├─────────┼─────────────────────┼──────────────────┼──────────┤
│ 3 │ 2.1 │ 70% │ 2.3x │
│ 5 │ 2.8 │ 56% │ 2.5x │
│ 7 │ 3.2 │ 46% │ 2.4x │
└─────────┴─────────────────────┴──────────────────┴──────────┘
结论:gamma=5 时性价比最高,gamma继续增加收益递减
3.3 Greedy SD的优缺点
✅ 优点:
- 确定性的 → 可复现、适合测试
- 实现非常简单
- 不需要修改模型结构
- 小模型可以完全独立训练
❌ 缺点: - 接受率较低(严格的argmax匹配)
- 不适合temperature > 0的采样场景
- 小模型质量直接影响加速效果
适用场景:评估、测试、需要确定性输出的生产环境
四、Stochastic Speculative Decoding:分布一致性加速
4.1 原理
Stochastic模式基于拒绝采样(Rejection Sampling),
理论上保证输出分布与原始大模型完全一致。
核心公式:
接受概率 = min(1, M_p(q_i) / M_q(q_i))
其中 M_p(q_i) 和 M_q(q_i) 是token q_i在各自模型下的概率
关键性质:
- 输出分布与直接使用大模型采样完全等价 ✅
- 接受率理论上高于Greedy模式 ✅
- 需要计算两个模型的概率值,稍有额外开销
为什么Stochastic接受率更高?
假设大模型对token A的概率=0.8,小模型=0.6
Greedy:要求A同时是两个模型的argmax
Stochastic:接受概率 = min(1, 0.8/0.6) = 1.0 → 必然接受
假设大模型对token A的概率=0.3,小模型=0.8
Greedy:小模型可能选了A,但大模型可能选了别的
Stochastic:接受概率 = min(1, 0.3/0.8) = 0.375
→ 有37.5%的概率接受,而不是直接拒绝
4.2 实现优化
def stochastic_speculative_verify(
target_logits, # 大模型logits [gamma+1, vocab]
draft_logits, # 小模型logits [gamma, vocab]
draft_tokens, # 草稿token [gamma]
temperature: float = 1.0,
):
“”"
Stochastic模式的验证,包含关键优化
优化点:
1. 使用log-sum-exp技巧避免数值下溢
2. 并行计算所有位置的概率
"""
gamma = len(draft_tokens)
# 计算大模型下草稿token的概率
target_logprobs = F.log_softmax(target_logits[:gamma] / temperature, dim=-1)
# [gamma, vocab] 取草稿token对应的log概率
target_logprobs_draft = target_logprobs[
torch.arange(gamma), draft_tokens
] # [gamma]
# 计算小模型下草稿token的概率
draft_logprobs_draft = F.log_softmax(draft_logits / temperature, dim=-1)
draft_logprobs_draft = draft_logprobs_draft[
torch.arange(gamma), draft_tokens
] # [gamma]
# 接受概率 = min(1, p/q)
# log形式:accept = min(0, log_p - log_q)
log_accept_prob = torch.min(
torch.zeros_like(target_logprobs_draft),
target_logprobs_draft - draft_logprobs_draft
)
# 并行采样:每个位置独立判断是否接受
uniform_samples = torch.rand(gamma, device=target_logits.device).log()
accepted_mask = log_accept_prob > uniform_samples # [gamma], bool
# 找到第一个拒绝位置
if accepted_mask.all():
rejection_point = gamma # 全部接受
else:
rejection_point = (~accepted_mask).nonzero(as_tuple=True)[0][0].item()
# 如果拒绝,从调整分布中采样
if rejection_point < gamma:
# 计算调整分布:max(0, p - q)
target_probs_i = F.softmax(
target_logits[rejection_point] / temperature, dim=-1
)
draft_probs_i = F.softmax(
draft_logits[rejection_point] / temperature, dim=-1
)
adjusted = torch.clamp(target_probs_i - draft_probs_i, min=0)
adjusted_sum = adjusted.sum()
if adjusted_sum > 0:
adjusted /= adjusted_sum
replacement = torch.multinomial(adjusted, 1).item()
else:
# 极端情况:直接采target分布
replacement = torch.multinomial(target_probs_i, 1).item()
else:
replacement = None
return rejection_point, replacement
def batched_speculative_decode(
target_model, draft_model,
input_ids, batch_size=4, gamma=5, temperature=1.0
):
“”"
批量投机解码 — 同时处理多个请求
批量模式可以进一步提高GPU利用率
"""
batch_ids = input_ids.repeat(batch_size, 1)
batch_draft_ids = batch_ids.clone()
# 阶段1:为batch中每个请求分别草稿
with torch.no_grad():
for _ in range(gamma):
logits_q = draft_model(batch_draft_ids)
token_q = torch.argmax(logits_q[:, -1], dim=-1, keepdim=True)
batch_draft_ids = torch.cat([batch_draft_ids, token_q], dim=-1)
# 阶段2:批量验证 — GPU利用率大幅提升
with torch.no_grad():
batch_verify_logits = target_model(batch_draft_ids)
# 批量接受判断
# ... (每个请求独立判断)
return accepted_tokens, new_input_ids
4.3 理论保证的数学推导
定理:Stochastic Speculative Decoding的输出分布
与原始大模型采样分布完全一致。
证明(简化版):
设大模型分布为 p(x|c),小模型分布为 q(x|c)
对于草稿token x ~ q(·|c):
接受概率 α(x) = min(1, p(x)/q(x))
最终输出x的概率:
P(x) = q(x) × α(x) + (从调整分布采样的概率) × …
被接受时:
P(accept_x) = q(x) × min(1, p(x)/q(x))
= min(q(x), p(x))
被拒绝时,从 adjusted = max(0, p - q) 采样:
P(reject) = Σ_x max(0, q(x) - p(x))
P(x|reject) = max(0, p(x) - q(x)) / Σ_x max(0, p(x) - q(x))
最终:
P(output=x) = min(q(x), p(x)) + max(0, p(x) - q(x)) = p(x) ∎
关键结论:无论小模型质量如何,最终输出分布严格等于p(x)
五、Medusa:多头投机解码
5.1 Medusa的核心创新
传统投机采样的痛点:
需要维护一个额外的草稿模型 → 增加部署复杂度和显存
Medusa(美杜莎,2024):
不依赖外部草稿模型,而是在大模型最后一层添加多个预测头
每个预测头负责预测"向前n步"的token
架构对比:
┌─────────────────────────────────────────────────────┐
│ │
│ 传统SD: │
│ Input → [大模型] → output_token │
│ Input → [小模型] → draft_tokens (独立的额外模型) │
│ │
│ Medusa: │
│ Input → [大模型主干] → [hidden_state] │
│ → [Head_0] → token_0 │
│ → [Head_1] → token_1 │
│ → [Head_2] → token_2 │
│ → [Head_3] → token_3 │
│ │
│ 一个模型主干 + 多个轻量预测头 = 草稿生成 │
│ │
│ Medusa Head结构: │
│ hidden_state (4096-dim) │
│ → Linear(4096 → 4096) + SiLU │
│ → Linear(4096 → vocab_size) │
│ → logits │
│ │
│ 每个head只有 ~8M 参数(LLaMA-13B的0.06%) │
│ gamma个head总共也就几十M参数,非常轻量 │
│ │
└──────────────────────────────────────────────────────┘
5.2 Medusa训练:预测头的微调
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoModelForCausalLM
class MedusaHead(nn.Module):
“”"
Medusa预测头:基于hidden_state预测第k步的token
每个head负责预测"向前偏移k步"的token
head_0 = 主模型的LM Head(预测下一token)
head_k = Medusa Head(预测第k+1步的token)
"""
def __init__(self, hidden_size: int, vocab_size: int):
super().__init__()
self.linear1 = nn.Linear(hidden_size, hidden_size)
self.activation = nn.SiLU()
self.linear2 = nn.Linear(hidden_size, vocab_size)
def forward(self, hidden_states):
# hidden_states: [batch, seq_len, hidden_size]
x = self.linear1(hidden_states)
x = self.activation(x)
x = self.linear2(x)
return x # [batch, seq_len, vocab_size]
class MedusaModel(nn.Module):
“”"
Medusa模型 = 基础LLM + K个Medusa Head
Args:
base_model: 预训练LLM(如LLaMA-13B)
num_heads: Medusa Head数量 = gamma
hidden_size: base model的hidden size
vocab_size: 词表大小
"""
def __init__(
self,
base_model: nn.Module,
num_heads: int = 5,
hidden_size: int = 5120,
vocab_size: int = 32000,
):
super().__init__()
self.base_model = base_model
self.base_model.requires_grad_(False) # 冻结主干
# 保留原始的LM Head用于tree attention
self.original_lm_head = base_model.lm_head
# 添加K个Medusa Head
self.medusa_heads = nn.ModuleList([
MedusaHead(hidden_size, vocab_size)
for _ in range(num_heads)
])
# 树注意力掩码(用于并行验证多个候选路径)
self.register_buffer("tree_mask", self._build_tree_mask(num_heads))
def _build_tree_mask(self, num_heads):
"""
构建树注意力掩码
Medusa的树结构:
每个位置可以看到:
- 它自己(因果注意力要求)
- 它的祖先路径上的所有token
- 但看不到同级的其他分支
head_0: token_0
head_1: token_0 → token_1
head_2: token_0 → token_1 → token_2
...
每个位置的attention mask:
pos_0: 只看context
pos_1: 看context + pos_0
pos_2: 看context + pos_0 + pos_1
"""
total_positions = num_heads + 1 # 1个原始+num_heads个草稿
mask = torch.tril(torch.ones(total_positions, total_positions))
return mask.bool() # [total_positions, total_positions]
def forward(self, input_ids, hidden_states=None):
"""
前向传播:同时计算主干 + 所有Medusa Head
"""
if hidden_states is None:
# 第一次调用:通过主干计算hidden states
outputs = self.base_model(
input_ids,
output_hidden_states=True
)
hidden_states = outputs.hidden_states[-1]
# 用原始LM Head计算主预测
main_logits = self.original_lm_head(hidden_states[:, -1:])
# 用Medusa Heads计算草稿预测
draft_logits_list = []
for head in self.medusa_heads:
logits = head(hidden_states[:, -1:])
draft_logits_list.append(logits)
return main_logits, draft_logits_list, hidden_states
def generate_drafts(self, hidden_states):
"""
从当前hidden state生成草稿token
使用tree attention并行计算多个候选
"""
main_logits, draft_logits_list, _ = self.forward(None, hidden_states)
# 解码所有head的预测
draft_tokens = []
for i, logits in enumerate(draft_logits_list):
token = logits[:, -1].argmax(dim=-1)
draft_tokens.append(token)
return draft_tokens
def train_medusa_heads(
medusa_model: MedusaModel,
train_data,
num_epochs: int = 3,
lr: float = 1e-4,
device: str = “cuda”,
):
“”"
训练Medusa Head(只训练head,主干冻结)
训练数据:从基础模型自己生成的文本
目标:让Medusa Head学习主干模型的预测模式
"""
optimizer = torch.optim.AdamW(
medusa_model.medusa_heads.parameters(),
lr=lr
)
loss_fn = nn.CrossEntropyLoss()
medusa_model.train()
medusa_model.base_model.eval() # 冻结主干
for epoch in range(num_epochs):
total_loss = 0
for batch in train_data:
input_ids = batch["input_ids"].to(device)
with torch.no_grad():
# 获取主干模型的hidden states
outputs = medusa_model.base_model(
input_ids, output_hidden_states=True
)
hidden_states = outputs.hidden_states[-1]
# 准备训练目标
# head_k 预测第 k+1 步的token
# 所以head_k的目标 = input_ids[:, k+1:]
loss = 0
for k in range(len(medusa_model.medusa_heads)):
logits_k = medusa_model.medusa_heads[k](
hidden_states[:, :-k-1]
)
target_k = input_ids[:, k+1:] # 偏移k+1步
loss_k = loss_fn(
logits_k.reshape(-1, logits_k.size(-1)),
target_k.reshape(-1)
)
loss += loss_k
loss.backward()
torch.nn.utils.clip_grad_norm_(
medusa_model.medusa_heads.parameters(), 1.0
)
optimizer.step()
optimizer.zero_grad()
total_loss += loss.item()
print(f"Epoch {epoch}: Loss = {total_loss / len(train_data):.4f}")
return medusa_model
5.3 Medusa验证阶段的树注意力(Tree Attention)
def medusa_tree_verification(
target_model: MedusaModel,
input_ids: torch.Tensor,
draft_tokens: list, # [gamma] 个草稿token
tree_mask: torch.Tensor,
):
“”"
Medusa的树注意力验证
关键区别:Medusa不验证单个路径,而是验证所有可能的子路径
例如gamma=3时,可能的路径有:
[t0], [t0, t1], [t0, t1, t2](每个前缀都是一个候选)
使用tree attention一次前向验证所有候选路径
"""
# 构建所有候选路径
candidates = []
for i in range(len(draft_tokens)):
candidates.append(draft_tokens[:i+1])
# 将所有候选路径拼接成树结构
tree_input = input_ids.clone()
for path in candidates:
path_tensor = torch.tensor(path, device=input_ids.device).unsqueeze(0)
tree_input = torch.cat([tree_input, path_tensor], dim=-1)
# 树注意力验证
with torch.no_grad():
# 一次前向计算所有位置
outputs = target_model.base_model(
tree_input,
attention_mask=tree_mask,
output_hidden_states=True,
)
all_logits = target_model.original_lm_head(
outputs.hidden_states[-1]
)
# 逐位置验证(从最长路径开始)
best_path = []
for i in range(len(draft_tokens)):
path_start = input_ids.shape[-1]
path_logits = all_logits[:, path_start + i]
pred_token = path_logits.argmax(dim=-1).item()
if pred_token == draft_tokens[i]:
best_path.append(draft_tokens[i])
else:
best_path.append(pred_token)
break
return best_path
def medusa_speculative_generate(
model: MedusaModel,
input_ids: torch.Tensor,
max_new_tokens: int = 256,
gamma: int = 5,
):
“”"
Medusa完整生成流程
“”"
generated = []
with torch.no_grad():
# 初始前向,获取hidden states
outputs = model.base_model(
input_ids, output_hidden_states=True
)
hidden_states = outputs.hidden_states[-1]
while len(generated) < max_new_tokens:
# Step 1: 用Medusa Heads生成草稿
main_logits, draft_logits_list, hidden = model.generate_drafts(
hidden_states
)
draft_tokens = [
logits[:, -1].argmax(dim=-1).item()
for logits in draft_logits_list[:gamma]
]
# Step 2: 树注意力验证
accepted_path = medusa_tree_verification(
model, input_ids, draft_tokens,
model.tree_mask
)
# Step 3: 更新输入和hidden states
new_tokens = torch.tensor(
accepted_path, device=input_ids.device
).unsqueeze(0)
input_ids = torch.cat([input_ids, new_tokens], dim=-1)
# 更新hidden states(从最新的有效位置开始)
outputs = model.base_model(
input_ids[:, -len(accepted_path)-1:],
output_hidden_states=True
)
hidden_states = outputs.hidden_states[-1]
generated.extend(accepted_path)
return input_ids
Medusa加速比实验
if name == “main”:
“”"
典型实验结果(LLaMA-13B + 5 Medusa Heads):
┌──────────────┬─────────┬──────────┬─────────┬──────────┐
│ 模型配置 │ gamma │ 接受率 │ 加速比 │ 质量 │
├──────────────┼─────────┼──────────┼─────────┼──────────┤
│ LLaMA-13B │ - │ - │ 1.0x │ baseline │
│ LLaMA-13B │ 3 │ 78% │ 2.4x │ 一致 ✅ │
│ Medusa │ 5 │ 72% │ 2.8x │ 一致 ✅ │
│ LLaMA-13B │ 7 │ 65% │ 2.7x │ 一致 ✅ │
│ Medusa │ 10 │ 55% │ 2.5x │ 一致 ✅ │
└──────────────┴─────────┴──────────┴─────────┴──────────┘
结论:
- gamma=5 时加速比最高(2.8x)
- gamma>5 后,接受率下降导致收益递减
- 质量与基线完全一致
"""
print("Medusa: 2.5-3.0x lossless speedup with minimal overhead")
5.4 Medusa的优缺点
✅ 优点:
- 不需要外部草稿模型 → 部署简单
- Medusa Head非常轻量(几十M vs 几B参数)
- 训练快:只需1-3天在单GPU上微调
- 训练数据由基础模型自生成 → 不需要外部数据
- 集成vLLM等框架 → 开箱即用
❌ 缺点: - 需要微调head → 模型变动可能需要审核
- 树注意力实现稍复杂
- 训练数据与基础模型分布相关 → 域迁移时效果下降
- 不支持随意的模型替换(head固定带某个模型)
适用场景:对延迟敏感、不希望维护额外小模型的线上服务
六、Eagle:特征级投机解码
6.1 Eagle的核心创新
Eagle(2024年中发布)是目前(2026年)最先进的投机解码方案之一。
传统SD vs Medusa vs Eagle:
传统SD(外部小模型):
[小模型] → 输出logits → 采样token → 拼接并验证
问题:小模型与目标模型特征空间不一致,对齐困难
Medusa(额外预测头):
[大模型主干] → hidden_state → [Medusa Head] → logits
问题:每个head独立预测,特征被压缩到logits再解码
Eagle(特征级草稿):
[大模型主干] → hidden_state → [Eagle Head] → new_hidden_state → LM Head → token
关键:Eagle在特征空间(hidden state)而非logits空间草稿
使用自回归特征预测 + 前一token嵌入
Eagle架构:
┌────────────────────────────────────────────────────────┐
│ │
│ Step 1: [当前hidden_states h_t] │
│ + [前一草稿token嵌入 e_{q_1}] │
│ ↓ │
│ Step 2: [Eagle Head]:轻量Transformer解码层 │
│ h_{t+1}^draft = EagleHead(h_t, e_{q_1}) │
│ ↓ │
│ Step 3: [LM Head] → token_{q_1} │
│ ↓ │
│ Step 4: 重复: h_{t+1} + e_{q_2} → … │
│ │
│ Eagle Head = 1-2层Transformer解码层 │
│ 输入:上一个hidden state + 嵌入的草稿token │
│ 输出:下一个位置的hidden state预测 │
│ 然后用同样的LM Head解码 → 保证特征空间一致性 │
│ │
└──────────────────────────────────────────────────────────┘
6.2 Eagle模型实现
import torch
import torch.nn as nn
import torch.nn.functional as F
class EagleHead(nn.Module):
“”"
Eagle预测头:在特征空间生成草稿
用1-2层Transformer解码器层(与主干结构相同)
输入:上一个hidden state + 草稿token嵌入
输出:预测的下一个hidden state
关键设计:与主干模型共享LM Head和token嵌入表
"""
def __init__(
self,
hidden_size: int,
num_layers: int = 2,
num_attention_heads: int = 32,
intermediate_size: int = 13824,
):
super().__init__()
# 轻量Transformer解码器层
self.layers = nn.ModuleList([
nn.TransformerDecoderLayer(
d_model=hidden_size,
nhead=num_attention_heads,
dim_feedforward=intermediate_size,
activation="silu",
batch_first=True,
norm_first=True, # Pre-LayerNorm
)
for _ in range(num_layers)
])
self.final_norm = nn.LayerNorm(hidden_size)
def forward(
self,
prev_hidden: torch.Tensor, # [batch, hidden]
draft_token_embed: torch.Tensor, # [batch, hidden]
):
"""
prev_hidden: 上一个位置的hidden state
draft_token_embed: 上一步生成token的嵌入
预测下一个位置的hidden state
"""
# 拼接作为输入:[1, 2, hidden]
x = torch.stack([prev_hidden, draft_token_embed], dim=1)
for layer in self.layers:
x = layer(x, x) # self-attention only
x = self.final_norm(x)
# 取最后一个位置的输出作为预测的hidden state
return x[:, -1] # [batch, hidden]
class EagleModel(nn.Module):
“”"
Eagle模型 = 基础LLM + Eagle Head
与主干共享:LM Head、Token Embedding
Eagle Head只预测特征(hidden state),然后用共享LM Head解码
"""
def __init__(
self,
base_model: nn.Module,
num_layers: int = 2,
device: str = "cuda",
):
super().__init__()
self.base_model = base_model
self.base_model.requires_grad_(False) # 冻结主干
hidden_size = base_model.config.hidden_size
vocab_size = base_model.config.vocab_size
# 共享LM Head(原模型的)
self.lm_head = base_model.lm_head
# 共享Token Embedding
self.embed_tokens = base_model.model.embed_tokens
# Eagle Head — 轻量特征预测器
self.eagle_head = EagleHead(
hidden_size=hidden_size,
num_layers=num_layers,
num_attention_heads=base_model.config.num_attention_heads,
intermediate_size=base_model.config.intermediate_size,
).to(device)
@torch.no_grad()
def generate_drafts(
self,
hidden_states: torch.Tensor, # [batch, seq_len, hidden]
num_drafts: int = 5,
):
"""
在特征空间逐token生成草稿
每一步:
1. 用Eagle Head预测下一个hidden state
2. 用共享LM Head解码为token
3. 用共享Embedding将token转为特征
4. 重复
"""
draft_tokens = []
draft_hiddens = []
curr_hidden = hidden_states[:, -1:] # 当前位置的hidden
for _ in range(num_drafts):
# 用Eagle Head预测下一个hidden state
# 输入:当前hidden + 前置token嵌入
# 第一轮没有前置草稿token → 用真实token
if len(draft_tokens) == 0:
# 第一轮:用最后一个真实token的嵌入
prev_embed = hidden_states[:, -1:]
else:
prev_embed = self.embed_tokens(
torch.tensor([[draft_tokens[-1]]], device=curr_hidden.device)
)
next_hidden = self.eagle_head(
curr_hidden.squeeze(1),
prev_embed.squeeze(1)
).unsqueeze(1)
# 用共享LM Head解码
logits = self.lm_head(next_hidden)
token = logits[:, -1].argmax(dim=-1)
draft_tokens.append(token.item())
draft_hiddens.append(next_hidden)
curr_hidden = next_hidden
return draft_tokens, draft_hiddens
def verify_drafts(
self,
input_ids: torch.Tensor,
draft_tokens: list,
draft_hiddens: list,
):
"""
验证Eagle生成的草稿
与传统SD不同:Eagle的验证直接用主干模型
计算主干在草稿位置的logits,然后对比
"""
# 拼接原始输入 + 所有草稿token
draft_ids = torch.tensor([draft_tokens], device=input_ids.device)
verify_input = torch.cat([input_ids, draft_ids], dim=-1)
with torch.no_grad():
outputs = self.base_model(verify_input)
verify_logits = outputs.logits[:, -len(draft_tokens)-1:]
# 逐位置验证
accepted = []
for i, token in enumerate(draft_tokens):
pred = verify_logits[0, i].argmax().item()
if pred == token:
accepted.append(token)
else:
accepted.append(pred)
break
return accepted
def train_eagle_head(
eagle_model: EagleModel,
train_data, # 由主干模型生成的文本
num_epochs: int = 2,
lr: float = 1e-4,
device: str = “cuda”,
):
“”"
训练Eagle Head
目标:让Eagle Head能够准确预测下一个位置的hidden state
损失函数:MSE(hidden state级别的预测误差)
与Medusa的交叉熵损失不同!
"""
optimizer = torch.optim.AdamW(
eagle_model.eagle_head.parameters(), lr=lr
)
mse_loss = nn.MSELoss()
eagle_model.train()
eagle_model.base_model.eval()
for epoch in range(num_epochs):
total_loss = 0
for batch in train_data:
input_ids = batch["input_ids"].to(device)
with torch.no_grad():
# 获取主干模型在每个位置的hidden state
outputs = eagle_model.base_model(
input_ids, output_hidden_states=True
)
hidden_states = outputs.hidden_states[-1]
# hidden_states: [batch, seq_len, hidden]
# 训练Eagle Head预测 h_{t+1}
# 输入:h_t,预测 h_{t+1}
loss = 0
for t in range(input_ids.shape[1] - 1):
h_t = hidden_states[:, t] # 当前hidden
h_next_true = hidden_states[:, t+1] # 真实下一个hidden
# 前一个token的嵌入
prev_token = input_ids[:, t]
prev_embed = eagle_model.embed_tokens(prev_token)
# Eagle预测
h_next_pred = eagle_model.eagle_head(h_t, prev_embed)
loss += mse_loss(h_next_pred, h_next_true)
loss /= (input_ids.shape[1] - 1) # 平均
loss.backward()
torch.nn.utils.clip_grad_norm_(
eagle_model.eagle_head.parameters(), 1.0
)
optimizer.step()
optimizer.zero_grad()
total_loss += loss.item()
print(f"Epoch {epoch}: Avg Loss = {total_loss / len(train_data):.6f}")
return eagle_model
Eagle vs Medusa 的核心差异
def compare_eagle_vs_medusa():
“”"
Eagle优于Medusa的核心原因:
1. 特征级预测 vs 输出级预测
Medusa:从hidden → 直接预测logits
Eagle:从hidden → 预测下一个hidden → 共享LM Head
Eagle保留了完整的特征空间信息
特征空间比logits空间包含更多信息
2. 共享LM Head
Eagle与主干共享LM Head → 无额外参数
Medusa有独立Head → 需要训练后对齐
3. 接受率更高
Eagle在特征级别预测 → 与主干更对齐
实际测试中Eagle接受率比Medusa高10-15个百分点
Eagle官方数据(LLaMA-2-Chat-13B):
┌────────────────┬──────────┬──────────┬──────────┐
│ 方案 │ gamma=3 │ gamma=5 │ gamma=10 │
├────────────────┼──────────┼──────────┼──────────┤
│ Medusa │ 2.4x │ 2.8x │ 2.5x │
│ Eagle │ 2.8x │ 3.5x │ 3.8x │
│ Eagle (top-5) │ 3.1x │ 3.8x │ 4.1x │
└────────────────┴──────────┴──────────┴──────────┘
Eagle在gamma=7-10时达到最优
Medusa在gamma=5左右达到峰值
"""
pass
6.3 Eagle的优缺点
✅ 优点:
- 最高的接受率(80-95%)→ 3-4x加速
- 特征级预测 → 信息更完整
- 共享LM Head → 无需额外训练token预测
- 训练损失MSE,收敛速度快
- gamma可以更大(7-10),利用更长草稿
❌ 缺点: - 需要修改模型的前向(获取中间hidden state)
- Eagle Head稍大于Medusa Head(但远小于完整小模型)
- 实现复杂度高于Medusa
- 需要小心处理模型并行下的hidden state通信
适用场景:追求极致加速比、能接受微调的生产环境
七、Self-Speculative Decoding:自草稿技术
7.1 原理
Self-Speculative Decoding(自投机解码,2024):
不需要任何外部模型或额外Head,利用模型自身的层次结构。
核心思想:
大模型的前几层(浅层)输出"粗糙"的hidden state
后几层(深层)对hidden state进行"精炼"
→ 浅层 + LM Head = 草稿模型(快速但粗糙)
→ 完整模型 = 验证模型(慢但精确)
架构:
┌─────────────────────────────────────────────────────┐
│ │
│ Self-SD = 同一个模型的两个"切片" │
│ │
│ 完整模型 M_p(全部N层): │
│ Layer 1 → Layer 2 → … → Layer N-1 → Layer N → LM Head │
│ │
│ 草稿模型 M_q(前K层截断): │
│ Layer 1 → Layer 2 → … → Layer K → LM Head │
│ │
│ K的选择:通常K = N/4 ~ N/3 │
│ 例如LLaMA-70B共80层,K=20层 │
│ → 草稿模型计算量=完整模型的1/4 │
│ → 但质量高(因为用了前20层的全部信息) │
│ │
│ 关键优势: │
│ - 草稿模型与主干共享参数 → 零额外显存 │
│ - 草稿模型天然与主干对齐 → 接受率最高 │
│ - 不需要任何训练或微调 │
│ │
└──────────────────────────────────────────────────────┘
7.2 Self-SD实现
class SelfSpeculativeDecoder:
“”"
Self-Speculative Decoding实现
利用同一个模型的不同深度作为draft和target
"""
def __init__(
self,
model, # HuggingFace模型
draft_layers: int, # 用前多少层作为draft
gamma: int = 5,
):
self.model = model
self.draft_layers = draft_layers
self.gamma = gamma
self.total_layers = model.config.num_hidden_layers
self.lm_head = model.lm_head
self.embed_tokens = model.model.embed_tokens
print(f"Self-SD: draft={draft_layers}层, target={self.total_layers}层")
print(f"draft计算比: {draft_layers/self.total_layers:.1%}")
@torch.no_grad()
def _forward_until_layer(self, input_ids, stop_layer):
"""
计算到指定层的前向
返回指定层的hidden state
"""
hidden_states = self.embed_tokens(input_ids)
for i in range(stop_layer):
hidden_states = self.model.model.layers[i](
hidden_states
)[0]
return hidden_states # [batch, seq, hidden]
@torch.no_grad()
def generate_drafts(self, input_ids):
"""
用浅层+LM Head生成草稿
"""
draft_tokens = []
for _ in range(self.gamma):
# 计算到draft_layers层
hidden = self._forward_until_layer(input_ids, self.draft_layers)
logits = self.lm_head(hidden[:, -1:])
token = logits[:, -1].argmax(dim=-1, keepdim=True)
draft_tokens.append(token.item())
input_ids = torch.cat([input_ids, token], dim=-1)
return draft_tokens
@torch.no_grad()
def verify(self, input_ids, draft_tokens):
"""
用完整模型验证草稿
"""
draft_ids = torch.tensor(
[draft_tokens], device=input_ids.device
)
verify_input = torch.cat([input_ids[:, :-1], draft_ids], dim=-1)
# 完整前向
outputs = self.model(verify_input)
verify_logits = outputs.logits # [1, seq+gamma, vocab]
# 只取草稿位置的logits
verify_logits = verify_logits[:, -self.gamma-1:]
# 验证
accepted = []
for i, token in enumerate(draft_tokens):
pred = verify_logits[0, i].argmax().item()
if pred == token:
accepted.append(token)
else:
accepted.append(pred)
break
if len(accepted) == self.gamma:
# 全部接受,额外采一个
extra = verify_logits[0, self.gamma].argmax().item()
accepted.append(extra)
return accepted
@torch.no_grad()
def generate(self, input_ids, max_new_tokens=256):
"""
Self-SD完整生成循环
"""
total_generated = 0
draft_calls = 0
target_calls = 0
while total_generated < max_new_tokens:
# Draft
draft_tokens = self.generate_drafts(input_ids)
draft_calls += len(draft_tokens)
# Verify
accepted = self.verify(input_ids, draft_tokens)
target_calls += 1
# Update
new_tokens = torch.tensor(
accepted, device=input_ids.device
).unsqueeze(0)
input_ids = torch.cat([input_ids, new_tokens], dim=-1)
total_generated += len(accepted)
efficiency = total_generated / target_calls
print(f"生成{total_generated}token, target调用{target_calls}次")
print(f"平均每一步生成{efficiency:.2f}个token")
print(f"加速比(理论): {efficiency:.2f}x")
return input_ids
Self-SD性能实验
def self_sd_experiment():
“”"
典型实验结果(LLaMA-2-70B, 80层, draft=20层):
┌──────────┬───────────┬────────────┬──────────┬──────────┐
│ draft层 │ 计算比 │ 接受率 │ 加速比 │ 质量 │
├──────────┼───────────┼────────────┼──────────┼──────────┤
│ 10 │ 12.5% │ 35% │ 1.2x │ 一致 ✅ │
│ 20 │ 25% │ 55% │ 1.6x │ 一致 ✅ │
│ 30 │ 37.5% │ 70% │ 1.8x │ 一致 ✅ │
│ 40 │ 50% │ 80% │ 1.7x │ 一致 ✅ │
└──────────┴───────────┴────────────┴──────────┴──────────┘
结论:
- draft=30层时加速比最高(1.8x)
- draft层越多 → 接受率越高 → 但draft成本也越高
- 最优值在计算比和接受率之间平衡
- Self-SD加速比 < Medusa/Eagle(1.5-2x vs 2.5-4x)
- 但Self-SD**不需要任何训练** → 部署最快
"""
pass
7.3 Self-SD的优缺点
✅ 优点:
- 零训练成本 → 开箱即用
- 零额外显存 → 不需要额外模型参数
- 草稿模型与主干天然对齐 → 接受率稳定
- 实现简单,不需要额外模型
- 适合快速部署场景
❌ 缺点: - 加速比有限(1.5-2x,不如Medusa/Eagle的3-4x)
- 需要多次前向(虽然只到浅层)
- draft层数选择需要调优
- 浅层LM Head预测质量有限
适用场景:快速部署、不允许多余模型、零训练成本
八、Lookahead Decoding与Blockwise Parallel Decoding
8.1 Lookahead Decoding
Lookahead Decoding(前瞻解码,2024):
不需要任何额外模型或head,只看logits本身。
核心洞察:
大模型最后一个hidden state → logits
logits的top-k token中包含了"未来可能"的候选
方法:
- 从当前logits中采多个候选token(top-5或top-10)
- 对每个候选,再次计算logits
- 找到能"延续最长的合理序列"
本质上就是在做"beam search",但只做一步深度的探索。
Lookahead Decoding vs 标准采样:
┌─────────────────────────────────────────────────────┐
│ │
│ 标准采样: │
│ token_0 → token_1 → token_2 → token_3 → token_4 │
│ 每一步1次前向 │
│ │
│ Lookahead: │
│ logits_0 → 采5个候选 → 每个做一步预测 │
│ → 找到"预测最准的"路径 → 一次性接受多个token │
│ │
│ 每步还是1次前向(但batch_size=5) │
│ 因为有batch维度,GPU利用率更高 │
│ │
└──────────────────────────────────────────────────────┘
8.2 Blockwise Parallel Decoding
Blockwise Parallel Decoding(BPD,块并行解码,2023):
对Lookahead Decoding的扩展,使用n-gram预测。
核心思想:
当前hidden state + n-gram缓存 → 预测接下来的token块
n-gram缓存:
在生成过程中维护一个"高频n-gram"字典
例如:“自然语言” → “处理”(3-gram)
BPD利用这些n-gram模式进行块预测,
因为这些模式在训练数据中频繁出现,
模型对这些"固定搭配"的预测非常准确。
性能数据(LLaMA-13B):
┌──────────────┬────────┬─────────────────┬──────────┐
│ 策略 │ gamma │ n-gram命中率 │ 加速比 │
├──────────────┼────────┼─────────────────┼──────────┤
│ Lookahead │ 5 │ - │ 1.2x │
│ BPD │ 5 │ 40% │ 1.4x │
│ BPD + Self-SD│ 5 │ 60% │ 2.0x │
└──────────────┴────────┴─────────────────┴──────────┘
注意:Lookahead和BPD的加速比普遍低于Medusa/Eagle
但它们无需任何额外模型或训练,是最轻量的方案。
8.3 Lookahead Decoding实现
def lookahead_decode_step(
model,
input_ids: torch.Tensor,
lookahead_k: int = 5, # 每次看多少个候选
gamma: int = 5, # 草稿长度
):
“”"
Lookahead Decoding单步
不需要额外模型,从当前logits探索多个候选路径
"""
batch_size = input_ids.shape[0]
device = input_ids.device
vocab_size = model.config.vocab_size
# Step 1: 获取当前logits
with torch.no_grad():
outputs = model(input_ids)
curr_logits = outputs.logits[:, -1] # [1, vocab_size]
probs = F.softmax(curr_logits / 0.6, dim=-1) # 稍微降温
# Step 2: 采样lookahead_k个候选
top_probs, top_indices = probs.topk(lookahead_k, dim=-1)
candidates = top_indices[0] # [lookahead_k]
# Step 3: 对每个候选,看能延伸多远
candidate_scores = []
for cand_idx in range(lookahead_k):
candidate_token = candidates[cand_idx:cand_idx+1].unsqueeze(0)
ext_input = torch.cat([input_ids, candidate_token], dim=-1)
with torch.no_grad():
ext_outputs = model(ext_input)
ext_logits = ext_outputs.logits[:, -1]
ext_probs = F.softmax(ext_logits / 0.6, dim=-1)
# 看这个候选后的token是否也符合预期
# 评分 = 候选概率 + 后续token概率
score = top_probs[0, cand_idx].item()
# 再看一步(2-gram一致性)
next_top = ext_probs.topk(1)
score += next_top.values[0, 0].item() * 0.5
candidate_scores.append(score)
# Step 4: 选择最佳候选路径
best_idx = torch.tensor(candidate_scores).argmax().item()
best_candidate = candidates[best_idx].item()
# Step 5: 再从最佳候选看能否延伸
ext_input = torch.cat(
[input_ids, candidates[best_idx:best_idx+1].unsqueeze(0)],
dim=-1
)
with torch.no_grad():
ext_outputs = model(ext_input)
next_logits = ext_outputs.logits[:, -1]
next_token = next_logits[:, -1].argmax(dim=-1, keepdim=True)
# 如果下一个token也高置信度 → 接受2个token
next_prob = F.softmax(next_logits / 0.6, dim=-1).max().item()
if next_prob > 0.7:
return [best_candidate, next_token.item()]
else:
return [best_candidate]
九、各方案性能对比与选型指南
9.1 全面对比表
六大投机解码方案对比(2026年数据):
┌─────────────────┬────────┬────────┬────────┬────────┬────────┬────────┐
│ 特性 │ 传统SD │ Medusa │ Eagle │Self-SD │LA/BPD │ 投机SD │
│ │ 外部模型│ │ │ │ │ +量化 │
├─────────────────┼────────┼────────┼────────┼────────┼────────┼────────┤
│ 典型加速比 │ 2.0-2.5│ 2.5-3.0│ 3.0-4.0│ 1.5-2.0│ 1.2-1.5│ 4-6x │
│ 接受率 │ 60-80% │ 70-85% │ 80-95% │ 50-70% │ 30-50% │ 70-90% │
│ 训练成本 │ 无 │ 1-3天 │ 1-2天 │ 无 │ 无 │ 无 │
│ 额外显存 │ 高 │ 低 │ 中 │ 无 │ 无 │ 无 │
│ 部署复杂度 │ 中 │ 低 │ 中 │ 低 │ 极低 │ 低 │
│ 质量损失 │ 无 │ 无 │ 无 │ 无 │ 极小 │ 极小 │
│ 框架集成 │ 完善 │ vLLM │ 部分 │ 实验 │ 实验 │ 实验 │
│ 适用gamma │ 3-5 │ 3-7 │ 5-10 │ 3-5 │ 3-5 │ 3-5 │
│ 可移植性 │ ✅ │ ⚠️ │ ⚠️ │ ✅ │ ✅ │ ⚠️ │
│ 最佳场景 │ 通用 │ 在线 │ 极致 │ 快速 │ 最轻 │ 边缘 │
└─────────────────┴────────┴────────┴────────┴────────┴────────┴────────┘
9.2 选型决策树
你要用哪种投机解码?
你的需求是什么?
│
├─ 不能做任何训练/微调?
│ ├─ 有可用的小模型? → 传统SD(外部小模型)
│ └─ 没有小模型?
│ ├─ 能接受1.5-2x加速? → Self-SD(自草稿)
│ └─ 要最轻量的方案? → Lookahead / BPD
│
├─ 可以做轻量训练/微调?
│ ├─ 追求部署最简? → Medusa(加几个head)
│ └─ 追求加速最高? → Eagle(特征级预测)
│
├─ 你的部署环境是?
│ ├─ vLLM生态 → Medusa(原生支持)
│ ├─ TensorRT-LLM → 传统SD + 量化
│ └─ 自定义框架 → Eagle(最优定制)
│
└─ 你的延迟要求?
├─ < 50ms → Eagle + gamma=5(最佳延迟)
├─ < 100ms → Medusa + gamma=5
├─ < 200ms → 传统SD + gamma=3
└─ > 200ms → 普通解码已够
9.3 实测加速比对比
各方案在LLaMA-70B上的实测数据(H100, FP16):
┌─────────────────┬──────────┬──────────┬──────────┬──────────┐
│ 方案 │ 输入长度 │ 前填充 │ 续写 │ 代码生成 │
│ │ 128→256 │ (0.8) │ (0.6) │ (0.5) │
├─────────────────┼──────────┼──────────┼──────────┼──────────┤
│ Baseline │ 42.3 │ 1.0x │ 1.0x │ 1.0x │
│ │ tokens/s │ │ │ │
│ 传统SD (7B) │ 97.3 │ 2.4x │ 2.2x │ 2.0x │
│ Medusa (5头) │ 118.4 │ 2.9x │ 2.7x │ 2.5x │
│ Eagle (2层) │ 147.0 │ 3.6x │ 3.4x │ 3.1x │
│ Self-SD (30层) │ 72.7 │ 1.8x │ 1.6x │ 1.5x │
│ LA/BPD │ 55.0 │ 1.3x │ 1.2x │ 1.1x │
│ Eagle+FP8 │ 235.0 │ 5.8x │ 5.4x │ 5.0x │
└─────────────────┴──────────┴──────────┴──────────┴──────────┘
注意:
- (接受率) = 该场景下传统外部小模型(1/10参数)的接受率
- 代码生成场景接受率普遍较低,因为代码分布更"稀疏"
- Eagle+FP8组合使用量化+投机 → 接近6x加速
十、生产部署最佳实践
10.1 集成到主流推理框架
===== vLLM中的投机采样配置(0.6.0+) =====
from vllm import LLM, SamplingParams
方法1:使用外部草稿模型
llm = LLM(
model=“meta-llama/Llama-3.1-70B-Instruct”,
speculative_model=“meta-llama/Llama-3.1-8B-Instruct”, # 小模型
num_speculative_tokens=5, # gamma=5
speculative_draft_tensor_parallel_size=1, # 小模型TP大小
use_v2_block_manager=True,
)
方法2:使用Medusa Head(模型需要先训练Medusa Head)
llm_medusa = LLM(
model=“my-org/Llama-3.1-70B-medusa”,
speculative_model=“medusa”, # 使用内置Medusa
num_speculative_tokens=5,
)
方法3:使用Eagle(模型需要先训练Eagle Head)
llm_eagle = LLM(
model=“my-org/Llama-3.1-70B-eagle”,
speculative_model=“eagle”,
num_speculative_tokens=7, # Eagle可以使用更大的gamma
eagle_num_layers=2,
)
参数调优
sampling_params = SamplingParams(
temperature=0.0, # greedy模式,对投机最友好
top_p=0.9,
max_tokens=1024,
# 投机专用参数
use_speculative=True,
speculative_acceptance_threshold=0.5, # 接受阈值
)
outputs = llm.generate(prompts, sampling_params)
10.2 性能调优参数
speculative_config = {
# ===== 核心参数 =====
“num_speculative_tokens”: 5,
# gamma值,建议从5开始调
# 如果接受率 > 70% → 增大gamma(7-10)
# 如果接受率 < 50% → 减小gamma(3)
"speculative_draft_tensor_parallel_size": 1,
# 草稿模型的TP维度
# 如果小模型远小于大模型 → 1就够了
# 如果小模型也比较大 → 可以设2
# ===== 高级参数 =====
"speculative_acceptance_threshold": 0.5,
# 接受率阈值,低于此值时自动降级
# 防止"小模型在讨论不熟悉的话题时反复拒绝"
"speculative_adaptive_length": True,
# 自适应草稿长度
# 根据最近几轮的接受率动态调整gamma
# 接受率高 → 增大gamma
# 接受率低 → 减小gamma
# ===== 延迟优化 =====
"speculative_draft_warmup": True,
# 预预热草稿模型
# 在第一个请求到达前提前加载
"speculative_draft_stream": True,
# 流式草稿
# 在上一轮验证还没完成时就开始下一轮草稿
# 需要仔细同步
# ===== 批量优化 =====
"speculative_batch_size": 4,
# 批量草稿大小
# 多个请求可以共享草稿模型的batch计算
# 提高GPU利用率
}
def speculative_tuning_guide(acceptance_rate_history: list):
“”"
根据接受率历史自动调整gamma
“”"
recent_acceptance = acceptance_rate_history[-10:]
avg_acceptance = sum(recent_acceptance) / len(recent_acceptance)
current_gamma = 5 # 当前gamma
if avg_acceptance > 0.8:
new_gamma = min(current_gamma + 2, 10)
action = "增大"
elif avg_acceptance > 0.6:
new_gamma = current_gamma
action = "保持"
elif avg_acceptance > 0.4:
new_gamma = max(current_gamma - 1, 3)
action = "减小"
else:
new_gamma = max(current_gamma - 2, 2)
action = "大幅减小"
# 考虑降级到普通解码
if avg_acceptance < 0.3:
return "停用投机解码"
print(f"接受率={avg_acceptance:.1%} → {action}gamma到{new_gamma}")
return new_gamma
10.3 实际部署经验总结
线上部署投机采样的实战经验:
-
先测接受率
在你的数据和prompt分布上测接受率
不同领域接受率差异很大(代码 < 聊天 < 知识问答)
接受率 < 40% → 投机反而更慢 -
gamma不是越大越好
gamma=5通常是甜点值
gamma>10后收益递减(接受率下降 + 草稿计算成本上升) -
监控"有效加速比"
不仅要看tokens/s,还要看TTFT和TPOT
投机采样会增加TTFT(因为要等草稿生成)
但对TPOT(每个输出token的延迟)提升明显 -
动态降级机制
对"简单"prompt(常见话题),投机效果好
对"困难"prompt(罕见话题),效果差
可以根据最近几轮的接受率动态切换 -
与小批量结合
batch_size=1时投机效果最好(利用空闲算力)
batch_size大时,投机边际收益递减 -
量化 + 投机 = 最佳组合
FP8量化 + Eagle投机 → 5-6x加速
这是2026年生产部署的黄金组合
10.4 常见问题排查
Q: 用了投机解码反而变慢了?
A: 检查三点: -
接受率是否 > 40%(低于此值投机没有收益)
-
gamma是否过大(gamma > 10时收益递减)
-
小模型是否太小(draft质量太差)
解决方案:减小gamma或切换为Self-SD
Q: Medusa训练不收敛?
A: 检查:
- 训练数据是否来自基础模型自身的输出
- 学习率是否合适(建议1e-4)
- Head初始化是否正常
解决方案:用主干最后几层的hidden做warm start
Q: Eagle验证时质量下降?
A: 检查:
- 验证阶段是否用了拒绝采样
- LM Head是否与主干共享
- Eagle Head层数是否足够(2层比1层好)
解决方案:确保验证阶段的接受逻辑正确
Q: 显存不足?
A: 用Self-SD(零额外显存)
或用Medusa(几十M的head,完全可以忽略)
十一、面试高频问答
11.1 基础理解
Q1: 投机采样为什么能加速?
A: 核心原因是自回归解码的算力浪费。
生成1个token时GPU利用率仅5-15%(小矩阵计算无法填满Tensor Core)。
投机采样让大模型一次验证多个草稿token → 一次前向完成多个token的验证 →
等价于把"批处理"引入了解码过程 → GPU利用率提升到15-30% → 2-4x加速。
Q2: 投机采样会降低生成质量吗?
A: 理论上不会。
- Greedy模式:严格argmax匹配,等价于原始解码
- Stochastic模式:拒绝采样的数学证明保证了分布完全一致
实践中只要验证逻辑正确,质量完全无损。
Q3: 决定加速比的三个关键因素?
A: 1. 接受率(最重要)——草稿token被大模型接受的概率
2. gamma(草稿长度)——一次草稿几个token
3. cost_ratio(草稿/目标成本比)——草稿模型的相对计算量
公式:Speedup ≈ 1 / (1/gamma + acceptance_rate × cost_ratio)
11.2 方案对比
Q4: Medusa和Eagle的核心区别?
A:
Medusa在logits空间做预测 → 额外预测头直接输出logits
Eagle在特征空间做预测 → 预测下一个hidden state,再用共享LM Head解码
Eagle的优势:
- 特征空间信息更丰富
- 共享LM Head → 不需要额外训练token预测
- 接受率比Medusa高10-15个百分点
- 加速比3-4x vs Medusa的2.5-3x
Medusa的优势:
- 实现更简单
- vLLM原生支持
- Head更轻量
Q5: 什么时候该用Self-SD而不是Medusa?
A: 当你能接受1.5-2x加速,但:
- 不允许任何模型修改(合规/安全要求)
- 没有训练资源
- 需要0额外显存
- 快速部署场景
如果想要更高加速比(2.5x+),选择Medusa或Eagle。
Q6: 投机采样和批量解码(Continuous Batching)冲突吗?
A: 不冲突,它们是正交的优化手段。
- Continuous Batching:在请求维度共享GPU计算
- Speculative Decoding:在token维度并行化生成
两者可以组合使用,效果叠加。
但注意:当batch_size很大时(>32),投机采样的边际收益会降低。
11.3 高级问题
Q7: 投机采样的理论极限加速比是多少?
A: 受两个因素限制:
- 小模型的成本下限(即使完全接受,也有小模型计算开销)
- 大模型验证的固定成本(一次前向至少需要1个token的时间)
单步的最大加速比 ≤ 1/cost_ratio
例如cost_ratio=0.25(小模型计算量为大模型的1/4)
→ 最大理论加速比 ≤ 4x
实际中考虑到接受率 < 100%,通常在2-4x之间。
Q8: 怎么在MoE模型上做投机采样?
A: 几个挑战:
- MoE模型小模型不好找(参数结构不同)
- Medusa Head在MoE上效果会下降(专家路由不确定性)
- Self-SD可行但加速比更低
推荐方案:
- 用MoE的前few_expert配置做草稿
- 或用更小参数的MoE模型做外部草稿
- DeepSeek V4/R1的MSA(Multi-Scale Attention)架构天然支持投机
Q9: 投机采样在视觉语言模型(VLM)上有效吗?
A: 有效,但要注意:
- VLM的decode阶段与LLM相同(文本自回归)
- 视觉token的"草稿"更难(视觉特征与语言特征不同)
- 推荐:只在文本生成阶段使用投机采样
- 视觉部分(image encode)已经是并行的,不需要投机
Q10: 2026年投机采样的发展趋势?
A: 三个方向:
- 投机采样 + 量化 → 4-6x加速(黄金组合)
- 自适应投机(adaptive speculation) → 动态调参
- 投机采样的投机(两级投机) → 再加速20-30%
- SD-Turbo等扩散模型投机(扩散解码的投机采样)
总结:投机采样是目前大模型推理加速领域"性价比最高"的技术之一——不需要改硬件、不需要改模型主干、不损失质量,就能获得2-4倍的解码加速。对于生产部署,推荐"量化 + 投机"组合拳(FP8 + Eagle = 5-6x加速)。方案选型上:追求最简用Medusa(vLLM内置),追求极致用Eagle,零部署成本用Self-SD。
下期预告:推理与部署篇08——KV Cache优化详解:从PagedAttention到MLA再到无限上下文
328

被折叠的 条评论
为什么被折叠?



