Files
maths-cs-ai-compendium-zh/chapter 16: SIMD and GPU programming/05. triton, TPUs and pallax.md
T
flykhan 2536c937e3 feat: 完整中文翻译 maths-cs-ai-compendium(数学·计算机科学·AI 知识大全)
翻译自英文原版 maths-cs-ai-compendium,共 20 章全部完成。

第01章 向量 | 第02章 矩阵 | 第03章 微积分
第04章 统计学 | 第05章 概率论 | 第06章 机器学习
第07章 计算语言学 | 第08章 计算机视觉 | 第09章 音频与语音
第10章 多模态学习 | 第11章 自主系统 | 第12章 图神经网络
第13章 计算与操作系统 | 第14章 数据结构与算法
第15章 生产级软件工程 | 第16章 SIMD与GPU编程
第17章 AI推理 | 第18章 ML系统设计
第19章 应用人工智能 | 第20章 前沿人工智能

翻译说明:
- 所有数学公式 $...$ / $$...$$、代码块、图片引用完整保留
- mkdocs.yml 配置中文导航 + language: zh
- README.md 已翻译为中文(兼 docs/index.md)
- docs/ 目录包含指向各章文件的 symlink
- 约 29,000 行中文内容,排除 .cache/ 构建缓存
2026-05-03 10:23:20 +08:00

15 KiB
Raw Blame History

Triton与TPU

CUDA C功能强大但冗长。Triton让你用Python编写GPU核函数。TPU提供了GPU之外的选择,具有不同的权衡。本文涵盖Triton核函数编程、以Flash Attention为案例研究、TPU架构与JAX/Pallas,以及如何选择合适的工具。关于Vulkan和跨平台GPU计算,请参见文件07。

  • 上篇文件教授了CUDA C中的GPU编程。本文更上一层抽象阶梯:Triton以20%的工作量提供CUDA 80%的性能,且用Python。TPU和Vulkan为特定用例提供替代硬件目标。

Triton:用Python编写GPU核函数

  • TritonOpenAI)是一种基于Python的GPU核函数编写语言。你不需要思考单个线程(CUDA),而是思考级数据。Triton的编译器自动处理线程映射、内存合并、共享内存管理和许多优化。

  • 为什么Triton重要:CUDA C需要对线程束调度、共享内存存储体冲突、寄存器压力和合并模式有深入理解。Triton抽象了其中大部分内容,使GPU核函数开发对了解Python但不了解系统编程的ML研究人员可及。

你的第一个Triton核函数

import triton
import triton.language as tl
import torch

@triton.jit
def add_kernel(
    x_ptr, y_ptr, output_ptr,
    n_elements,
    BLOCK_SIZE: tl.constexpr,  # 编译时常量
):
    # 每个程序实例处理一个BLOCK_SIZE元素的块
    pid = tl.program_id(axis=0)  # 我是哪个块?
    block_start = pid * BLOCK_SIZE

    # 此块的偏移量
    offsets = block_start + tl.arange(0, BLOCK_SIZE)

    # 掩码处理n_elements不是BLOCK_SIZE倍数的情况
    mask = offsets < n_elements

    # 加载数据(带掩码:越界读取返回0
    x = tl.load(x_ptr + offsets, mask=mask)
    y = tl.load(y_ptr + offsets, mask=mask)

    # 计算
    output = x + y

    # 存储结果
    tl.store(output_ptr + offsets, output, mask=mask)


def add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
    output = torch.empty_like(x)
    n_elements = output.numel()

    # 启动:每个块一个程序
    grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
    add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024)

    return output


# 使用
x = torch.randn(1000000, device='cuda')
y = torch.randn(1000000, device='cuda')
z = add(x, y)
  • 与CUDA的关键区别
    • 无需显式线程管理。你思考(程序),而非线程。
    • tl.arange(0, BLOCK_SIZE) 为整个块创建一个偏移向量。此向量上的所有操作都隐式向量化。
    • mask 处理边界条件(类似于AVX-512掩码寄存器,文件03)。无需标量清理循环。
    • tl.loadtl.store 自动处理合并访问。
    • @triton.jit 在首次调用时将函数编译为PTX(GPU汇编),然后缓存编译后的核函数。

Triton Softmax核函数

  • Softmax是一个很好的Triton示例,因为它需要对数据进行多次遍历(最大值、减去、指数、求和、除法),并且受益于在多次遍历之间将数据保留在SRAM(共享内存)中:
