# Triton与TPU *CUDA C功能强大但冗长。Triton让你用Python编写GPU核函数。TPU提供了GPU之外的选择,具有不同的权衡。本文涵盖Triton核函数编程、以Flash Attention为案例研究、TPU架构与JAX/Pallas,以及如何选择合适的工具。关于Vulkan和跨平台GPU计算,请参见文件07。* - 上篇文件教授了CUDA C中的GPU编程。本文更上一层抽象阶梯:Triton以20%的工作量提供CUDA 80%的性能,且用Python。TPU和Vulkan为特定用例提供替代硬件目标。 ## Triton:用Python编写GPU核函数 - **Triton**(OpenAI)是一种基于Python的GPU核函数编写语言。你不需要思考单个线程(CUDA),而是思考**块**级数据。Triton的编译器自动处理线程映射、内存合并、共享内存管理和许多优化。 - **为什么Triton重要**:CUDA C需要对线程束调度、共享内存存储体冲突、寄存器压力和合并模式有深入理解。Triton抽象了其中大部分内容,使GPU核函数开发对了解Python但不了解系统编程的ML研究人员可及。 ### 你的第一个Triton核函数 ```python import triton import triton.language as tl import torch @triton.jit def add_kernel( x_ptr, y_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr, # 编译时常量 ): # 每个程序实例处理一个BLOCK_SIZE元素的块 pid = tl.program_id(axis=0) # 我是哪个块? block_start = pid * BLOCK_SIZE # 此块的偏移量 offsets = block_start + tl.arange(0, BLOCK_SIZE) # 掩码处理n_elements不是BLOCK_SIZE倍数的情况 mask = offsets < n_elements # 加载数据(带掩码:越界读取返回0) x = tl.load(x_ptr + offsets, mask=mask) y = tl.load(y_ptr + offsets, mask=mask) # 计算 output = x + y # 存储结果 tl.store(output_ptr + offsets, output, mask=mask) def add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: output = torch.empty_like(x) n_elements = output.numel() # 启动:每个块一个程序 grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),) add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024) return output # 使用 x = torch.randn(1000000, device='cuda') y = torch.randn(1000000, device='cuda') z = add(x, y) ``` - **与CUDA的关键区别**: - 无需显式线程管理。你思考**块**(程序),而非线程。 - `tl.arange(0, BLOCK_SIZE)` 为整个块创建一个偏移向量。此向量上的所有操作都隐式向量化。 - `mask` 处理边界条件(类似于AVX-512掩码寄存器,文件03)。无需标量清理循环。 - `tl.load` 和 `tl.store` 自动处理合并访问。 - `@triton.jit` 在首次调用时将函数编译为PTX(GPU汇编),然后缓存编译后的核函数。 ### Triton Softmax核函数 - Softmax是一个很好的Triton示例,因为它需要对数据进行多次遍历(最大值、减去、指数、求和、除法),并且受益于在多次遍历之间将数据保留在SRAM(共享内存)中: ```python @triton.jit def softmax_kernel( output_ptr, input_ptr, input_row_stride, output_row_stride, n_cols, BLOCK_SIZE: tl.constexpr, ): # 每个程序处理一行 row_idx = tl.program_id(0) row_start = input_ptr + row_idx * input_row_stride # 加载该行 col_offsets = tl.arange(0, BLOCK_SIZE) mask = col_offsets < n_cols row = tl.load(row_start + col_offsets, mask=mask, other=-float('inf')) # Softmax:为数值稳定性取最大值,然后exp,然后归一化 row_max = tl.max(row, axis=0) numerator = tl.exp(row - row_max) denominator = tl.sum(numerator, axis=0) softmax_output = numerator / denominator # 存储结果 output_start = output_ptr + row_idx * output_row_stride tl.store(output_start + col_offsets, softmax_output, mask=mask) ``` - 在PyTorch中,`F.softmax(x, dim=-1)` 启动3个独立核函数(最大值、指数-求和、除法),每个都从全局内存读取和写入。Triton版本在一个核函数内完成所有操作,将数据保留在寄存器/SRAM中。这种**核函数融合**就是自定义Triton核函数可以比PyTorch内置操作快2-4倍的原因。 ### Triton自动调优 - Triton支持**自动调优**:尝试多种配置并选择最快的: ```python @triton.autotune( configs=[ triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}), triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32}), triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64}), ], key=['M', 'N', 'K'], # 当这些变化时重新调优 ) @triton.jit def matmul_kernel(a_ptr, b_ptr, c_ptr, M, N, K, ...): ... ``` - Triton在实际硬件上对每种配置进行基准测试并选择最快者。最优瓦片大小取决于GPU架构、矩阵维度和内存布局——自动调优无需手动实验即可找到它们。 ### Triton vs CUDA:何时使用 | | Triton | CUDA C | |--|--------|--------| | 语言 | Python | C/C++ | | 抽象层级 | 块级 | 线程级 | | 开发速度 | 快(每核函数10-50行) | 慢(100-500行) | | 性能天花板 | 手工调优CUDA的约80-95% | 100%(完全硬件控制) | | 共享内存 | 自动 | 手动 | | 合并 | 自动 | 手动 | | 线程束级原语 | 有限 | 完整(shuffle、vote等) | | 硬件支持 | 仅NVIDIA(AMD实验性) | 仅NVIDIA | - **使用Triton**对于:融合核函数、自定义注意力模式、激活函数、大多数ML研究核函数需求。 - **使用CUDA C**对于:最高性能(最后5-20%)、线程束级原语、复杂数据相关并行性、当Triton无法表达你的模式。 ## 案例研究:Flash Attention - **Flash Attention**(Dao等人,2022)是近年来最具影响力的自定义核函数。它以 $O(n)$ 内存而非 $O(n^2)$ 计算注意力,使得更长的序列成为可能。 - **问题**:标准注意力计算 $\\text{softmax}(QK^T / \\sqrt{d}) \\cdot V$。$QK^T$ 矩阵是 $n \\times n$,其中 $n$ 是序列长度。对于 $n = 128K$,此矩阵为 $128K \\times 128K \\times 4$ 字节 = 64 GB。它无法放入GPU内存。 - **关键洞察**:你不需要具体化完整的 $n \\times n$ 矩阵。按**瓦片**计算注意力:加载一组 $Q$、一组 $K$,计算它们的部分注意力得分,累加,然后移动到下一个块。$n \\times n$ 矩阵从未完全具体化——每次只有一块存在于SRAM中。 - **在线softmax**:棘手的部分是softmax,它需要知道整个行上的最大值(为数值稳定性)。Flash Attention使用**在线softmax**技巧:维护一个运行中的最大值,当发现新的最大值时重新缩放先前计算的值。这允许softmax以增量方式逐块计算。 - 算法: ``` 对于每个Q行块: 对于每个K列块: 1. 将Q_block从HBM加载到SRAM 2. 将K_block从HBM加载到SRAM 3. 计算S_block = Q_block @ K_block.T(在SRAM中) 4. 更新运行中最大值,重新缩放先前结果 5. 计算exp(S_block - 运行中最大值) 6. 更新运行中求和和输出累加器 加载V_block并计算最终输出 将输出块写回HBM ``` - **为什么它快**:内循环完全在SRAM(共享内存)中操作。全局内存(HBM)仅用于加载Q、K、V块和写入最终输出。数据重用因子与SRAM大小成正比,而SRAM比HBM快约100倍。 - Flash Attention在Triton和CUDA C中都有实现。CUDA版本更快(效率高约10%),但Triton版本更具可读性和可修改性,这对研究新的注意力变体很重要。 ## TPU架构 - **TPU**(张量处理单元)是Google的自定义ML加速器。它们采用与GPU截然不同的方法: - **脉动阵列**:TPU的核心计算单元是**矩阵乘法单元(MXU)**,一个128×128或256×256的脉动阵列,通过让数据流经乘加单元网格来计算矩阵乘法。数据从边缘进入并通过阵列传播,每个单元执行一次乘加并将结果传递给下一个。 - 与GPU(调度数千个独立线程)不同,脉动阵列是单一的确定性数据流。没有线程调度、没有线程束分歧、没有分支预测。这种简朴性使MXU在矩阵乘法方面极其能效高效。 - **HBM**:TPU使用与GPU相同的高带宽内存。TPU v5e每芯片16 GB HBM2e;TPU v5p每芯片95 GB HBM2e。 - **ICI**(芯片间互连):TPU Pod用自定义高速网络连接数百个TPU。JAX原生支持跨TPU Pod的数据并行性和模型并行性(第6章)。 - **BFloat16**:TPU是首个使用bfloat16的(第13章文件02)。BF16具有与float32相同的指数范围(防止训练期间溢出),尾数精度较低。这种权衡对ML是理想的,其中梯度值范围广但不需要23位精度。 ### 编程TPU:JAX与Pallas - TPU通过**JAX**和**XLA**编程。你编写Python/JAX代码,`jax.jit` 将其编译为XLA HLO,XLA将HLO编译为TPU特定的指令。无需CUDA,无需C++。 ```python import jax import jax.numpy as jnp @jax.jit def matmul(a, b): return jnp.dot(a, b) # 这将根据设备在CPU、GPU或TPU上运行 a = jnp.ones((1024, 1024)) b = jnp.ones((1024, 1024)) c = matmul(a, b) ``` - **Pallas**是JAX的核函数编写API——JAX版的Triton。它让你编写低级核函数,XLA将其编译为GPU或TPU: ```python from jax.experimental import pallas as pl import jax.numpy as jnp def add_kernel(x_ref, y_ref, o_ref): o_ref[...] = x_ref[...] + y_ref[...] def add_pallas(x, y): return pl.pallas_call( add_kernel, out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype), grid=(x.shape[0] // 128,), in_specs=[pl.BlockSpec((128,), lambda i: (i,)), pl.BlockSpec((128,), lambda i: (i,))], out_specs=pl.BlockSpec((128,), lambda i: (i,)), )(x, y) ``` - Pallas比Triton更新且不太成熟,但它是为TPU编写自定义核函数的唯一方式(因为TPU不支持CUDA)。 ### GPU vs TPU | | GPU(NVIDIA) | TPU(Google) | |--|-------------|--------------| | 可用性 | 任何云、本地部署 | 仅Google Cloud | | 编程 | CUDA C、Triton、PyTorch | JAX/XLA、Pallas | | 灵活性 | 通用计算 | 针对矩阵密集型ML优化 | | 峰值矩阵乘法FLOPS | 非常高(张量核心) | 非常高(MXU) | | 非矩阵乘法操作 | 好 | 较慢(通过向量单元路由,而非MXU) | | 多芯片扩展 | NVLink(8个GPU)、InfiniBand | ICI(数千个TPU,更紧密集成) | | 成本效率 | 有竞争力 | 大规模训练通常更便宜 | | 生态系统 | 最大(PyTorch、TensorFlow、JAX) | 面向JAX | - **使用GPU**对于:大多数ML工作负载、基于PyTorch的研究、推理服务、有大量非矩阵乘法计算的工作负载。 - **使用TPU**对于:大规模JAX训练(数千芯片)、Google Cloud上的成本敏感训练、以矩阵乘法为主的工作负载。 ## 选择合适的工具 | 工作负载 | 最佳工具 | 为什么 | |----------|---------|-------| | ML训练(PyTorch) | NVIDIA GPU + CUDA/Triton | 最大生态系统、最佳工具链 | | ML训练(JAX,大规模) | TPU或NVIDIA GPU | TPU在Google规模下成本低,GPU灵活 | | 自定义融合核函数 | Triton(Python)或CUDA C | Triton开发速度快,CUDA峰值性能高 | | JAX自定义核函数 | Pallas | TPU唯一选项,也可在GPU上工作 | | 跨平台推理 | Vulkan(文件07)或ONNX Runtime | 运行在任何GPU供应商上 | | 移动/边缘推理 | Metal(Apple)、Vulkan(Android)、NNAPI | 平台特定的加速器 | | 浏览器推理 | WebGPU(文件07) | 浏览器中唯一选项 | | 仅CPU推理 | ONNX Runtime + AVX/NEON | 无需GPU,使用SIMD(文件02-03) | | 新型硬件 | 供应商专用SDK | 每个加速器有自己的工具链 | ## 编程任务(使用带GPU运行时的CoLab) 1. 编写并运行向量加法的Triton核函数。将其性能与PyTorch内置加法比较。 ```python import triton import triton.language as tl import torch import time @triton.jit def add_kernel(x_ptr, y_ptr, out_ptr, n, BLOCK: tl.constexpr): pid = tl.program_id(0) offs = pid * BLOCK + tl.arange(0, BLOCK) mask = offs < n x = tl.load(x_ptr + offs, mask=mask) y = tl.load(y_ptr + offs, mask=mask) tl.store(out_ptr + offs, x + y, mask=mask) n = 10_000_000 x = torch.randn(n, device='cuda') y = torch.randn(n, device='cuda') # Triton out_triton = torch.empty_like(x) grid = lambda meta: (triton.cdiv(n, meta['BLOCK']),) add_kernel[grid](x, y, out_triton, n, BLOCK=1024) # PyTorch out_torch = x + y # 验证正确性 assert torch.allclose(out_triton, out_torch, atol=1e-5) # 基准测试 torch.cuda.synchronize() start = time.time() for _ in range(1000): add_kernel[grid](x, y, out_triton, n, BLOCK=1024) torch.cuda.synchronize() triton_time = (time.time() - start) / 1000 start = time.time() for _ in range(1000): out_torch = x + y torch.cuda.synchronize() torch_time = (time.time() - start) / 1000 print(f"Triton: {triton_time*1000:.3f} ms") print(f"PyTorch: {torch_time*1000:.3f} ms") print(f"比率: {torch_time/triton_time:.2f}x") ``` 2. 编写一个Triton融合核函数,在单次遍历中执行乘法+加法+ReLU。与三个独立的PyTorch操作比较。 ```python import triton import triton.language as tl import torch import time @triton.jit def fused_mul_add_relu_kernel(x_ptr, w_ptr, b_ptr, out_ptr, n, BLOCK: tl.constexpr): pid = tl.program_id(0) offs = pid * BLOCK + tl.arange(0, BLOCK) mask = offs < n x = tl.load(x_ptr + offs, mask=mask) w = tl.load(w_ptr + offs, mask=mask) b = tl.load(b_ptr + offs, mask=mask) result = tl.maximum(x * w + b, 0.0) # 融合:乘法 + 加法 + relu tl.store(out_ptr + offs, result, mask=mask) n = 10_000_000 x = torch.randn(n, device='cuda') w = torch.randn(n, device='cuda') b = torch.randn(n, device='cuda') # 融合(Triton) out_fused = torch.empty_like(x) grid = lambda meta: (triton.cdiv(n, meta['BLOCK']),) fused_mul_add_relu_kernel[grid](x, w, b, out_fused, n, BLOCK=1024) # 未融合(PyTorch) out_unfused = torch.relu(x * w + b) assert torch.allclose(out_fused, out_unfused, atol=1e-5) # 基准测试 torch.cuda.synchronize() start = time.time() for _ in range(1000): fused_mul_add_relu_kernel[grid](x, w, b, out_fused, n, BLOCK=1024) torch.cuda.synchronize() fused_time = (time.time() - start) / 1000 start = time.time() for _ in range(1000): out_unfused = torch.relu(x * w + b) torch.cuda.synchronize() unfused_time = (time.time() - start) / 1000 print(f"融合(Triton): {fused_time*1000:.3f} ms") print(f"未融合(PyTorch): {unfused_time*1000:.3f} ms") print(f"加速比: {unfused_time/fused_time:.2f}x") ``` 3. 测量JAX的XLA编译器如何自动融合操作。比较带和不带jit的操作链。 ```python import jax import jax.numpy as jnp import time def chain_ops(x): x = x * 2.0 x = x + 1.0 x = jnp.maximum(x, 0.0) # ReLU x = x / jnp.sum(x) return x chain_jit = jax.jit(chain_ops) x = jax.random.normal(jax.random.PRNGKey(0), (10000, 1000)) # 预热 _ = chain_jit(x) jax.block_until_ready(_) # 即时模式(每个操作是独立核函数启动) start = time.time() for _ in range(100): y = chain_ops(x) jax.block_until_ready(y) eager_time = (time.time() - start) / 100 # JIT(XLA融合操作) start = time.time() for _ in range(100): y = chain_jit(x) jax.block_until_ready(y) jit_time = (time.time() - start) / 100 print(f"即时: {eager_time*1000:.2f} ms") print(f"JIT: {jit_time*1000:.2f} ms") print(f"加速比: {eager_time/jit_time:.1f}x(XLA将4个操作融合为1个核函数)") ```