2536c937e3
翻译自英文原版 maths-cs-ai-compendium,共 20 章全部完成。 第01章 向量 | 第02章 矩阵 | 第03章 微积分 第04章 统计学 | 第05章 概率论 | 第06章 机器学习 第07章 计算语言学 | 第08章 计算机视觉 | 第09章 音频与语音 第10章 多模态学习 | 第11章 自主系统 | 第12章 图神经网络 第13章 计算与操作系统 | 第14章 数据结构与算法 第15章 生产级软件工程 | 第16章 SIMD与GPU编程 第17章 AI推理 | 第18章 ML系统设计 第19章 应用人工智能 | 第20章 前沿人工智能 翻译说明: - 所有数学公式 $...$ / $$...$$、代码块、图片引用完整保留 - mkdocs.yml 配置中文导航 + language: zh - README.md 已翻译为中文(兼 docs/index.md) - docs/ 目录包含指向各章文件的 symlink - 约 29,000 行中文内容,排除 .cache/ 构建缓存
265 lines
17 KiB
Markdown
265 lines
17 KiB
Markdown
# 分布式深度学习
|
||
|
||
*分布式训练将计算分散到多个GPU和机器上,以训练单个设备无法容纳或训练太慢的模型。本文件涵盖混合精度、数据并行、模型并行、流水线并行、ZeRO、FSDP、张量并行以及全规约等通信原语——这些对于大规模训练LLM至关重要。*
|
||
|
||
- 在单个GPU上训练大型神经网络最终会遇到瓶颈。模型可能无法放入内存,或者训练可能需要数月。分布式训练将工作分散到多个设备(GPU、TPU或整台机器)上,以更快地训练和训练更大的模型。本文件涵盖了实现这一目标的技术。
|
||
|
||
- 要理解为何分布式重要,从训练的**计算成本**开始。在一个包含 $d_{\text{in}}$ 个输入和 $d_{\text{out}}$ 个输出的密集层上,对一批 $B$ 个样本进行一次前向传播需要大约 $2 \cdot B \cdot d_{\text{in}} \cdot d_{\text{out}}$ 次FLOP(浮点运算):对输出矩阵的每个元素进行一次乘法和一次加法。反向传播的成本大约是前向传播的两倍(计算相对于输入和权重的梯度),因此一个密集层的一个训练步骤约为 $6 \cdot B \cdot d_{\text{in}} \cdot d_{\text{out}}$ 次FLOP。
|
||
|
||
- 对于隐藏维度为 $d$ 的Transformer层,自注意力块涉及四个投影(Q、K、V和输出),每个的成本为 $O(B \cdot n \cdot d^2)$ 次FLOP(其中 $n$ 是序列长度),加上注意力矩阵计算 $O(B \cdot n^2 \cdot d)$。前馈块有两个密集层,通常扩展到 $4d$ 再回来:$O(B \cdot n \cdot 8d^2)$。每层总计:大约 $O(B \cdot n \cdot 12d^2 + B \cdot n^2 \cdot d)$。乘以层数,你就会明白为什么训练GPT规模的模型需要数千个GPU小时。
|
||
|
||
- **内存墙**通常是更严格的约束。在训练期间,GPU内存必须同时容纳四样东西:
|
||
|
||

|
||
|
||
- **参数**:模型权重。一个70亿参数的模型在FP32中(每个参数4字节)仅权重就需要28 GB。
|
||
- **梯度**:与参数大小相同。又是28 GB。
|
||
- **优化器状态**:Adam维护两个额外的缓冲区(一阶和二阶矩估计),每个与参数大小相同。即使模型使用较低精度,这些也以FP32格式保存以确保数值稳定性。对于我们的7B模型,那就是 $2 \times 28 = 56$ GB。
|
||
- **激活值**:在前向传播过程中保存下来供反向传播使用的中间值。大小取决于批量大小、序列长度和模型宽度。这通常是最主要的组成部分,并随批量大小线性增长。
|
||
|
||
- 对于使用FP32 Adam的7B模型:28(参数)+ 28(梯度)+ 56(优化器)= 112 GB,这还没算激活值。单个80 GB的A100 GPU无法容纳。这就是分布式策略至关重要的原因。
|
||
|
||
- **混合精度训练**是第一道防线。不是将所有内容存储在FP32(32位浮点)中,而是使用FP16或BF16(16位)进行前向和反向传播,同时将权重的FP32主副本保留给优化器更新。
|
||
|
||
- **FP16**具有高精度(10位尾数),但范围有限,可能导致上溢/下溢。损失缩放(在反向传播前将损失乘以一个大因子,然后将梯度除以相同因子)缓解了这个问题。
|
||
|
||
- **BF16**(脑浮点)具有与FP32相同的指数范围(8位指数),但精度较低(7位尾数)。它几乎从不溢出,很少需要损失缩放,因此使用更简单。BF16是现代Transformer训练的默认选择。
|
||
|
||
- 混合精度大致将激活值和梯度的内存减半(前向/反向传播期间的主要成本),同时将优化器状态保留在FP32中以确保数值稳定性。
|
||
|
||
- **数据并行**是最简单的分布式策略。你在 $N$ 个GPU上复制整个模型,将每个小批量分成 $N$ 个相等的块,并将一个块发送到每个GPU。每个GPU在其块上独立运行前向和反向传播。然后梯度在所有GPU上平均(使用全规约操作),每个GPU更新其本地模型副本。
|
||
|
||
- 从模型的角度来看,这相当于使用大了 $N$ 倍的小批量进行训练。如果每个GPU处理一个大小为 $B$ 的批次,则有效批量大小为 $N \cdot B$。
|
||
|
||

