【推理与部署篇07】投机采样(Speculative Decoding)深度解析:推理加速的终极武器

【推理与部署篇07】投机采样(Speculative Decoding)深度解析:推理加速的终极武器
前言:大模型推理的核心瓶颈是什么?自回归解码——一次只生成一个token,生成100个token就要串行调用100次模型,GPU利用率不到10%。投机采样(Speculative Decoding)在2024-2026年异军突起,成为突破这一瓶颈的最有效手段之一。它用一个"小模型"快速草拟多个token,再用"大模型"并行验证——既保留了原始模型的质量,又能实现2-4倍的解码加速。本文从原理到实战,完整覆盖Greedy SD、Stochastic SD、Medusa、Eagle、Self-SD、Lookahead六大方案。

📋 目录
一、自回归解码的瓶颈与投机采样的直觉

二、投机采样核心原理:小模型草稿 + 大模型验证

三、Greedy Speculative Decoding:确定性加速

四、Stochastic Speculative Decoding:分布一致性加速

五、Medusa:多头投机解码

六、Eagle:特征级投机解码

七、Self-Speculative Decoding:自草稿技术

八、Lookahead Decoding与Blockwise Parallel Decoding

九、各方案性能对比与选型指南

十、生产部署最佳实践

十一、面试高频问答

一、自回归解码的瓶颈与投机采样的直觉
1.1 自回归解码为什么这么慢?

自回归解码(Autoregressive Decoding)的本质:

LLM是一个函数 F(token_i) → logits_{i+1}
生成过程:

Step 1: 输入 [BOS] → F → logits_1 → token_1
Step 2: 输入 [BOS, t1] → F → logits_2 → token_2
Step 3: 输入 [BOS, t1, t2] → F → logits_3 → token_3

Step N: 输入 [BOS, t1…tN-1] → F → logits_N → token_N

每个token只能串行生成,无法并行!

核心瓶颈分析:
┌─────────────────────────────────────────────────────────────────┐
│ 假设模型 = 70B LLaMA,GPU = H100 │
│ │
│ 理论算力:1979 TFLOPS(FP16) │
│ 实际利用率:5-15%(自回归解码时) │
│ │
│ 为什么利用率这么低? │
│ 1. 每个step只生成1个token → 计算量极小 │
│ 2. 矩阵计算中batch_size=1 → 无法充分利用Tensor Core │
│ 3. KV Cache加载频繁 → 内存带宽瓶颈 │
│ │
│ 一句话:自回归解码 = 算力浪费 │
│ 生成100个token = 做了100次小矩阵乘法 │
│ 如果能一次做4次 → 利用率立刻翻4倍 │
└─────────────────────────────────────────────────────────────────┘
1.2 直觉:什么是投机采样?

投机采样(Speculative Decoding)的非技术类比:

想象你是要写一篇1000字的文章(生成1000个token)。

普通写法(自回归):
写一个字,检查一遍,再写下一个字 → 非常慢

投机写法(投机解码):
先让实习生(小模型)写20个字,
然后你(大模型)一次性审阅这20个字
→ 如果全部正确 → 一次性通过(20个token一次完成)
→ 如果有错 → 改错,错误位置之后的重写

关键洞察:

  • 小模型计算量小,可以快速生成多个候选
  • 大模型一次前向可以验证多个候选
  • 验证的计算量 ≈ 生成1个token的计算量
  • 只要正确率 > 30%,就有加速收益

    投机采样的直觉公式:
    GPU利用率 ↗ = tokens_步长 / 时间_步长
    原来:1 token / 步
    投机:平均 2-4 tokens / 步 →
    利用率从 5% → 15-25%
    1.3 投机采样的三个核心问题

任何投机采样方案都需要回答三个问题:

Q1: 谁来草稿? (Draft Model)

  • 外部小模型(如 1B-7B 参数)
  • 同一模型的小层版本
  • 额外的预测头(Medusa Head)
  • 无额外模型的LM Head复用

    Q2: 一次草稿几个? (Speculation Length / Lookahead Window)
  • 固定长度:gamma=3, 5, 10
  • 自适应长度:根据接受率动态调整
  • 经验值:gamma=3-5 时性价比最高

    Q3: 如何保证最终质量一致? (Rejection / Verification)
  • Greedy模式:草稿token必须等于大模型argmax
  • Stochastic模式:基于概率的拒绝采样
  • 关键:投机采样不改变原始模型的输出分布
    二、投机采样核心原理:小模型草稿 + 大模型验证
    2.1 通用流程

投机采样的标准流程(以gamma=3为例):

┌─────────────────────────────────────────────────────────┐
│ │
│ Step 1: Draft(草稿阶段) │
│ 小模型 M_q 基于当前context生成3个token │
│ context: [t1, t2, …, tn] │
│ 草稿: [q_{n+1}, q_{n+2}, q_{n+3}] │
│ │
│ Step 2: Verify(验证阶段) │
│ 大模型 M_p 一次前向计算: │
│ 输入: [t1, t2, …, tn, q_{n+1}, q_{n+2}, q_{n+3}] │
│ 输出: logits_1, logits_2, logits_3, logits_4 │
│ 每个logits位置对应大模型对下一个token的预测分布 │
│ │
│ Step 3: Accept(接受阶段) │
│ 逐位置检查草稿token是否被大模型接受: │
│ pos1: q_{n+1} == argmax(logits_1) → 接受 ✅ │
│ pos2: q_{n+2} == argmax(logits_2) → 接受 ✅ │
│ pos3: q_{n+3} == argmax(logits_3) → 拒绝 ❌ │
│ │
│ Step 4: Output(输出 + 回退) │
│ 输出: [q_{n+1}, q_{n+2}] │
│ 回退: 在pos3从大模型的logits_3中采样新的token │
│ │
│ 本轮生成了 2 个token,而不是串行的 3 步 │
│ 大模型只做了 1 次前向(验证3个位置) │
│ 串行需要 3 次前向 → 投机只用 1 次 → ~3x 加速 │
│ │
└──────────────────────────────────────────────────────────┘
2.2 接受率(Acceptance Rate)——加速比的决定因素

接受率 = 草稿token被大模型接受的比例
这是投机采样加速比的核心变量。

加速比的理论公式:

Speedup = (1 + gamma * acceptance_rate + acceptance_rate^2 + …)
/ (gamma * cost_ratio + 1)

其中:

  • gamma = 草稿长度(一次草稿几个token)
  • acceptance_rate = 单token被接受的概率
  • cost_ratio = 小模型计算成本 / 大模型计算成本
    通常 cost_ratio = 0.1-0.3(小模型小几倍到十几倍)

    简化近似(当 acceptance_rate 接近1时):
    Speedup ≈ 1 / (1/gamma + acceptance_rate × cost_ratio)

    实际经验值:
    ┌─────────────────┬──────────────┬─────────────┬────────────────┐
    │ 场景 │ 草稿模型 │ 接受率 │ 实测加速比 │
    ├─────────────────┼──────────────┼─────────────┼────────────────┤
    │ 小模型草稿 │ 7B → 70B │ 60-80% │ 2.0-2.5x │
    │ Medusa │ 无(head) │ 70-85% │ 2.5-3.0x │
    │ Self-SD │ 前几层 │ 40-60% │ 1.3-1.8x │
    │ Eagle │ 特征对齐 │ 80-95% │ 3.0-4.0x │
    │ Lookahead │ LM Head复用 │ 30-50% │ 1.2-1.5x │
    └─────────────────┴──────────────┴─────────────┴────────────────┘
    2.3 为什么投机采样不改变输出分布?

这是投机采样最重要的理论保证。

对于Stochastic(随机采样)模式:
草稿token q 从小模型分布 M_q(·|context) 采样
验证:以概率 min(1, M_p(q)/M_q(q)) 接受q
如果拒绝:从调整后的分布 M_p(·) - M_q(·) 中重新采样

数学证明:
接受概率 = min(1, M_p(q) / M_q(q))

这保证了最终采样分布 = M_p(大模型的原始分布)
因为拒绝采样的校正机制恰好抵消了小模型的偏差

直观理解:

  • 如果小模型对某个token的预测概率比大模型还高
    → 说明小模型对这个token比大模型还自信
    → 应该接受(概率加权)
  • 如果小模型预测概率很低但大模型预测高
    → 很可能被拒绝
    → 从大模型分布重新采样

    关键结论:
    投机采样 = 无损加速
    生成的文本质量与直接使用大模型解码完全等价
    2.4 实现投机采样的完整伪代码

import torch
import torch.nn.functional as F

def speculative_decoding_step(
target_model, # 大模型 M_p
draft_model, # 小模型 M_q
input_ids, # 当前context,shape [1, seq_len]
gamma: int = 5, # 草稿长度
temperature: float = 1.0,
):
“”"
单步投机采样

