Files
maths-cs-ai-compendium-zh/chapter 06: machine learning/05. distributed deep learning.md
T
flykhan 2536c937e3 feat: 完整中文翻译 maths-cs-ai-compendium(数学·计算机科学·AI 知识大全)
翻译自英文原版 maths-cs-ai-compendium,共 20 章全部完成。

第01章 向量 | 第02章 矩阵 | 第03章 微积分
第04章 统计学 | 第05章 概率论 | 第06章 机器学习
第07章 计算语言学 | 第08章 计算机视觉 | 第09章 音频与语音
第10章 多模态学习 | 第11章 自主系统 | 第12章 图神经网络
第13章 计算与操作系统 | 第14章 数据结构与算法
第15章 生产级软件工程 | 第16章 SIMD与GPU编程
第17章 AI推理 | 第18章 ML系统设计
第19章 应用人工智能 | 第20章 前沿人工智能

翻译说明:
- 所有数学公式 $...$ / $$...$$、代码块、图片引用完整保留
- mkdocs.yml 配置中文导航 + language: zh
- README.md 已翻译为中文(兼 docs/index.md)
- docs/ 目录包含指向各章文件的 symlink
- 约 29,000 行中文内容,排除 .cache/ 构建缓存
2026-05-03 10:23:20 +08:00

