Files
maths-cs-ai-compendium-zh/chapter 08: computer vision/04. vision transformers and generation.md
T
flykhan 2536c937e3 feat: 完整中文翻译 maths-cs-ai-compendium(数学·计算机科学·AI 知识大全)
翻译自英文原版 maths-cs-ai-compendium,共 20 章全部完成。

第01章 向量 | 第02章 矩阵 | 第03章 微积分
第04章 统计学 | 第05章 概率论 | 第06章 机器学习
第07章 计算语言学 | 第08章 计算机视觉 | 第09章 音频与语音
第10章 多模态学习 | 第11章 自主系统 | 第12章 图神经网络
第13章 计算与操作系统 | 第14章 数据结构与算法
第15章 生产级软件工程 | 第16章 SIMD与GPU编程
第17章 AI推理 | 第18章 ML系统设计
第19章 应用人工智能 | 第20章 前沿人工智能

翻译说明:
- 所有数学公式 $...$ / $$...$$、代码块、图片引用完整保留
- mkdocs.yml 配置中文导航 + language: zh
- README.md 已翻译为中文(兼 docs/index.md)
- docs/ 目录包含指向各章文件的 symlink
- 约 29,000 行中文内容,排除 .cache/ 构建缓存
2026-05-03 10:23:20 +08:00

23 KiB
Raw Blame History

视觉Transformer与生成模型

视觉Transformer将自注意力应用于图像块,通过数据驱动的空间学习挑战了CNN的主导地位。本文涵盖ViT、DeiT、Swin Transformer、基于GAN的图像生成(StyleGAN)、VAE和扩散模型(DDPM、Stable Diffusion),以及超分辨率和神经风格迁移。

  • CNN(文件02)内置了很强的空间归纳偏置:局部连接、权重共享和平移等变性。视觉Transformer(ViT)提出了一个启发性的问题:如果我们完全抛弃这些偏置,仅使用第06章中的注意力机制,让模型从数据中学习空间结构,结果会怎样?

  • ViTVision TransformerDosovitskiy等人,2021)将标准的Transformer编码器直接应用于图像。其核心思想是将图像视为一个图像块序列,就像NLP将文本视为一个词元序列一样。

  • 其处理流程如下:

    1. 将图像(高度$H$,宽度$W$,通道数$C$)分割成$P \times P$大小的不重叠图像块网格。得到$N = HW / P^2$个图像块。
    2. 将每个图像块展平成长度为$P^2 \cdot C$的向量,并通过一个可学习的线性嵌入(单个矩阵乘法,第02章)将其投影到模型维度$D$。
    3. 在前面添加一个可学习的**[CLS]标记**嵌入(类似于BERT的[CLS],第07章)。该标记会关注所有图像块,其最终表示用于分类。
    4. 添加位置嵌入(每个位置一个可学习向量)以提供空间信息,因为注意力是置换等变的。
    5. 将$(N + 1)$个标记嵌入序列通过标准的Transformer编码器(多头自注意力 + FFN,第06章)。
    6. [CLS]标记的最终表示通过一个分类头(小型MLP)进行分类。