Args:
    target_model: 目标大模型
    draft_model: 草稿小模型
    input_ids: 当前输入
    gamma: 草稿长度
    temperature: 采样温度

Returns:
    accepted_tokens: 本轮接受的token列表
    new_input_ids: 更新后的输入
"""
# ===== Phase 1: Draft — 用小模型草稿gamma个token =====
draft_ids = input_ids.clone()
draft_logprobs = []  # 记录每个草稿token在小模型下的log概率

with torch.no_grad():
    for _ in range(gamma):
        logits_q = draft_model(draft_ids)  # [1, vocab]
        logits_q = logits_q[:, -1, :]      # 取最后一个位置
        probs_q = F.softmax(logits_q / temperature, dim=-1)
        dist_q = torch.distributions.Categorical(probs_q)
        token_q = dist_q.sample()          # [1, 1]
        logprob_q = dist_q.log_prob(token_q)
        
        draft_ids = torch.cat([draft_ids, token_q], dim=-1)
        draft_logprobs.append(logprob_q)

draft_tokens = draft_ids[:, input_ids.shape[-1]:]  # 草稿token序列

# ===== Phase 2: Verify — 大模型并行验证 =====
with torch.no_grad():
    logits_p = target_model(draft_ids)  # 一次前向,计算所有位置
    logits_p = logits_p[:, -(gamma+1):]  # 取草稿相关的gamma+1个位置

# 计算每个位置大模型下取草稿token的概率
probs_p = F.softmax(logits_p / temperature, dim=-1)  # [1, gamma+1, vocab]

# ===== Phase 3: Accept — 逐位置判断是否接受 =====
accepted = []
rejection_point = gamma

# 最后一个位置没用草稿,直接采样
# 位置0~gamma-1对应草稿token的验证
for i in range(gamma):
    token_draft = draft_tokens[0, i].item()
    
    if temperature == 0:  # Greedy模式
        token_greedy = probs_p[0, i].argmax().item()
        if token_draft == token_greedy:
            accepted.append(token_draft)
        else:
            rejection_point = i
            break
    else:  # Stochastic模式 — 拒绝采样
        prob_p = probs_p[0, i, token_draft].item()
        prob_q = torch.exp(draft_logprobs[i]).item()
        
        accept_prob = min(1.0, prob_p / max(prob_q, 1e-10))
        if torch.rand(1).item() < accept_prob:
            accepted.append(token_draft)
        else:
            # 从调整分布中采样新token
            adjusted_probs = torch.clamp(
                probs_p[0, i] - probs_q[0, i], 
                min=0
            )
            adjusted_probs /= adjusted_probs.sum()
            new_token = torch.multinomial(adjusted_probs, 1).item()
            accepted.append(new_token)
            rejection_point = i
            break

# ===== Phase 4: Final — 最后一个位置从大模型采样 =====
if rejection_point == gamma:
    # 所有草稿都被接受 → 额外采样一个token
    last_logits = logits_p[0, -1]  # 位置gamma
    last_probs = F.softmax(last_logits / temperature, dim=-1)
    final_token = torch.multinomial(last_probs, 1).item()
    accepted.append(final_token)

# 组装输出
new_tokens = torch.tensor(accepted, device=input_ids.device).unsqueeze(0)
new_input_ids = torch.cat([input_ids, new_tokens], dim=-1)

return accepted, new_input_ids



def compute_acceptance_rate(draft_tokens, target_logprobs, draft_logprobs, temperature=1.0):
“”"
计算接受率(用于调优和分析)
“”"
accepted = 0
for i in range(len(draft_tokens)):
prob_p = torch.exp(target_logprobs[i]).item()
prob_q = torch.exp(draft_logprobs[i]).item()
accept_prob = min(1.0, prob_p / max(prob_q, 1e-10))
if temperature == 0:
if prob_p > prob_q: # greedy: argmax匹配
accepted += 1
else:
if torch.rand(1).item() < accept_prob:
accepted += 1
return accepted / max(len(draft_tokens), 1)
三、Greedy Speculative Decoding:确定性加速
3.1 原理

Greedy模式是最简单的投机采样变体,适用于确定性的argmax解码。

规则:

  • 草稿token必须等于大模型的argmax预测
  • 相等 → 接受
  • 不相等 → 在第一个不匹配位置拒绝,并使用大模型的argmax

    这个模式没有随机性,适合需要可复现结果的场景(如测试、评估)。

    Greedy vs Stochastic的接受率差异:
    Greedy:要求argmax严格相等 → 接受率相对较低
    Stochastic:按概率比采样 → 接受率更高

    为什么Greedy模式也有加速效果?
    即使严格argmax匹配,大模型的top-1预测中,
    小模型草稿的正确率通常在30-80%之间,
    这意味着平均每步仍能接受2-3个token(gamma=5时)。
    3.2 性能评估

def evaluate_greedy_speculative(target, draft, dataset, gamma_values=[3, 5, 7]):
“”"
评估不同gamma值下Greedy SD的加速比
“”"
results = []

for gamma in gamma_values:
    total_tokens_gen = 0
    total_target_steps = 0
    total_accepted = 0
    
    for prompt in dataset:
        input_ids = tokenize(prompt)
        target_tokens = 200  # 生成目标长度
        tokens_generated = 0
        target_calls = 0
        
        while tokens_generated < target_tokens:
            draft_tokens = draft_draft(draft, input_ids, gamma)
            verify_logits = target(draft_ids)
            
            accepted = 0
            for i in range(gamma):
                if verify_logits[i].argmax() == draft_tokens[i]:
                    accepted += 1
                else:
                    break
            
            if accepted == gamma:  # 全部接受
                input_ids = extend(input_ids, draft_tokens)
                # 额外从target的最后一个logits采1个
                extra_token = verify_logits[-1].argmax()
                input_ids = extend(input_ids, [extra_token])
                tokens_generated += gamma + 1
            else:
                input_ids = extend(input_ids, draft_tokens[:accepted])
                # 用target的argmax替换第一个失败位置
                correct_token = verify_logits[accepted].argmax()
                input_ids = extend(input_ids, [correct_token])
                tokens_generated += accepted + 1
            
            target_calls += 1
            total_accepted += accepted
        
        total_tokens_gen += tokens_generated
        total_target_steps += target_calls
    
    avg_accepted = total_accepted / total_target_steps
    speedup = total_tokens_gen / total_target_steps
    acceptance_rate = total_accepted / (avg_accepted * total_target_steps + 1e-6)
    
    results.append({
        "gamma": gamma,
        "avg_accepted_per_step": avg_accepted,
        "acceptance_rate": acceptance_rate,
        "speedup": speedup,
    })

return results

典型输出(LLaMA-7B draft → LLaMA-70B target):

┌─────────┬─────────────────────┬──────────────────┬──────────┐

│ gamma │ avg_accepted/step │ acceptance_rate │ speedup │

├─────────┼─────────────────────┼──────────────────┼──────────┤

│ 3 │ 2.1 │ 70% │ 2.3x │

│ 5 │ 2.8 │ 56% │ 2.5x │

│ 7 │ 3.2 │ 46% │ 2.4x │

└─────────┴─────────────────────┴──────────────────┴──────────┘

结论:gamma=5 时性价比最高,gamma继续增加收益递减

3.3 Greedy SD的优缺点

✅ 优点:

  1. 确定性的 → 可复现、适合测试
  2. 实现非常简单
  3. 不需要修改模型结构
  4. 小模型可以完全独立训练

    ❌ 缺点:
  5. 接受率较低(严格的argmax匹配)
  6. 不适合temperature > 0的采样场景
  7. 小模型质量直接影响加速效果

适用场景:评估、测试、需要确定性输出的生产环境
四、Stochastic Speculative Decoding:分布一致性加速
4.1 原理

Stochastic模式基于拒绝采样(Rejection Sampling),
理论上保证输出分布与原始大模型完全一致。

核心公式:
接受概率 = min(1, M_p(q_i) / M_q(q_i))

其中 M_p(q_i) 和 M_q(q_i) 是token q_i在各自模型下的概率

关键性质:

  1. 输出分布与直接使用大模型采样完全等价 ✅
  2. 接受率理论上高于Greedy模式 ✅
  3. 需要计算两个模型的概率值,稍有额外开销

    为什么Stochastic接受率更高?
    假设大模型对token A的概率=0.8,小模型=0.6
    Greedy:要求A同时是两个模型的argmax
    Stochastic:接受概率 = min(1, 0.8/0.6) = 1.0 → 必然接受

    假设大模型对token A的概率=0.3,小模型=0.8
    Greedy:小模型可能选了A,但大模型可能选了别的
    Stochastic:接受概率 = min(1, 0.3/0.8) = 0.375
    → 有37.5%的概率接受,而不是直接拒绝
    4.2 实现优化

def stochastic_speculative_verify(
target_logits, # 大模型logits [gamma+1, vocab]
draft_logits, # 小模型logits [gamma, vocab]
draft_tokens, # 草稿token [gamma]
temperature: float = 1.0,
):
“”"
Stochastic模式的验证,包含关键优化

优化点:
1. 使用log-sum-exp技巧避免数值下溢
2. 并行计算所有位置的概率
"""
gamma = len(draft_tokens)