@triton.jit
def softmax_kernel(
    output_ptr, input_ptr, input_row_stride, output_row_stride, n_cols,
    BLOCK_SIZE: tl.constexpr,
):
    # 每个程序处理一行
    row_idx = tl.program_id(0)
    row_start = input_ptr + row_idx * input_row_stride

    # 加载该行
    col_offsets = tl.arange(0, BLOCK_SIZE)
    mask = col_offsets < n_cols
    row = tl.load(row_start + col_offsets, mask=mask, other=-float('inf'))

    # Softmax:为数值稳定性取最大值,然后exp,然后归一化
    row_max = tl.max(row, axis=0)
    numerator = tl.exp(row - row_max)
    denominator = tl.sum(numerator, axis=0)
    softmax_output = numerator / denominator

    # 存储结果
    output_start = output_ptr + row_idx * output_row_stride
    tl.store(output_start + col_offsets, softmax_output, mask=mask)
  • 在PyTorch中,F.softmax(x, dim=-1) 启动3个独立核函数(最大值、指数-求和、除法),每个都从全局内存读取和写入。Triton版本在一个核函数内完成所有操作,将数据保留在寄存器/SRAM中。这种核函数融合就是自定义Triton核函数可以比PyTorch内置操作快2-4倍的原因。

Triton自动调优

  • Triton支持自动调优:尝试多种配置并选择最快的:
@triton.autotune(
    configs=[
        triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}),
        triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32}),
        triton.Config({'BLOCK_SIZE_M': 256, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64}),
    ],
    key=['M', 'N', 'K'],  # 当这些变化时重新调优
)
@triton.jit
def matmul_kernel(a_ptr, b_ptr, c_ptr, M, N, K, ...):
    ...
  • Triton在实际硬件上对每种配置进行基准测试并选择最快者。最优瓦片大小取决于GPU架构、矩阵维度和内存布局——自动调优无需手动实验即可找到它们。

Triton vs CUDA:何时使用

Triton CUDA C
语言 Python C/C++
抽象层级 块级 线程级
开发速度 快(每核函数10-50行) 慢(100-500行)
性能天花板 手工调优CUDA的约80-95% 100%(完全硬件控制)
共享内存 自动 手动
合并 自动 手动
线程束级原语 有限 完整(shuffle、vote等)
硬件支持 仅NVIDIAAMD实验性) 仅NVIDIA
  • 使用Triton对于:融合核函数、自定义注意力模式、激活函数、大多数ML研究核函数需求。
  • 使用CUDA C对于:最高性能(最后5-20%)、线程束级原语、复杂数据相关并行性、当Triton无法表达你的模式。

案例研究:Flash Attention

  • Flash Attention(Dao等人,2022)是近年来最具影响力的自定义核函数。它以 O(n) 内存而非 O(n^2) 计算注意力,使得更长的序列成为可能。

  • 问题:标准注意力计算 $\text{softmax}(QK^T / \sqrt{d}) \cdot V$。QK^T 矩阵是 $n \times n$,其中 n 是序列长度。对于 $n = 128K$,此矩阵为 128K \\times 128K \\times 4 字节 = 64 GB。它无法放入GPU内存。

  • 关键洞察:你不需要具体化完整的 n \\times n 矩阵。按瓦片计算注意力:加载一组 $Q$、一组 $K$,计算它们的部分注意力得分,累加,然后移动到下一个块。n \\times n 矩阵从未完全具体化——每次只有一块存在于SRAM中。

  • 在线softmax:棘手的部分是softmax,它需要知道整个行上的最大值(为数值稳定性)。Flash Attention使用在线softmax技巧:维护一个运行中的最大值,当发现新的最大值时重新缩放先前计算的值。这允许softmax以增量方式逐块计算。

  • 算法:

对于每个Q行块:
    对于每个K列块:
        1. 将Q_block从HBM加载到SRAM
        2. 将K_block从HBM加载到SRAM
        3. 计算S_block = Q_block @ K_block.T(在SRAM中)
        4. 更新运行中最大值,重新缩放先前结果
        5. 计算exp(S_block - 运行中最大值)
        6. 更新运行中求和和输出累加器
    加载V_block并计算最终输出
    将输出块写回HBM
  • 为什么它快:内循环完全在SRAM(共享内存)中操作。全局内存(HBM)仅用于加载Q、K、V块和写入最终输出。数据重用因子与SRAM大小成正比,而SRAM比HBM快约100倍。

  • Flash Attention在Triton和CUDA C中都有实现。CUDA版本更快(效率高约10%),但Triton版本更具可读性和可修改性,这对研究新的注意力变体很重要。