|
||
|
||
- 梯度平均可以同步或异步进行。**同步SGD**等待所有GPU完成后再进行平均,确保与使用更大批量的单GPU训练数学上等价。缺点是,最慢的GPU("掉队者")会拖慢所有人。
|
||
|
||
- **异步SGD**让每个GPU独立地更新一个共享的参数服务器,无需等待。这消除了掉队者问题,但引入了"陈旧梯度":一个GPU可能基于略微过时的参数计算梯度。陈旧梯度增加了噪声,可能减缓收敛。在实践中,带高效通信的同步SGD更受青睐。
|
||
|
||
- **梯度累积**是一种软件技巧,用于在有限硬件上模拟更大的批量大小。不必每个小批量做一次更新,而是运行多次前向/反向传播并累积梯度,然后做一次更新。这与更大批量得到相同的结果,而无需更多GPU内存用于激活值(一次只有一个小批量的激活值在内存中)。
|
||
|
||
- 当模型本身太大无法放入单个GPU时,需要**模型并行**。有两种主要形式。
|
||
|
||
- **张量并行**将单个层分割到多个GPU上。一个大的矩阵乘法 $Y = XW$ 可以按列分割:将 $W$ 分区为 $[W_1, W_2]$ 分布在两个GPU上,并行计算 $Y_1 = XW_1$ 和 $Y_2 = XW_2$,然后拼接。这适用于注意力投影和前馈层。它需要GPU之间快速通信(通常是节点内的NVLink),因为每层都必须组合部分结果。
|
||
|
||
- **流水线并行**将不同的层分配到不同的GPU上。GPU 0运行第1-4层,GPU 1运行第5-8层,依此类推。数据像流水线一样流经整个管道。朴素的方法有一个"流水线气泡":当GPU 0处理微批次1的前向传播时,GPU 1-3处于空闲状态。**微批处理**通过将小批量分割成更小的微批次来缓解这个问题,这些微批次按顺序流经流水线,使所有GPU大部分时间保持忙碌。
|
||
|
||
- **混合并行**结合了数据并行、张量并行和流水线并行。一个典型的大模型设置可能使用节点内的张量并行(8个GPU通过快速NVLink连接)、跨节点的流水线并行以及跨节点组的数据并行。这就是GPT-4和Llama等模型的训练方式。
|
||
|
||
- 分布式训练的效率在很大程度上取决于**通信**。关键操作是**全规约(all-reduce)**:给定 $N$ 个GPU上各有一个值,计算总和(或平均值)并将结果分发给所有GPU。
|
||
|
||
- 朴素的全规约将所有数据发送到一个GPU,求和,然后广播回来。通信量为 $O(N)$,并在根节点造成瓶颈。
|
||
|
||
- **环全规约(Ring all-reduce)** 要高效得多。将 $N$ 个GPU排列成一个环。每个GPU将其数据分割成 $N$ 块。在 $N - 1$ 步中,每个GPU向邻居发送一块,并从另一个邻居接收一块,累加部分和。再经过 $N - 1$ 步后,完整的总和传播到所有GPU。每个GPU的总数据传输量:数据大小的 $2(N-1)/N$ 倍,随着 $N$ 的增长趋近于 $2\times$。关键在于,这不随 $N$ 增加,使其带宽最优。
|
||
|
||