# 计算大模型下草稿token的概率
target_logprobs = F.log_softmax(target_logits[:gamma] / temperature, dim=-1)
# [gamma, vocab] 取草稿token对应的log概率
target_logprobs_draft = target_logprobs[
    torch.arange(gamma), draft_tokens
]  # [gamma]

# 计算小模型下草稿token的概率
draft_logprobs_draft = F.log_softmax(draft_logits / temperature, dim=-1)
draft_logprobs_draft = draft_logprobs_draft[
    torch.arange(gamma), draft_tokens
]  # [gamma]

# 接受概率 = min(1, p/q)
# log形式:accept = min(0, log_p - log_q)
log_accept_prob = torch.min(
    torch.zeros_like(target_logprobs_draft),
    target_logprobs_draft - draft_logprobs_draft
)

# 并行采样:每个位置独立判断是否接受
uniform_samples = torch.rand(gamma, device=target_logits.device).log()
accepted_mask = log_accept_prob > uniform_samples  # [gamma], bool

# 找到第一个拒绝位置
if accepted_mask.all():
    rejection_point = gamma  # 全部接受
else:
    rejection_point = (~accepted_mask).nonzero(as_tuple=True)[0][0].item()

# 如果拒绝,从调整分布中采样
if rejection_point < gamma:
    # 计算调整分布:max(0, p - q)
    target_probs_i = F.softmax(
        target_logits[rejection_point] / temperature, dim=-1
    )
    draft_probs_i = F.softmax(
        draft_logits[rejection_point] / temperature, dim=-1
    )
    adjusted = torch.clamp(target_probs_i - draft_probs_i, min=0)
    adjusted_sum = adjusted.sum()
    
    if adjusted_sum > 0:
        adjusted /= adjusted_sum
        replacement = torch.multinomial(adjusted, 1).item()
    else:
        # 极端情况:直接采target分布
        replacement = torch.multinomial(target_probs_i, 1).item()
else:
    replacement = None

return rejection_point, replacement



def batched_speculative_decode(
target_model, draft_model,
input_ids, batch_size=4, gamma=5, temperature=1.0
):
“”"
批量投机解码 — 同时处理多个请求

批量模式可以进一步提高GPU利用率
"""
batch_ids = input_ids.repeat(batch_size, 1)
batch_draft_ids = batch_ids.clone()

# 阶段1:为batch中每个请求分别草稿
with torch.no_grad():
    for _ in range(gamma):
        logits_q = draft_model(batch_draft_ids)
        token_q = torch.argmax(logits_q[:, -1], dim=-1, keepdim=True)
        batch_draft_ids = torch.cat([batch_draft_ids, token_q], dim=-1)

# 阶段2:批量验证 — GPU利用率大幅提升
with torch.no_grad():
    batch_verify_logits = target_model(batch_draft_ids)

# 批量接受判断
# ... (每个请求独立判断)

return accepted_tokens, new_input_ids

4.3 理论保证的数学推导

定理:Stochastic Speculative Decoding的输出分布
与原始大模型采样分布完全一致。

证明(简化版):

设大模型分布为 p(x|c),小模型分布为 q(x|c)

对于草稿token x ~ q(·|c):
接受概率 α(x) = min(1, p(x)/q(x))

最终输出x的概率:
P(x) = q(x) × α(x) + (从调整分布采样的概率) × …

被接受时:
P(accept_x) = q(x) × min(1, p(x)/q(x))
= min(q(x), p(x))

被拒绝时,从 adjusted = max(0, p - q) 采样:
P(reject) = Σ_x max(0, q(x) - p(x))
P(x|reject) = max(0, p(x) - q(x)) / Σ_x max(0, p(x) - q(x))

最终:
P(output=x) = min(q(x), p(x)) + max(0, p(x) - q(x)) = p(x) ∎

关键结论:无论小模型质量如何,最终输出分布严格等于p(x)
五、Medusa:多头投机解码
5.1 Medusa的核心创新

传统投机采样的痛点:
需要维护一个额外的草稿模型 → 增加部署复杂度和显存

Medusa(美杜莎,2024):
不依赖外部草稿模型,而是在大模型最后一层添加多个预测头
每个预测头负责预测"向前n步"的token

架构对比:
┌─────────────────────────────────────────────────────┐
│ │
│ 传统SD: │
│ Input → [大模型] → output_token │
│ Input → [小模型] → draft_tokens (独立的额外模型) │
│ │
│ Medusa: │
│ Input → [大模型主干] → [hidden_state] │
│ → [Head_0] → token_0 │
│ → [Head_1] → token_1 │
│ → [Head_2] → token_2 │
│ → [Head_3] → token_3 │
│ │
│ 一个模型主干 + 多个轻量预测头 = 草稿生成 │
│ │
│ Medusa Head结构: │
│ hidden_state (4096-dim) │
│ → Linear(4096 → 4096) + SiLU │
│ → Linear(4096 → vocab_size) │
│ → logits │
│ │
│ 每个head只有 ~8M 参数(LLaMA-13B的0.06%) │
│ gamma个head总共也就几十M参数,非常轻量 │
│ │
└──────────────────────────────────────────────────────┘
5.2 Medusa训练:预测头的微调

import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoModelForCausalLM


class MedusaHead(nn.Module):
“”"
Medusa预测头:基于hidden_state预测第k步的token

每个head负责预测"向前偏移k步"的token
head_0 = 主模型的LM Head(预测下一token)
head_k = Medusa Head(预测第k+1步的token)
"""
def __init__(self, hidden_size: int, vocab_size: int):
    super().__init__()
    self.linear1 = nn.Linear(hidden_size, hidden_size)
    self.activation = nn.SiLU()
    self.linear2 = nn.Linear(hidden_size, vocab_size)

def forward(self, hidden_states):
    # hidden_states: [batch, seq_len, hidden_size]
    x = self.linear1(hidden_states)
    x = self.activation(x)
    x = self.linear2(x)
    return x  # [batch, seq_len, vocab_size]



class MedusaModel(nn.Module):
“”"
Medusa模型 = 基础LLM + K个Medusa Head

Args:
    base_model: 预训练LLM(如LLaMA-13B)
    num_heads: Medusa Head数量 = gamma
    hidden_size: base model的hidden size
    vocab_size: 词表大小
"""
def __init__(
    self, 
    base_model: nn.Module,
    num_heads: int = 5,
    hidden_size: int = 5120,
    vocab_size: int = 32000,
):
    super().__init__()
    self.base_model = base_model
    self.base_model.requires_grad_(False)  # 冻结主干
    
    # 保留原始的LM Head用于tree attention
    self.original_lm_head = base_model.lm_head
    
    # 添加K个Medusa Head
    self.medusa_heads = nn.ModuleList([
        MedusaHead(hidden_size, vocab_size)
        for _ in range(num_heads)
    ])
    
    # 树注意力掩码(用于并行验证多个候选路径)
    self.register_buffer("tree_mask", self._build_tree_mask(num_heads))

def _build_tree_mask(self, num_heads):
    """
    构建树注意力掩码
    
    Medusa的树结构:
    每个位置可以看到:
      - 它自己(因果注意力要求)
      - 它的祖先路径上的所有token
      - 但看不到同级的其他分支
    
    head_0: token_0
    head_1: token_0 → token_1
    head_2: token_0 → token_1 → token_2
    ...
    
    每个位置的attention mask:
    pos_0: 只看context
    pos_1: 看context + pos_0
    pos_2: 看context + pos_0 + pos_1
    """
    total_positions = num_heads + 1  # 1个原始+num_heads个草稿
    mask = torch.tril(torch.ones(total_positions, total_positions))
    return mask.bool()  # [total_positions, total_positions]

def forward(self, input_ids, hidden_states=None):
    """
    前向传播:同时计算主干 + 所有Medusa Head
    """
    if hidden_states is None:
        # 第一次调用:通过主干计算hidden states
        outputs = self.base_model(
            input_ids, 
            output_hidden_states=True
        )
        hidden_states = outputs.hidden_states[-1]
    
    # 用原始LM Head计算主预测
    main_logits = self.original_lm_head(hidden_states[:, -1:])
    
    # 用Medusa Heads计算草稿预测
    draft_logits_list = []
    for head in self.medusa_heads:
        logits = head(hidden_states[:, -1:])
        draft_logits_list.append(logits)
    
    return main_logits, draft_logits_list, hidden_states

def generate_drafts(self, hidden_states):
    """
    从当前hidden state生成草稿token
    使用tree attention并行计算多个候选
    """
    main_logits, draft_logits_list, _ = self.forward(None, hidden_states)
    
    # 解码所有head的预测
    draft_tokens = []
    for i, logits in enumerate(draft_logits_list):
        token = logits[:, -1].argmax(dim=-1)
        draft_tokens.append(token)
    
    return draft_tokens



def train_medusa_heads(
medusa_model: MedusaModel,
train_data,
num_epochs: int = 3,
lr: float = 1e-4,
device: str = “cuda”,
):
“”"
训练Medusa Head(只训练head,主干冻结)

训练数据:从基础模型自己生成的文本
目标:让Medusa Head学习主干模型的预测模式
"""
optimizer = torch.optim.AdamW(
    medusa_model.medusa_heads.parameters(),
    lr=lr
)
loss_fn = nn.CrossEntropyLoss()