ViT流程:将图像分割为16x16图像块,每个块展平并线性投影,添加[CLS]标记,加上位置嵌入,然后由Transformer编码器块处理

  • 图像块嵌入等价于一个卷积核大小为$P$、步长为$P$(不重叠)的卷积操作。ViT将2D图像字面地转换为1D序列,然后用与处理语言相同的架构来处理它。

  • ViT的归纳偏置比CNN少:它不强制局部连接或平移等变性。这意味着它需要更多的训练数据才能从头学习空间结构。在小型数据集上,CNN优于ViT。但在非常大的数据集(JFT-300M,3亿张图像)上训练时,ViT达到或超过了最佳CNN的性能,这表明CNN的归纳偏置有助于数据效率,但对于最终性能并非必需。

  • ViT自注意力的复杂度为$O(N^2)$,其中N是图像块数量。对于224x224的图像和16x16的图像块,$N = 196$,这在可控范围内。但对于更高分辨率的图像或更小的图像块,二次成本变得难以承受。

  • DeiT(数据高效的图像TransformerTouvron等人,2021)表明,仅使用ImageNet(无需庞大的JFT数据集)并借助强数据增强、正则化(随机深度、标签平滑、dropout)和知识蒸馏,就可以有效训练ViT:一个预训练的CNN教师提供软标签,ViT学生学习匹配这些标签。DeiT在[CLS]标记旁边添加了一个蒸馏标记,训练用于预测教师的输出。

  • Swin Transformer(Liu等人,2021)解决了ViT的两个主要局限:随图像大小呈二次增长的计算成本,以及缺少层次化特征图(检测和分割需要层次化特征)。

  • Swin引入了移动窗口:不再对所有图像块进行全局自注意力,而是在局部窗口内(例如7x7个图像块)计算注意力。这使得计算成本与图像大小呈线性关系:$O(N)$而非$O(N^2)$。但仅靠局部窗口会阻止区域之间的信息流动。

  • 窗口移动解决了这个问题:在交替层中,窗口划分会偏移半个窗口大小。这创建了跨窗口连接,使得信息可以在所有图像部分之间流动,而无需全局注意力的成本。

Swin Transformer:第l层在常规窗口内计算注意力,第l+1层将窗口划分偏移一半,创建跨窗口连接

  • Swin还通过跨阶段合并图像块来构建层次化表示。每个阶段之后,相邻的2x2图像块被拼接并投影,使通道维度加倍、空间分辨率减半。这产生了多尺度特征图,类似于CNN和FPN(文件03)中的特征图,使得Swin可以直接兼容Faster R-CNN等检测头和U-Net等分割头。

  • PVT(金字塔视觉Transformer)采用了类似的层次化方法,具有空间缩减注意力:在每个阶段,键和值在计算注意力之前先进行空间下采样,从而在保持全局感受野的同时降低二次成本。

  • 自监督视觉学习从未标注的图像中训练表示。标注成本高,但图像资源丰富。目标是在没有任何人工标注的情况下,学习能很好地迁移到下游任务的特征。

  • 对比学习训练模型识别:同一张图像的两个增广视图("正样本对")应具有相似的表示,而不同图像的视图("负样本对")应具有不相似的表示。

  • SimCLR(Chen等人,2020)对一个批次中的每张图像创建两个增广视图,用共享主干网络+投影头对两者进行编码,并应用NT-Xent损失(归一化温度标度交叉熵):

\ell_{i,j} = -\log \frac{\exp(\text{sim}(z_i, z_j) / \tau)}{\sum_{k \neq i} \exp(\text{sim}(z_i, z_k) / \tau)}
  • 其中$\text{sim}$是余弦相似度(第01章),$\tau$是温度参数。分子将正样本对拉近;分母将负样本对推远。SimCLR需要大批量大小(4,096+)来提供足够的负样本。

  • MoCo(动量对比,He等人,2020)通过维护一个动量更新的负嵌入队列来解决大批量需求。查询编码器通过梯度下降更新;键编码器作为查询编码器的指数移动平均(EMA,第04章)进行更新:$\theta_k \leftarrow m \theta_k + (1 - m) \theta_q$,其中$m = 0.999$。队列存储最近的键嵌入,提供了大量且一致的负样本集,无需巨大的批次。

  • BYOL(自举你自己的隐空间,Grill等人,2020)完全消除了负样本对。它使用两个网络:"在线"网络和"目标"网络(在线的EMA)。在线网络预测目标网络对另一增广视图的表示。无需负样本,BYOL通过预测头的不对称性和EMA目标避免了坍塌问题(模型对所有输入输出相同向量)。

  • DINO(无标签自蒸馏,Caron等人,2021)将自蒸馏应用于ViT。学生网络预测教师网络(学生的EMA)在不同增广视图下的输出。教师使用更大的裁剪区域;学生使用更小的裁剪区域。DINO产生的特征包含关于场景布局的显式信息:DINO训练的ViT的自注意力图自然地对物体进行分割,无需任何分割监督。

  • 掩码图像建模是BERT掩码语言建模(第07章)在视觉领域的类比。输入图像块的一大部分被掩码,模型学习重建它们。

  • MAE(掩码自编码器,He等人,2022)掩码了75%的图像块,并训练一个ViT编码器-解码器来重建缺失的像素值。只有未掩码的图像块由编码器处理(在预训练期间节省4倍计算量),轻量级解码器从编码后的可见图像块加上可学习的掩码标记重建完整图像。

  • BEiT(图像Transformer的BERT预训练,Bao等人,2022)掩码图像块并预测离散的视觉标记(从预训练的dVAE分词器获得),而不是原始像素。这类似于BERT预测离散词标记,避免了像素重建的低层细节。

  • 图像生成旨在生成训练集中不存在的新颖、逼真的图像。核心挑战是对自然图像的高维概率分布进行建模。

  • 生成对抗网络(GANGoodfellow等人,2014)使用两个相互竞争的网络:一个生成器$G$从随机噪声中创建假图像,和一个判别器$D$试图区分真实图像和假图像。它们通过对抗性训练:$G$试图欺骗$D$,而$D$试图抓住$G$。

