Files
maths-cs-ai-compendium-zh/chapter 16: SIMD and GPU programming/05. triton, TPUs and pallax.md
T
flykhan 2536c937e3 feat: 完整中文翻译 maths-cs-ai-compendium(数学·计算机科学·AI 知识大全)
翻译自英文原版 maths-cs-ai-compendium,共 20 章全部完成。

第01章 向量 | 第02章 矩阵 | 第03章 微积分
第04章 统计学 | 第05章 概率论 | 第06章 机器学习
第07章 计算语言学 | 第08章 计算机视觉 | 第09章 音频与语音
第10章 多模态学习 | 第11章 自主系统 | 第12章 图神经网络
第13章 计算与操作系统 | 第14章 数据结构与算法
第15章 生产级软件工程 | 第16章 SIMD与GPU编程
第17章 AI推理 | 第18章 ML系统设计
第19章 应用人工智能 | 第20章 前沿人工智能

翻译说明:
- 所有数学公式 $...$ / $$...$$、代码块、图片引用完整保留
- mkdocs.yml 配置中文导航 + language: zh
- README.md 已翻译为中文(兼 docs/index.md)
- docs/ 目录包含指向各章文件的 symlink
- 约 29,000 行中文内容,排除 .cache/ 构建缓存
2026-05-03 10:23:20 +08:00