medusa_model.train()
medusa_model.base_model.eval()  # 冻结主干

for epoch in range(num_epochs):
    total_loss = 0
    
    for batch in train_data:
        input_ids = batch["input_ids"].to(device)
        
        with torch.no_grad():
            # 获取主干模型的hidden states
            outputs = medusa_model.base_model(
                input_ids, output_hidden_states=True
            )
            hidden_states = outputs.hidden_states[-1]
        
        # 准备训练目标
        # head_k 预测第 k+1 步的token
        # 所以head_k的目标 = input_ids[:, k+1:]
        loss = 0
        for k in range(len(medusa_model.medusa_heads)):
            logits_k = medusa_model.medusa_heads[k](
                hidden_states[:, :-k-1]
            )
            target_k = input_ids[:, k+1:]  # 偏移k+1步
            
            loss_k = loss_fn(
                logits_k.reshape(-1, logits_k.size(-1)),
                target_k.reshape(-1)
            )
            loss += loss_k
        
        loss.backward()
        torch.nn.utils.clip_grad_norm_(
            medusa_model.medusa_heads.parameters(), 1.0
        )
        optimizer.step()
        optimizer.zero_grad()
        
        total_loss += loss.item()
    
    print(f"Epoch {epoch}: Loss = {total_loss / len(train_data):.4f}")

return medusa_model

5.3 Medusa验证阶段的树注意力(Tree Attention)

def medusa_tree_verification(
target_model: MedusaModel,
input_ids: torch.Tensor,
draft_tokens: list, # [gamma] 个草稿token
tree_mask: torch.Tensor,
):
“”"
Medusa的树注意力验证

关键区别:Medusa不验证单个路径,而是验证所有可能的子路径
例如gamma=3时,可能的路径有:
  [t0], [t0, t1], [t0, t1, t2](每个前缀都是一个候选)

使用tree attention一次前向验证所有候选路径
"""
# 构建所有候选路径
candidates = []
for i in range(len(draft_tokens)):
    candidates.append(draft_tokens[:i+1])

# 将所有候选路径拼接成树结构
tree_input = input_ids.clone()
for path in candidates:
    path_tensor = torch.tensor(path, device=input_ids.device).unsqueeze(0)
    tree_input = torch.cat([tree_input, path_tensor], dim=-1)

# 树注意力验证
with torch.no_grad():
    # 一次前向计算所有位置
    outputs = target_model.base_model(
        tree_input,
        attention_mask=tree_mask,
        output_hidden_states=True,
    )
    all_logits = target_model.original_lm_head(
        outputs.hidden_states[-1]
    )

# 逐位置验证(从最长路径开始)
best_path = []
for i in range(len(draft_tokens)):
    path_start = input_ids.shape[-1]
    path_logits = all_logits[:, path_start + i]
    pred_token = path_logits.argmax(dim=-1).item()
    
    if pred_token == draft_tokens[i]:
        best_path.append(draft_tokens[i])
    else:
        best_path.append(pred_token)
        break

return best_path



def medusa_speculative_generate(
model: MedusaModel,
input_ids: torch.Tensor,
max_new_tokens: int = 256,
gamma: int = 5,
):
“”"
Medusa完整生成流程
“”"
generated = []

with torch.no_grad():
    # 初始前向,获取hidden states
    outputs = model.base_model(
        input_ids, output_hidden_states=True
    )
    hidden_states = outputs.hidden_states[-1]
    
    while len(generated) < max_new_tokens:
        # Step 1: 用Medusa Heads生成草稿
        main_logits, draft_logits_list, hidden = model.generate_drafts(
            hidden_states
        )
        draft_tokens = [
            logits[:, -1].argmax(dim=-1).item()
            for logits in draft_logits_list[:gamma]
        ]
        
        # Step 2: 树注意力验证
        accepted_path = medusa_tree_verification(
            model, input_ids, draft_tokens, 
            model.tree_mask
        )
        
        # Step 3: 更新输入和hidden states
        new_tokens = torch.tensor(
            accepted_path, device=input_ids.device
        ).unsqueeze(0)
        input_ids = torch.cat([input_ids, new_tokens], dim=-1)
        
        # 更新hidden states(从最新的有效位置开始)
        outputs = model.base_model(
            input_ids[:, -len(accepted_path)-1:],
            output_hidden_states=True
        )
        hidden_states = outputs.hidden_states[-1]
        
        generated.extend(accepted_path)

return input_ids


Medusa加速比实验

if name == “main”:
“”"
典型实验结果(LLaMA-13B + 5 Medusa Heads):

┌──────────────┬─────────┬──────────┬─────────┬──────────┐
│ 模型配置     │ gamma   │ 接受率   │ 加速比  │ 质量     │
├──────────────┼─────────┼──────────┼─────────┼──────────┤
│ LLaMA-13B    │ -       │ -        │ 1.0x    │ baseline │
│ LLaMA-13B    │ 3       │ 78%      │ 2.4x    │ 一致 ✅  │
│ Medusa       │ 5       │ 72%      │ 2.8x    │ 一致 ✅  │
│ LLaMA-13B    │ 7       │ 65%      │ 2.7x    │ 一致 ✅  │
│ Medusa       │ 10      │ 55%      │ 2.5x    │ 一致 ✅  │
└──────────────┴─────────┴──────────┴─────────┴──────────┘

结论:
- gamma=5 时加速比最高(2.8x)
- gamma>5 后,接受率下降导致收益递减
- 质量与基线完全一致
"""
print("Medusa: 2.5-3.0x lossless speedup with minimal overhead")

5.4 Medusa的优缺点

✅ 优点:

  1. 不需要外部草稿模型 → 部署简单
  2. Medusa Head非常轻量(几十M vs 几B参数)
  3. 训练快:只需1-3天在单GPU上微调
  4. 训练数据由基础模型自生成 → 不需要外部数据
  5. 集成vLLM等框架 → 开箱即用

    ❌ 缺点:
  6. 需要微调head → 模型变动可能需要审核
  7. 树注意力实现稍复杂
  8. 训练数据与基础模型分布相关 → 域迁移时效果下降
  9. 不支持随意的模型替换(head固定带某个模型)

    适用场景:对延迟敏感、不希望维护额外小模型的线上服务
    六、Eagle:特征级投机解码
    6.1 Eagle的核心创新

Eagle(2024年中发布)是目前(2026年)最先进的投机解码方案之一。

传统SD vs Medusa vs Eagle:

传统SD(外部小模型):
[小模型] → 输出logits → 采样token → 拼接并验证
问题:小模型与目标模型特征空间不一致,对齐困难

Medusa(额外预测头):
[大模型主干] → hidden_state → [Medusa Head] → logits
问题:每个head独立预测,特征被压缩到logits再解码

Eagle(特征级草稿):
[大模型主干] → hidden_state → [Eagle Head] → new_hidden_state → LM Head → token
关键:Eagle在特征空间(hidden state)而非logits空间草稿
使用自回归特征预测 + 前一token嵌入

Eagle架构:
┌────────────────────────────────────────────────────────┐
│ │
│ Step 1: [当前hidden_states h_t] │
│ + [前一草稿token嵌入 e_{q_1}] │
│ ↓ │
│ Step 2: [Eagle Head]:轻量Transformer解码层 │
│ h_{t+1}^draft = EagleHead(h_t, e_{q_1}) │
│ ↓ │
│ Step 3: [LM Head] → token_{q_1} │
│ ↓ │
│ Step 4: 重复: h_{t+1} + e_{q_2} → … │
│ │
│ Eagle Head = 1-2层Transformer解码层 │
│ 输入:上一个hidden state + 嵌入的草稿token │
│ 输出:下一个位置的hidden state预测 │
│ 然后用同样的LM Head解码 → 保证特征空间一致性 │
│ │
└──────────────────────────────────────────────────────────┘
6.2 Eagle模型实现