|
||
|
||
- **参数服务器**是一种替代架构,其中专用服务器节点保存模型参数。工作节点计算梯度并将其发送到服务器,服务器更新参数并将其发送回来。这更简单,但可能在服务器处造成通信瓶颈。
|
||
|
||
- **NCCL**(NVIDIA集合通信库)是GPU间通信的标准库。它提供了全规约、全收集、广播和其他集合操作的高效实现,自动为网络拓扑选择最佳算法。
|
||
|
||
- **缩放定律**描述了模型性能如何随计算量、数据量和模型大小而提升。原始的Kaplan等人(2020)缩放定律发现,损失随每个因素以幂律方式下降:
|
||
|
||
$$L(N) \propto N^{-\alpha_N}, \quad L(D) \propto D^{-\alpha_D}, \quad L(C) \propto C^{-\alpha_C}$$
|
||
|
||
- 其中 $N$ 是参数数量,$D$ 是数据集大小,$C$ 是计算预算。
|
||
|
||
- **Chinchilla缩放定律**(Hoffmann等人,2022)表明大多数模型训练不足:对于给定的计算预算,应该训练一个更小的模型,使用比以前认为的更多的数据。最优比例大约是每参数20个token。一个7B模型应该看到大约140B个token,而不是Llama 1在65B模型上使用的300B个token。这一发现将领域转向了"计算最优"训练。
|
||
|
||
- **混合专家(MoE)** 是一种在不按比例增加计算量的情况下扩展模型容量的架构。每个Transformer层不是使用一个前馈网络,而是有 $N$ 个"专家"网络(每个都是一个标准FFN)。一个**门控网络**(路由器)检查每个token并将其发送到top-$K$个专家(通常 $K = 1$ 或 $K = 2$)。
|
||
|
||