\min_G \max_D \; \mathbb{E}_{x \sim p_{\text{data}}}[\log D(x)] + \mathbb{E}_{z \sim p(z)}[\log(1 - D(G(z)))]
  • 生成器接收随机隐向量$z$(从高斯分布等简单分布中采样),通过一系列转置卷积将其映射生成图像。判别器是一个标准的CNN分类器。在均衡状态下,$G$生成的图像与真实数据无法区分,$D$对所有输入输出0.5。

  • 模式坍塌是GAN的主要失败模式:生成器学会只生成少数几种能欺骗判别器的图像,忽略了训练数据的多样性。生成器找到一小部分"安全"输出,而不是覆盖完整的数据分布。

  • 稳定GAN的训练技巧包括:谱归一化(约束判别器的Lipschitz常数)、渐进式增长(先在低分辨率训练,然后逐步提高)、特征匹配(匹配中间判别器特征的统计量而非最终输出),以及使用Wasserstein距离替代原始的JS散度目标。

  • StyleGAN(Karras等人,2019)是最具影响力的高质量图像合成GAN架构。其关键创新是基于风格的生成器:不是将隐向量$z$直接输入生成器,而是先通过一个映射网络(8层MLP)生成风格向量$w$。该风格向量通过**自适应实例归一化(AdaIN)**注入到生成器的每一层,调节特征图的统计量:

\text{AdaIN}(x, y) = y_{s} \cdot \frac{x - \mu(x)}{\sigma(x)} + y_{b}
  • 其中$y_s$和$y_b$是从$w$推导出的缩放和偏置。不同层控制不同方面:早期层控制粗粒度特征(姿态、脸型),中间层控制中粒度特征(发型、眼睛),后期层控制细粒度细节(雀斑、发质纹理)。StyleGAN能以1024x1024分辨率生成照片级逼真的人脸。

  • 变分自编码器(VAE(第06章)提供了另一种生成方法。与GAN不同,VAE有一个原则性的概率框架,具有清晰的训练目标(ELBO)。它们生成的图像通常比GAN模糊,但提供了更平滑、更结构化的隐空间。VAE是隐扩散模型中用于将图像压缩到隐空间和从隐空间重建的编码器-解码器对。

  • 扩散模型已成为图像生成的主导范式,在质量和多样性上都超越了GAN。其思想概念上很简单:逐步向数据添加噪声直到变成纯高斯噪声(前向过程),然后学习逐步逆转这一过程(反向过程)。

  • 前向过程在$T$个时间步中添加高斯噪声:

q(x_t | x_{t-1}) = \mathcal{N}(x_t; \sqrt{1 - \beta_t} \, x_{t-1}, \beta_t I)
  • 其中$\beta_t$是一个随时间递增的噪声调度。经过足够多的步骤后,无论原始图像$x_0$如何,$x_T$都近似于纯高斯噪声。利用重参数化技巧(第06章),设$\alpha_t = 1 - \beta_t$$\bar{\alpha}t = \prod{s=1}^{t} \alpha_s$,我们可以直接从$x_0$采样$x_t$:
x_t = \sqrt{\bar{\alpha}_t} \, x_0 + \sqrt{1 - \bar{\alpha}_t} \, \epsilon, \quad \epsilon \sim \mathcal{N}(0, I)
  • 反向过程学习去噪:从纯噪声$x_T$开始,模型预测每一步添加的噪声$\epsilon$并将其减去以恢复$x_{t-1}$。这由一个神经网络$\epsilon_\theta$(通常是U-Net,来自文件03)参数化,使用简单的MSE损失训练:
\mathcal{L} = \mathbb{E}_{t, x_0, \epsilon}\left[\|\epsilon - \epsilon_\theta(x_t, t)\|^2\right]

扩散前向和反向过程:干净图像在T步中逐渐被噪声破坏(前向),神经网络学习逆转每一步(反向),从纯噪声开始生成干净图像

  • DDPM(去噪扩散概率模型,Ho等人,2020)建立了这个框架。采样需要迭代所有$T$步(通常为1,000步),这很慢。DDIM(去噪扩散隐式模型,Song等人,2021)将采样过程重新表述为确定性映射,允许大跨度跳过(例如50步代替1,000步)且质量损失极小。

  • 基于分数的模型Song和Ermon,2019)提供了另一种视角。该模型不是预测噪声$\epsilon$,而是估计分数函数$\nabla_{x_t} \log p(x_t)$,即对数概率相对于含噪图像的梯度。该梯度指向数据分布中更高概率(更干净)的区域。采样使用Langevin动力学沿着该梯度进行。基于分数的模型和DDPM在**随机微分方程(SDE)**的框架下被统一:前向过程是添加噪声的SDE,反向过程是时间反转的SDE。

  • 无分类器引导Ho和Salimans,2022)控制样本质量和多样性之间的权衡。模型同时进行条件训练(使用文本提示或类别标签)和无条件训练(条件随机丢弃)。在采样时,预测是加权组合:

\hat{\epsilon} = \epsilon_\theta(x_t, \varnothing) + s \cdot (\epsilon_\theta(x_t, c) - \epsilon_\theta(x_t, \varnothing))
  • 其中$c$是条件,$\varnothing$是空条件,$s > 1$是引导尺度。$s$越高,生成的图像越符合条件,但多样性越低。$s = 1$是无引导模型;$s = 7.5$是常见的默认值。

  • 隐扩散Rombach等人,2022Stable Diffusion)将扩散过程从像素空间转移到学习的隐空间中。一个预训练的VAE编码器将图像压缩为较低维度的隐空间表示(通常空间下采样4倍或8倍),扩散在这个压缩空间中进行,VAE解码器从去噪后的隐变量重建像素。这大大提高了效率:在像素空间扩散512x512图像需要处理$512 \times 512 \times 3$的张量,但在隐空间中仅需处理$64 \times 64 \times 4$的张量。

  • 隐扩散中的去噪U-Net接收含噪隐变量、时间步(编码为正弦嵌入,类似于Transformer中的位置编码)和条件信号(来自冻结的CLIP或T5文本编码器的文本嵌入)。文本条件通过U-Net内的交叉注意力层进入:文本嵌入作为键和值,图像特征作为查询。这使得模型在每个空间位置都能关注文本提示的相关部分。

  • 流匹配是扩散模型的一个新兴替代方案,它学习噪声和数据之间的直接传输路径,而不是DDPM的迭代去噪。

  • **连续归一化流(CNF)**定义了一个时间相关的速度场$v_\theta(x, t)$,沿着平滑轨迹将样本从简单分布$p_0$(噪声)推送到数据分布$p_1$。该变换遵循一个常微分方程(ODE):