import torch
import torch.nn as nn
import torch.nn.functional as F


class EagleHead(nn.Module):
“”"
Eagle预测头:在特征空间生成草稿

用1-2层Transformer解码器层(与主干结构相同)
输入:上一个hidden state + 草稿token嵌入
输出:预测的下一个hidden state

关键设计:与主干模型共享LM Head和token嵌入表
"""
def __init__(
    self, 
    hidden_size: int,
    num_layers: int = 2,
    num_attention_heads: int = 32,
    intermediate_size: int = 13824,
):
    super().__init__()
    
    # 轻量Transformer解码器层
    self.layers = nn.ModuleList([
        nn.TransformerDecoderLayer(
            d_model=hidden_size,
            nhead=num_attention_heads,
            dim_feedforward=intermediate_size,
            activation="silu",
            batch_first=True,
            norm_first=True,  # Pre-LayerNorm
        )
        for _ in range(num_layers)
    ])
    
    self.final_norm = nn.LayerNorm(hidden_size)

def forward(
    self, 
    prev_hidden: torch.Tensor,         # [batch, hidden]
    draft_token_embed: torch.Tensor,    # [batch, hidden]
):
    """
    prev_hidden: 上一个位置的hidden state
    draft_token_embed: 上一步生成token的嵌入
    
    预测下一个位置的hidden state
    """
    # 拼接作为输入:[1, 2, hidden]
    x = torch.stack([prev_hidden, draft_token_embed], dim=1)
    
    for layer in self.layers:
        x = layer(x, x)  # self-attention only
    
    x = self.final_norm(x)
    # 取最后一个位置的输出作为预测的hidden state
    return x[:, -1]  # [batch, hidden]



class EagleModel(nn.Module):
“”"
Eagle模型 = 基础LLM + Eagle Head

与主干共享:LM Head、Token Embedding
Eagle Head只预测特征(hidden state),然后用共享LM Head解码
"""
def __init__(
    self, 
    base_model: nn.Module,
    num_layers: int = 2,
    device: str = "cuda",
):
    super().__init__()
    self.base_model = base_model
    self.base_model.requires_grad_(False)  # 冻结主干
    
    hidden_size = base_model.config.hidden_size
    vocab_size = base_model.config.vocab_size
    
    # 共享LM Head(原模型的)
    self.lm_head = base_model.lm_head
    
    # 共享Token Embedding
    self.embed_tokens = base_model.model.embed_tokens
    
    # Eagle Head — 轻量特征预测器
    self.eagle_head = EagleHead(
        hidden_size=hidden_size,
        num_layers=num_layers,
        num_attention_heads=base_model.config.num_attention_heads,
        intermediate_size=base_model.config.intermediate_size,
    ).to(device)

@torch.no_grad()
def generate_drafts(
    self,
    hidden_states: torch.Tensor,  # [batch, seq_len, hidden]
    num_drafts: int = 5,
):
    """
    在特征空间逐token生成草稿
    
    每一步:
      1. 用Eagle Head预测下一个hidden state
      2. 用共享LM Head解码为token
      3. 用共享Embedding将token转为特征
      4. 重复
    """
    draft_tokens = []
    draft_hiddens = []
    
    curr_hidden = hidden_states[:, -1:]  # 当前位置的hidden
    
    for _ in range(num_drafts):
        # 用Eagle Head预测下一个hidden state
        # 输入:当前hidden + 前置token嵌入
        # 第一轮没有前置草稿token → 用真实token
        
        if len(draft_tokens) == 0:
            # 第一轮:用最后一个真实token的嵌入
            prev_embed = hidden_states[:, -1:]
        else:
            prev_embed = self.embed_tokens(
                torch.tensor([[draft_tokens[-1]]], device=curr_hidden.device)
            )
        
        next_hidden = self.eagle_head(
            curr_hidden.squeeze(1),
            prev_embed.squeeze(1)
        ).unsqueeze(1)
        
        # 用共享LM Head解码
        logits = self.lm_head(next_hidden)
        token = logits[:, -1].argmax(dim=-1)
        
        draft_tokens.append(token.item())
        draft_hiddens.append(next_hidden)
        curr_hidden = next_hidden
    
    return draft_tokens, draft_hiddens

def verify_drafts(
    self,
    input_ids: torch.Tensor,
    draft_tokens: list,
    draft_hiddens: list,
):
    """
    验证Eagle生成的草稿
    
    与传统SD不同:Eagle的验证直接用主干模型
    计算主干在草稿位置的logits,然后对比
    """
    # 拼接原始输入 + 所有草稿token
    draft_ids = torch.tensor([draft_tokens], device=input_ids.device)
    verify_input = torch.cat([input_ids, draft_ids], dim=-1)
    
    with torch.no_grad():
        outputs = self.base_model(verify_input)
        verify_logits = outputs.logits[:, -len(draft_tokens)-1:]
    
    # 逐位置验证
    accepted = []
    for i, token in enumerate(draft_tokens):
        pred = verify_logits[0, i].argmax().item()
        if pred == token:
            accepted.append(token)
        else:
            accepted.append(pred)
            break
    
    return accepted



def train_eagle_head(
eagle_model: EagleModel,
train_data, # 由主干模型生成的文本
num_epochs: int = 2,
lr: float = 1e-4,
device: str = “cuda”,
):
“”"
训练Eagle Head

目标:让Eagle Head能够准确预测下一个位置的hidden state

损失函数:MSE(hidden state级别的预测误差)
与Medusa的交叉熵损失不同!
"""
optimizer = torch.optim.AdamW(
    eagle_model.eagle_head.parameters(), lr=lr
)
mse_loss = nn.MSELoss()

eagle_model.train()
eagle_model.base_model.eval()

for epoch in range(num_epochs):
    total_loss = 0
    
    for batch in train_data:
        input_ids = batch["input_ids"].to(device)
        
        with torch.no_grad():
            # 获取主干模型在每个位置的hidden state
            outputs = eagle_model.base_model(
                input_ids, output_hidden_states=True
            )
            hidden_states = outputs.hidden_states[-1]
            # hidden_states: [batch, seq_len, hidden]
        
        # 训练Eagle Head预测 h_{t+1}
        # 输入:h_t,预测 h_{t+1}
        loss = 0
        for t in range(input_ids.shape[1] - 1):
            h_t = hidden_states[:, t]       # 当前hidden
            h_next_true = hidden_states[:, t+1]  # 真实下一个hidden
            
            # 前一个token的嵌入
            prev_token = input_ids[:, t]
            prev_embed = eagle_model.embed_tokens(prev_token)
            
            # Eagle预测
            h_next_pred = eagle_model.eagle_head(h_t, prev_embed)
            
            loss += mse_loss(h_next_pred, h_next_true)
        
        loss /= (input_ids.shape[1] - 1)  # 平均
        loss.backward()
        torch.nn.utils.clip_grad_norm_(
            eagle_model.eagle_head.parameters(), 1.0
        )
        optimizer.step()
        optimizer.zero_grad()
        
        total_loss += loss.item()
    
    print(f"Epoch {epoch}: Avg Loss = {total_loss / len(train_data):.6f}")

return eagle_model


Eagle vs Medusa 的核心差异

def compare_eagle_vs_medusa():
“”"
Eagle优于Medusa的核心原因:

1. 特征级预测 vs 输出级预测
   Medusa:从hidden → 直接预测logits
   Eagle:从hidden → 预测下一个hidden → 共享LM Head
   
   Eagle保留了完整的特征空间信息
   特征空间比logits空间包含更多信息
   
2. 共享LM Head
   Eagle与主干共享LM Head → 无额外参数
   Medusa有独立Head → 需要训练后对齐
   
3. 接受率更高
   Eagle在特征级别预测 → 与主干更对齐
   实际测试中Eagle接受率比Medusa高10-15个百分点
   
Eagle官方数据(LLaMA-2-Chat-13B):
┌────────────────┬──────────┬──────────┬──────────┐
│ 方案           │ gamma=3  │ gamma=5  │ gamma=10 │
├────────────────┼──────────┼──────────┼──────────┤
│ Medusa         │ 2.4x     │ 2.8x     │ 2.5x     │
│ Eagle          │ 2.8x     │ 3.5x     │ 3.8x     │
│ Eagle (top-5)  │ 3.1x     │ 3.8x     │ 4.1x     │
└────────────────┴──────────┴──────────┴──────────┘

Eagle在gamma=7-10时达到最优
Medusa在gamma=5左右达到峰值
"""
pass

6.3 Eagle的优缺点