265 lines
17 KiB
Markdown
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# 分布式深度学习
*分布式训练将计算分散到多个GPU和机器上,以训练单个设备无法容纳或训练太慢的模型。本文件涵盖混合精度、数据并行、模型并行、流水线并行、ZeRO、FSDP、张量并行以及全规约等通信原语——这些对于大规模训练LLM至关重要。*
- 在单个GPU上训练大型神经网络最终会遇到瓶颈。模型可能无法放入内存,或者训练可能需要数月。分布式训练将工作分散到多个设备(GPU、TPU或整台机器)上,以更快地训练和训练更大的模型。本文件涵盖了实现这一目标的技术。
- 要理解为何分布式重要,从训练的**计算成本**开始。在一个包含 $d_{\text{in}}$ 个输入和 $d_{\text{out}}$ 个输出的密集层上,对一批 $B$ 个样本进行一次前向传播需要大约 $2 \cdot B \cdot d_{\text{in}} \cdot d_{\text{out}}$ 次FLOP(浮点运算):对输出矩阵的每个元素进行一次乘法和一次加法。反向传播的成本大约是前向传播的两倍(计算相对于输入和权重的梯度),因此一个密集层的一个训练步骤约为 $6 \cdot B \cdot d_{\text{in}} \cdot d_{\text{out}}$ 次FLOP。
- 对于隐藏维度为 $d$ 的Transformer层,自注意力块涉及四个投影(Q、K、V和输出),每个的成本为 $O(B \cdot n \cdot d^2)$ 次FLOP(其中 $n$ 是序列长度),加上注意力矩阵计算 $O(B \cdot n^2 \cdot d)$。前馈块有两个密集层,通常扩展到 $4d$ 再回来:$O(B \cdot n \cdot 8d^2)$。每层总计:大约 $O(B \cdot n \cdot 12d^2 + B \cdot n^2 \cdot d)$。乘以层数,你就会明白为什么训练GPT规模的模型需要数千个GPU小时。
- **内存墙**通常是更严格的约束。在训练期间,GPU内存必须同时容纳四样东西:
![堆叠柱状图展示训练内存分解:参数、梯度、优化器状态、激活值](../images/training_memory_breakdown.svg)
- **参数**:模型权重。一个70亿参数的模型在FP32中(每个参数4字节)仅权重就需要28 GB。
- **梯度**:与参数大小相同。又是28 GB。
- **优化器状态**:Adam维护两个额外的缓冲区(一阶和二阶矩估计),每个与参数大小相同。即使模型使用较低精度,这些也以FP32格式保存以确保数值稳定性。对于我们的7B模型,那就是 $2 \times 28 = 56$ GB。
- **激活值**:在前向传播过程中保存下来供反向传播使用的中间值。大小取决于批量大小、序列长度和模型宽度。这通常是最主要的组成部分,并随批量大小线性增长。
- 对于使用FP32 Adam的7B模型:28(参数)+ 28(梯度)+ 56(优化器)= 112 GB,这还没算激活值。单个80 GB的A100 GPU无法容纳。这就是分布式策略至关重要的原因。
- **混合精度训练**是第一道防线。不是将所有内容存储在FP32(32位浮点)中,而是使用FP16或BF16(16位)进行前向和反向传播,同时将权重的FP32主副本保留给优化器更新。
- **FP16**具有高精度(10位尾数),但范围有限,可能导致上溢/下溢。损失缩放(在反向传播前将损失乘以一个大因子,然后将梯度除以相同因子)缓解了这个问题。
- **BF16**(脑浮点)具有与FP32相同的指数范围(8位指数),但精度较低(7位尾数)。它几乎从不溢出,很少需要损失缩放,因此使用更简单。BF16是现代Transformer训练的默认选择。
- 混合精度大致将激活值和梯度的内存减半(前向/反向传播期间的主要成本),同时将优化器状态保留在FP32中以确保数值稳定性。
- **数据并行**是最简单的分布式策略。你在 $N$ 个GPU上复制整个模型,将每个小批量分成 $N$ 个相等的块,并将一个块发送到每个GPU。每个GPU在其块上独立运行前向和反向传播。然后梯度在所有GPU上平均(使用全规约操作),每个GPU更新其本地模型副本。
- 从模型的角度来看,这相当于使用大了 $N$ 倍的小批量进行训练。如果每个GPU处理一个大小为 $B$ 的批次,则有效批量大小为 $N \cdot B$。
![并排比较:数据并行复制模型并分割数据,模型并行分割模型并分享数据](../images/data_model_parallelism.svg)
- 梯度平均可以同步或异步进行。**同步SGD**等待所有GPU完成后再进行平均,确保与使用更大批量的单GPU训练数学上等价。缺点是,最慢的GPU("掉队者")会拖慢所有人。
- **异步SGD**让每个GPU独立地更新一个共享的参数服务器,无需等待。这消除了掉队者问题,但引入了"陈旧梯度":一个GPU可能基于略微过时的参数计算梯度。陈旧梯度增加了噪声,可能减缓收敛。在实践中,带高效通信的同步SGD更受青睐。
- **梯度累积**是一种软件技巧,用于在有限硬件上模拟更大的批量大小。不必每个小批量做一次更新,而是运行多次前向/反向传播并累积梯度,然后做一次更新。这与更大批量得到相同的结果,而无需更多GPU内存用于激活值(一次只有一个小批量的激活值在内存中)。
- 当模型本身太大无法放入单个GPU时,需要**模型并行**。有两种主要形式。
- **张量并行**将单个层分割到多个GPU上。一个大的矩阵乘法 $Y = XW$ 可以按列分割:将 $W$ 分区为 $[W_1, W_2]$ 分布在两个GPU上,并行计算 $Y_1 = XW_1$ 和 $Y_2 = XW_2$,然后拼接。这适用于注意力投影和前馈层。它需要GPU之间快速通信(通常是节点内的NVLink),因为每层都必须组合部分结果。
- **流水线并行**将不同的层分配到不同的GPU上。GPU 0运行第1-4层,GPU 1运行第5-8层,依此类推。数据像流水线一样流经整个管道。朴素的方法有一个"流水线气泡":当GPU 0处理微批次1的前向传播时,GPU 1-3处于空闲状态。**微批处理**通过将小批量分割成更小的微批次来缓解这个问题,这些微批次按顺序流经流水线,使所有GPU大部分时间保持忙碌。
- **混合并行**结合了数据并行、张量并行和流水线并行。一个典型的大模型设置可能使用节点内的张量并行(8个GPU通过快速NVLink连接)、跨节点的流水线并行以及跨节点组的数据并行。这就是GPT-4和Llama等模型的训练方式。
- 分布式训练的效率在很大程度上取决于**通信**。关键操作是**全规约(all-reduce**:给定 $N$ 个GPU上各有一个值,计算总和(或平均值)并将结果分发给所有GPU。
- 朴素的全规约将所有数据发送到一个GPU,求和,然后广播回来。通信量为 $O(N)$,并在根节点造成瓶颈。
- **环全规约(Ring all-reduce** 要高效得多。将 $N$ 个GPU排列成一个环。每个GPU将其数据分割成 $N$ 块。在 $N - 1$ 步中,每个GPU向邻居发送一块,并从另一个邻居接收一块,累加部分和。再经过 $N - 1$ 步后,完整的总和传播到所有GPU。每个GPU的总数据传输量:数据大小的 $2(N-1)/N$ 倍,随着 $N$ 的增长趋近于 $2\times$。关键在于,这不随 $N$ 增加,使其带宽最优。
![四个GPU排列成环,每个将梯度块传递给邻居,直到所有GPU都得到完整总和](../images/ring_allreduce.svg)
- **参数服务器**是一种替代架构,其中专用服务器节点保存模型参数。工作节点计算梯度并将其发送到服务器,服务器更新参数并将其发送回来。这更简单,但可能在服务器处造成通信瓶颈。
- **NCCL**NVIDIA集合通信库)是GPU间通信的标准库。它提供了全规约、全收集、广播和其他集合操作的高效实现,自动为网络拓扑选择最佳算法。
- **缩放定律**描述了模型性能如何随计算量、数据量和模型大小而提升。原始的Kaplan等人(2020)缩放定律发现,损失随每个因素以幂律方式下降:
$$L(N) \propto N^{-\alpha_N}, \quad L(D) \propto D^{-\alpha_D}, \quad L(C) \propto C^{-\alpha_C}$$
- 其中 $N$ 是参数数量,$D$ 是数据集大小,$C$ 是计算预算。
- **Chinchilla缩放定律**Hoffmann等人,2022)表明大多数模型训练不足:对于给定的计算预算,应该训练一个更小的模型,使用比以前认为的更多的数据。最优比例大约是每参数20个token。一个7B模型应该看到大约140B个token,而不是Llama 1在65B模型上使用的300B个token。这一发现将领域转向了"计算最优"训练。
- **混合专家(MoE)** 是一种在不按比例增加计算量的情况下扩展模型容量的架构。每个Transformer层不是使用一个前馈网络,而是有 $N$ 个"专家"网络(每个都是一个标准FFN)。一个**门控网络**(路由器)检查每个token并将其发送到top-$K$个专家(通常 $K = 1$ 或 $K = 2$)。
![token通过门控网络路由到选定的专家,使用top-K稀疏路由和加权输出组合](../images/moe_routing.svg)
- 总参数量要大得多(因为有 $N$ 个专家),但每个token的FLOPs大致保持不变(因为每个token只有 $K$ 个专家激活)。例如,Mixtral 8x7B共有47B个参数,但每次前向传播只用大约13B,以较小模型的代价获得更大模型的性能。
- MoE带来了挑战。**负载均衡**:如果路由器将大多数token发送到同一个专家,其他专家就被浪费了。辅助损失鼓励均匀路由。**通信**:不同的专家可能位于不同的GPU上,因此路由token需要全对全通信,这很昂贵。
- **容错**在训练运行持续数周或数月、涉及数千个GPU时至关重要。如果单个GPU失效,你不想丢失所有进度。**检查点**定期将模型权重、优化器状态和训练状态(学习率、步数、数据位置)保存到磁盘。如果发生故障,你可以从最近的检查点重新开始。
- **梯度检查点**(也称为激活重计算)是一种内存优化,而非容错机制。在前向传播过程中,不是保存所有激活值供反向传播使用,而是只在某些检查点保存激活值。在反向传播过程中,从检查点重新计算缺失的激活值。这以计算换取内存:它使前向传播成本增加约33%,但可以将激活内存减少 $\sqrt{L}$ 倍(其中 $L$ 是层数)。
- 综合起来,训练前沿模型结合了所有这些技术:BF16混合精度、使用环全规约在数千个GPU上进行数据并行、节点内的张量并行、跨节点的流水线并行、减少内存的梯度检查点、提高参数效率的MoE,以及用于容错的定期检查点。系统工程与算法设计一样具有挑战性。
- 总结分布式训练工具包:
| 技术 | 作用 | 权衡 |
|---|---|---|
| 混合精度 (BF16) | 将激活值/梯度的内存减半 | 轻微数值差异 |
| 数据并行 | 在GPU间扩展批量大小 | 梯度同步的通信开销 |
| 张量并行 | 在GPU间分割层 | 需要快速互联 |
| 流水线并行 | 在GPU间分割模型阶段 | 流水线气泡(计算浪费) |
| 梯度累积 | 模拟大批量 | 更慢(多次前向/反向传播) |
| 梯度检查点 | 减少激活内存 | 约多33%计算 |
| 环全规约 | 高效的梯度平均 | 大模型受限于带宽 |
| MoE | 更多容量,相同FLOPs | 负载均衡、路由复杂性 |
| 缩放定律 | 指导计算分配 | 经验公式,未必在所有规模都成立 |
## 编程任务(使用CoLab或笔记本)
1. 计算Transformer层的FLOPs和内存需求。给定隐藏维度 $d$、序列长度 $n$、批量大小 $B$ 和层数,估计总训练成本。
```python
import jax.numpy as jnp
def transformer_layer_flops(d, n, B):
"""一个Transformer层前向传播的近似FLOPs。"""
# QKV投影:3 * (B * n * d * d) * 2(乘法-加法)
qkv_flops = 3 * 2 * B * n * d * d
# 注意力:(B * n * n * d) * 2 用于QK^T(B * n * n * d) * 2 用于attn*V
attn_flops = 2 * 2 * B * n * n * d
# 输出投影:(B * n * d * d) * 2
out_flops = 2 * B * n * d * d
# FFN:两层,d->4d 和 4d->d2 * (B * n * d * 4d) * 2
ffn_flops = 2 * 2 * B * n * d * 4 * d
return qkv_flops + attn_flops + out_flops + ffn_flops
def transformer_layer_memory(d, n, B, dtype_bytes=2):
"""一个层的近似激活内存(字节)。"""
# QKV3 * B * n * d
qkv_mem = 3 * B * n * d * dtype_bytes
# 注意力权重:B * heads * n * n(近似 B * n * n * sizeof
attn_mem = B * n * n * dtype_bytes
# FFN中间值:B * n * 4d
ffn_mem = B * n * 4 * d * dtype_bytes
return qkv_mem + attn_mem + ffn_mem
# 示例:GPT-2规模
d, n, B, L = 1024, 1024, 8, 24
fwd_flops = transformer_layer_flops(d, n, B)
total_flops = 3 * L * fwd_flops # 前向+反向的3倍
act_mem = L * transformer_layer_memory(d, n, B)
param_count = L * (12 * d * d + 13 * d) # 近似
print(f"模型:d={d}, n={n}, B={B}, L={L}")
print(f"参数:{param_count / 1e6:.0f}M")
print(f"每步FLOPs{total_flops / 1e12:.2f} TFLOPs")
print(f"激活内存:{act_mem / 1e9:.2f} GB (BF16)")
print(f"参数内存 (FP32){param_count * 4 / 1e9:.2f} GB")
print(f"Adam优化器内存:{param_count * 8 / 1e9:.2f} GB")
print(f"总训练内存:{(param_count * 16 + act_mem) / 1e9:.2f} GB")
```
2. 模拟数据并行训练。将数据集分割到多个"虚拟GPU"上,独立计算梯度,平均它们,并验证结果与单GPU训练匹配。
```python
import jax
import jax.numpy as jnp
# 简单线性模型:y = wx + b
key = jax.random.PRNGKey(0)
X = jax.random.normal(key, (64, 4))
w_true = jnp.array([1.0, -2.0, 3.0, 0.5])
y = X @ w_true + 0.1 * jax.random.normal(key, (64,))
def loss_fn(w, X, y):
return jnp.mean((X @ w - y) ** 2)
grad_fn = jax.grad(loss_fn)
# 单GPU:全批量梯度
w = jnp.zeros(4)
grad_single = grad_fn(w, X, y)
# 数据并行:分割到4个"GPU"上
n_gpus = 4
chunk_size = len(X) // n_gpus
grads = []
for i in range(n_gpus):
X_chunk = X[i*chunk_size:(i+1)*chunk_size]
y_chunk = y[i*chunk_size:(i+1)*chunk_size]
grads.append(grad_fn(w, X_chunk, y_chunk))
# 全规约:平均梯度
grad_parallel = jnp.mean(jnp.stack(grads), axis=0)
print("单GPU梯度:", grad_single)
print("数据并行梯度(平均):", grad_parallel)
print(f"匹配:{jnp.allclose(grad_single, grad_parallel, atol=1e-5)}")
# 训练两者并比较
w_single, w_parallel = jnp.zeros(4), jnp.zeros(4)
lr = 0.1
for step in range(100):
w_single = w_single - lr * grad_fn(w_single, X, y)
grads = [grad_fn(w_parallel, X[i*chunk_size:(i+1)*chunk_size],
y[i*chunk_size:(i+1)*chunk_size]) for i in range(n_gpus)]
avg_grad = jnp.mean(jnp.stack(grads), axis=0)
w_parallel = w_parallel - lr * avg_grad
print(f"\n100步之后:")
print(f"单GPU权重:{w_single}")
print(f"数据并行权重:{w_parallel}")
print(f"最大差异:{jnp.max(jnp.abs(w_single - w_parallel)):.2e}")
```
3. 实现一个简单的混合专家层。创建一个门控网络,将token路由到top-K个专家并组合它们的输出。
```python
import jax
import jax.numpy as jnp
def expert_fn(x, W1, b1, W2, b2):
"""简单的2层FFN专家。"""
h = jnp.maximum(0, x @ W1 + b1) # ReLU
return h @ W2 + b2
def moe_layer(x, gate_W, experts_params, top_k=2):
"""
MoE前向传播。
x: (batch, d_model)
gate_W: (d_model, n_experts)
experts_params: 每个专家的 (W1, b1, W2, b2) 列表
"""
n_experts = len(experts_params)
# 门控:计算路由分数
gate_logits = x @ gate_W # (batch, n_experts)
gate_probs = jax.nn.softmax(gate_logits, axis=-1)
# Top-K选择
top_k_indices = jnp.argsort(-gate_probs, axis=-1)[:, :top_k]
top_k_probs = jnp.take_along_axis(gate_probs, top_k_indices, axis=-1)
# 重新归一化
top_k_probs = top_k_probs / jnp.sum(top_k_probs, axis=-1, keepdims=True)
# 计算专家输出(简化:运行所有专家,稍后掩码)
expert_outputs = jnp.stack([
expert_fn(x, *experts_params[i]) for i in range(n_experts)
], axis=1) # (batch, n_experts, d_model)
# 收集top-K专家输出并加权
batch_idx = jnp.arange(x.shape[0])[:, None]
selected_outputs = expert_outputs[batch_idx, top_k_indices] # (batch, top_k, d_model)
output = jnp.sum(selected_outputs * top_k_probs[:, :, None], axis=1)
return output, gate_probs
# 设置
key = jax.random.PRNGKey(42)
batch, d_model, d_ff, n_experts = 8, 16, 32, 4
# 初始化专家
experts_params = []
for i in range(n_experts):
k1, k2, key = jax.random.split(key, 3)[0], jax.random.split(key, 3)[1], jax.random.split(key, 3)[2]
experts_params.append((
jax.random.normal(k1, (d_model, d_ff)) * 0.1,
jnp.zeros(d_ff),
jax.random.normal(k2, (d_ff, d_model)) * 0.1,
jnp.zeros(d_model),
))
key, subkey = jax.random.split(key)
gate_W = jax.random.normal(subkey, (d_model, n_experts)) * 0.1
x = jax.random.normal(key, (batch, d_model))
output, gate_probs = moe_layer(x, gate_W, experts_params, top_k=2)
print(f"输入形状:{x.shape}")
print(f"输出形状:{output.shape}")
print(f"门控概率(第一个样本):{gate_probs[0]}")
print(f"专家使用率(批量平均):")
for i in range(n_experts):
usage = jnp.mean(gate_probs[:, i])
print(f" 专家 {i}: {usage:.3f}")
```