394 lines
15 KiB
Markdown
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# Triton与TPU
*CUDA C功能强大但冗长。Triton让你用Python编写GPU核函数。TPU提供了GPU之外的选择,具有不同的权衡。本文涵盖Triton核函数编程、以Flash Attention为案例研究、TPU架构与JAX/Pallas,以及如何选择合适的工具。关于Vulkan和跨平台GPU计算,请参见文件07。*
- 上篇文件教授了CUDA C中的GPU编程。本文更上一层抽象阶梯:Triton以20%的工作量提供CUDA 80%的性能,且用Python。TPU和Vulkan为特定用例提供替代硬件目标。
## Triton:用Python编写GPU核函数
- **Triton**OpenAI)是一种基于Python的GPU核函数编写语言。你不需要思考单个线程(CUDA),而是思考**块**级数据。Triton的编译器自动处理线程映射、内存合并、共享内存管理和许多优化。
- **为什么Triton重要**:CUDA C需要对线程束调度、共享内存存储体冲突、寄存器压力和合并模式有深入理解。Triton抽象了其中大部分内容,使GPU核函数开发对了解Python但不了解系统编程的ML研究人员可及。
### 你的第一个Triton核函数
```python
import triton
import triton.language as tl
import torch
@triton.jit
def add_kernel(
x_ptr, y_ptr, output_ptr,
n_elements,
BLOCK_SIZE: tl.constexpr, # 编译时常量
):
# 每个程序实例处理一个BLOCK_SIZE元素的块
pid = tl.program_id(axis=0) # 我是哪个块?
block_start = pid * BLOCK_SIZE
# 此块的偏移量
offsets = block_start + tl.arange(0, BLOCK_SIZE)
# 掩码处理n_elements不是BLOCK_SIZE倍数的情况
mask = offsets < n_elements
# 加载数据(带掩码:越界读取返回0
x = tl.load(x_ptr + offsets, mask=mask)
y = tl.load(y_ptr + offsets, mask=mask)
# 计算
output = x + y
# 存储结果
tl.store(output_ptr + offsets, output, mask=mask)
def add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
output = torch.empty_like(x)
n_elements = output.numel()
# 启动:每个块一个程序
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024)
return output
# 使用
x = torch.randn(1000000, device='cuda')
y = torch.randn(1000000, device='cuda')
z = add(x, y)
```
- **与CUDA的关键区别**
- 无需显式线程管理。你思考**块**(程序),而非线程。
- `tl.arange(0, BLOCK_SIZE)` 为整个块创建一个偏移向量。此向量上的所有操作都隐式向量化。
- `mask` 处理边界条件(类似于AVX-512掩码寄存器,文件03)。无需标量清理循环。
- `tl.load``tl.store` 自动处理合并访问。
- `@triton.jit` 在首次调用时将函数编译为PTX(GPU汇编),然后缓存编译后的核函数。
### Triton Softmax核函数
- Softmax是一个很好的Triton示例,因为它需要对数据进行多次遍历(最大值、减去、指数、求和、除法),并且受益于在多次遍历之间将数据保留在SRAM(共享内存)中:
```python
@triton.jit
def softmax_kernel(
output_ptr, input_ptr, input_row_stride, output_row_stride, n_cols,
BLOCK_SIZE: tl.constexpr,
):
# 每个程序处理一行
row_idx = tl.program_id(0)
row_start = input_ptr + row_idx * input_row_stride
# 加载该行
col_offsets = tl.arange(0, BLOCK_SIZE)
mask = col_offsets < n_cols
row = tl.load(row_start + col_offsets, mask=mask, other=-float('inf'))
# Softmax:为数值稳定性取最大值,然后exp,然后归一化
row_max = tl.max(row, axis=0)
numerator = tl.exp(row - row_max)
denominator = tl.sum(numerator, axis=0)
softmax_output = numerator / denominator
# 存储结果
output_start = output_ptr + row_idx * output_row_stride
tl.store(output_start + col_offsets, softmax_output, mask=mask)
```
- 在PyTorch中,`F.softmax(x, dim=-1)` 启动3个独立核函数(最大值、指数-求和、除法),每个都从全局内存读取和写入。Triton版本在一个核函数内完成所有操作,将数据保留在寄存器/SRAM中。这种**核函数融合**就是自定义Triton核函数可以比PyTorch内置操作快2-4倍的原因。
### Triton自动调优
- Triton支持**自动调优**:尝试多种配置并选择最快的:
```python
@triton.autotune(
configs=[
triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}),
triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32}),
triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64}),
],
key=['M', 'N', 'K'], # 当这些变化时重新调优
)
@triton.jit
def matmul_kernel(a_ptr, b_ptr, c_ptr, M, N, K, ...):
...
```
- Triton在实际硬件上对每种配置进行基准测试并选择最快者。最优瓦片大小取决于GPU架构、矩阵维度和内存布局——自动调优无需手动实验即可找到它们。
### Triton vs CUDA:何时使用
| | Triton | CUDA C |
|--|--------|--------|
| 语言 | Python | C/C++ |
| 抽象层级 | 块级 | 线程级 |
| 开发速度 | 快(每核函数10-50行) | 慢(100-500行) |
| 性能天花板 | 手工调优CUDA的约80-95% | 100%(完全硬件控制) |
| 共享内存 | 自动 | 手动 |
| 合并 | 自动 | 手动 |
| 线程束级原语 | 有限 | 完整(shuffle、vote等) |
| 硬件支持 | 仅NVIDIAAMD实验性) | 仅NVIDIA |
- **使用Triton**对于:融合核函数、自定义注意力模式、激活函数、大多数ML研究核函数需求。
- **使用CUDA C**对于:最高性能(最后5-20%)、线程束级原语、复杂数据相关并行性、当Triton无法表达你的模式。
## 案例研究:Flash Attention
- **Flash Attention**Dao等人,2022)是近年来最具影响力的自定义核函数。它以 $O(n)$ 内存而非 $O(n^2)$ 计算注意力,使得更长的序列成为可能。
- **问题**:标准注意力计算 $\\text{softmax}(QK^T / \\sqrt{d}) \\cdot V$。$QK^T$ 矩阵是 $n \\times n$,其中 $n$ 是序列长度。对于 $n = 128K$,此矩阵为 $128K \\times 128K \\times 4$ 字节 = 64 GB。它无法放入GPU内存。
- **关键洞察**:你不需要具体化完整的 $n \\times n$ 矩阵。按**瓦片**计算注意力:加载一组 $Q$、一组 $K$,计算它们的部分注意力得分,累加,然后移动到下一个块。$n \\times n$ 矩阵从未完全具体化——每次只有一块存在于SRAM中。
- **在线softmax**:棘手的部分是softmax,它需要知道整个行上的最大值(为数值稳定性)。Flash Attention使用**在线softmax**技巧:维护一个运行中的最大值,当发现新的最大值时重新缩放先前计算的值。这允许softmax以增量方式逐块计算。
- 算法:
```
对于每个Q行块:
对于每个K列块:
1. 将Q_block从HBM加载到SRAM
2. 将K_block从HBM加载到SRAM
3. 计算S_block = Q_block @ K_block.T(在SRAM中)
4. 更新运行中最大值,重新缩放先前结果
5. 计算exp(S_block - 运行中最大值)
6. 更新运行中求和和输出累加器
加载V_block并计算最终输出
将输出块写回HBM
```
- **为什么它快**:内循环完全在SRAM(共享内存)中操作。全局内存(HBM)仅用于加载Q、K、V块和写入最终输出。数据重用因子与SRAM大小成正比,而SRAM比HBM快约100倍。
- Flash Attention在Triton和CUDA C中都有实现。CUDA版本更快(效率高约10%),但Triton版本更具可读性和可修改性,这对研究新的注意力变体很重要。
## TPU架构
- **TPU**(张量处理单元)是Google的自定义ML加速器。它们采用与GPU截然不同的方法:
- **脉动阵列**:TPU的核心计算单元是**矩阵乘法单元(MXU)**,一个128×128或256×256的脉动阵列,通过让数据流经乘加单元网格来计算矩阵乘法。数据从边缘进入并通过阵列传播,每个单元执行一次乘加并将结果传递给下一个。
- 与GPU(调度数千个独立线程)不同,脉动阵列是单一的确定性数据流。没有线程调度、没有线程束分歧、没有分支预测。这种简朴性使MXU在矩阵乘法方面极其能效高效。
- **HBM**TPU使用与GPU相同的高带宽内存。TPU v5e每芯片16 GB HBM2eTPU v5p每芯片95 GB HBM2e。
- **ICI**(芯片间互连):TPU Pod用自定义高速网络连接数百个TPU。JAX原生支持跨TPU Pod的数据并行性和模型并行性(第6章)。
- **BFloat16**TPU是首个使用bfloat16的(第13章文件02)。BF16具有与float32相同的指数范围(防止训练期间溢出),尾数精度较低。这种权衡对ML是理想的,其中梯度值范围广但不需要23位精度。
### 编程TPUJAX与Pallas
- TPU通过**JAX**和**XLA**编程。你编写Python/JAX代码,`jax.jit` 将其编译为XLA HLO,XLA将HLO编译为TPU特定的指令。无需CUDA,无需C++。
```python
import jax
import jax.numpy as jnp
@jax.jit
def matmul(a, b):
return jnp.dot(a, b)
# 这将根据设备在CPU、GPU或TPU上运行
a = jnp.ones((1024, 1024))
b = jnp.ones((1024, 1024))
c = matmul(a, b)
```
- **Pallas**是JAX的核函数编写API——JAX版的Triton。它让你编写低级核函数,XLA将其编译为GPU或TPU:
```python
from jax.experimental import pallas as pl
import jax.numpy as jnp
def add_kernel(x_ref, y_ref, o_ref):
o_ref[...] = x_ref[...] + y_ref[...]
def add_pallas(x, y):
return pl.pallas_call(
add_kernel,
out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype),
grid=(x.shape[0] // 128,),
in_specs=[pl.BlockSpec((128,), lambda i: (i,)),
pl.BlockSpec((128,), lambda i: (i,))],
out_specs=pl.BlockSpec((128,), lambda i: (i,)),
)(x, y)
```
- Pallas比Triton更新且不太成熟,但它是为TPU编写自定义核函数的唯一方式(因为TPU不支持CUDA)。
### GPU vs TPU
| | GPUNVIDIA | TPUGoogle |
|--|-------------|--------------|
| 可用性 | 任何云、本地部署 | 仅Google Cloud |
| 编程 | CUDA C、Triton、PyTorch | JAX/XLA、Pallas |
| 灵活性 | 通用计算 | 针对矩阵密集型ML优化 |
| 峰值矩阵乘法FLOPS | 非常高(张量核心) | 非常高(MXU) |
| 非矩阵乘法操作 | 好 | 较慢(通过向量单元路由,而非MXU) |
| 多芯片扩展 | NVLink8个GPU)、InfiniBand | ICI(数千个TPU,更紧密集成) |
| 成本效率 | 有竞争力 | 大规模训练通常更便宜 |
| 生态系统 | 最大(PyTorch、TensorFlow、JAX | 面向JAX |
- **使用GPU**对于:大多数ML工作负载、基于PyTorch的研究、推理服务、有大量非矩阵乘法计算的工作负载。
- **使用TPU**对于:大规模JAX训练(数千芯片)、Google Cloud上的成本敏感训练、以矩阵乘法为主的工作负载。
## 选择合适的工具
| 工作负载 | 最佳工具 | 为什么 |
|----------|---------|-------|
| ML训练(PyTorch | NVIDIA GPU + CUDA/Triton | 最大生态系统、最佳工具链 |
| ML训练(JAX,大规模) | TPU或NVIDIA GPU | TPU在Google规模下成本低,GPU灵活 |
| 自定义融合核函数 | TritonPython)或CUDA C | Triton开发速度快,CUDA峰值性能高 |
| JAX自定义核函数 | Pallas | TPU唯一选项,也可在GPU上工作 |
| 跨平台推理 | Vulkan(文件07)或ONNX Runtime | 运行在任何GPU供应商上 |
| 移动/边缘推理 | MetalApple)、VulkanAndroid)、NNAPI | 平台特定的加速器 |
| 浏览器推理 | WebGPU(文件07) | 浏览器中唯一选项 |
| 仅CPU推理 | ONNX Runtime + AVX/NEON | 无需GPU,使用SIMD(文件02-03 |
| 新型硬件 | 供应商专用SDK | 每个加速器有自己的工具链 |
## 编程任务(使用带GPU运行时的CoLab)
1. 编写并运行向量加法的Triton核函数。将其性能与PyTorch内置加法比较。
```python
import triton
import triton.language as tl
import torch
import time
@triton.jit
def add_kernel(x_ptr, y_ptr, out_ptr, n, BLOCK: tl.constexpr):
pid = tl.program_id(0)
offs = pid * BLOCK + tl.arange(0, BLOCK)
mask = offs < n
x = tl.load(x_ptr + offs, mask=mask)
y = tl.load(y_ptr + offs, mask=mask)
tl.store(out_ptr + offs, x + y, mask=mask)
n = 10_000_000
x = torch.randn(n, device='cuda')
y = torch.randn(n, device='cuda')
# Triton
out_triton = torch.empty_like(x)
grid = lambda meta: (triton.cdiv(n, meta['BLOCK']),)
add_kernel[grid](x, y, out_triton, n, BLOCK=1024)
# PyTorch
out_torch = x + y
# 验证正确性
assert torch.allclose(out_triton, out_torch, atol=1e-5)
# 基准测试
torch.cuda.synchronize()
start = time.time()
for _ in range(1000):
add_kernel[grid](x, y, out_triton, n, BLOCK=1024)
torch.cuda.synchronize()
triton_time = (time.time() - start) / 1000
start = time.time()
for _ in range(1000):
out_torch = x + y
torch.cuda.synchronize()
torch_time = (time.time() - start) / 1000
print(f"Triton: {triton_time*1000:.3f} ms")
print(f"PyTorch: {torch_time*1000:.3f} ms")
print(f"比率: {torch_time/triton_time:.2f}x")
```
2. 编写一个Triton融合核函数,在单次遍历中执行乘法+加法+ReLU。与三个独立的PyTorch操作比较。
```python
import triton
import triton.language as tl
import torch
import time
@triton.jit
def fused_mul_add_relu_kernel(x_ptr, w_ptr, b_ptr, out_ptr, n, BLOCK: tl.constexpr):
pid = tl.program_id(0)
offs = pid * BLOCK + tl.arange(0, BLOCK)
mask = offs < n
x = tl.load(x_ptr + offs, mask=mask)
w = tl.load(w_ptr + offs, mask=mask)
b = tl.load(b_ptr + offs, mask=mask)
result = tl.maximum(x * w + b, 0.0) # 融合:乘法 + 加法 + relu
tl.store(out_ptr + offs, result, mask=mask)
n = 10_000_000
x = torch.randn(n, device='cuda')
w = torch.randn(n, device='cuda')
b = torch.randn(n, device='cuda')
# 融合(Triton
out_fused = torch.empty_like(x)
grid = lambda meta: (triton.cdiv(n, meta['BLOCK']),)
fused_mul_add_relu_kernel[grid](x, w, b, out_fused, n, BLOCK=1024)
# 未融合(PyTorch
out_unfused = torch.relu(x * w + b)
assert torch.allclose(out_fused, out_unfused, atol=1e-5)
# 基准测试
torch.cuda.synchronize()
start = time.time()
for _ in range(1000):
fused_mul_add_relu_kernel[grid](x, w, b, out_fused, n, BLOCK=1024)
torch.cuda.synchronize()
fused_time = (time.time() - start) / 1000
start = time.time()
for _ in range(1000):
out_unfused = torch.relu(x * w + b)
torch.cuda.synchronize()
unfused_time = (time.time() - start) / 1000
print(f"融合(Triton: {fused_time*1000:.3f} ms")
print(f"未融合(PyTorch: {unfused_time*1000:.3f} ms")
print(f"加速比: {unfused_time/fused_time:.2f}x")
```
3. 测量JAX的XLA编译器如何自动融合操作。比较带和不带jit的操作链。
```python
import jax
import jax.numpy as jnp
import time
def chain_ops(x):
x = x * 2.0
x = x + 1.0
x = jnp.maximum(x, 0.0) # ReLU
x = x / jnp.sum(x)
return x
chain_jit = jax.jit(chain_ops)
x = jax.random.normal(jax.random.PRNGKey(0), (10000, 1000))
# 预热
_ = chain_jit(x)
jax.block_until_ready(_)
# 即时模式(每个操作是独立核函数启动)
start = time.time()
for _ in range(100):
y = chain_ops(x)
jax.block_until_ready(y)
eager_time = (time.time() - start) / 100
# JITXLA融合操作)
start = time.time()
for _ in range(100):
y = chain_jit(x)
jax.block_until_ready(y)
jit_time = (time.time() - start) / 100
print(f"即时: {eager_time*1000:.2f} ms")
print(f"JIT: {jit_time*1000:.2f} ms")
print(f"加速比: {eager_time/jit_time:.1f}xXLA将4个操作融合为1个核函数)")
```