\frac{dx}{dt} = v_\theta(x, t), \quad t \in [0, 1]
  • 从$x_0 \sim \mathcal{N}(0, I)$开始,将ODE向前积分到$t = 1$即可得到数据分布中的样本。速度场由神经网络参数化,训练目标是匹配目标条件流。

  • 最优传输(OT)流匹配(Lipman等人,2023)使用噪声和数据之间的直线路径作为目标流:从噪声样本$x_0$到数据样本$x_1$的条件路径简单地是$x_t = (1 - t) x_0 + t x_1$,目标速度为$v = x_1 - x_0$。训练损失变为:

\mathcal{L} = \mathbb{E}_{t, x_0, x_1} \left[\|v_\theta(x_t, t) - (x_1 - x_0)\|^2\right]
  • 整流流(Liu等人,2022)通过迭代方式拉直学习到的流路径。在初始训练后,模型通过模拟ODE生成(噪声,数据)对。这些比随机配对更紧密对齐的对用于重新训练模型。重复此过程会产生越来越直的路径,可以通过更少的ODE步骤(甚至单步)来遍历,从而实现极快速的生成。

  • 流匹配相比扩散有几个优势:训练目标更简单(直接的速度回归,无需噪声调度),采样ODE更平滑(需要的积分步骤更少),与最优传输的联系提供了理论依据。Stable Diffusion 3和Flux使用流匹配替代了传统的DDPM。