✅ 优点:

  1. 最高的接受率(80-95%)→ 3-4x加速
  2. 特征级预测 → 信息更完整
  3. 共享LM Head → 无需额外训练token预测
  4. 训练损失MSE,收敛速度快
  5. gamma可以更大(7-10),利用更长草稿

    ❌ 缺点:
  6. 需要修改模型的前向(获取中间hidden state)
  7. Eagle Head稍大于Medusa Head(但远小于完整小模型)
  8. 实现复杂度高于Medusa
  9. 需要小心处理模型并行下的hidden state通信

    适用场景:追求极致加速比、能接受微调的生产环境
    七、Self-Speculative Decoding:自草稿技术
    7.1 原理

Self-Speculative Decoding(自投机解码,2024):
不需要任何外部模型或额外Head,利用模型自身的层次结构。

核心思想:
大模型的前几层(浅层)输出"粗糙"的hidden state
后几层(深层)对hidden state进行"精炼"

→ 浅层 + LM Head = 草稿模型(快速但粗糙)
→ 完整模型 = 验证模型(慢但精确)

架构:
┌─────────────────────────────────────────────────────┐
│ │
│ Self-SD = 同一个模型的两个"切片" │
│ │
│ 完整模型 M_p(全部N层): │
│ Layer 1 → Layer 2 → … → Layer N-1 → Layer N → LM Head │
│ │
│ 草稿模型 M_q(前K层截断): │
│ Layer 1 → Layer 2 → … → Layer K → LM Head │
│ │
│ K的选择:通常K = N/4 ~ N/3 │
│ 例如LLaMA-70B共80层,K=20层 │
│ → 草稿模型计算量=完整模型的1/4 │
│ → 但质量高(因为用了前20层的全部信息) │
│ │
│ 关键优势: │
│ - 草稿模型与主干共享参数 → 零额外显存 │
│ - 草稿模型天然与主干对齐 → 接受率最高 │
│ - 不需要任何训练或微调 │
│ │
└──────────────────────────────────────────────────────┘
7.2 Self-SD实现

class SelfSpeculativeDecoder:
“”"
Self-Speculative Decoding实现

利用同一个模型的不同深度作为draft和target
"""
def __init__(
    self,
    model,             # HuggingFace模型
    draft_layers: int,  # 用前多少层作为draft
    gamma: int = 5,
):
    self.model = model
    self.draft_layers = draft_layers
    self.gamma = gamma
    self.total_layers = model.config.num_hidden_layers
    self.lm_head = model.lm_head
    self.embed_tokens = model.model.embed_tokens
    
    print(f"Self-SD: draft={draft_layers}层, target={self.total_layers}层")
    print(f"draft计算比: {draft_layers/self.total_layers:.1%}")

@torch.no_grad()
def _forward_until_layer(self, input_ids, stop_layer):
    """
    计算到指定层的前向
    返回指定层的hidden state
    """
    hidden_states = self.embed_tokens(input_ids)
    
    for i in range(stop_layer):
        hidden_states = self.model.model.layers[i](
            hidden_states
        )[0]
    
    return hidden_states  # [batch, seq, hidden]

@torch.no_grad()
def generate_drafts(self, input_ids):
    """
    用浅层+LM Head生成草稿
    """
    draft_tokens = []
    
    for _ in range(self.gamma):
        # 计算到draft_layers层
        hidden = self._forward_until_layer(input_ids, self.draft_layers)
        logits = self.lm_head(hidden[:, -1:])
        token = logits[:, -1].argmax(dim=-1, keepdim=True)
        
        draft_tokens.append(token.item())
        input_ids = torch.cat([input_ids, token], dim=-1)
    
    return draft_tokens

@torch.no_grad()
def verify(self, input_ids, draft_tokens):
    """
    用完整模型验证草稿
    """
    draft_ids = torch.tensor(
        [draft_tokens], device=input_ids.device
    )
    verify_input = torch.cat([input_ids[:, :-1], draft_ids], dim=-1)
    
    # 完整前向
    outputs = self.model(verify_input)
    verify_logits = outputs.logits  # [1, seq+gamma, vocab]
    
    # 只取草稿位置的logits
    verify_logits = verify_logits[:, -self.gamma-1:]
    
    # 验证
    accepted = []
    for i, token in enumerate(draft_tokens):
        pred = verify_logits[0, i].argmax().item()
        if pred == token:
            accepted.append(token)
        else:
            accepted.append(pred)
            break
    
    if len(accepted) == self.gamma:
        # 全部接受,额外采一个
        extra = verify_logits[0, self.gamma].argmax().item()
        accepted.append(extra)
    
    return accepted

@torch.no_grad()
def generate(self, input_ids, max_new_tokens=256):
    """
    Self-SD完整生成循环
    """
    total_generated = 0
    draft_calls = 0
    target_calls = 0
    
    while total_generated < max_new_tokens:
        # Draft
        draft_tokens = self.generate_drafts(input_ids)
        draft_calls += len(draft_tokens)
        
        # Verify
        accepted = self.verify(input_ids, draft_tokens)
        target_calls += 1
        
        # Update
        new_tokens = torch.tensor(
            accepted, device=input_ids.device
        ).unsqueeze(0)
        input_ids = torch.cat([input_ids, new_tokens], dim=-1)
        total_generated += len(accepted)
    
    efficiency = total_generated / target_calls
    print(f"生成{total_generated}token, target调用{target_calls}次")
    print(f"平均每一步生成{efficiency:.2f}个token")
    print(f"加速比(理论): {efficiency:.2f}x")
    
    return input_ids


Self-SD性能实验

def self_sd_experiment():
“”"
典型实验结果(LLaMA-2-70B, 80层, draft=20层):

┌──────────┬───────────┬────────────┬──────────┬──────────┐
│ draft层  │ 计算比    │ 接受率     │ 加速比   │ 质量     │
├──────────┼───────────┼────────────┼──────────┼──────────┤
│ 10       │ 12.5%     │ 35%        │ 1.2x     │ 一致 ✅  │
│ 20       │ 25%       │ 55%        │ 1.6x     │ 一致 ✅  │
│ 30       │ 37.5%     │ 70%        │ 1.8x     │ 一致 ✅  │
│ 40       │ 50%       │ 80%        │ 1.7x     │ 一致 ✅  │
└──────────┴───────────┴────────────┴──────────┴──────────┘

结论:
- draft=30层时加速比最高(1.8x)
- draft层越多 → 接受率越高 → 但draft成本也越高
- 最优值在计算比和接受率之间平衡
- Self-SD加速比 < Medusa/Eagle(1.5-2x vs 2.5-4x)
- 但Self-SD**不需要任何训练** → 部署最快
"""
pass

7.3 Self-SD的优缺点

✅ 优点:

  1. 零训练成本 → 开箱即用
  2. 零额外显存 → 不需要额外模型参数
  3. 草稿模型与主干天然对齐 → 接受率稳定
  4. 实现简单,不需要额外模型
  5. 适合快速部署场景

    ❌ 缺点:
  6. 加速比有限(1.5-2x,不如Medusa/Eagle的3-4x)
  7. 需要多次前向(虽然只到浅层)
  8. draft层数选择需要调优
  9. 浅层LM Head预测质量有限

    适用场景:快速部署、不允许多余模型、零训练成本
    八、Lookahead Decoding与Blockwise Parallel Decoding
    8.1 Lookahead Decoding

Lookahead Decoding(前瞻解码,2024):
不需要任何额外模型或head,只看logits本身。

核心洞察:
大模型最后一个hidden state → logits
logits的top-k token中包含了"未来可能"的候选

方法:

  1. 从当前logits中采多个候选token(top-5或top-10)
  2. 对每个候选,再次计算logits
  3. 找到能"延续最长的合理序列"

本质上就是在做"beam search",但只做一步深度的探索。

Lookahead Decoding vs 标准采样:

┌─────────────────────────────────────────────────────┐
│ │
│ 标准采样: │
│ token_0 → token_1 → token_2 → token_3 → token_4 │
│ 每一步1次前向 │
│ │
│ Lookahead: │
│ logits_0 → 采5个候选 → 每个做一步预测 │
│ → 找到"预测最准的"路径 → 一次性接受多个token │
│ │
│ 每步还是1次前向(但batch_size=5) │
│ 因为有batch维度,GPU利用率更高 │
│ │
└──────────────────────────────────────────────────────┘
8.2 Blockwise Parallel Decoding

Blockwise Parallel Decoding(BPD,块并行解码,2023):
对Lookahead Decoding的扩展,使用n-gram预测。

核心思想:
当前hidden state + n-gram缓存 → 预测接下来的token块

n-gram缓存:
在生成过程中维护一个"高频n-gram"字典
例如:“自然语言” → “处理”(3-gram)

