07. MoE Load Balancing Loss | MoE 进阶:负载均衡损失函数 (Load Balancing Loss)
难度: Hard | 标签: MoE, Loss Function, Mixtral | 目标人群: 核心 Infra 与算子开发
在上一节 06_MoE_Router 中,我们实现了 Top-K 路由。但在真实的 MoE 模型(如 Mixtral 8x7B, DeepSeek)训练中,会遇到一个非常严重的问题:路由崩塌 (Router Collapse)。
即门控网络“偷懒”,把所有的 Token 都发给了第 0 号和第 1 号专家,导致其他专家被饿死(闲置),不仅失去了 MoE 的意义,还会导致算力非常不均衡(OOM)。
因此,面试官非常爱考:如何用代码实现 MoE 的辅助损失函数 (Auxiliary Loss) 来强制负载均衡?
Step 1: 核心数学公式
为了让 NNN 个 Token 均匀地分配给 EEE 个专家,我们需要设计一个惩罚项,加到总的 CrossEntropy Loss 里。
Mixtral / Switch Transformer 使用的经典公式:
Laux=α⋅E∑i=1Efi⋅Pi L_{aux} = \alpha \cdot E \sum_{i=1}^E f_i \cdot P_i Laux=α⋅Ei=1∑Efi⋅Pi
- EEE: 专家总数。
- fif_ifi: 专家 iii 被路由到的 Token 比例(即选了专家 iii 的 token 数 / 总 token 数)。
- PiP_iPi: 专家 iii 在所有 Token 上的 平均路由概率得分(Softmax 之后的概率的均值)。
- α\alphaα: 辅助损失的权重系数(通常很小,如 0.01)。
为什么这个公式有效?
根据均值不等式,给定总和为 1 的 fff 和 PPP,当且仅当所有的 fi=1/Ef_i = 1/Efi=1/E 且 Pi=1/EP_i = 1/EPi=1/E 时(即绝对均匀分配),它们的内积(点积)之和最小。优化器为了降低这个 Loss,会拼命把 Token 往不同的专家那里赶!
Step 2: 代码实现框架
你需要统计在当前批次中每个专家实际被选中的次数(形成频率分布 fif_ifi),同时求出门控概率的均值分布(PiP_iPi)。将这两个分布点乘并乘以专家总数 EEE 和超参数 α\alphaα,即可得到最终的 Load Balancing Loss。
关键点:本实现支持 Top-K 路由(不仅限于 Top-1),通过 top_k 参数控制每个 Token 选择的专家数量。
Step 3: 动手实战
要求:请补全下方 compute_load_balancing_loss 的逻辑。
注意:本实现支持 Top-K 路由,即每个 Token 可以选择 K 个专家(通常 K=2)。
def compute_load_balancing_loss(
routing_weights: torch.Tensor,
selected_experts: torch.Tensor,
num_experts: int,
top_k: int,
alpha: float = 0.01
):
"""
计算 MoE 的负载均衡辅助损失(支持 Top-K 路由)
Args:
routing_weights: [batch_size * seq_len, top_k],每个 token 选中的 K 个专家的权重(已归一化)
selected_experts: [batch_size * seq_len, top_k],每个 token 选中的 K 个专家的索引
num_experts: 专家总数 E
top_k: 每个 token 选择的专家数量 K
alpha: 损失权重系数
Returns:
aux_loss: 标量,负载均衡损失
"""
batch_size_x_seq_len, _ = selected_experts.shape
total_tokens = batch_size_x_seq_len
# ==========================================
# TODO 1: 计算 P_i(每个专家的平均路由概率得分)
# ==========================================
# P_i = ???
P_i = torch.zeros(num_experts, dtype=routing_weights.dtype, device=routing_weights.device)
P_i.scatter_add_(0, selected_experts.flatten(), routing_weights.flatten())
P_i = P_i / (total_tokens * top_k)
# ==========================================
# TODO 2: 计算 f_i(每个专家实际分到的 Token 比例)
# ==========================================
# expert_mask = ???
# tokens_per_expert = ???
# f_i = ???
expert_mask = F.one_hot(selected_experts, num_classes=num_experts)
tokens_per_expert = expert_mask.sum(dim=(0, 1)).float()
f_i = tokens_per_expert / (total_tokens * top_k)
# ==========================================
# TODO 3: 计算最终的 auxiliary loss
# ==========================================
# aux_loss = ???
aux_loss = alpha * num_experts * (f_i * P_i).sum()
return aux_loss
解析
1. TODO 1: 计算 P_i(平均路由概率)
- 实现方式:
P_i = torch.zeros(num_experts, dtype=routing_weights.dtype, device=routing_weights.device) P_i.scatter_add_(0, selected_experts.flatten(), routing_weights.flatten()) P_i = P_i / (total_tokens * top_k) - 核心逻辑:使用
scatter_add_将每个 token 对选中专家的权重累加到对应专家的位置。 - 归一化:除以总的选择次数
(total_tokens * top_k)得到平均权重。 - 物理含义:PiP_iPi 表示专家 iii 在所有 token 上的平均被选中概率。
2. TODO 2: 计算 f_i(Token 分配比例)
- 实现方式:
expert_mask = F.one_hot(selected_experts, num_classes=num_experts) tokens_per_expert = expert_mask.sum(dim=(0, 1)).float() f_i = tokens_per_expert / (total_tokens * top_k) - 核心逻辑:
F.one_hot将专家索引转换为 one-hot 编码,形状为[batch_size_x_seq_len, top_k, num_experts]。 - 统计方法:沿前两个维度求和,统计每个专家被选中的总次数。
- 归一化:除以总的选择次数得到比例。
- 物理含义:fif_ifi 表示专家 iii 实际分到的 token 比例。
3. TODO 3: 计算辅助损失
- 实现方式:
aux_loss = alpha * num_experts * (f_i * P_i).sum() - 数学公式:Laux=α⋅E∑i=1Efi⋅PiL_{aux} = \alpha \cdot E \sum_{i=1}^E f_i \cdot P_iLaux=α⋅E∑i=1Efi⋅Pi
- 最小值分析:根据均值不等式,当 fi=Pi=1/Ef_i = P_i = 1/Efi=Pi=1/E 时(完全均匀),损失最小。对于 Top-K 路由,理论最小值为 α/K\alpha / Kα/K。
- 优化目标:优化器为了降低这个 Loss,会强制将 Token 均匀分配给所有专家,防止路由崩塌。
工程要点
- Top-K 兼容性:代码支持任意 K 值,通过
(total_tokens * top_k)归一化确保比例计算正确。 - 数值稳定性:使用
scatter_add_而非循环累加,提升计算效率和数值稳定性。 - 超参数调优:α\alphaα 通常设为 0.01,过大会影响主任务性能,过小则无法有效平衡负载。
- 与主损失结合:在实际训练中,将
aux_loss加到 CrossEntropy Loss 上:total_loss = ce_loss + aux_loss。
08. Architecture Tricks | 经典架构变体:Qwen 与 Gemma 的核心机制 (Architecture Tricks)
难度: Easy | 标签: 模型架构, Qwen, Gemma | 目标人群: 模型微调与工程部署
在 06_LLaMA3_Block_Tutorial 中我们搭建了 LLaMA 的骨架。但如果你去面试阿里云(通义千问团队)或者谷歌,他们必然会问自家模型与 LLaMA 的区别。
本节我们将以“打补丁”的方式,在 PyTorch 中快速实现 Qwen 的 Tie Word Embeddings 以及 Gemma 的带偏置 RMSNorm。
Step 1: 核心差异与机制
Trick 1: Tie Word Embeddings (权重绑定) - Qwen 系列 / GPT-2
- 做法:在绝大多数模型(如 LLaMA)中,最开始的
Token Embedding矩阵(把 ID 变向量)和最后的LM Head矩阵(把向量变概率)是两个独立的权重矩阵。但在 Qwen 中,这两个矩阵共享同一份物理内存的参数!- 意义:极大减少了参数量(词表动辄 15 万,非常占参数),并且在训练时能让 Embedding 获得更直接的梯度更新。
Trick 2: RMSNorm 的 “+1 缩放” - Gemma 系列
- 做法:标准的 RMSNorm 公式是 y=xRMS⋅wy = \frac{x}{RMS} \cdot wy=RMSx⋅w。而 Google 的 Gemma 把它改成了 y=xRMS⋅(1+w)y = \frac{x}{RMS} \cdot (1 + w)y=RMSx⋅(1+w)。
- 意义:在 PyTorch 中,权重的默认初始化通常是 0(或者很小的值)。Gemma 加上 1,使得在训练的极早期(wpprox0w pprox 0wpprox0 时),RMSNorm 直接等价于一个不做任何缩放的纯归一化层,这带来了非常平滑的梯度和非常稳定的早期训练!
Step 2: Weight Tying 与偏置项的权衡
Weight Tying(权重绑定)强制 Embedding 层和最终的 LM Head 线性层共享同一个权重矩阵。这种方法在早期的模型中很流行,因为它大幅减少了参数量。但在现代极大规模 LLM 中,解绑通常能获得更好的容量表达。此外,取消大部分 Linear 和 Norm 层中的 Bias 项,可以略微提高计算效率并防止显存浪费。
Step 3: 代码实现框架
要实现权重绑定,只需在网络初始化时将 LM Head 的 weight 引用直接指向 Embedding 层的 weight。注意,这意味着隐藏层维度必须与词表维度兼容(或者存在中间投影层)。
Step 4: 动手实战
要求:
- 补全
GemmaRMSNorm的公式。 - 补全
QwenTieEmbeddings中的参数共享逻辑。
# --- Trick 1: Gemma 风格的 RMSNorm ---
class GemmaRMSNorm(nn.Module):
def __init__(self, hidden_size: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
# weight 初始化为全 0
self.weight = nn.Parameter(torch.zeros(hidden_size))
def forward(self, x: torch.Tensor) -> torch.Tensor:
# 计算均方根
x_f32 = x.float()
variance = x_f32.pow(2).mean(-1, keepdim=True)
x_norm = x_f32 * torch.rsqrt(variance + self.eps)
# ==========================================
# TODO 1: 实现 Gemma 的 +1 缩放
# 注意类型转换回 x.dtype
# ==========================================
# output = ???
# 占位初始化(返回错误值,确保数值测试失败)
output = x_norm * (1 + self.weight)
return output
# --- Trick 2: Qwen 风格的权重绑定 ---
class QwenTieEmbeddings(nn.Module):
def __init__(self, vocab_size: int, hidden_size: int):
super().__init__()
# 1. 定义标准的 Embedding 层
self.embed_tokens = nn.Embedding(vocab_size, hidden_size)
# 2. 定义最后的 LM Head 预测层,注意不要 bias
self.lm_head = nn.Linear(hidden_size, vocab_size, bias=False)
# ==========================================
# TODO 2: 将 lm_head 的权重在内存级别绑定到 embed_tokens 上
# 提示: 在 PyTorch 中,可以直接赋值 nn.Parameter 或是底层 tensor
# self.lm_head.weight = ???
# ==========================================
# ???
self.lm_head.weight = self.embed_tokens.weight
def forward_embed(self, input_ids):
return self.embed_tokens(input_ids)
def forward_lm_head(self, hidden_states):
return self.lm_head(hidden_states)
解析
1. TODO 1: Gemma 的 +1 缩放机制
- 实现方式:
output = x_norm * (1 + self.weight) - 核心思想:在标准 RMSNorm 的基础上,将缩放因子从
w改为(1 + w)。 - 初始化优势:权重初始化为 0 时,
(1 + 0) = 1,此时 RMSNorm 等价于纯归一化层(无缩放),梯度非常平滑。 - 训练稳定性:在训练早期(权重接近 0),避免了因权重过小导致的梯度消失问题。随着训练进行,权重逐渐学习到合适的缩放值。
- 工程细节:必须先转换为 FP32 计算(
x.float()),最后再转回原始精度(type_as(x)),防止 FP16/BF16 下的数值不稳定。
2. TODO 2: Qwen 的权重绑定(Weight Tying)
- 实现方式:
self.lm_head.weight = self.embed_tokens.weight - 物理指针级共享:这不是复制权重,而是让两个模块的
weight参数指向同一块内存。修改其中一个,另一个自动同步。 - 参数量优势:词表通常很大(15万+),绑定后可以节省一半的参数量。例如,词表 150k、隐藏层 4096 的模型,可以节省 150k × 4096 × 4 bytes ≈ 2.4GB 显存。
- 梯度更新:训练时,Embedding 层和 LM Head 的梯度会累加到同一个权重上,使得 Embedding 获得更直接的监督信号。
- 适用场景:Qwen、GPT-2 等模型使用此技巧。但在超大规模模型(如 LLaMA 70B)中,解绑通常能获得更好的表达能力。
工程要点
- 内存验证:可以通过
data_ptr()检查两个权重是否指向同一内存地址。 - 训练同步:由于是物理指针共享,更新 Embedding 权重时,LM Head 权重会自动同步,无需手动处理。
- 架构权衡:权重绑定减少参数但可能限制表达能力;+1 缩放提升训练稳定性但增加计算量(需要额外的加法)。
909

被折叠的 条评论
为什么被折叠?