|
||
|
||
- 总参数量要大得多(因为有 $N$ 个专家),但每个token的FLOPs大致保持不变(因为每个token只有 $K$ 个专家激活)。例如,Mixtral 8x7B共有47B个参数,但每次前向传播只用大约13B,以较小模型的代价获得更大模型的性能。
|
||
|
||
- MoE带来了挑战。**负载均衡**:如果路由器将大多数token发送到同一个专家,其他专家就被浪费了。辅助损失鼓励均匀路由。**通信**:不同的专家可能位于不同的GPU上,因此路由token需要全对全通信,这很昂贵。
|
||
|
||
- **容错**在训练运行持续数周或数月、涉及数千个GPU时至关重要。如果单个GPU失效,你不想丢失所有进度。**检查点**定期将模型权重、优化器状态和训练状态(学习率、步数、数据位置)保存到磁盘。如果发生故障,你可以从最近的检查点重新开始。
|
||
|
||
- **梯度检查点**(也称为激活重计算)是一种内存优化,而非容错机制。在前向传播过程中,不是保存所有激活值供反向传播使用,而是只在某些检查点保存激活值。在反向传播过程中,从检查点重新计算缺失的激活值。这以计算换取内存:它使前向传播成本增加约33%,但可以将激活内存减少 $\sqrt{L}$ 倍(其中 $L$ 是层数)。
|
||
|
||
- 综合起来,训练前沿模型结合了所有这些技术:BF16混合精度、使用环全规约在数千个GPU上进行数据并行、节点内的张量并行、跨节点的流水线并行、减少内存的梯度检查点、提高参数效率的MoE,以及用于容错的定期检查点。系统工程与算法设计一样具有挑战性。
|
||
|
||
- 总结分布式训练工具包:
|
||
|
||
| 技术 | 作用 | 权衡 |
|
||
|---|---|---|
|
||
| 混合精度 (BF16) | 将激活值/梯度的内存减半 | 轻微数值差异 |
|
||
| 数据并行 | 在GPU间扩展批量大小 | 梯度同步的通信开销 |
|
||
| 张量并行 | 在GPU间分割层 | 需要快速互联 |
|
||
| 流水线并行 | 在GPU间分割模型阶段 | 流水线气泡(计算浪费) |
|
||
| 梯度累积 | 模拟大批量 | 更慢(多次前向/反向传播) |
|
||
| 梯度检查点 | 减少激活内存 | 约多33%计算 |
|
||
| 环全规约 | 高效的梯度平均 | 大模型受限于带宽 |
|
||
| MoE | 更多容量,相同FLOPs | 负载均衡、路由复杂性 |
|
||
| 缩放定律 | 指导计算分配 | 经验公式,未必在所有规模都成立 |
|
||
|
||
## 编程任务(使用CoLab或笔记本)
|
||
|
||
1. 计算Transformer层的FLOPs和内存需求。给定隐藏维度 $d$、序列长度 $n$、批量大小 $B$ 和层数,估计总训练成本。
|
||
```python
|
||
import jax.numpy as jnp
|
||
|
||
def transformer_layer_flops(d, n, B):
|
||
"""一个Transformer层前向传播的近似FLOPs。"""
|
||
# QKV投影:3 * (B * n * d * d) * 2(乘法-加法)
|
||
qkv_flops = 3 * 2 * B * n * d * d
|
||
# 注意力:(B * n * n * d) * 2 用于QK^T,(B * n * n * d) * 2 用于attn*V
|
||
attn_flops = 2 * 2 * B * n * n * d
|
||
# 输出投影:(B * n * d * d) * 2
|
||
out_flops = 2 * B * n * d * d
|
||
# FFN:两层,d->4d 和 4d->d:2 * (B * n * d * 4d) * 2
|
||
ffn_flops = 2 * 2 * B * n * d * 4 * d
|
||
return qkv_flops + attn_flops + out_flops + ffn_flops
|
||
|
||
def transformer_layer_memory(d, n, B, dtype_bytes=2):
|
||
"""一个层的近似激活内存(字节)。"""
|
||
# QKV:3 * B * n * d
|
||
qkv_mem = 3 * B * n * d * dtype_bytes
|
||
# 注意力权重:B * heads * n * n(近似 B * n * n * sizeof)
|
||
attn_mem = B * n * n * dtype_bytes
|
||
# FFN中间值:B * n * 4d
|
||
ffn_mem = B * n * 4 * d * dtype_bytes
|
||
return qkv_mem + attn_mem + ffn_mem
|
||
|
||
# 示例:GPT-2规模
|
||
d, n, B, L = 1024, 1024, 8, 24
|
||
fwd_flops = transformer_layer_flops(d, n, B)
|
||
total_flops = 3 * L * fwd_flops # 前向+反向的3倍
|
||
act_mem = L * transformer_layer_memory(d, n, B)
|
||
param_count = L * (12 * d * d + 13 * d) # 近似
|
||
|
||
print(f"模型:d={d}, n={n}, B={B}, L={L}")
|
||
print(f"参数:{param_count / 1e6:.0f}M")
|
||
print(f"每步FLOPs:{total_flops / 1e12:.2f} TFLOPs")
|
||
print(f"激活内存:{act_mem / 1e9:.2f} GB (BF16)")
|
||
print(f"参数内存 (FP32):{param_count * 4 / 1e9:.2f} GB")
|
||
print(f"Adam优化器内存:{param_count * 8 / 1e9:.2f} GB")
|
||
print(f"总训练内存:{(param_count * 16 + act_mem) / 1e9:.2f} GB")
|
||
```
|
||
|
||
2. 模拟数据并行训练。将数据集分割到多个"虚拟GPU"上,独立计算梯度,平均它们,并验证结果与单GPU训练匹配。
|
||
```python
|
||
import jax
|
||
import jax.numpy as jnp
|
||
|
||
# 简单线性模型:y = wx + b
|
||
key = jax.random.PRNGKey(0)
|
||
X = jax.random.normal(key, (64, 4))
|
||
w_true = jnp.array([1.0, -2.0, 3.0, 0.5])
|
||
y = X @ w_true + 0.1 * jax.random.normal(key, (64,))
|
||
|
||
def loss_fn(w, X, y):
|
||
return jnp.mean((X @ w - y) ** 2)
|
||
|
||
grad_fn = jax.grad(loss_fn)
|
||
|
||
# 单GPU:全批量梯度
|
||
w = jnp.zeros(4)
|
||
grad_single = grad_fn(w, X, y)
|
||
|
||
# 数据并行:分割到4个"GPU"上
|
||
n_gpus = 4
|
||
chunk_size = len(X) // n_gpus
|
||
grads = []
|
||
for i in range(n_gpus):
|
||
X_chunk = X[i*chunk_size:(i+1)*chunk_size]
|
||
y_chunk = y[i*chunk_size:(i+1)*chunk_size]
|
||
grads.append(grad_fn(w, X_chunk, y_chunk))
|
||
|
||
# 全规约:平均梯度
|
||
grad_parallel = jnp.mean(jnp.stack(grads), axis=0)
|
||
|
||
print("单GPU梯度:", grad_single)
|
||
print("数据并行梯度(平均):", grad_parallel)
|
||
print(f"匹配:{jnp.allclose(grad_single, grad_parallel, atol=1e-5)}")
|
||
|
||
# 训练两者并比较
|
||
w_single, w_parallel = jnp.zeros(4), jnp.zeros(4)
|
||
lr = 0.1
|
||
for step in range(100):
|
||
w_single = w_single - lr * grad_fn(w_single, X, y)
|
||
|
||
grads = [grad_fn(w_parallel, X[i*chunk_size:(i+1)*chunk_size],
|
||
y[i*chunk_size:(i+1)*chunk_size]) for i in range(n_gpus)]
|
||
avg_grad = jnp.mean(jnp.stack(grads), axis=0)
|
||
w_parallel = w_parallel - lr * avg_grad
|
||
|
||
print(f"\n100步之后:")
|
||
print(f"单GPU权重:{w_single}")
|
||
print(f"数据并行权重:{w_parallel}")
|
||
print(f"最大差异:{jnp.max(jnp.abs(w_single - w_parallel)):.2e}")
|
||
```
|
||
|
||
3. 实现一个简单的混合专家层。创建一个门控网络,将token路由到top-K个专家并组合它们的输出。
|
||
```python
|
||
import jax
|
||
import jax.numpy as jnp
|
||
|
||
def expert_fn(x, W1, b1, W2, b2):
|
||
"""简单的2层FFN专家。"""
|
||
h = jnp.maximum(0, x @ W1 + b1) # ReLU
|
||
return h @ W2 + b2
|
||
|
||
def moe_layer(x, gate_W, experts_params, top_k=2):
|
||
"""
|
||
MoE前向传播。
|
||
x: (batch, d_model)
|
||
gate_W: (d_model, n_experts)
|
||
experts_params: 每个专家的 (W1, b1, W2, b2) 列表
|
||
"""
|
||
n_experts = len(experts_params)
|
||
|
||
# 门控:计算路由分数
|
||
gate_logits = x @ gate_W # (batch, n_experts)
|
||
gate_probs = jax.nn.softmax(gate_logits, axis=-1)
|
||
|
||
# Top-K选择
|
||
top_k_indices = jnp.argsort(-gate_probs, axis=-1)[:, :top_k]
|
||
top_k_probs = jnp.take_along_axis(gate_probs, top_k_indices, axis=-1)
|
||
# 重新归一化
|
||
top_k_probs = top_k_probs / jnp.sum(top_k_probs, axis=-1, keepdims=True)
|
||
|
||
# 计算专家输出(简化:运行所有专家,稍后掩码)
|
||
expert_outputs = jnp.stack([
|
||
expert_fn(x, *experts_params[i]) for i in range(n_experts)
|
||
], axis=1) # (batch, n_experts, d_model)
|
||
|
||
# 收集top-K专家输出并加权
|
||
batch_idx = jnp.arange(x.shape[0])[:, None]
|
||
selected_outputs = expert_outputs[batch_idx, top_k_indices] # (batch, top_k, d_model)
|
||
output = jnp.sum(selected_outputs * top_k_probs[:, :, None], axis=1)
|
||
|
||
return output, gate_probs
|
||
|
||
# 设置
|
||
key = jax.random.PRNGKey(42)
|
||
batch, d_model, d_ff, n_experts = 8, 16, 32, 4
|
||
|
||
# 初始化专家
|
||
experts_params = []
|
||
for i in range(n_experts):
|
||
k1, k2, key = jax.random.split(key, 3)[0], jax.random.split(key, 3)[1], jax.random.split(key, 3)[2]
|
||
experts_params.append((
|
||
jax.random.normal(k1, (d_model, d_ff)) * 0.1,
|
||
jnp.zeros(d_ff),
|
||
jax.random.normal(k2, (d_ff, d_model)) * 0.1,
|
||
jnp.zeros(d_model),
|
||
))
|
||
|
||
key, subkey = jax.random.split(key)
|
||
gate_W = jax.random.normal(subkey, (d_model, n_experts)) * 0.1
|
||
x = jax.random.normal(key, (batch, d_model))
|
||
|
||
output, gate_probs = moe_layer(x, gate_W, experts_params, top_k=2)
|
||
|
||
print(f"输入形状:{x.shape}")
|
||
print(f"输出形状:{output.shape}")
|
||
print(f"门控概率(第一个样本):{gate_probs[0]}")
|
||
print(f"专家使用率(批量平均):")
|
||
for i in range(n_experts):
|
||
usage = jnp.mean(gate_probs[:, i])
|
||
print(f" 专家 {i}: {usage:.3f}")
|
||
```
|