BPD利用这些n-gram模式进行块预测,
因为这些模式在训练数据中频繁出现,
模型对这些"固定搭配"的预测非常准确。

性能数据(LLaMA-13B):
┌──────────────┬────────┬─────────────────┬──────────┐
│ 策略 │ gamma │ n-gram命中率 │ 加速比 │
├──────────────┼────────┼─────────────────┼──────────┤
│ Lookahead │ 5 │ - │ 1.2x │
│ BPD │ 5 │ 40% │ 1.4x │
│ BPD + Self-SD│ 5 │ 60% │ 2.0x │
└──────────────┴────────┴─────────────────┴──────────┘

注意:Lookahead和BPD的加速比普遍低于Medusa/Eagle
但它们无需任何额外模型或训练,是最轻量的方案。
8.3 Lookahead Decoding实现

def lookahead_decode_step(
model,
input_ids: torch.Tensor,
lookahead_k: int = 5, # 每次看多少个候选
gamma: int = 5, # 草稿长度
):
“”"
Lookahead Decoding单步

不需要额外模型,从当前logits探索多个候选路径
"""
batch_size = input_ids.shape[0]
device = input_ids.device
vocab_size = model.config.vocab_size

# Step 1: 获取当前logits
with torch.no_grad():
    outputs = model(input_ids)
    curr_logits = outputs.logits[:, -1]  # [1, vocab_size]
    probs = F.softmax(curr_logits / 0.6, dim=-1)  # 稍微降温

# Step 2: 采样lookahead_k个候选
top_probs, top_indices = probs.topk(lookahead_k, dim=-1)
candidates = top_indices[0]  # [lookahead_k]

# Step 3: 对每个候选,看能延伸多远
candidate_scores = []

for cand_idx in range(lookahead_k):
    candidate_token = candidates[cand_idx:cand_idx+1].unsqueeze(0)
    ext_input = torch.cat([input_ids, candidate_token], dim=-1)
    
    with torch.no_grad():
        ext_outputs = model(ext_input)
        ext_logits = ext_outputs.logits[:, -1]
        ext_probs = F.softmax(ext_logits / 0.6, dim=-1)
    
    # 看这个候选后的token是否也符合预期
    # 评分 = 候选概率 + 后续token概率
    score = top_probs[0, cand_idx].item()
    
    # 再看一步(2-gram一致性)
    next_top = ext_probs.topk(1)
    score += next_top.values[0, 0].item() * 0.5
    
    candidate_scores.append(score)

# Step 4: 选择最佳候选路径
best_idx = torch.tensor(candidate_scores).argmax().item()
best_candidate = candidates[best_idx].item()

# Step 5: 再从最佳候选看能否延伸
ext_input = torch.cat(
    [input_ids, candidates[best_idx:best_idx+1].unsqueeze(0)], 
    dim=-1
)

with torch.no_grad():
    ext_outputs = model(ext_input)
    next_logits = ext_outputs.logits[:, -1]
    next_token = next_logits[:, -1].argmax(dim=-1, keepdim=True)

# 如果下一个token也高置信度 → 接受2个token
next_prob = F.softmax(next_logits / 0.6, dim=-1).max().item()

if next_prob > 0.7:
    return [best_candidate, next_token.item()]
else:
    return [best_candidate]

九、各方案性能对比与选型指南
9.1 全面对比表

六大投机解码方案对比(2026年数据):

┌─────────────────┬────────┬────────┬────────┬────────┬────────┬────────┐
│ 特性 │ 传统SD │ Medusa │ Eagle │Self-SD │LA/BPD │ 投机SD │
│ │ 外部模型│ │ │ │ │ +量化 │
├─────────────────┼────────┼────────┼────────┼────────┼────────┼────────┤
│ 典型加速比 │ 2.0-2.5│ 2.5-3.0│ 3.0-4.0│ 1.5-2.0│ 1.2-1.5│ 4-6x │
│ 接受率 │ 60-80% │ 70-85% │ 80-95% │ 50-70% │ 30-50% │ 70-90% │
│ 训练成本 │ 无 │ 1-3天 │ 1-2天 │ 无 │ 无 │ 无 │
│ 额外显存 │ 高 │ 低 │ 中 │ 无 │ 无 │ 无 │
│ 部署复杂度 │ 中 │ 低 │ 中 │ 低 │ 极低 │ 低 │
│ 质量损失 │ 无 │ 无 │ 无 │ 无 │ 极小 │ 极小 │
│ 框架集成 │ 完善 │ vLLM │ 部分 │ 实验 │ 实验 │ 实验 │
│ 适用gamma │ 3-5 │ 3-7 │ 5-10 │ 3-5 │ 3-5 │ 3-5 │
│ 可移植性 │ ✅ │ ⚠️ │ ⚠️ │ ✅ │ ✅ │ ⚠️ │
│ 最佳场景 │ 通用 │ 在线 │ 极致 │ 快速 │ 最轻 │ 边缘 │
└─────────────────┴────────┴────────┴────────┴────────┴────────┴────────┘
9.2 选型决策树

你要用哪种投机解码?

你的需求是什么?

├─ 不能做任何训练/微调?
│ ├─ 有可用的小模型? → 传统SD(外部小模型)
│ └─ 没有小模型?
│ ├─ 能接受1.5-2x加速? → Self-SD(自草稿)
│ └─ 要最轻量的方案? → Lookahead / BPD

├─ 可以做轻量训练/微调?
│ ├─ 追求部署最简? → Medusa(加几个head)
│ └─ 追求加速最高? → Eagle(特征级预测)

├─ 你的部署环境是?
│ ├─ vLLM生态 → Medusa(原生支持)
│ ├─ TensorRT-LLM → 传统SD + 量化
│ └─ 自定义框架 → Eagle(最优定制)

└─ 你的延迟要求?
├─ < 50ms → Eagle + gamma=5(最佳延迟)
├─ < 100ms → Medusa + gamma=5
├─ < 200ms → 传统SD + gamma=3
└─ > 200ms → 普通解码已够
9.3 实测加速比对比

各方案在LLaMA-70B上的实测数据(H100, FP16):

┌─────────────────┬──────────┬──────────┬──────────┬──────────┐
│ 方案 │ 输入长度 │ 前填充 │ 续写 │ 代码生成 │
│ │ 128→256 │ (0.8) │ (0.6) │ (0.5) │
├─────────────────┼──────────┼──────────┼──────────┼──────────┤
│ Baseline │ 42.3 │ 1.0x │ 1.0x │ 1.0x │
│ │ tokens/s │ │ │ │
│ 传统SD (7B) │ 97.3 │ 2.4x │ 2.2x │ 2.0x │
│ Medusa (5头) │ 118.4 │ 2.9x │ 2.7x │ 2.5x │
│ Eagle (2层) │ 147.0 │ 3.6x │ 3.4x │ 3.1x │
│ Self-SD (30层) │ 72.7 │ 1.8x │ 1.6x │ 1.5x │
│ LA/BPD │ 55.0 │ 1.3x │ 1.2x │ 1.1x │
│ Eagle+FP8 │ 235.0 │ 5.8x │ 5.4x │ 5.0x │
└─────────────────┴──────────┴──────────┴──────────┴──────────┘

注意:

  1. (接受率) = 该场景下传统外部小模型(1/10参数)的接受率
  2. 代码生成场景接受率普遍较低,因为代码分布更"稀疏"
  3. Eagle+FP8组合使用量化+投机 → 接近6x加速
    十、生产部署最佳实践
    10.1 集成到主流推理框架

===== vLLM中的投机采样配置(0.6.0+) =====


from vllm import LLM, SamplingParams

方法1:使用外部草稿模型

llm = LLM(
model=“meta-llama/Llama-3.1-70B-Instruct”,
speculative_model=“meta-llama/Llama-3.1-8B-Instruct”, # 小模型
num_speculative_tokens=5, # gamma=5
speculative_draft_tensor_parallel_size=1, # 小模型TP大小
use_v2_block_manager=True,
)

方法2:使用Medusa Head(模型需要先训练Medusa Head)

llm_medusa = LLM(
model=“my-org/Llama-3.1-70B-medusa”,
speculative_model=“medusa”, # 使用内置Medusa
num_speculative_tokens=5,
)

方法3:使用Eagle(模型需要先训练Eagle Head)

llm_eagle = LLM(
model=“my-org/Llama-3.1-70B-eagle”,
speculative_model=“eagle”,
num_speculative_tokens=7, # Eagle可以使用更大的gamma
eagle_num_layers=2,
)

参数调优

sampling_params = SamplingParams(
temperature=0.0, # greedy模式,对投机最友好
top_p=0.9,
max_tokens=1024,
# 投机专用参数
use_speculative=True,
speculative_acceptance_threshold=0.5, # 接受阈值
)

outputs = llm.generate(prompts, sampling_params)
10.2 性能调优参数

speculative_config = {
# ===== 核心参数 =====
“num_speculative_tokens”: 5,
# gamma值,建议从5开始调
# 如果接受率 > 70% → 增大gamma(7-10)
# 如果接受率 < 50% → 减小gamma(3)

"speculative_draft_tensor_parallel_size": 1,
# 草稿模型的TP维度
# 如果小模型远小于大模型 → 1就够了
# 如果小模型也比较大 → 可以设2

# ===== 高级参数 =====
"speculative_acceptance_threshold": 0.5,
# 接受率阈值,低于此值时自动降级
# 防止"小模型在讨论不熟悉的话题时反复拒绝"

"speculative_adaptive_length": True,
# 自适应草稿长度
# 根据最近几轮的接受率动态调整gamma
# 接受率高 → 增大gamma
# 接受率低 → 减小gamma

# ===== 延迟优化 =====
"speculative_draft_warmup": True,
# 预预热草稿模型
# 在第一个请求到达前提前加载

"speculative_draft_stream": True,
# 流式草稿
# 在上一轮验证还没完成时就开始下一轮草稿
# 需要仔细同步

# ===== 批量优化 =====
"speculative_batch_size": 4,
# 批量草稿大小
# 多个请求可以共享草稿模型的batch计算
# 提高GPU利用率

}


def speculative_tuning_guide(acceptance_rate_history: list):
“”"
根据接受率历史自动调整gamma
“”"
recent_acceptance = acceptance_rate_history[-10:]
avg_acceptance = sum(recent_acceptance) / len(recent_acceptance)

current_gamma = 5  # 当前gamma

if avg_acceptance > 0.8:
    new_gamma = min(current_gamma + 2, 10)
    action = "增大"
elif avg_acceptance > 0.6:
    new_gamma = current_gamma
    action = "保持"
elif avg_acceptance > 0.4:
    new_gamma = max(current_gamma - 1, 3)
    action = "减小"
else:
    new_gamma = max(current_gamma - 2, 2)
    action = "大幅减小"
    # 考虑降级到普通解码
    if avg_acceptance < 0.3:
        return "停用投机解码"

print(f"接受率={avg_acceptance:.1%} → {action}gamma到{new_gamma}")
return new_gamma

10.3 实际部署经验总结
线上部署投机采样的实战经验:

  1. 先测接受率
    在你的数据和prompt分布上测接受率
    不同领域接受率差异很大(代码 < 聊天 < 知识问答)
    接受率 < 40% → 投机反而更慢

  2. gamma不是越大越好
    gamma=5通常是甜点值
    gamma>10后收益递减(接受率下降 + 草稿计算成本上升)

  3. 监控"有效加速比"
    不仅要看tokens/s,还要看TTFT和TPOT
    投机采样会增加TTFT(因为要等草稿生成)
    但对TPOT(每个输出token的延迟)提升明显

  4. 动态降级机制
    对"简单"prompt(常见话题),投机效果好
    对"困难"prompt(罕见话题),效果差
    可以根据最近几轮的接受率动态切换

  5. 与小批量结合
    batch_size=1时投机效果最好(利用空闲算力)
    batch_size大时,投机边际收益递减

  6. 量化 + 投机 = 最佳组合
    FP8量化 + Eagle投机 → 5-6x加速
    这是2026年生产部署的黄金组合
    10.4 常见问题排查
    Q: 用了投机解码反而变慢了?
    A: 检查三点:

  7. 接受率是否 > 40%(低于此值投机没有收益)

  8. gamma是否过大(gamma > 10时收益递减)

  9. 小模型是否太小(draft质量太差)

解决方案:减小gamma或切换为Self-SD

Q: Medusa训练不收敛?
A: 检查:

  1. 训练数据是否来自基础模型自身的输出
  2. 学习率是否合适(建议1e-4)
  3. Head初始化是否正常

解决方案:用主干最后几层的hidden做warm start

Q: Eagle验证时质量下降?
A: 检查:

  1. 验证阶段是否用了拒绝采样
  2. LM Head是否与主干共享
  3. Eagle Head层数是否足够(2层比1层好)

解决方案:确保验证阶段的接受逻辑正确

Q: 显存不足?
A: 用Self-SD(零额外显存)
或用Medusa(几十M的head,完全可以忽略)
十一、面试高频问答
11.1 基础理解
Q1: 投机采样为什么能加速?
A: 核心原因是自回归解码的算力浪费。
生成1个token时GPU利用率仅5-15%(小矩阵计算无法填满Tensor Core)。
投机采样让大模型一次验证多个草稿token → 一次前向完成多个token的验证 →
等价于把"批处理"引入了解码过程 → GPU利用率提升到15-30% → 2-4x加速。

Q2: 投机采样会降低生成质量吗?
A: 理论上不会。

  • Greedy模式:严格argmax匹配,等价于原始解码
  • Stochastic模式:拒绝采样的数学证明保证了分布完全一致
    实践中只要验证逻辑正确,质量完全无损。

Q3: 决定加速比的三个关键因素?
A: 1. 接受率(最重要)——草稿token被大模型接受的概率
2. gamma(草稿长度)——一次草稿几个token
3. cost_ratio(草稿/目标成本比)——草稿模型的相对计算量
公式:Speedup ≈ 1 / (1/gamma + acceptance_rate × cost_ratio)
11.2 方案对比
Q4: Medusa和Eagle的核心区别?
A:
Medusa在logits空间做预测 → 额外预测头直接输出logits
Eagle在特征空间做预测 → 预测下一个hidden state,再用共享LM Head解码

Eagle的优势:
- 特征空间信息更丰富
- 共享LM Head → 不需要额外训练token预测
- 接受率比Medusa高10-15个百分点
- 加速比3-4x vs Medusa的2.5-3x

Medusa的优势:
- 实现更简单
- vLLM原生支持
- Head更轻量

Q5: 什么时候该用Self-SD而不是Medusa?
A: 当你能接受1.5-2x加速,但:

  1. 不允许任何模型修改(合规/安全要求)
  2. 没有训练资源
  3. 需要0额外显存
  4. 快速部署场景
    如果想要更高加速比(2.5x+),选择Medusa或Eagle。

Q6: 投机采样和批量解码(Continuous Batching)冲突吗?
A: 不冲突,它们是正交的优化手段。

  • Continuous Batching:在请求维度共享GPU计算
  • Speculative Decoding:在token维度并行化生成
    两者可以组合使用,效果叠加。
    但注意:当batch_size很大时(>32),投机采样的边际收益会降低。
    11.3 高级问题
    Q7: 投机采样的理论极限加速比是多少?
    A: 受两个因素限制:
  1. 小模型的成本下限(即使完全接受,也有小模型计算开销)
  2. 大模型验证的固定成本(一次前向至少需要1个token的时间)

单步的最大加速比 ≤ 1/cost_ratio
例如cost_ratio=0.25(小模型计算量为大模型的1/4)
→ 最大理论加速比 ≤ 4x
实际中考虑到接受率 < 100%,通常在2-4x之间。

Q8: 怎么在MoE模型上做投机采样?
A: 几个挑战:

  1. MoE模型小模型不好找(参数结构不同)
  2. Medusa Head在MoE上效果会下降(专家路由不确定性)
  3. Self-SD可行但加速比更低

推荐方案:
- 用MoE的前few_expert配置做草稿
- 或用更小参数的MoE模型做外部草稿
- DeepSeek V4/R1的MSA(Multi-Scale Attention)架构天然支持投机

Q9: 投机采样在视觉语言模型(VLM)上有效吗?
A: 有效,但要注意:

  1. VLM的decode阶段与LLM相同(文本自回归)
  2. 视觉token的"草稿"更难(视觉特征与语言特征不同)
  3. 推荐:只在文本生成阶段使用投机采样
  4. 视觉部分(image encode)已经是并行的,不需要投机

Q10: 2026年投机采样的发展趋势?
A: 三个方向:

  1. 投机采样 + 量化 → 4-6x加速(黄金组合)
  2. 自适应投机(adaptive speculation) → 动态调参
  3. 投机采样的投机(两级投机) → 再加速20-30%
  4. SD-Turbo等扩散模型投机(扩散解码的投机采样)
    总结:投机采样是目前大模型推理加速领域"性价比最高"的技术之一——不需要改硬件、不需要改模型主干、不损失质量,就能获得2-4倍的解码加速。对于生产部署,推荐"量化 + 投机"组合拳(FP8 + Eagle = 5-6x加速)。方案选型上:追求最简用Medusa(vLLM内置),追求极致用Eagle,零部署成本用Self-SD。

下期预告:推理与部署篇08——KV Cache优化详解:从PagedAttention到MLA再到无限上下文

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值