编程练习(使用CoLab或notebook

  1. 从头实现ViT图像块嵌入。将图像分割成图像块,展平,投影到模型维度,添加位置嵌入,并前置[CLS]标记。
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt

def create_patch_embedding(image, patch_size, d_model, params):
    """将图像转换为图像块嵌入序列。"""
    H, W, C = image.shape
    n_patches_h = H // patch_size
    n_patches_w = W // patch_size
    n_patches = n_patches_h * n_patches_w

    # 提取图像块
    patches = []
    for i in range(n_patches_h):
        for j in range(n_patches_w):
            patch = image[i*patch_size:(i+1)*patch_size,
                          j*patch_size:(j+1)*patch_size, :]
            patches.append(patch.ravel())
    patches = jnp.stack(patches)  # (N, P*P*C)

    # 线性投影到d_model
    embeddings = patches @ params['proj_w'] + params['proj_b']  # (N, d_model)

    # 前置CLS标记
    cls_token = params['cls_token']  # (1, d_model)
    embeddings = jnp.concatenate([cls_token, embeddings], axis=0)  # (N+1, d_model)

    # 添加位置嵌入
    embeddings = embeddings + params['pos_embed']  # (N+1, d_model)

    return embeddings, patches

# 设置
H, W, C = 32, 32, 3
patch_size = 8
d_model = 64
n_patches = (H // patch_size) * (W // patch_size)  # 16

key = jax.random.PRNGKey(42)
keys = jax.random.split(key, 5)

# 创建具有不同象限的合成图像
image = jnp.zeros((H, W, C))
image = image.at[:16, :16, 0].set(1.0)   # 红色 左上
image = image.at[:16, 16:, 1].set(1.0)   # 绿色 右上
image = image.at[16:, :16, 2].set(1.0)   # 蓝色 左下
image = image.at[16:, 16:, :2].set(1.0)  # 黄色 右下

params = {
    'proj_w': jax.random.normal(keys[0], (patch_size**2 * C, d_model)) * 0.02,
    'proj_b': jnp.zeros(d_model),
    'cls_token': jax.random.normal(keys[1], (1, d_model)) * 0.02,
    'pos_embed': jax.random.normal(keys[2], (n_patches + 1, d_model)) * 0.02,
}

embeddings, patches = create_patch_embedding(image, patch_size, d_model, params)

print(f"图像形状: {image.shape}")
print(f"图像块大小: {patch_size}x{patch_size}")
print(f"图像块数量: {n_patches}")
print(f"图像块向量长度: {patch_size**2 * C}")
print(f"嵌入形状: {embeddings.shape}  (CLS + {n_patches} 个图像块)")

# 可视化图像块
fig, axes = plt.subplots(2, 5, figsize=(14, 6))
axes[0, 0].imshow(image); axes[0, 0].set_title('完整图像'); axes[0, 0].axis('off')
for idx in range(min(9, n_patches)):
    ax = axes[(idx+1) // 5, (idx+1) % 5]
    patch_img = patches[idx].reshape(patch_size, patch_size, C)
    ax.imshow(patch_img); ax.set_title(f'图像块 {idx}'); ax.axis('off')
plt.suptitle('ViT 图像块分解')
plt.tight_layout(); plt.show()
  1. 实现一个简单的GAN训练循环。在二维数据上训练生成器和判别器,并可视化生成分布逐渐收敛到真实分布。
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt

def generator(z, params):
    h = jnp.tanh(z @ params['g_w1'] + params['g_b1'])
    h = jnp.tanh(h @ params['g_w2'] + params['g_b2'])
    return h @ params['g_w3'] + params['g_b3']

def discriminator(x, params):
    h = jax.nn.leaky_relu(x @ params['d_w1'] + params['d_b1'], 0.2)
    h = jax.nn.leaky_relu(h @ params['d_w2'] + params['d_b2'], 0.2)
    return jax.nn.sigmoid(h @ params['d_w3'] + params['d_b3'])

def init_params(key):
    keys = jax.random.split(key, 6)
    z_dim, h_dim, data_dim = 2, 32, 2
    scale = 0.1
    return {
        'g_w1': jax.random.normal(keys[0], (z_dim, h_dim)) * scale,
        'g_b1': jnp.zeros(h_dim),
        'g_w2': jax.random.normal(keys[1], (h_dim, h_dim)) * scale,
        'g_b2': jnp.zeros(h_dim),
        'g_w3': jax.random.normal(keys[2], (h_dim, data_dim)) * scale,
        'g_b3': jnp.zeros(data_dim),
        'd_w1': jax.random.normal(keys[3], (data_dim, h_dim)) * scale,
        'd_b1': jnp.zeros(h_dim),
        'd_w2': jax.random.normal(keys[4], (h_dim, h_dim)) * scale,
        'd_b2': jnp.zeros(h_dim),
        'd_w3': jax.random.normal(keys[5], (h_dim, 1)) * scale,
        'd_b3': jnp.zeros(1),
    }

def d_loss(params, real_data, fake_data):
    real_score = discriminator(real_data, params)
    fake_score = discriminator(fake_data, params)
    return -jnp.mean(jnp.log(real_score + 1e-7) + jnp.log(1 - fake_score + 1e-7))

def g_loss(params, fake_data):
    fake_score = discriminator(fake_data, params)
    return -jnp.mean(jnp.log(fake_score + 1e-7))

# 真实数据:环形分布
key = jax.random.PRNGKey(42)
theta = jax.random.uniform(key, (512,)) * 2 * jnp.pi
real_data = jnp.stack([jnp.cos(theta), jnp.sin(theta)], axis=1)
real_data = real_data + jax.random.normal(key, real_data.shape) * 0.05

params = init_params(jax.random.PRNGKey(0))
d_grad = jax.grad(d_loss)
g_grad = jax.grad(g_loss)
lr = 0.001

snapshots = []
for step in range(3000):
    key, k1 = jax.random.split(key)
    z = jax.random.normal(k1, (512, 2))
    fake_data = generator(z, params)

    # 更新判别器
    grads = d_grad(params, real_data, fake_data)
    for k in ['d_w1', 'd_b1', 'd_w2', 'd_b2', 'd_w3', 'd_b3']:
        params[k] = params[k] - lr * grads[k]

    # 更新生成器
    fake_data = generator(z, params)
    grads = g_grad(params, fake_data)
    for k in ['g_w1', 'g_b1', 'g_w2', 'g_b2', 'g_w3', 'g_b3']:
        params[k] = params[k] - lr * grads[k]

    if step in [0, 500, 1500, 2999]:
        snapshots.append((step, fake_data.copy()))

fig, axes = plt.subplots(1, 4, figsize=(16, 4))
for ax, (step, fake) in zip(axes, snapshots):
    ax.scatter(real_data[:, 0], real_data[:, 1], s=5, alpha=0.3, c='#3498db', label='真实')
    ax.scatter(fake[:, 0], fake[:, 1], s=5, alpha=0.3, c='#e74c3c', label='生成')
    ax.set_title(f'步骤 {step}'); ax.set_xlim(-2, 2); ax.set_ylim(-2, 2)
    ax.set_aspect('equal'); ax.legend(markerscale=3)
plt.suptitle('GAN训练:生成器学习环形分布')
plt.tight_layout(); plt.show()
  1. 实现扩散前向过程:在不同时间步向图像添加噪声,并可视化逐步破坏过程。然后实现单步去噪。
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt

def noise_schedule(T, beta_start=0.0001, beta_end=0.02):
    """线性噪声调度。"""
    betas = jnp.linspace(beta_start, beta_end, T)
    alphas = 1.0 - betas
    alpha_bars = jnp.cumprod(alphas)
    return betas, alphas, alpha_bars

def forward_diffusion(x0, t, alpha_bars, key):
    """在时间步t向x0添加噪声。"""
    alpha_bar_t = alpha_bars[t]
    noise = jax.random.normal(key, x0.shape)
    xt = jnp.sqrt(alpha_bar_t) * x0 + jnp.sqrt(1 - alpha_bar_t) * noise
    return xt, noise

# 创建简单的2D"图像"(棋盘格)
img = jnp.zeros((32, 32))
for i in range(4):
    for j in range(4):
        if (i + j) % 2 == 0:
            img = img.at[i*8:(i+1)*8, j*8:(j+1)*8].set(1.0)

T = 1000
betas, alphas, alpha_bars = noise_schedule(T)

# 可视化前向过程
timesteps = [0, 50, 200, 500, 999]
key = jax.random.PRNGKey(42)

fig, axes = plt.subplots(1, len(timesteps), figsize=(16, 3.5))
for ax, t in zip(axes, timesteps):
    key, subkey = jax.random.split(key)
    xt, noise = forward_diffusion(img, t, alpha_bars, subkey)
    ax.imshow(xt, cmap='gray', vmin=-2, vmax=2)
    ax.set_title(f't={t}\n$\\bar{{\\alpha}}$={alpha_bars[t]:.3f}')
    ax.axis('off')
plt.suptitle('扩散前向过程:逐步添加噪声')
plt.tight_layout(); plt.show()

# 简单去噪:训练小型网络在t=200时预测噪声
t_denoise = 200
key, k1 = jax.random.split(key)
xt, true_noise = forward_diffusion(img, t_denoise, alpha_bars, k1)

# 小型"去噪器":仅学习恒定的噪声估计(用于演示)
noise_estimate = jnp.zeros_like(img)
lr = 0.01
for step in range(100):
    residual = noise_estimate - true_noise
    noise_estimate = noise_estimate - lr * residual

# 反向一步
alpha_bar_t = alpha_bars[t_denoise]
x_denoised = (xt - jnp.sqrt(1 - alpha_bar_t) * noise_estimate) / jnp.sqrt(alpha_bar_t)

fig, axes = plt.subplots(1, 3, figsize=(12, 4))
axes[0].imshow(img, cmap='gray'); axes[0].set_title('原始 $x_0$'); axes[0].axis('off')
axes[1].imshow(xt, cmap='gray', vmin=-2, vmax=2)
axes[1].set_title(f'含噪 $x_{{200}}$'); axes[1].axis('off')
axes[2].imshow(x_denoised, cmap='gray')
axes[2].set_title('去噪后(单步)'); axes[2].axis('off')
plt.tight_layout(); plt.show()

mse = jnp.mean((x_denoised - img)**2)
print(f"去噪MSE: {mse:.4f}")