TPU架构

  • TPU(张量处理单元)是Google的自定义ML加速器。它们采用与GPU截然不同的方法:

  • 脉动阵列TPU的核心计算单元是矩阵乘法单元(MXU,一个128×128或256×256的脉动阵列,通过让数据流经乘加单元网格来计算矩阵乘法。数据从边缘进入并通过阵列传播,每个单元执行一次乘加并将结果传递给下一个。

  • 与GPU(调度数千个独立线程)不同,脉动阵列是单一的确定性数据流。没有线程调度、没有线程束分歧、没有分支预测。这种简朴性使MXU在矩阵乘法方面极其能效高效。

  • HBM:TPU使用与GPU相同的高带宽内存。TPU v5e每芯片16 GB HBM2eTPU v5p每芯片95 GB HBM2e。

  • ICI(芯片间互连):TPU Pod用自定义高速网络连接数百个TPU。JAX原生支持跨TPU Pod的数据并行性和模型并行性(第6章)。

  • BFloat16TPU是首个使用bfloat16的(第13章文件02)。BF16具有与float32相同的指数范围(防止训练期间溢出),尾数精度较低。这种权衡对ML是理想的,其中梯度值范围广但不需要23位精度。

编程TPUJAX与Pallas

  • TPU通过JAXXLA编程。你编写Python/JAX代码,jax.jit 将其编译为XLA HLO,XLA将HLO编译为TPU特定的指令。无需CUDA,无需C++。
import jax
import jax.numpy as jnp

@jax.jit
def matmul(a, b):
    return jnp.dot(a, b)

# 这将根据设备在CPU、GPU或TPU上运行
a = jnp.ones((1024, 1024))
b = jnp.ones((1024, 1024))
c = matmul(a, b)
  • Pallas是JAX的核函数编写API——JAX版的Triton。它让你编写低级核函数,XLA将其编译为GPU或TPU:
from jax.experimental import pallas as pl
import jax.numpy as jnp

def add_kernel(x_ref, y_ref, o_ref):
    o_ref[...] = x_ref[...] + y_ref[...]

def add_pallas(x, y):
    return pl.pallas_call(
        add_kernel,
        out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype),
        grid=(x.shape[0] // 128,),
        in_specs=[pl.BlockSpec((128,), lambda i: (i,)),
                  pl.BlockSpec((128,), lambda i: (i,))],
        out_specs=pl.BlockSpec((128,), lambda i: (i,)),
    )(x, y)
  • Pallas比Triton更新且不太成熟,但它是为TPU编写自定义核函数的唯一方式(因为TPU不支持CUDA)。

GPU vs TPU

GPUNVIDIA TPUGoogle
可用性 任何云、本地部署 仅Google Cloud
编程 CUDA C、Triton、PyTorch JAX/XLA、Pallas
灵活性 通用计算 针对矩阵密集型ML优化
峰值矩阵乘法FLOPS 非常高(张量核心) 非常高(MXU
非矩阵乘法操作 较慢(通过向量单元路由,而非MXU
多芯片扩展 NVLink8个GPU)、InfiniBand ICI(数千个TPU,更紧密集成)
成本效率 有竞争力 大规模训练通常更便宜
生态系统 最大(PyTorch、TensorFlow、JAX 面向JAX
  • 使用GPU对于:大多数ML工作负载、基于PyTorch的研究、推理服务、有大量非矩阵乘法计算的工作负载。
  • 使用TPU对于:大规模JAX训练(数千芯片)、Google Cloud上的成本敏感训练、以矩阵乘法为主的工作负载。

选择合适的工具

工作负载 最佳工具 为什么
ML训练(PyTorch NVIDIA GPU + CUDA/Triton 最大生态系统、最佳工具链
ML训练(JAX,大规模) TPU或NVIDIA GPU TPU在Google规模下成本低,GPU灵活
自定义融合核函数 TritonPython)或CUDA C Triton开发速度快,CUDA峰值性能高
JAX自定义核函数 Pallas TPU唯一选项,也可在GPU上工作
跨平台推理 Vulkan(文件07)或ONNX Runtime 运行在任何GPU供应商上
移动/边缘推理 MetalApple)、VulkanAndroid)、NNAPI 平台特定的加速器
浏览器推理 WebGPU(文件07 浏览器中唯一选项
仅CPU推理 ONNX Runtime + AVX/NEON 无需GPU,使用SIMD(文件02-03
新型硬件 供应商专用SDK 每个加速器有自己的工具链

编程任务(使用带GPU运行时的CoLab)

  1. 编写并运行向量加法的Triton核函数。将其性能与PyTorch内置加法比较。
import triton
import triton.language as tl
import torch
import time

@triton.jit
def add_kernel(x_ptr, y_ptr, out_ptr, n, BLOCK: tl.constexpr):
    pid = tl.program_id(0)
    offs = pid * BLOCK + tl.arange(0, BLOCK)
    mask = offs < n
    x = tl.load(x_ptr + offs, mask=mask)
    y = tl.load(y_ptr + offs, mask=mask)
    tl.store(out_ptr + offs, x + y, mask=mask)

n = 10_000_000
x = torch.randn(n, device='cuda')
y = torch.randn(n, device='cuda')

# Triton
out_triton = torch.empty_like(x)
grid = lambda meta: (triton.cdiv(n, meta['BLOCK']),)
add_kernel[grid](x, y, out_triton, n, BLOCK=1024)

# PyTorch
out_torch = x + y

# 验证正确性
assert torch.allclose(out_triton, out_torch, atol=1e-5)

# 基准测试
torch.cuda.synchronize()
start = time.time()
for _ in range(1000):
    add_kernel[grid](x, y, out_triton, n, BLOCK=1024)
torch.cuda.synchronize()
triton_time = (time.time() - start) / 1000

start = time.time()
for _ in range(1000):
    out_torch = x + y
torch.cuda.synchronize()
torch_time = (time.time() - start) / 1000

print(f"Triton:  {triton_time*1000:.3f} ms")
print(f"PyTorch: {torch_time*1000:.3f} ms")
print(f"比率:   {torch_time/triton_time:.2f}x")
  1. 编写一个Triton融合核函数,在单次遍历中执行乘法+加法+ReLU。与三个独立的PyTorch操作比较。
import triton
import triton.language as tl
import torch
import time

@triton.jit
def fused_mul_add_relu_kernel(x_ptr, w_ptr, b_ptr, out_ptr, n, BLOCK: tl.constexpr):
    pid = tl.program_id(0)
    offs = pid * BLOCK + tl.arange(0, BLOCK)
    mask = offs < n
    x = tl.load(x_ptr + offs, mask=mask)
    w = tl.load(w_ptr + offs, mask=mask)
    b = tl.load(b_ptr + offs, mask=mask)
    result = tl.maximum(x * w + b, 0.0)  # 融合:乘法 + 加法 + relu
    tl.store(out_ptr + offs, result, mask=mask)

n = 10_000_000
x = torch.randn(n, device='cuda')
w = torch.randn(n, device='cuda')
b = torch.randn(n, device='cuda')

# 融合(Triton
out_fused = torch.empty_like(x)
grid = lambda meta: (triton.cdiv(n, meta['BLOCK']),)
fused_mul_add_relu_kernel[grid](x, w, b, out_fused, n, BLOCK=1024)

# 未融合(PyTorch
out_unfused = torch.relu(x * w + b)

assert torch.allclose(out_fused, out_unfused, atol=1e-5)

# 基准测试
torch.cuda.synchronize()
start = time.time()
for _ in range(1000):
    fused_mul_add_relu_kernel[grid](x, w, b, out_fused, n, BLOCK=1024)
torch.cuda.synchronize()
fused_time = (time.time() - start) / 1000

start = time.time()
for _ in range(1000):
    out_unfused = torch.relu(x * w + b)
torch.cuda.synchronize()
unfused_time = (time.time() - start) / 1000

print(f"融合(Triton:    {fused_time*1000:.3f} ms")
print(f"未融合(PyTorch: {unfused_time*1000:.3f} ms")
print(f"加速比:           {unfused_time/fused_time:.2f}x")
  1. 测量JAX的XLA编译器如何自动融合操作。比较带和不带jit的操作链。
import jax
import jax.numpy as jnp
import time

def chain_ops(x):
    x = x * 2.0
    x = x + 1.0
    x = jnp.maximum(x, 0.0)  # ReLU
    x = x / jnp.sum(x)
    return x

chain_jit = jax.jit(chain_ops)
x = jax.random.normal(jax.random.PRNGKey(0), (10000, 1000))

# 预热
_ = chain_jit(x)
jax.block_until_ready(_)

# 即时模式(每个操作是独立核函数启动)
start = time.time()
for _ in range(100):
    y = chain_ops(x)
jax.block_until_ready(y)
eager_time = (time.time() - start) / 100

# JITXLA融合操作)
start = time.time()
for _ in range(100):
    y = chain_jit(x)
jax.block_until_ready(y)
jit_time = (time.time() - start) / 100

print(f"即时: {eager_time*1000:.2f} ms")
print(f"JIT:   {jit_time*1000:.2f} ms")
print(f"加速比: {eager_time/jit_time:.1f}x(XLA将4个操作融合为1个核函数)")