feat: 完整中文翻译 maths-cs-ai-compendium(数学·计算机科学·AI 知识大全)
翻译自英文原版 maths-cs-ai-compendium,共 20 章全部完成。 第01章 向量 | 第02章 矩阵 | 第03章 微积分 第04章 统计学 | 第05章 概率论 | 第06章 机器学习 第07章 计算语言学 | 第08章 计算机视觉 | 第09章 音频与语音 第10章 多模态学习 | 第11章 自主系统 | 第12章 图神经网络 第13章 计算与操作系统 | 第14章 数据结构与算法 第15章 生产级软件工程 | 第16章 SIMD与GPU编程 第17章 AI推理 | 第18章 ML系统设计 第19章 应用人工智能 | 第20章 前沿人工智能 翻译说明: - 所有数学公式 $...$ / $$...$$、代码块、图片引用完整保留 - mkdocs.yml 配置中文导航 + language: zh - README.md 已翻译为中文(兼 docs/index.md) - docs/ 目录包含指向各章文件的 symlink - 约 29,000 行中文内容,排除 .cache/ 构建缓存
This commit is contained in:
@@ -0,0 +1,366 @@
|
||||
# 多模态表征
|
||||
|
||||
*多模态表征将视觉、语言和音频桥接到共享嵌入空间中。本文件涵盖融合策略、CLIP、ALIGN、SigLIP、对比损失函数(InfoNCE、NT-Xent)、零样本分类和检索评估。*
|
||||
|
||||
- 想象你坐在一家咖啡馆里。你看到桌上冒热气的水杯,听到陶瓷的叮当声,闻到烘焙咖啡豆的香气,感受到从马克杯传来的暖意。没有哪一种感官能告诉你一切:你的大脑将这些信号融合成一个统一的感知——"热咖啡"。**多模态学习** 对机器做了同样的事:它结合来自多种模态(视觉、语言、音频等)的信息,构建出比任何单一模态单独提供的表征更丰富、更鲁棒的表征。
|
||||
|
||||
- **模态(modality)** 是一种独特的信息通道。在机器学习中,最常见的模态包括图像(像素网格)、文本(词元序列)、音频(波形或语谱图,如第9章所述)、视频(帧序列)和结构化数据(表格、图)。每种模态都有其自身的统计结构:图像具有空间连贯性,文本是序列化和离散的,音频是时间性的和连续的。多模态学习的挑战在于桥接这些根本不同的数据类型。
|
||||
|
||||
- 为什么要费心结合多种模态?因为它们提供互补的信息。一张狗的照片告诉你它的品种和颜色,但不会告诉你名字。像"我的金毛犬 Max"这样的描述告诉你名字和品种,但不会告诉你确切姿态。图像和文本结合起来,比任何单独一个给出的画面都更完整。这种互补性是其核心动机:多模态模型可以回答那些单模态模型无法回答的问题、生成内容并做出决策。
|
||||
|
||||

|
||||
|
||||
## 融合策略
|
||||
|
||||
- 想象一个小组项目。你有两种组合想法的方式:每个人从一开始就在同一个房间里一起工作(共享原始笔记和草稿),或者每个人独立撰写自己的部分,最后合并最终文档。这分别对应于多模态学习中的**早期融合(early fusion)** 和**晚期融合(late fusion)**。
|
||||
|
||||
- **早期融合**(也称为特征级融合)在任何高级处理之前,对来自不同模态的原始或低级特征进行拼接或混合。例如,你可以将图像的像素特征与文本的词元嵌入拼接起来,将组合后的序列输入到一个单一的 Transformer 中。模型可以从一开始就学习细粒度的跨模态交互,但输入空间很大,且模型必须学会同时处理截然不同的数据类型。
|
||||
|
||||
- 形式化地,给定来自两种模态的特征向量 $x_{\\text{img}} \\in \\mathbb{R}^{d_1}$ 和 $x_{\\text{txt}} \\in \\mathbb{R}^{d_2}$,早期融合简单地拼接它们:
|
||||
|
||||
$$x_{\\text{fused}} = [x_{\\text{img}}; x_{\\text{txt}}] \\in \\mathbb{R}^{d_1 + d_2}$$
|
||||
|
||||
- 这个拼接后的向量由共享网络处理。其优势在于模型可以在每一层发现跨模态相关性。缺点是计算成本高,且难以对齐非常不同的特征类型(密集的像素值与稀疏的词元索引)。
|
||||
|
||||
- **晚期融合**(也称为决策级融合)通过各自的编码器独立处理每种模态,为每种模态生成一个高层表征甚至最终的预测结果。这些输出随后被组合,通常通过平均分数、投票或一个可学习的组合层。晚期融合更简单,且允许你直接复用预训练的单模态模型,但它无法捕捉低层的跨模态交互,因为各模态从未"看到"彼此的原始特征。
|
||||
|
||||
- 给定模态特定的预测值 $\hat{y}_1$ 和 $\hat{y}_2$,一个简单的晚期融合规则是:
|
||||
|
||||
$$\hat{y} = \\alpha \\hat{y}_1 + (1 - \\alpha) \\hat{y}_2$$
|
||||
|
||||
- 其中 $\\alpha \\in [0, 1]$ 是一个可学习或手动调节的混合权重。
|
||||
|
||||
- **中间融合(middle fusion)**(也称为中间融合 intermediate fusion)是大多数现代系统使用的实用折中方案。每种模态先由其自身的编码器处理(提取模态特定的特征),然后在网络中间部分通过跨注意力层等方式组合编码后的表征。这使得每个编码器可以专注于自身的模态,同时仍能实现丰富的跨模态交互。Flamingo、LLaVA 和大多数视觉-语言模型(文件 02)都使用中间融合。
|
||||
|
||||

|
||||
|
||||
- 融合策略的选择取决于数据可用性、计算预算和任务。早期融合功能强大但数据需求高。晚期融合廉价但受限。带有跨注意力的中间融合已成为大规模多模态模型的主流做法,因为它在表达能力与模块化之间取得了平衡。
|
||||
|
||||
## 联合嵌入空间
|
||||
|
||||
- 想象一个通用翻译器,它可以将任何语言的任何句子映射到同一个共享"意义空间"中的同一点。用英语、法语或日语说的"a dog on a beach"都会落在同一个坐标上。**联合嵌入空间** 跨模态做了完全相同的事:一张沙滩上的狗的图像和文本"a dog on a beach"应该映射到同一向量空间中的邻近点。
|
||||
|
||||
- 形式化地,我们学习两个编码器函数:模态 1(如图像)的 $f_\\theta : \\mathcal{X}_1 \\to \\mathbb{R}^d$ 和模态 2(如文本)的 $g_\\phi : \\mathcal{X}_2 \\to \\mathbb{R}^d$。两者都将输入映射到相同的 $d$ 维空间。训练目标确保语义匹配的对 $(x_1, x_2)$ 的嵌入 $f_\\theta(x_1)$ 和 $g_\\phi(x_2)$ 彼此接近(高余弦相似度),而不匹配的对则相距很远。
|
||||
|
||||
- 这是第 7 章中词嵌入空间的直接推广。回忆一下,Word2Vec 和 GloVe 将语义相似的词放置在向量空间中彼此靠近。联合嵌入空间将这一思想扩展到跨模态:不是衡量词与词的相似性,而是衡量图像到文本的相似性、音频到文本的相似性,甚至图像到音频的相似性。
|
||||
|
||||
- 相似度度量几乎总是**余弦相似度**(第 1 章):
|
||||
|
||||
$$\\text{sim}(u, v) = \\frac{u \\cdot v}{\\|u\\| \\|v\\|}$$
|
||||
|
||||
- 通过将所有嵌入 $L_2$ 归一化到单位超球面上,余弦相似度简化为简单的点积 $u \\cdot v$,计算效率极高,并且可以使用近似最近邻库进行加速。
|
||||
|
||||

|
||||
|
||||
- 联合嵌入空间的强大之处在于它实现了**零样本迁移**。一旦你对齐了图像和文本嵌入,你就可以将从未训练过的类别图像分类:只需将类别名称作为文本嵌入,然后找出与图像嵌入最接近的文本嵌入即可。无需特定任务的微调。这是 CLIP 及其后继模型背后的关键洞察。
|
||||
|
||||
## 用于多模态对齐的对比学习
|
||||
|
||||
- 想象一个课堂练习:学生们拿到打乱的照片和描述对,需要将每张照片与其正确的描述配对。要出色地完成这项任务,你需要同时理解视觉内容与语言,并知道它们如何关联。**对比学习** 正是以这种方式训练模型:给定一批 (图像, 文本) 对,模型必须找出哪张图像对应哪段文本。
|
||||
|
||||
- 正如我们在第 8 章(文件 04)中看到的,单模态环境下的对比学习(SimCLR、MoCo)将同一图像的不同增广视图拉近,将不同图像的视图推远。多模态对比学习将"增广视图"替换为"匹配的模态":图像及其描述构成正样本对;该图像与批次中任何其他描述的配对构成负样本对。
|
||||
|
||||
### CLIP
|
||||
|
||||
- **CLIP**(Contrastive Language-Image Pre-training,对比语言-图像预训练,Radford 等,2021)是多模态对比学习的基础模型。它在从互联网上抓取的 4 亿个 (图像, 文本) 对上联合训练一个图像编码器(ViT 或 ResNet,第 8 章)和一个文本编码器(Transformer,第 7 章)。
|
||||
|
||||
- 给定一批 $N$ 个图像-文本对,CLIP 计算所有图像嵌入与所有文本嵌入之间的 $N \\times N$ 余弦相似度矩阵。对角线上的条目是匹配的对(正样本);所有非对角线条目是不匹配的(负样本)。训练损失促使对角线条目升高,非对角线条目降低。
|
||||
|
||||
- 该损失是对称交叉熵。对于图像 $i$ 与文本 $j = i$ 的配对,图像到文本的损失为:
|
||||
|
||||
$$\\mathcal{L}_{i \\to t} = -\\frac{1}{N} \\sum_{i=1}^{N} \\log \\frac{\\exp(\\text{sim}(z_i^{\\text{img}}, z_i^{\\text{txt}}) / \\tau)}{\\sum_{k=1}^{N} \\exp(\\text{sim}(z_i^{\\text{img}}, z_k^{\\text{txt}}) / \\tau)}$$
|
||||
|
||||
- 文本到图像的损失与之相同,只是交换了角色:
|
||||
|
||||
$$\\mathcal{L}_{t \\to i} = -\\frac{1}{N} \\sum_{i=1}^{N} \\log \\frac{\\exp(\\text{sim}(z_i^{\\text{txt}}, z_i^{\\text{img}}) / \\tau)}{\\sum_{k=1}^{N} \\exp(\\text{sim}(z_i^{\\text{txt}}, z_k^{\\text{img}}) / \\tau)}$$
|
||||
|
||||
- 总的 CLIP 损失是平均值:
|
||||
|
||||
$$\\mathcal{L}_{\\text{CLIP}} = \\frac{1}{2}(\\mathcal{L}_{i \\to t} + \\mathcal{L}_{t \\to i})$$
|
||||
|
||||
- 这里 $\\tau$ 是一个可学习的**温度**参数(初始化为 $\\tau = 0.07$)。温度控制 softmax 分布的尖锐程度:较低的 $\\tau$ 使模型更专注于最接近的匹配,较高的 $\\tau$ 则更均匀地分布概率。CLIP 将 $\\tau$ 与模型权重一起联合学习,而不是将其视为固定的超参数。
|
||||
|
||||

|
||||
|
||||
- CLIP 的图像编码器通常是 ViT-L/14(大型 Vision Transformer,14x14 块,第 8 章文件 04)。文本编码器是一个 12 层带有因果掩码的 Transformer(类似 GPT,第 7 章文件 04)。两个编码器都通过一个可学习的线性投影将其输出映射到共享的 512 或 768 维空间,随后进行 $L_2$ 归一化。
|
||||
|
||||
- CLIP 最引人注目的特性是**零样本图像分类**。要将图像分类到 $K$ 个类别之一,你创建 $K$ 个文本提示,如"a photo of a {class name}",用文本编码器嵌入每个提示,用图像编码器嵌入图像,然后选择文本嵌入与图像嵌入余弦相似度最高的类别。在 ImageNet 上,CLIP 在从未见过任何 ImageNet 训练样本的情况下取得了具有竞争力的准确率。
|
||||
|
||||
### ALIGN
|
||||
|
||||
- **ALIGN**(Jia 等,2021)将 CLIP 的方法扩展到更大、更嘈杂的数据集:18 亿个图像-文本对,仅极少量过滤。CLIP 精心筛选其数据,而 ALIGN 表明规模可以弥补噪声。ALIGN 使用 EfficientNet 图像编码器和 BERT 文本编码器,并使用相同的对比损失进行训练。关键发现是,只要有足够的数据,就不需要昂贵的数据清洗:对比目标会自然地降低噪声对的权重,因为它们产生不一致的梯度。
|
||||
|
||||
### SigLIP
|
||||
|
||||
- **SigLIP**(Sigmoid Loss for Language-Image Pre-training,Sigmoid 损失语言-图像预训练,Zhai 等,2023)用更简单的 sigmoid 损失取代了 CLIP 基于 softmax 的对比损失。SigLIP 不将 $N \\times N$ 相似度矩阵视为分类问题(每行是一个列上的 softmax),而是将每个条目独立视为二分类问题:这个 (图像, 文本) 对是否匹配?
|
||||
|
||||
- 单个对 $(i, j)$ 的 SigLIP 损失是:
|
||||
|
||||
$$\\mathcal{L}_{ij} = -y_{ij} \\log \\sigma(z_i^{\\text{img}} \\cdot z_j^{\\text{txt}} / \\tau) - (1 - y_{ij}) \\log(1 - \\sigma(z_i^{\\text{img}} \\cdot z_j^{\\text{txt}} / \\tau))$$
|
||||
|
||||
- 其中 $y_{ij} = 1$ 如果 $i = j$(匹配),否则 $y_{ij} = 0$,$\\sigma$ 是 sigmoid 函数。
|
||||
|
||||
- SigLIP 的关键优势在于它消除了跨整个批次进行全局 softmax 归一化的需要。在 CLIP 中,softmax 分母需要收集所有设备上的所有嵌入,这在分布式训练中是一个通信瓶颈。SigLIP 的逐对 sigmoid 损失可以在本地计算,从而能够更高效地扩展到非常大的批次。SigLIP 以更低的训练成本达到了与 CLIP 相当的质量。
|
||||
|
||||
## 对比损失函数详解
|
||||
|
||||
- 对比学习中使用的损失函数共享一个共同的结构:它们都试图使正样本对的相似度得分高于负样本对的相似度得分,同时通过某种"间隔"或"温度"控制模型施加的力度。让我们形式化关键变体。
|
||||
|
||||
### InfoNCE
|
||||
|
||||
- **InfoNCE**(噪声对比估计,van den Oord 等,2018)是 CLIP 损失背后的理论基础。给定一个查询 $q$、一个正样本键 $k^+$ 和 $K$ 个负样本键 $\\{k_1^-, \\ldots, k_K^-\\}$,损失为:
|
||||
|
||||
$$\\mathcal{L}_{\\text{InfoNCE}} = -\\log \\frac{\\exp(q \\cdot k^+ / \\tau)}{\\exp(q \\cdot k^+ / \\tau) + \\sum_{j=1}^{K} \\exp(q \\cdot k_j^- / \\tau)}$$
|
||||
|
||||
- 这是一个 $(K+1)$ 类分类问题:从 $K+1$ 个候选中识别出正样本。InfoNCE 是查询与正样本键之间互信息的下界,这就是为什么最大化它能够对齐语义匹配输入的表征。随着负样本数量 $K$ 的增加,下界更加紧致,这解释了为什么对比方法受益于大批量大小。
|
||||
|
||||
### NT-Xent
|
||||
|
||||
- **NT-Xent**(归一化温度标度交叉熵,Chen 等,2020)是 SimCLR(第 8 章文件 04)中使用的损失,本质上是在批次内对称应用的 InfoNCE。对于一批 $N$ 个对,$2N$ 个增广视图为每个锚点产生 $2N - 2$ 个负样本(除自身及其正样本外的所有视图)。正样本对 $(i, j)$ 的损失为:
|
||||
|
||||
$$\\ell_{i,j} = -\\log \\frac{\\exp(\\text{sim}(z_i, z_j) / \\tau)}{\\sum_{k=1}^{2N} \\mathbf{1}_{[k \\neq i]} \\exp(\\text{sim}(z_i, z_k) / \\tau)}$$
|
||||
|
||||
- NT-Xent 和 InfoNCE 是相同的数学公式;名称不同只是因为它们是在不同的上下文(自监督视觉 vs. 表征学习理论)中引入的。
|
||||
|
||||
### 温度的作用
|
||||
|
||||
- **温度** $\\tau$ 是对比学习中最重要的超参数之一。为了建立直觉,可以从物理意义上考虑温度:在高温下,分子随机运动(softmax 是平坦的,所有负样本看起来一样差);在低温下,分子沉降为刚性结构(softmax 是尖锐的,只有最难的负样本才重要)。
|
||||
|
||||
- 形式化地,当 $\\tau \\to 0$ 时,softmax 趋近于硬 argmax,只选择最单一的困难负样本。当 $\\tau \\to \\infty$ 时,所有负样本的贡献相等。在实践中,$\\tau \\in [0.01, 0.1]$ 对归一化嵌入效果良好。温度过低会导致训练不稳定(困难负样本的梯度变得非常大);温度过高会使损失对违反情况不敏感。
|
||||
|
||||
- CLIP 初始化 $\\tau = 0.07$ 并将其作为对数参数化的标量 $\\tau = \\exp(t)$ 学习,其中 $t$ 与模型权重一起通过梯度下降更新。这使得模型能够在训练过程中自动调整对比任务的难度。
|
||||
|
||||

|
||||
|
||||
### 三元组损失和基于间隔的替代方案
|
||||
|
||||
- 在 InfoNCE 主导之前,**三元组损失(triplet loss)** 是度量学习的标准。给定一个锚点 $a$、一个正样本 $p$ 和一个负样本 $n$:
|
||||
|
||||
$$\\mathcal{L}_{\\text{triplet}} = \\max(0, \\|a - p\\|^2 - \\|a - n\\|^2 + m)$$
|
||||
|
||||
- 其中 $m$ 是一个间隔,确保正样本至少比负样本近 $m$。三元组损失操作在单个三元组上而非批次上,因此样本效率低于 InfoNCE。它还对挖掘策略敏感:随机负样本通常过于简单(损失为零),因此**困难负样本挖掘**(hard negative mining,选择最接近的不正确匹配)或**半困难挖掘**(semi-hard mining,选择间隔内的负样本)至关重要。
|
||||
|
||||
- InfoNCE 在整个批次中隐式地执行困难负样本挖掘,这是它在规模上优于三元组损失的原因之一。InfoNCE 中的 softmax 归一化自动提高困难负样本(与锚点相似度高的负样本)的权重,在无需显式挖掘的情况下提供了自然的课程学习。
|
||||
|
||||
## 图像-文本检索与零样本分类
|
||||
|
||||
- 一旦你有了训练好的联合嵌入空间,就可以执行**图像-文本检索**:给定一个图像查询,从数据库中找出最相关的文本(图像到文本检索),或者给定一个文本查询,找出最相关的图像(文本到图像检索)。这仅仅是共享嵌入空间中的最近邻搜索。
|
||||
|
||||
- 想象一个图书管理员,可以即时比较一百万条目录中的任何照片与任何描述。他们不需要事先理解每一个可能的类别;只需测量每张照片与每条描述有多"接近"。这就是 CLIP 风格的模型执行检索和零样本分类的方式。
|
||||
|
||||
- **零样本分类**是文本到图像检索的一个特例。给定 $K$ 个类别名称,你构建文本提示 $\\{t_1, \\ldots, t_K\\}$(例如,"a photo of a cat"、"a photo of a dog")并对其进行嵌入。对于一张新图像 $x$,预测的类别为:
|
||||
|
||||
$$\\hat{y} = \\arg\\max_{k} \\; \\text{sim}(f_\\theta(x), g_\\phi(t_k))$$
|
||||
|
||||
- 关键洞察在于,文本编码器充当了一个灵活的分类器头。你不需要为每个下游任务训练新的线性层,只需用自然语言描述任务。这就是 CLIP 泛化能力如此之强的原因:文本编码器在预训练期间见过数百万种不同的描述。
|
||||
|
||||
- **提示工程(prompt engineering)** 很重要。CLIP 在 ImageNet 上的零样本准确率从 63.2% 提升到 68.4%,仅仅是将提示模板从 "{class name}" 改为 "a photo of a {class name}." 更好的是,**提示集成(prompt ensembling)** 通过平均多个模板的文本嵌入(例如,"a photo of a {class name}"、"a good photo of a {class name}"、"a drawing of a {class name}")来产生更鲁棒的文本表征。
|
||||
|
||||

|
||||
|
||||
## 音视频对应
|
||||
|
||||
- 闭上眼睛,听某人拍篮球。你能从节奏性的砰砰声中判断球何时落地。现在睁开眼睛:视觉上的弹跳与每次砰声完美对齐。这种音频与视觉事件之间的紧密对应关系是一种机器可以学习的免费监督信号。**音视频对应学习(audio-visual correspondence learning)** 训练模型将声音与其视觉来源关联起来,无需任何人工标注。
|
||||
|
||||
- 这个想法与 CLIP 惊人地相似,只是将文本替换为音频。给定配对的视频帧和音频片段,模型学习一个嵌入空间,其中时间上对齐的音视频对彼此接近,而错位的对则相距很远。
|
||||
|
||||
- **音视频嵌入(Audio-Visual Embedding, AVE)** 方法(Arandjelovic 和 Zisserman,2017)使用对比损失在视频数据上训练一个视觉编码器 $f$ 和一个音频编码器 $g$。正样本对是(视频帧,来自同一时刻的音频片段),负样本是来自不同视频或不同时刻的音频片段。模型学会狗叫声对应狗的图像,吉他声对应吉他的图像,所有这些都不需要标签。
|
||||
|
||||
- 音频编码器通常使用 CNN 或音频 Transformer 处理**对数梅尔语谱图(log-mel spectrograms)**(第 9 章文件 01),生成固定大小的嵌入。视觉编码器使用标准图像骨干网络(ResNet、ViT)处理视频帧。两者都投影到共享的 $d$ 维空间,训练使用与 CLIP 相同的 InfoNCE 损失:
|
||||
|
||||
$$\\mathcal{L}_{\\text{AV}} = -\\log \\frac{\\exp(\\text{sim}(z^{\\text{vis}}, z^{\\text{aud}}) / \\tau)}{\\sum_{k=1}^{N} \\exp(\\text{sim}(z^{\\text{vis}}, z_k^{\\text{aud}}) / \\tau)}$$
|
||||
|
||||

|
||||
|
||||
- 音视频学习的**应用**包括:声源定位(图像中声音来自何处?)、音视频语音识别(结合嘴唇运动和音频,如第 9 章文件 02)、音视频源分离(通过看着对方的脸来隔离一个人的声音——第 9 章文件 05 中的"鸡尾酒会"问题),以及基于音频的视频生成。
|
||||
|
||||
- **ImageBind**(Girdhar 等,2023)将其扩展到六种模态:图像、文本、音频、深度、热成像和 IMU 数据。关键洞察在于,你不需要每个组合都有配对数据。通过将每种模态与图像对齐(文本通过图像-文本对,音频通过图像-音频对等),所有模态通过共享的图像嵌入空间隐式对齐。这种通过公共锚点模态的"绑定"产生了涌现式对齐:音频和文本变得相似,即使它们从未被直接一起训练过。
|
||||
|
||||
## 评估
|
||||
|
||||
- 评估多模态模型需要能够捕捉跨模态理解的度量指标。两种主流的评估范式是**零样本基准测试**和**检索度量**。
|
||||
|
||||
### 零样本基准测试
|
||||
|
||||
- 零样本评估衡量模型是否能够执行从未被明确训练过的任务。最常用的基准是**ImageNet 零样本准确率**:将所有 1,000 个 ImageNet 类别名称作为文本嵌入,嵌入每个测试图像,根据余弦相似度测量 top-1 和 top-5 分类准确率。CLIP ViT-L/14 在零样本下达到 75.5% 的 top-1 准确率,与在 ImageNet 上训练的监督式 ResNet-50 相当。
|
||||
|
||||
- 其他零样本基准包括:CIFAR-10/100、STL-10、Food-101、Oxford Pets 和 Flowers-102。在多个数据集上评估可以测试模型是否真正具有通用的视觉理解能力,还是仅仅是记住了预训练数据中的模式。
|
||||
|
||||
- **线性探测(linear probe)** 评估是一种互补的测试。你冻结预训练的图像编码器,为标注数据集提取特征,然后在其上训练一个简单的线性分类器。这独立于零样本检索机制来度量学习到的表征的质量。CLIP 的特征是极好的线性探测特征,通常达到或超过监督预训练。
|
||||
|
||||
### 检索度量
|
||||
|
||||
- 对于检索任务(图像到文本和文本到图像),标准度量是 **Recall@K**(R@K):正确匹配出现在前 $K$ 个检索结果中的查询比例。常用的取值为 R@1、R@5 和 R@10。
|
||||
|
||||
- 形式化地,对于一组 $Q$ 个查询:
|
||||
|
||||
$$\\text{R@}K = \\frac{1}{Q} \\sum_{q=1}^{Q} \\mathbf{1}[\\text{rank}(q) \\leq K]$$
|
||||
|
||||
- 其中 $\\text{rank}(q)$ 是查询 $q$ 的排序检索列表中正确匹配的位置。
|
||||
|
||||
- 标准的检索基准包括 **Flickr30K**(31,000 张图像,每张 5 条描述)和 **MS-COCO**(123,000 张图像,每张 5 条描述)。在测试集上评估:给定一张图像,从全部测试集中检索正确的描述,反之亦然。
|
||||
|
||||
- **中位数排名(Median Rank, MedR)** 是一种补充度量:所有查询中正确匹配的中位数位置。完美模型的 MedR = 1。数值越小越好。
|
||||
|
||||
- 除了检索,多模态模型还在组合理解基准上进行评估,如 **Winoground**(测试模型能否区分"a mug in a dog"和"a dog in a mug")和 **ARO**(属性、关系、顺序),这些基准测试模型是否真正理解语言的结构,而不仅仅是匹配词袋。CLIP 风格的模型通常在这些任务上表现不佳,这揭示了一个基本的局限:对比预训练对齐了全局语义,但可能无法捕捉细粒度的组合结构。
|
||||
|
||||

|
||||
|
||||
## 总结
|
||||
|
||||
- 本文件涵盖的多模态表征构成了本章后续所有内容的基础。CLIP 及其后继模型训练的联合嵌入空间是连接视觉和语言的"胶水"。文件 02 在此基础之上,构建了超越检索、能够生成关于图像文本的视觉-语言模型。文件 03 探讨了如何在序列模型中对图像和视频进行分词。文件 04 涵盖跨模态生成(文本到图像、文本到视频)。文件 05 研究了在单一模型中处理多种模态的统一架构。
|
||||
|
||||
- 核心要点:在配对数据上进行对比学习产生了嵌入空间,使得不同模态之间可以互换。图像嵌入和文本嵌入变成了"同一种东西",从而实现零样本分类、检索以及无缝集成到更大的系统中。这个想法——将匹配的对拉近、不匹配的对推远——的简单性掩盖了其非凡的有效性。
|
||||
|
||||
## 编程任务(使用 CoLab 或 notebook)
|
||||
|
||||
1. 从头实现 CLIP 对比损失。创建随机图像和文本嵌入,计算相似度矩阵,并计算对称交叉熵损失。
|
||||
```python
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
def clip_loss(image_embeds, text_embeds, temperature=0.07):
|
||||
"""计算对称 CLIP 对比损失。"""
|
||||
# L2 归一化嵌入
|
||||
image_embeds = image_embeds / jnp.linalg.norm(image_embeds, axis=1, keepdims=True)
|
||||
text_embeds = text_embeds / jnp.linalg.norm(text_embeds, axis=1, keepdims=True)
|
||||
|
||||
# 计算余弦相似度矩阵 (N x N)
|
||||
logits = image_embeds @ text_embeds.T / temperature # (N, N)
|
||||
|
||||
# 标签:对角线(第 i 张图像匹配第 i 段文本)
|
||||
N = logits.shape[0]
|
||||
labels = jnp.arange(N)
|
||||
|
||||
# 对称交叉熵:图像到文本 + 文本到图像
|
||||
loss_i2t = -jnp.mean(jax.nn.log_softmax(logits, axis=1)[jnp.arange(N), labels])
|
||||
loss_t2i = -jnp.mean(jax.nn.log_softmax(logits, axis=0)[labels, jnp.arange(N)])
|
||||
return (loss_i2t + loss_t2i) / 2, logits * temperature
|
||||
|
||||
# 模拟一批 8 个图像-文本对,64 维空间
|
||||
key = jax.random.PRNGKey(42)
|
||||
k1, k2 = jax.random.split(key)
|
||||
N, D = 8, 64
|
||||
image_embeds = jax.random.normal(k1, (N, D))
|
||||
text_embeds = jax.random.normal(k2, (N, D))
|
||||
|
||||
loss, sim_matrix = clip_loss(image_embeds, text_embeds)
|
||||
print(f"CLIP loss (random embeddings): {loss:.4f}")
|
||||
|
||||
# 可视化相似度矩阵
|
||||
fig, ax = plt.subplots(figsize=(6, 5))
|
||||
im = ax.imshow(sim_matrix, cmap='coolwarm', vmin=-1, vmax=1)
|
||||
ax.set_xlabel("Text index"); ax.set_ylabel("Image index")
|
||||
ax.set_title(f"Cosine Similarity Matrix (loss={loss:.3f})")
|
||||
plt.colorbar(im); plt.tight_layout(); plt.show()
|
||||
# 尝试改变温度 (0.01, 0.1, 1.0) 并观察损失如何变化
|
||||
# 尝试使匹配对相似:将 text_embeds 设置为 image_embeds + 小噪声
|
||||
```
|
||||
|
||||
2. 构建一个玩具联合嵌入模型,学习使用 InfoNCE 损失和梯度下降来对齐 2D"图像"(随机向量)与"描述"(不同的随机向量)。
|
||||
```python
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
def info_nce_loss(img_enc, txt_enc, img_data, txt_data, tau=0.1):
|
||||
"""在一批配对的 (图像, 文本) 数据上计算 InfoNCE。"""
|
||||
z_img = img_data @ img_enc # (N, D)
|
||||
z_txt = txt_data @ txt_enc # (N, D)
|
||||
# L2 归一化
|
||||
z_img = z_img / jnp.linalg.norm(z_img, axis=1, keepdims=True)
|
||||
z_txt = z_txt / jnp.linalg.norm(z_txt, axis=1, keepdims=True)
|
||||
logits = z_img @ z_txt.T / tau
|
||||
labels = jnp.arange(logits.shape[0])
|
||||
return -jnp.mean(jax.nn.log_softmax(logits, axis=1)[jnp.arange(len(labels)), labels])
|
||||
|
||||
# 创建 32 个配对样本:图像在 R^8 中,文本在 R^6 中,嵌入到 R^4
|
||||
key = jax.random.PRNGKey(0)
|
||||
k1, k2, k3, k4 = jax.random.split(key, 4)
|
||||
N, d_img, d_txt, d_embed = 32, 8, 6, 4
|
||||
|
||||
img_data = jax.random.normal(k1, (N, d_img))
|
||||
txt_data = jax.random.normal(k2, (N, d_txt))
|
||||
|
||||
# 可学习的投影矩阵
|
||||
img_enc = jax.random.normal(k3, (d_img, d_embed)) * 0.1
|
||||
txt_enc = jax.random.normal(k4, (d_txt, d_embed)) * 0.1
|
||||
|
||||
grad_fn = jax.jit(jax.grad(info_nce_loss, argnums=(0, 1)))
|
||||
lr = 0.05
|
||||
losses = []
|
||||
|
||||
for step in range(300):
|
||||
loss = info_nce_loss(img_enc, txt_enc, img_data, txt_data)
|
||||
losses.append(float(loss))
|
||||
g_img, g_txt = grad_fn(img_enc, txt_enc, img_data, txt_data)
|
||||
img_enc = img_enc - lr * g_img
|
||||
txt_enc = txt_enc - lr * g_txt
|
||||
|
||||
print(f"Initial loss: {losses[0]:.3f}, Final loss: {losses[-1]:.3f}")
|
||||
print(f"Random baseline (log N): {jnp.log(N):.3f}")
|
||||
|
||||
plt.figure(figsize=(8, 4))
|
||||
plt.plot(losses, color='#2c3e50')
|
||||
plt.axhline(y=0, color='green', linestyle='--', alpha=0.5, label='Perfect alignment')
|
||||
plt.axhline(y=float(jnp.log(N)), color='red', linestyle='--', alpha=0.5, label='Random (log N)')
|
||||
plt.xlabel("Step"); plt.ylabel("InfoNCE Loss")
|
||||
plt.title("Learning a Joint Embedding Space")
|
||||
plt.legend(); plt.grid(alpha=0.3); plt.tight_layout(); plt.show()
|
||||
# 修改 d_embed(尝试 2, 4, 16)观察嵌入维度如何影响对齐
|
||||
```
|
||||
|
||||
3. 使用预计算的嵌入实现零样本分类。模拟类"原型"作为文本嵌入,通过最近邻查找对新图像进行分类。
|
||||
```python
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
# 模拟 5 个类,每个类有一个原型文本嵌入在 R^32 中
|
||||
key = jax.random.PRNGKey(42)
|
||||
n_classes, d = 5, 32
|
||||
class_names = ["cat", "dog", "car", "plane", "ship"]
|
||||
|
||||
# 类原型(想象这些来自文本编码器)
|
||||
k1, k2 = jax.random.split(key)
|
||||
class_prototypes = jax.random.normal(k1, (n_classes, d))
|
||||
class_prototypes = class_prototypes / jnp.linalg.norm(class_prototypes, axis=1, keepdims=True)
|
||||
|
||||
# 生成 200 个测试"图像"(在其类原型附近加上噪声的嵌入)
|
||||
n_per_class = 40
|
||||
true_labels = jnp.repeat(jnp.arange(n_classes), n_per_class)
|
||||
keys = jax.random.split(k2, n_classes * n_per_class)
|
||||
|
||||
image_embeds = []
|
||||
for i in range(n_classes):
|
||||
noise = jax.random.normal(keys[i], (n_per_class, d)) * 0.5
|
||||
cluster = class_prototypes[i] + noise
|
||||
image_embeds.append(cluster)
|
||||
image_embeds = jnp.concatenate(image_embeds, axis=0)
|
||||
image_embeds = image_embeds / jnp.linalg.norm(image_embeds, axis=1, keepdims=True)
|
||||
|
||||
# 零样本分类:与每个原型的余弦相似度
|
||||
similarities = image_embeds @ class_prototypes.T # (200, 5)
|
||||
predicted_labels = jnp.argmax(similarities, axis=1)
|
||||
accuracy = jnp.mean(predicted_labels == true_labels)
|
||||
print(f"Zero-shot accuracy: {accuracy:.1%}")
|
||||
|
||||
# 混淆矩阵
|
||||
conf = jnp.zeros((n_classes, n_classes), dtype=jnp.int32)
|
||||
for true, pred in zip(true_labels, predicted_labels):
|
||||
conf = conf.at[true, pred].add(1)
|
||||
|
||||
fig, ax = plt.subplots(figsize=(6, 5))
|
||||
im = ax.imshow(conf, cmap='Blues')
|
||||
ax.set_xticks(range(n_classes)); ax.set_xticklabels(class_names, rotation=45)
|
||||
ax.set_yticks(range(n_classes)); ax.set_yticklabels(class_names)
|
||||
ax.set_xlabel("Predicted"); ax.set_ylabel("True")
|
||||
for i in range(n_classes):
|
||||
for j in range(n_classes):
|
||||
ax.text(j, i, int(conf[i, j]), ha='center', va='center', fontsize=11)
|
||||
ax.set_title(f"Zero-Shot Confusion Matrix (acc={accuracy:.1%})")
|
||||
plt.colorbar(im); plt.tight_layout(); plt.show()
|
||||
# 尝试增加噪声(0.5 -> 1.0 -> 2.0)观察准确率下降
|
||||
# 尝试提示集成:平均每个原型的 3 个噪声副本
|
||||
```
|
||||
@@ -0,0 +1,388 @@
|
||||
# 视觉语言模型
|
||||
|
||||
*视觉语言模型共同理解图像和文本,实现视觉问答、图像描述和视觉推理。本文件涵盖 VQA、图像描述、视觉定位,以及 VisualBERT、BLIP、LLaVA、Flamingo、PaLI 和 Qwen-VL 等将视觉编码器与大型语言模型融合的架构。*
|
||||
|
||||
- 想象一位博物馆导览员,他能看着一幅画并清晰描述画中的一切:有哪些物体、讲述了什么故事、传达了怎样的情感,还能回答参观者的任何问题。**视觉语言模型(VLM)** 就是计算领域的等价物——一个能同时理解图像和文本的系统,能够描述视觉场景、回答相关问题、执行视觉指令,甚至根据自然语言查询在图像中定位特定物体。
|
||||
|
||||
- VLM 位于你在第 8 章学到的视觉编码器和第 7 章的语言模型的交汇点。核心工程挑战在于桥接两个截然不同的表征世界:视觉骨干网络产生的空间化、连续的 feature map,与语言模型产生的序列化、离散的 token 嵌入。本文件中的每一种架构,本质上都是对同一个问题的不同回答:如何融合视觉和语言?
|
||||
|
||||

|
||||
|
||||
## 视觉问答
|
||||
|
||||
- 想象有人向你展示一张照片并问:"公园里有几只狗?"你毫不费力地解析图像、定位狗、数出数量并给出答案。**视觉问答(VQA)** 将这一过程形式化:给定一张图像 $I$ 和一个自然语言问题 $q$,预测答案 $a$。
|
||||
|
||||
- 该任务可以有多种定义方式。最常见的方式将 VQA 视为**开放式分类**:模型从最常见的答案构成的固定词汇表中选择(例如 VQA v2 中排名前 3,129 的答案)。另一种方式是**生成式回答**,模型生成自由形式的文本字符串——这是现代 VLM 采用的方法。
|
||||
|
||||
- 形式上,你需要学习一个最大化正确答案似然的函数 $f(I, q) \to a$。在分类设置中,这变为:
|
||||
|
||||
$$p(a \mid I, q) = \text{softmax}(W \cdot g(v, h))$$
|
||||
|
||||
- 其中 $v$ 是视觉特征向量(来自 CNN 或 ViT),$h$ 是问题编码(来自 LSTM 或 Transformer),$g$ 是融合函数。$g$ 的设计正是真正的架构创造力所在。
|
||||
|
||||
- **VQA v1**(Antol 等人,2015)引入了该基准,包含来自 MS COCO 的 204,000 张图像上的 614,000 个问题。研究人员很快发现,模型可以通过利用**语言先验**达到惊人高的准确率——对"多少个"问题回答"2",对"有没有"问题回答"是",甚至不需要看图像。
|
||||
|
||||
- **VQA v2**(Goyal 等人,2017)通过为每个问题配对不同答案的两张相似图像来解决这个问题。这迫使模型真正将其推理建立在视觉内容之上。平衡配对设置使数据集规模大约翻倍,并使纯语言捷径的效果大打折扣。
|
||||
|
||||
- 其他重要的 VQA 数据集包括 **GQA**(Hudson & Manning,2019),包含需要多步推理的组合性问题;**OK-VQA**(Marino 等人,2019),需要超出图像范围的外部知识;以及 **TextVQA**(Singh 等人,2019),答案依赖于读取图像中的文字。
|
||||
|
||||

|
||||
|
||||
- 早期的 VQA 模型使用简单策略:从预训练 CNN 中提取图像特征(通常是第 8 章中 ResNet 或 VGGNet 的倒数第二层),用 LSTM(第 6 章)对问题进行编码,然后将它们组合。组合函数 $g$ 演变迅速:从简单的逐元素乘法,到双线性池化,再到多模态 Tucker 分解。**双线性注意力**计算 $v^T W h$,其中 $W$ 是可学习的交互矩阵,但完整的双线性形式有 $O(d_v \times d_h)$ 个参数,规模过大。**MLB**(多模态低秩双线性池化)将其分解为两个低秩投影,使其变得可行。
|
||||
|
||||
- VQA 的突破是注意力机制。**堆叠注意力网络**(Yang 等人,2016)使用问题编码在空间图像区域上施加注意力,迭代式地精炼需要关注的图像部分。这个思想——让问题"关注"相关图像区域——成为了标准做法。
|
||||
|
||||
## 图像描述
|
||||
|
||||
- 想象一位朋友看着你的度假照片并叙述他们所看到的:"一只金毛猎犬在阳光明媚的沙滩上接飞盘。"**图像描述**是生成图像的自然语言描述的任务。与 VQA 不同,这里没有提问——模型必须自行决定哪些内容值得描述。
|
||||
|
||||
- **Show and Tell**(Vinyals 等人,2015)建立了描述任务的标准编码器-解码器架构。CNN 编码器(如 Inception 或 ResNet)生成一个单一图像特征向量 $v$。该向量被用作 LSTM 解码器的初始隐藏状态,然后逐词自回归地生成描述:
|
||||
|
||||
$$p(w_t \mid w_{1:t-1}, I) = \text{LSTM}(w_{t-1}, h_{t-1})$$
|
||||
|
||||
- 整个模型通过最大化真实描述的对数似然进行端到端训练。推理时使用束搜索(第 7 章)来找到高概率的描述。
|
||||
|
||||
- Show and Tell 的问题在于整张图像被压缩成一个单一向量。对于复杂场景,单一向量无法捕捉所有相关细节。你会丢失空间信息——模型在生成不同词语时无法"回看"图像的特定区域。
|
||||
|
||||
- **Show, Attend and Tell**(Xu 等人,2015)通过引入**图像区域上的注意力**解决了这个问题。模型不是将图像编码为一个向量,而是由 CNN 产生一个空间特征网格(例如来自 VGGNet 最后一个卷积层的 $14 \times 14 \times 512$)。在每个解码步骤,模型计算这些空间位置上的注意力权重,生成一个突出当前词语最相关区域的上下文向量。
|
||||
|
||||
- 回顾第 6 章的注意力机制:解码器隐藏状态充当查询,空间特征充当键和值,注意力权重告诉模型应该看哪里。作者提出了两种变体:**软注意力**(可微分,所有区域的加权平均)和**硬注意力**(对单个区域进行随机采样,使用 REINFORCE 训练)。
|
||||
|
||||

|
||||
|
||||
- 这些模型产生的注意力图具有显著的可解释性:生成"狗"时,注意力集中在狗的区域;生成"海滩"时,注意力转移到沙子和水面。这是注意力机制提供内置可解释性的最早令人信服的演示之一。
|
||||
|
||||
- **CIDEr**(Vedantam 等人,2015)、**METEOR**、**BLEU** 和 **SPICE** 是标准描述评估指标。CIDEr 计算生成描述与参考描述之间的 TF-IDF 加权 n-gram 相似度,专门为描述评估设计。现代 VLM 通常在 MS COCO Captions 和 NoCaps 等描述基准上用 CIDEr 进行评估。
|
||||
|
||||
- 后来的描述模型引入了**自底向上注意力**(Anderson 等人,2018),其中目标检测器(Faster R-CNN,第 8 章)首先提出显著的图像区域,然后描述模型在这些区域特征而非均匀网格上施加注意力。在基于 ViT 的编码器接管之前,这是主导方法。
|
||||
|
||||
## 架构模式
|
||||
|
||||
- 每个 VLM 都必须回答一个基本设计问题:视觉和语言在哪个节点交互?答案决定了模型的架构家族。有三种主要模式,各自具有不同的权衡。
|
||||
|
||||
### 双编码器
|
||||
|
||||
- 想象两位独立工作的译者——一位读法语文件,另一位读英语文件——他们各自用一种共享的"通用语言"生成摘要。他们在翻译过程中从不交流,但他们的摘要可以直接比较。这就是**双编码器**模式。
|
||||
|
||||
- 视觉编码器 $f_v$ 和文本编码器 $f_t$ 独立地将各自的输入映射到一个维度为 $d$ 的共享嵌入空间。图像嵌入为 $v = f_v(I) \in \mathbb{R}^d$,文本嵌入为 $t = f_t(q) \in \mathbb{R}^d$。相似度通过点积或余弦相似度计算:$\text{sim}(I, q) = v^T t / (\|v\| \|t\|)$。
|
||||
|
||||
- **CLIP**(Radford 等人,2021),在前一篇关于多模态表示的文件中已介绍,是典型的双编码器。它在从互联网抓取的 4 亿图像-文本对上使用对比目标函数(InfoNCE)进行训练。由于编码器相互独立,你可以预计算并缓存所有图像嵌入,使检索极其高效——搜索时只需对查询文本进行编码。
|
||||
|
||||
- 双编码器的缺点在于视觉和语言从未在特征层面进行交互。模型无法进行细粒度的跨模态推理:例如,它无法确定描述中的特定词是否对应图像中的特定区域。这限制了它在 VQA 或 grounded 描述等任务中的实用性。
|
||||
|
||||
### 融合编码器
|
||||
|
||||
- 现在想象两位译者共处一室,积极讨论两篇文件。他们可以指向特定段落、互相提问,并建立共同的理解。这就是**融合编码器**模式。
|
||||
|
||||
- 两种模态都被编码,然后通过**交叉注意力层**进行融合,其中一种模态的 token 关注另一种模态的 token。图像首先由视觉编码器处理为一系列 patch 或区域 token $V = [v_1, \ldots, v_N]$。文本被分词化为 $T = [t_1, \ldots, t_M]$。在融合层中,文本 token 通过交叉注意力关注图像 token:
|
||||
|
||||
$$\text{CrossAttn}(T, V) = \text{softmax}\!\left(\frac{(TW_Q)(VW_K)^T}{\sqrt{d_k}}\right)(VW_V)$$
|
||||
|
||||
- 这实现了细粒度的交互:每个文本 token 都可以关注其所需的特定图像区域。**VisualBERT**、**VilBERT** 和 **UNITER** 等模型使用这种模式。代价是你无法为检索预计算独立的嵌入——每个图像-文本对都需要通过融合层进行完整的前向传播。
|
||||
|
||||

|
||||
|
||||
### 编码器-解码器
|
||||
|
||||
- **编码器-解码器**模式将视觉编码器与自回归生成输出 token 的文本解码器相结合,类似于第 7 章中的 seq2seq 模型。视觉编码器产生上下文图像表征,文本解码器在生成输出文本时对其执行交叉注意力。
|
||||
|
||||
- 这种模式天然支持生成式任务:图像描述、自由形式答案的 VQA 以及视觉对话。**GIT**(Generative Image-to-text Transformer,Wang 等人,2022)、**CoCa**(Contrastive Captioner,Yu 等人,2022)和 **PaLI** 使用这种架构。CoCa 巧妙地将双编码器和编码器-解码器模式结合起来:文本解码器的前半部分作为单模态文本编码器(用于对比学习),而后半部分对图像特征执行交叉注意力(用于生成式描述),兼得两者之优势。
|
||||
|
||||
- 这三种模式的选择取决于目标任务。双编码器最适合大规模检索。融合编码器最适合细粒度理解任务。编码器-解码器对于生成任务最为通用。现代最先进的 VLM 越来越多地采用编码器-解码器或仅解码器范式,将每项视觉语言任务都视为文本生成。
|
||||
|
||||
## Flamingo:少样本多模态学习
|
||||
|
||||
- 想象一位经验丰富的专家,经过多年对艺术和文学的研究,只需要看一两个例子就能优雅地描述一种全新的绘画风格。**Flamingo**(Alonso 等人,2022,DeepMind)基于相同原理构建:它利用强大的预训练语言模型和预训练视觉编码器,通过轻量级架构组件将其连接,实现多模态任务上的少样本学习。
|
||||
|
||||
- Flamingo 的设计理念保守而有效:保持预训练的视觉编码器(NFNet)和语言模型(Chinchilla)冻结,仅学习连接它们的"胶水"。这种胶水由两个组件组成:**Perceiver 重采样器**和**门控交叉注意力层**。
|
||||
|
||||
- **Perceiver 重采样器**将视觉编码器的变长输出(取决于图像分辨率)压缩为一组固定数量的 $N$ 个视觉 token(通常 $N = 64$)。它的工作原理是初始化一组 $N$ 个可学习的查询向量,并使用交叉注意力让这些查询关注完整的视觉编码器输出。这本质上是 Perceiver 架构(Jaegle 等人,2021)作为瓶颈的应用——无论输入图像大小如何,它都能生成紧凑的、固定大小的视觉表示。
|
||||
|
||||
$$z = \text{CrossAttn}(Q_{\text{learned}}, V_{\text{image}}) \in \mathbb{R}^{N \times d}$$
|
||||
|
||||
- **门控交叉注意力层**交错插入在冻结的语言模型层之间。在每个这样的层中,语言模型的文本 token 对 Perceiver 重采样器产生的视觉 token 执行交叉注意力。关键之处在于,每个门控交叉注意力层包含一个可学习的标量门控 $\alpha$,初始化为零,将交叉注意力输出乘以 $\alpha$ 后再加到残差流中:
|
||||
|
||||
$$\hat{x} = x + \alpha \cdot \text{CrossAttn}(x, z)$$
|
||||
|
||||
- 初始化 $\alpha = 0$ 意味着训练开始时交叉注意力不贡献任何信息,模型行为与原始的冻结语言模型完全相同。门控在训练过程中逐渐打开,平滑地整合视觉信息,同时不破坏语言模型的预训练表示。
|
||||
|
||||

|
||||
|
||||
- Flamingo 原生支持**交错图像-文本序列**。你可以向它输入包含多张图像穿插文本的提示,例如:"[图像 1] 这是一只猫。[图像 2] 这是一只狗。[图像 3] 这是一个 ___。"模型将每张图像通过视觉编码器和 Perceiver 重采样器处理,得到的视觉 token 插入到文本序列中的对应位置。语言模型的因果注意力掩码确保每个文本 token 只能关注当前及之前图像的视觉 token。
|
||||
|
||||
- 这种交错机制实现了强大的**少样本多模态学习**。通过在上下文中提供少量图像-文本示例,Flamingo 可以在没有任何梯度更新的情况下执行新任务。在 VQAv2、OK-VQA 和描述等基准上,具有 800 亿参数的 Flamingo 实现了最先进的少样本性能,仅需 4 到 32 个示例即可匹配甚至超越经过微调的专家模型。
|
||||
|
||||
## LLaVA 与视觉指令微调
|
||||
|
||||
- 想象你有一位出色的语言专家(一个 LLM)和一位出色的艺术评论家(一个视觉编码器)。如果你能教会艺术评论家"说语言专家的语言",他们就可以无缝协作。**LLaVA**(Large Language and Vision Assistant,Liu 等人,2023)正是这样做的:它使用一个简单的线性层将视觉特征投影到 LLM 的 token 嵌入空间,然后在指令遵循数据上微调整个系统。
|
||||
|
||||
- LLaVA 的架构出奇地简单。图像由一个预训练的 CLIP ViT-L/14 视觉编码器编码为一个 patch 特征网格 $V \in \mathbb{R}^{N \times d_v}$,其中 $N = 256$ 个 patch(对于 336px 图像和 14px patch)。一个**投影层** $W$ 将这些视觉特征映射到 LLM 的嵌入维度:
|
||||
|
||||
$$H_v = VW, \quad W \in \mathbb{R}^{d_v \times d_{\text{LLM}}}$$
|
||||
|
||||
- 投影后的视觉 token $H_v$ 直接与文本 token 嵌入拼接,作为一个单一序列输入到 LLM(Vicuna,一个微调后的 LLaMA)。LLM 使用其标准因果自注意力处理它们——没有特殊的交叉注意力层,没有 perceiver,只有拼接。视觉 token 被当作恰好编码了视觉信息的文本 token 来处理。
|
||||
|
||||

|
||||
|
||||
- **视觉指令微调**是 LLaVA 的关键训练创新。作者使用 GPT-4 从 COCO 图像生成了 158,000 个多模态指令遵循示例。每个示例包含一张图像和一个对话式指令(例如"详细描述这张图像"、"这张图像有什么不寻常之处?"、"如果我是一名游客参观这个地方,我应该知道什么?")。模型接受训练,根据图像和指令生成 GPT-4 撰写的回答。
|
||||
|
||||
- 训练分为两个阶段。**阶段 1(预训练)**:仅训练投影层 $W$,使用图像-描述对(来自 CC3M 的 595K 数据),视觉编码器和 LLM 都保持冻结。这教会 $W$ 将视觉特征与 LLM 的嵌入空间对齐。**阶段 2(微调)**:投影层和 LLM 在指令遵循数据上联合微调,视觉编码器保持冻结。这教会模型遵循复杂的视觉指令。
|
||||
|
||||
- **LLaVA-1.5** 通过三项关键更改改进了原始版本:将单层线性投影替换为两层 MLP(更具表现力的映射),使用更高分辨率的图像(336px 而非 224px,产生更多 patch token),以及在训练混合数据中加入学术 VQA 数据集。这些看似细微的修改带来了基准性能的大幅提升。
|
||||
|
||||
- LLaVA 的方法证明,你不需要像 Flamingo 的 Perceiver 重采样器或门控交叉注意力那样复杂的架构创新。一个简单的线性投影,结合高质量的指令微调数据,就足以有效地将视觉编码器连接到 LLM。这种简洁性使得 LLaVA 极具影响力——后续大多数开源 VLM 都遵循类似的方案。
|
||||
|
||||
## 扩展视觉语言模型
|
||||
|
||||
- 该领域从概念验证型 VLM 迅速发展为在数十亿图像-文本对上训练的工业级系统。三个模型家族展示了不同的扩展方法。
|
||||
|
||||
### PaLI
|
||||
|
||||
- **PaLI**(Pathways Language and Image model,Chen 等人,2022,Google)同时扩展视觉编码器和语言模型。PaLI 使用 ViT-e(40 亿参数)作为视觉编码器,mT5(130 亿参数)作为语言模型,总计 170 亿参数。图像被编码为一系列 patch token,拼接在文本 token 之前,输入到编码器-解码器架构的 mT5。
|
||||
|
||||
- PaLI 的关键洞见是**扩展视觉编码器与扩展语言模型同样重要**。先前的工作通常使用固定的、中等规模的视觉骨干网络(如 ViT-B 或 ViT-L),将参数预算全部投入 LLM。PaLI 表明,一个 40 亿参数的 ViT-e,在 JFT-4B(40 亿张标注图像)上预训练后,能够显著提升 OCR 和空间推理等细粒度视觉任务的性能。
|
||||
|
||||
- PaLI 在 WebLI(一个包含 109 种语言、100 亿图像-文本对的数据集)上训练,因此天然具备多语言能力。模型通过混合任务进行预训练:图像描述、VQA 和图像-文本匹配,全部作为文本到文本生成任务(遵循第 7 章的 T5 范式)。**PaLI-X**(550 亿参数)和 **PaLI-3**(50 亿,使用 SigLIP 作为视觉编码器)是后续迭代版本。
|
||||
|
||||
### Qwen-VL
|
||||
|
||||
- **Qwen-VL**(Bai 等人,2023,阿里巴巴)在 Qwen LLM 基础上增加了一个 ViT 视觉编码器和一个单层交叉注意力模块(类似于 Flamingo 的 Perceiver 重采样器),将视觉编码器的输出压缩为一组固定的 256 个视觉 token。视觉 token 与文本 token 拼接后由 Qwen LLM 处理。
|
||||
|
||||
- Qwen-VL 的训练采用三阶段方案。阶段 1:在 14 亿个弱监督图像-文本对上预训练,仅解冻视觉编码器。阶段 2:在更高质量的数据上进行多任务预训练,包括 VQA、描述、定位和 OCR 数据集,整个模型解冻。阶段 3:在指令遵循和对话数据上进行监督微调。这种从噪声网络数据到精选指令数据的渐进式精炼,是大多数现代 VLM 共享的模式。
|
||||
|
||||
- **Qwen2-VL**(2024)引入了**动态分辨率**支持:模型不是将所有图像缩放到固定大小,而是通过动态调整视觉 token 数量以原始分辨率处理图像。更高分辨率的图像产生更多 token,更低分辨率的图像产生更少 token。这在不浪费低分辨率输入计算量的前提下,提升了文档理解和细粒度识别等对细节敏感的任务的性能。
|
||||
|
||||
### InternVL
|
||||
|
||||
- **InternVL**(Chen 等人,2024,上海人工智能实验室)激进地扩展了视觉编码器,使用 InternViT-6B——一个 60 亿参数的视觉 Transformer——与语言模型配对。关键的架构贡献是**动态高分辨率处理**:图像被分割为 448x448 像素的图块,每个图块由视觉编码器独立处理,得到的图块特征与完整图像的缩略图特征拼接。这使得模型能够处理任意宽高比和分辨率的图像。
|
||||
|
||||
- InternVL-2 进一步引入了**渐进对齐训练**:首先用对比目标(如 CLIP)对齐视觉编码器,然后通过轻量级 MLP 连接器将其连接到 LLM,最后在指令数据上进行端到端微调。这种渐进策略防止了视觉编码器预训练表示的灾难性遗忘。
|
||||
|
||||

|
||||
|
||||
- 所有三个模型家族的一个共同主题是**训练数据精选**的重要性。从网络抓取的原始图像-文本对是噪声大且常常不对齐的。后续的训练阶段逐步过滤和精炼数据,从数十亿噪声对过渡到数百万高质量指令示例。最终微调数据的质量往往比模型的原始参数数量更为重要。
|
||||
|
||||
## 定位与指代
|
||||
|
||||
- 想象你在人群中指着一个人说"戴红帽子的女士"。你在用语言指代一个特定的空间区域。**视觉定位**是相反的过程:给定一张图像和一个自然语言表述,模型必须识别(定位)所指的对象。**指代表达理解**产生边界框;**指代表达分割**产生像素掩码。
|
||||
|
||||
- 形式上,给定一张图像 $I$ 和一个指代表达 $r$(例如"左边那只大型棕色狗"),模型预测一个边界框 $b = (x, y, w, h)$ 或一组定位所引用对象的坐标。数据集包括 **RefCOCO**、**RefCOCO+** 和 **RefCOCOg**,每个数据集包含具有多个对象的图像以及每个对象的明确指代表达。
|
||||
|
||||
- 早期的定位模型使用两阶段方法:首先生成区域提议(使用 Faster R-CNN 或类似方法),然后使用融合模型对每个提议与语言查询进行评分。评分最高的区域即为预测结果。这种方法计算代价高昂,且受限于提议的质量。
|
||||
|
||||
- 现代 VLM 将定位直接整合到生成式框架中。关键思想是将边界框坐标表示为**文本 token**。你将连续的坐标空间离散化为槽位(例如 $x, y, w, h$ 各 1000 个槽位),并向词汇表中添加特殊的位置 token,如 `<loc_342>`。然后模型通过输出一系列位置 token 来生成边界框:
|
||||
|
||||
$$\text{输出: } \texttt{<loc\_102><loc\_215><loc\_487><loc\_398>}$$
|
||||
|
||||
- 这种 token 化技巧使得任何自回归语言模型无需架构更改即可执行定位——它只需学会"说坐标"。**Pix2Seq**(Chen 等人,2022)率先将这种方法用于目标检测,而 Qwen-VL、Ferret 和 Kosmos-2 等模型将其扩展到指代表达理解和短语定位。
|
||||
|
||||
- **Kosmos-2**(Peng 等人,2023,Microsoft)通过将空间位置表示为嵌入在生成文本中的特殊 token,为多模态 LLM 增加了定位能力。例如,它可以生成:"一只 `<phrase>` 金毛猎犬 `</phrase>` `<box>` `<loc_102>` `<loc_215>` `<loc_487>` `<loc_398>` `</box>` 正在接飞盘。"这种文本和空间 token 的交错融合实现了同步描述和定位。
|
||||
|
||||

|
||||
|
||||
- **定点指向**将定位更进一步:模型不再输出边界框,而是预测一个单一的点(通常是指代物体的中心)。这对于交互式应用非常有用,例如用户问"最近的出口在哪里?",模型返回一个叠加在图像上的坐标。**Shikra** 和 **Ferret** 等模型支持基于点的指代以及基于框的定位。
|
||||
|
||||
## 免 OCR 文档理解
|
||||
|
||||
- 传统的文档理解流水线很复杂:首先运行 OCR 引擎提取文本和布局,然后将提取的文本输入语言模型。这种多阶段方法很脆弱——OCR 错误向下游传播,空间布局信息常常丢失或表征不良。如果模型能像人类一样直接从像素中读取信息呢?
|
||||
|
||||
- **Donut**(Document Understanding Transformer,Kim 等人,2022)完全消除了 OCR。它使用 Swin Transformer(第 8 章)作为视觉编码器处理文档图像,并使用 BART 风格的 Transformer 解码器直接从视觉特征生成结构化文本输出。解码器可以根据任务生成 JSON、键值对或纯文本。
|
||||
|
||||
- Donut 的训练分为两个阶段。**预训练**:模型通过执行合成 OCR 来学习阅读——给定一张文档图像,生成完整的文本内容。这在从文本语料库渲染的数百万张合成文档图像上进行训练,教会视觉编码器识别字符、字体和布局。**微调**:模型通过训练生成特定于任务的结构化输出,适应特定的下游任务,如收据解析、表格理解或文档分类。
|
||||
|
||||
- Donut 解码器使用特殊的提示方案:任务由提示 token 指定(例如分类用 `<doc_class>`,收据解析用 `<parse_receipt>`),模型根据此提示生成输出。这种统一接口使得单个模型可以处理多种文档理解任务。
|
||||
|
||||
- **Pix2Struct**(Lee 等人,2023,Google)将免 OCR 思想应用于网页理解和图表/图形理解。关键的预训练目标是**截图解析**:给定一个网页的带掩码截图,模型生成产生可见区域的底层 HTML。这教会模型理解视觉呈现与结构化标记之间的关系。
|
||||
|
||||
- Pix2Struct 引入了**可变分辨率输入处理**:它并不是将所有图像缩放到固定大小(这会扭曲宽高比并破坏精细文字),而是在保持原始宽高比的同时将图像打包为固定数量的 patch。一个高而窄的文档产生一个高而窄的 patch 网格。这对于文档理解至关重要,因为宽高比携带着语义信息(收据窄而高;表格宽而短)。
|
||||
|
||||

|
||||
|
||||
- **Nougat**(Blecher 等人,2023,Meta)将 Donut 架构专门应用于学术论文,直接从 PDF 页面图像生成完整的 LaTeX 标记。它可以处理复杂的数学方程、表格和图形——这些任务正是传统 OCR 流水线难以应付的。该模型在 PDF 页面图像及其对应的 LaTeX 源代码对上进行训练。
|
||||
|
||||
- 免 OCR 模型的成功展示了深度学习中的一个更广泛原则:直接从原始输入(像素)学习的端到端模型通常优于复杂的多阶段流水线,因为它们可以联合优化所有组件,并学习专门针对最终任务定制的表示。中间的 OCR 步骤是一个瓶颈,限制了模型能够学习的内容。
|
||||
|
||||
## 视觉 Token 流水线
|
||||
|
||||
- 无论架构家族如何,每个 VLM 都必须将图像转换为语言模型可以处理的一系列 token。理解这一流水线至关重要。不同模型的处理过程有所差异,但总体流程如下:
|
||||
|
||||
- **第 1 步:Patch 提取。** 图像(高度 $H$,宽度 $W$)被划分为不重叠的、大小为 $P \times P$ 的 patch,产生 $N = HW / P^2$ 个 patch。对于 336x336 图像和 14x14 patch,$N = 576$。
|
||||
|
||||
- **第 2 步:视觉编码。** 每个 patch 经过线性投影并通过视觉编码器(通常是 ViT)。输出是一系列上下文 patch 嵌入 $V = [v_1, \ldots, v_N] \in \mathbb{R}^{N \times d_v}$。这些嵌入既携带局部外观信息,也携带全局上下文(来自自注意力)。
|
||||
|
||||
- **第 3 步:Token 压缩(可选)。** 一些模型将 $N$ 个视觉 token 压缩为更少的 $M \ll N$ 个 token,以减少语言模型的计算负担。Flamingo 使用 Perceiver 重采样器($M = 64$);Qwen-VL 使用交叉注意力($M = 256$);**Q-Former**(在 BLIP-2 中使用,Li 等人,2023)使用一组 $M = 32$ 个可学习查询 token,对视觉编码器的输出执行交叉注意力。
|
||||
|
||||
- **第 4 步:投影。** 视觉 token(全部或压缩后的集合)通过线性层或 MLP 投影到语言模型的嵌入空间。投影后,视觉 token 与文本 token 嵌入具有相同维度,可以与它们拼接。
|
||||
|
||||
- **第 5 步:注入 LLM。** 投影后的视觉 token 在特殊 `<image>` 占位符 token 的位置插入到 token 序列中,组合后的序列由语言模型处理。LLM 的自注意力使文本 token 能够关注视觉 token,反之亦然。
|
||||
|
||||

|
||||
|
||||
- 视觉 token 的数量直接影响计算成本。每个视觉 token 参与 LLM 的自注意力,其复杂度与序列长度的平方成正比。具有多个 patch 的高分辨率图像可能产生数百或数千个视觉 token,占据 LLM 上下文窗口的主导地位。这就是 token 压缩的重要性所在:将 576 个视觉 token 减少到 64 个,可将视觉部分在注意力中的贡献减少约 9 倍。
|
||||
|
||||
- **BLIP-2**(Li 等人,2023)以其高效的桥接策略而闻名。它引入了一个轻量级的 **Q-Former**(一个带有可学习查询的小型 Transformer),位于冻结的视觉编码器和冻结的 LLM 之间。Q-Former 是唯一可训练的组件——视觉编码器和 LLM 都保持冻结。它的预训练分为两个阶段:首先是图像-文本对比学习、匹配和描述目标(连接视觉编码器),然后是语言生成目标(连接 LLM)。这种模块化设计使得 BLIP-2 可以将任何视觉编码器插入到任何 LLM 中。
|
||||
|
||||
## 训练目标
|
||||
|
||||
- VLM 使用多种目标的组合进行训练,具体取决于架构模式:
|
||||
|
||||
- **图像-文本对比损失(ITC):** 在共享嵌入空间中对齐图像和文本表示,如 CLIP 中所示。这是双编码器的主要目标,也常被用作融合模型的预训练目标。该损失就是上一篇文件中的 InfoNCE 损失。
|
||||
|
||||
- **图像-文本匹配(ITM):** 一个二分类目标——给定图像和文本,预测它们是否匹配。困难负样本(与不同图像配对的相似文本)使这项任务具有挑战性,迫使模型学习细粒度的对齐。
|
||||
|
||||
- **语言建模(LM):** 标准的自回归语言建模目标——给定之前的所有 token 预测下一个 token。对于 VLM,"之前的 token" 包括视觉 token,因此模型学习在视觉输入条件下生成文本。这是编码器-解码器和仅解码器 VLM 的主要目标。
|
||||
|
||||
$$\mathcal{L}_{\text{LM}} = -\sum_{t=1}^{T} \log p(w_t \mid w_{<t}, V)$$
|
||||
|
||||
- **前缀语言建模:** 一种变体,其中图像和文本前缀作为上下文提供(不进行训练),模型仅训练生成后续部分。这用于 PaLI 和 SimVLM 等模型。
|
||||
|
||||
- 大多数现代 VLM 在预训练期间结合多个目标(例如 BLIP 中的 ITC + ITM + LM,CoCa 中的 ITC + LM),然后在指令数据上使用纯 LM 目标进行微调。
|
||||
|
||||
## 编程练习(使用 CoLab 或 notebook)
|
||||
|
||||
1. 实现一个简单的基于注意力的图像描述解码器。使用随机的"图像特征"作为编码器输出,训练解码器生成固定的描述,观察注意力权重在每个解码步骤如何跨空间位置移动。
|
||||
```python
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
# 模拟 4x4 空间网格的图像特征(16 个区域,dim=32)
|
||||
key = jax.random.PRNGKey(42)
|
||||
k1, k2, k3 = jax.random.split(key, 3)
|
||||
img_features = jax.random.normal(k1, (16, 32)) # 16 个空间区域,32 维
|
||||
|
||||
# 词汇表:0=<start>, 1="a", 2="red", 3="car", 4=<end>
|
||||
vocab_size, embed_dim, hidden_dim = 5, 16, 32
|
||||
W_embed = jax.random.normal(k2, (vocab_size, embed_dim)) * 0.1
|
||||
W_attn_q = jax.random.normal(k3, (hidden_dim, 32)) * 0.1 # 查询投影
|
||||
|
||||
def attend(h, img_feats, W_q):
|
||||
"""在给定解码器状态 h 的情况下计算图像特征上的软注意力。"""
|
||||
query = h @ W_q # (32,)
|
||||
scores = img_feats @ query # (16,)
|
||||
weights = jax.nn.softmax(scores) # (16,)
|
||||
context = weights @ img_feats # (32,)
|
||||
return context, weights
|
||||
|
||||
# 简单的 GRU 风格步骤(为说明目的,仅用线性 + tanh)
|
||||
W_h = jax.random.normal(jax.random.PRNGKey(0), (embed_dim + 32, hidden_dim)) * 0.1
|
||||
|
||||
def decode_step(h, word_idx, img_feats):
|
||||
context, attn_weights = attend(h, img_feats, W_attn_q)
|
||||
word_emb = W_embed[word_idx] # (16,)
|
||||
inp = jnp.concatenate([word_emb, context]) # (48,)
|
||||
h_new = jnp.tanh(inp @ W_h) # (32,)
|
||||
return h_new, attn_weights
|
||||
|
||||
# 运行解码序列:<start> -> "a" -> "red" -> "car" -> <end>
|
||||
target_seq = [0, 1, 2, 3, 4]
|
||||
h = jnp.zeros(hidden_dim)
|
||||
all_attn = []
|
||||
for word_idx in target_seq[:-1]:
|
||||
h, attn_w = decode_step(h, word_idx, img_features)
|
||||
all_attn.append(attn_w)
|
||||
|
||||
# 可视化每一步的注意力图(重塑为 4x4 网格)
|
||||
words = ["<start>", "a", "red", "car"]
|
||||
fig, axes = plt.subplots(1, 4, figsize=(14, 3))
|
||||
for i, (ax, w) in enumerate(zip(axes, words)):
|
||||
ax.imshow(all_attn[i].reshape(4, 4), cmap='viridis')
|
||||
ax.set_title(f'生成"{w}"后\n关注的区域')
|
||||
ax.axis('off')
|
||||
plt.suptitle('每个解码步骤的图像区域注意力')
|
||||
plt.tight_layout(); plt.show()
|
||||
# 尝试修改 img_features,观察注意力模式如何变化!
|
||||
```
|
||||
|
||||
2. 模拟视觉 token 流水线:将图像划分为 patch,将 patch 投影到嵌入空间,与文本 token 嵌入拼接,并在组合序列上运行单层自注意力。
|
||||
```python
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
key = jax.random.PRNGKey(7)
|
||||
|
||||
# 创建一个合成的 8x8 "图像",3 个通道
|
||||
k1, k2, k3, k4 = jax.random.split(key, 4)
|
||||
image = jax.random.uniform(k1, (8, 8, 3))
|
||||
|
||||
# 第 1 步:划分为 4x4 patch -> 4 个 patch
|
||||
patch_size = 4
|
||||
patches = image.reshape(2, patch_size, 2, patch_size, 3)
|
||||
patches = patches.transpose(0, 2, 1, 3, 4).reshape(4, patch_size * patch_size * 3) # (4, 48)
|
||||
print(f"Patch 数量: {patches.shape[0]}, Patch 维度: {patches.shape[1]}")
|
||||
|
||||
# 第 2 步:将 patch 投影到嵌入维度 (d=16)
|
||||
d_model = 16
|
||||
W_patch = jax.random.normal(k2, (patches.shape[1], d_model)) * 0.1
|
||||
visual_tokens = patches @ W_patch # (4, 16)
|
||||
|
||||
# 第 3 步:创建文本 token 嵌入(模拟 3 个文本 token)
|
||||
text_tokens = jax.random.normal(k3, (3, d_model)) * 0.1
|
||||
|
||||
# 第 4 步:拼接视觉 + 文本 token
|
||||
combined = jnp.concatenate([visual_tokens, text_tokens], axis=0) # (7, 16)
|
||||
print(f"组合序列长度: {combined.shape[0]} (4 个视觉 + 3 个文本)")
|
||||
|
||||
# 第 5 步:在组合序列上运行单头自注意力
|
||||
W_Q = jax.random.normal(k4, (d_model, d_model)) * 0.1
|
||||
k5, k6 = jax.random.split(k4)
|
||||
W_K = jax.random.normal(k5, (d_model, d_model)) * 0.1
|
||||
W_V = jax.random.normal(k6, (d_model, d_model)) * 0.1
|
||||
|
||||
Q = combined @ W_Q
|
||||
K = combined @ W_K
|
||||
V = combined @ W_V
|
||||
attn_scores = (Q @ K.T) / jnp.sqrt(d_model)
|
||||
attn_weights = jax.nn.softmax(attn_scores, axis=-1) # (7, 7)
|
||||
|
||||
output = attn_weights @ V # (7, 16)
|
||||
|
||||
# 可视化跨模态注意力模式
|
||||
labels = ['V1', 'V2', 'V3', 'V4', 'T1', 'T2', 'T3']
|
||||
fig, ax = plt.subplots(figsize=(6, 5))
|
||||
im = ax.imshow(attn_weights, cmap='Blues')
|
||||
ax.set_xticks(range(7)); ax.set_xticklabels(labels)
|
||||
ax.set_yticks(range(7)); ax.set_yticklabels(labels)
|
||||
ax.set_xlabel('键'); ax.set_ylabel('查询')
|
||||
ax.set_title('自注意力:视觉(V)和文本(T)Token')
|
||||
plt.colorbar(im, ax=ax); plt.tight_layout(); plt.show()
|
||||
# 观察:文本 token 关注视觉 token(跨模态注意力)!
|
||||
```
|
||||
|
||||
3. 实现用于视觉定位的坐标 token 化。给定一个边界框,将其转换为离散 token;给定离散 token,重构边界框。在不同槽位分辨率下可视化量化误差。
|
||||
```python
|
||||
import jax.numpy as jnp
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
def encode_bbox(bbox, num_bins=1000):
|
||||
"""将连续的边界框 (x, y, w, h)(在 [0,1] 范围内)转换为离散 token。"""
|
||||
tokens = jnp.round(jnp.array(bbox) * (num_bins - 1)).astype(jnp.int32)
|
||||
return tokens
|
||||
|
||||
def decode_bbox(tokens, num_bins=1000):
|
||||
"""将离散 token 转换回连续的边界框。"""
|
||||
return tokens.astype(jnp.float32) / (num_bins - 1)
|
||||
|
||||
# 真实边界框(归一化到 [0, 1])
|
||||
gt_bbox = jnp.array([0.123, 0.456, 0.333, 0.222])
|
||||
|
||||
# 测试不同槽位分辨率下的量化
|
||||
bin_sizes = [10, 50, 100, 500, 1000]
|
||||
errors = []
|
||||
for n_bins in bin_sizes:
|
||||
tokens = encode_bbox(gt_bbox, n_bins)
|
||||
reconstructed = decode_bbox(tokens, n_bins)
|
||||
error = jnp.max(jnp.abs(gt_bbox - reconstructed))
|
||||
errors.append(float(error))
|
||||
print(f"槽位数={n_bins:>5d} | Token={tokens} | "
|
||||
f"重构={reconstructed} | 最大误差={error:.6f}")
|
||||
|
||||
fig, ax = plt.subplots(figsize=(8, 4))
|
||||
ax.plot(bin_sizes, errors, 'o-', color='#e74c3c', linewidth=2, markersize=8)
|
||||
ax.set_xlabel('槽位数'); ax.set_ylabel('最大量化误差')
|
||||
ax.set_title('边界框量化误差 vs 槽位分辨率')
|
||||
ax.set_xscale('log'); ax.set_yscale('log')
|
||||
ax.grid(True, alpha=0.3); plt.tight_layout(); plt.show()
|
||||
# 尝试:槽位非常少时(如 5)会发生什么?误差在何时是可接受的?
|
||||
```
|
||||
@@ -0,0 +1,419 @@
|
||||
# 图像与视频词元化
|
||||
|
||||
*图像与视频词元化将连续的视觉数据转换为离散的词元序列,使 Transformer 能够像处理文本一样处理它们。本节涵盖 VQ-VAE、VQ-GAN、码本学习、DALL-E 的 dVAE、视频词元化以及免查询词元化。*
|
||||
|
||||
## 为什么要对图像进行词元化
|
||||
|
||||
- 把语言想象成一个有限的字母表:英语大约有 26 个字母,现代语言模型将文本切分为 30,000 到 100,000 个子词词元。每个句子都变成一串离散符号,Transformer 可以逐个预测。而图像存在于连续的高维空间中:一张 256×256 的 RGB 图像就是 $\mathbb{R}^{256 \times 256 \times 3} \approx \mathbb{R}^{196{,}608}$ 中的一个点。如果你希望语言模型用与说英语同样的机制来"说"图像,就需要将这些连续的像素数组转换为一串可管理的离散词元,这些词元来自一个有限的词汇表。这种转换就是**图像词元化**。
|
||||
|
||||
- 想象你是一位马赛克艺术家。你没有无限多种瓷砖色调,只有一个固定的调色板,比如说 8192 种不同的瓷砖颜色。要再现一张照片作为马赛克,你必须 (1) 确定每个瓷砖代表照片的哪个区域,(2) 为每个区域选择最接近的瓷砖颜色,(3) 接受一些细节的丢失,但整体画面仍然可辨认。图像词元化做的正是这件事:编码器将空间块压缩为潜在向量,码本将每个向量映射到其最近的条目,结果是一个整数索引网格(每个块对应一个索引),离散模型可以处理它。
|
||||
|
||||
- 词元化的好处有三方面。首先,它大幅压缩了图像:一张 256×256 的图像可能变成一个 16×16 的词元网格,序列长度从 65,536 个像素减少到 256 个词元,这对于成本随序列长度呈二次方增长的注意力模型来说是可行的。其次,它统一了表示形式:文本词元和图像词元位于同一个离散词汇表中,使得单个自回归 Transformer 可以生成交织的文本和图像。第三,它施加了一个有用的瓶颈,迫使模型学习语义上有意义的编码,而不是记忆像素噪声。
|
||||
|
||||

|
||||
|
||||
- 回顾第 8 章中卷积网络如何从图像中提取层次化特征图,以及第 7 章中文本词元化器如何将字符串转换为整数序列。图像词元化正处于两者的交汇点:它使用 CNN 或视觉 Transformer 编码器(第 8 章)产生空间特征,然后借用离散词汇表的思想(第 7 章)将这些特征转换为词元索引。
|
||||
|
||||
## VQ-VAE:向量量化
|
||||
|
||||
- 正如我们在第 6 章中看到的,标准**变分自编码器**(VAE)将输入编码为连续潜在分布,并从该分布中采样再解码为重建结果。潜在空间是连续的,这使得将其输入离散序列模型变得困难。**向量量化变分自编码器**(VQ-VAE),由 van den Oord 等人(2017)提出,通过引入一个可学习的嵌入向量码本,并将每个编码器输出映射到其最近的码本条目,用离散潜在表示取代了连续潜在表示。
|
||||
|
||||
- 想象一个藏书室,里面有恰好 $K$ 个贴有标签的书架。当一本新书(编码器输出)到达时,图书管理员将它放在与其现有书籍(码本向量)最相似的书架上,并记录下书架编号。之后,要取回这本书,你只需要书架编号:那个书架上的码本条目就是一个足够好的替代。这就是向量量化。
|
||||
|
||||
- 形式上,VQ-VAE 有三个组件:
|
||||
|
||||
- **编码器** $E$,将输入图像 $\mathbf{x} \in \mathbb{R}^{H \times W \times 3}$ 映射到连续潜在向量的空间网格 $\mathbf{z}_e = E(\mathbf{x}) \in \mathbb{R}^{h \times w \times d}$,其中 $h \times w$ 是降采样后的空间分辨率,$d$ 是嵌入维度。
|
||||
|
||||
- **码本** $\mathcal{C} = \{\mathbf{e}_1, \mathbf{e}_2, \ldots, \mathbf{e}_K\} \subset \mathbb{R}^d$,包含 $K$ 个可学习的嵌入向量。典型码本大小范围为 512 到 16,384 个条目。
|
||||
|
||||
- **解码器** $D$,从量化后的潜在表示重建图像。
|
||||
|
||||
- **量化步骤**将每个编码器输出 $\mathbf{z}_e(\mathbf{x})$ 在空间位置 $(i, j)$ 处替换为最近的码本条目:
|
||||
|
||||
$$\mathbf{z}_q(i,j) = \mathbf{e}_{k^\ast} \quad \text{其中} \quad k^\ast = \arg\min_k \|\mathbf{z}_e(i,j) - \mathbf{e}_k\|_2$$
|
||||
|
||||
- 这是在嵌入空间中的最近邻查找,与 k-means 分配(第 6 章)完全相同。索引 $k^\ast$ 是空间位置 $(i,j)$ 的离散词元,整张图像被表示为一个 $h \times w$ 的整数网格,取值范围为 $\{1, \ldots, K\}$。
|
||||
|
||||

|
||||
|
||||
- 挑战在于 $\arg\min$ 是不可微的:你无法通过离散选择进行反向传播。VQ-VAE 通过**直通估计器**解决了这个问题:在前向传播过程中,解码器接收 $\mathbf{z}_q$(量化后的向量);在反向传播过程中,重建损失相对于 $\mathbf{z}_q$ 的梯度被直接复制到 $\mathbf{z}_e$,就好像量化步骤是恒等函数一样。这可以简洁地写为:
|
||||
|
||||
$$\mathbf{z}_q = \mathbf{z}_e + \text{sg}(\mathbf{z}_q - \mathbf{z}_e)$$
|
||||
|
||||
- 其中 $\text{sg}(\cdot)$ 是停止梯度算子。在前向传播中,计算结果为 $\mathbf{z}_q$;在反向传播中,梯度仅流经 $\mathbf{z}_e$ 项。
|
||||
|
||||
- 完整的 VQ-VAE 损失包含三项:
|
||||
|
||||
$$\mathcal{L} = \underbrace{\|\mathbf{x} - D(\mathbf{z}_q)\|_2^2}_{\text{重建损失}} + \underbrace{\|\text{sg}(\mathbf{z}_e) - \mathbf{e}\|_2^2}_{\text{码本(VQ)损失}} + \underbrace{\beta \|\mathbf{z}_e - \text{sg}(\mathbf{e})\|_2^2}_{\text{承诺损失}}$$
|
||||
|
||||
- **重建损失**训练编码器和解码器忠实地再现输入。**码本损失**(也称为 VQ 损失)将码本向量拉向编码器输出;注意 $\text{sg}(\mathbf{z}_e)$ 意味着编码器不会从这一项接收梯度,因此它只更新码本。**承诺损失**则相反:它鼓励编码器输出保持接近码本向量,防止编码器"远离"码本。超参数 $\beta$(通常为 0.25)控制码本损失和承诺损失之间的平衡。
|
||||
|
||||
- 在实践中,码本通常使用**指数移动平均**(EMA)而不是梯度下降来更新,这样更稳定。令 $\mathbf{n}_k$ 为分配给码本条目 $k$ 的编码器输出计数,$\mathbf{s}_k$ 为它们的和。EMA 更新为:
|
||||
|
||||
$$\mathbf{n}_k \leftarrow \gamma \mathbf{n}_k + (1 - \gamma) |\{(i,j) : k^\ast_{ij} = k\}|$$
|
||||
|
||||
$$\mathbf{s}_k \leftarrow \gamma \mathbf{s}_k + (1 - \gamma) \sum_{(i,j) : k^\ast_{ij} = k} \mathbf{z}_e(i,j)$$
|
||||
|
||||
$$\mathbf{e}_k \leftarrow \frac{\mathbf{s}_k}{\mathbf{n}_k}$$
|
||||
|
||||
- 其中 $\gamma$ 是衰减率(通常为 0.99)。这等价于对编码器输出运行在线 k-means 算法。
|
||||
|
||||
### 码本坍塌
|
||||
|
||||
- VQ-VAE 一个臭名昭著的失败模式是**码本坍塌**(也称为索引坍塌):模型只学会使用 $K$ 个码本条目中的一小部分,导致大多数条目"死亡"。想象一个图书馆,90% 的书架是空的,因为图书管理员总是把书送到同样的几个热门书架上。这浪费了表示能力。
|
||||
|
||||
- 码本坍塌的发生是因为编码器、码本和解码器在训练过程中共同适应。如果一个条目在几个批次中都没有被选中,它就会漂离编码器流形,使其更不可能被选中,从而形成正反馈循环。
|
||||
|
||||
- 缓解码本坍塌的几种技术:
|
||||
- **码本重置**:定期通过随机采样编码器输出重新初始化死亡条目。这为死亡条目在潜在空间活跃区域附近提供了一个新的起点。
|
||||
- **带拉普拉斯平滑的 EMA 更新**:向 $\mathbf{n}_k$ 添加一个小常数,防止任何条目计数为零,确保所有条目都能接收到梯度信号。
|
||||
- **承诺损失调优**:增大 $\beta$ 迫使编码器输出更紧密地聚集在码本条目周围,使分配更均匀。
|
||||
- **分解编码**:将码本查找分解为多个较小查找的乘积(例如,两个大小各为 $\sqrt{K}$ 的码本),通过减少每次查找的有效码本大小来提高利用率。
|
||||
- **熵正则化**:添加一个惩罚项,鼓励码本使用上的均匀分布,最大化熵 $H = -\sum_k p_k \log p_k$,其中 $p_k$ 是经验分配概率。
|
||||
|
||||

|
||||
|
||||
## VQ-GAN:对抗训练实现更高保真度
|
||||
|
||||
- VQ-VAE 能产生不错的重建效果,但像素级的 $\ell_2$ 损失往往会产生模糊的输出,因为它对每个像素偏差都同等惩罚,在合理的细节上取平均而不是选择清晰的细节。想象一下,要求某人画一张脸,使得与所有可能的脸的平均差异最小——他们会画出一张模糊的平均脸,而不是一张清晰的特定人脸。
|
||||
|
||||
- **VQ-GAN**(Esser 等人,2021)通过将 VQ-VAE 框架与生成对抗网络(第 6 章)中的**判别器**相结合来解决这个问题。判别器是一个基于块的卷积网络,用于判断局部图像块是真(来自训练数据)还是假(来自解码器)。这种对抗损失鼓励解码器产生感知上清晰、逼真的纹理,而不是像素级的平均值。
|
||||
|
||||
- VQ-GAN 目标函数在 VQ-VAE 损失的基础上增加了两项:
|
||||
|
||||
$$\mathcal{L}_\text{VQ-GAN} = \mathcal{L}_\text{VQ-VAE} + \lambda_\text{adv} \mathcal{L}_\text{adv} + \lambda_\text{perc} \mathcal{L}_\text{perc}$$
|
||||
|
||||
- **对抗损失** $\mathcal{L}_\text{adv}$ 是应用于解码器输出的标准 GAN 目标。判别器 $\mathcal{D}$ 试图区分真实块和解码块,而解码器(生成器)试图欺骗它。非饱和形式为:
|
||||
|
||||
$$\mathcal{L}_\text{adv} = -\mathbb{E}[\log \mathcal{D}(D(\mathbf{z}_q))]$$
|
||||
|
||||
- **感知损失** $\mathcal{L}_\text{perc}$ 比较原始图像和重建图像在预训练网络(通常是 VGG 或 LPIPS)中的特征激活:
|
||||
|
||||
$$\mathcal{L}_\text{perc} = \sum_l \|\phi_l(\mathbf{x}) - \phi_l(D(\mathbf{z}_q))\|_2^2$$
|
||||
|
||||
- 其中 $\phi_l$ 表示预训练网络在第 $l$ 层的特征图。这个损失捕捉的是高层结构相似性,而非像素级精度。
|
||||
|
||||
- 权重 $\lambda_\text{adv}$ 被自适应地设置,使得对抗梯度和重建梯度保持平衡,防止在训练早期重建效果还很差时对抗损失占主导。
|
||||
|
||||

|
||||
|
||||
- 结果是,在相同码本大小下,VQ-GAN 产生的词元化器重建效果远比 VQ-VAE 清晰。VQ-GAN 是许多主要图像生成系统(包括最初的 DALL-E、Parti 以及众多文生图模型)背后的骨干词元化器。它将 256×256 的图像转换为 16×16 或 32×32 的离散词元网格,来源于大小为 1024–16384 的码本,在每个空间维度上实现 16 倍到 64 倍的压缩比。
|
||||
|
||||
## 残差量化与多尺度码本
|
||||
|
||||
- 单个码本对重建质量施加了一个硬上限:每个空间位置恰好由一个码本向量表示,任何比码本所能表达的更精细的细节都会丢失。想象用固定调色板中的一个词来描述一种颜色:"青色"很接近但不精确。如果你能添加一个细化描述——"青色,但稍微偏蓝一点,亮一点"——你就能得到更接近的结果。
|
||||
|
||||
- **残差量化**(RQ)迭代地应用了这一思想。在第一次量化步骤产生 $\mathbf{z}_q^{(1)}$ 之后,计算残差 $\mathbf{r}^{(1)} = \mathbf{z}_e - \mathbf{z}_q^{(1)}$,然后对残差使用第二个码本进行量化得到 $\mathbf{z}_q^{(2)}$,以此类推,共 $T$ 个层级:
|
||||
|
||||
$$\mathbf{r}^{(0)} = \mathbf{z}_e$$
|
||||
|
||||
$$\mathbf{z}_q^{(t)} = \text{Quantise}(\mathbf{r}^{(t-1)}, \mathcal{C}^{(t)})$$
|
||||
|
||||
$$\mathbf{r}^{(t)} = \mathbf{r}^{(t-1)} - \mathbf{z}_q^{(t)}$$
|
||||
|
||||
- 最终的量化表示为 $\hat{\mathbf{z}} = \sum_{t=1}^{T} \mathbf{z}_q^{(t)}$。使用 $T$ 个层级,每个层级码本大小为 $K$,有效词汇表大小为 $K^T$,但你只需要存储 $T \times K$ 个向量,而不是 $K^T$ 个。例如,8 个层级,$K = 1024$,有效条目数为 $1024^8 \approx 10^{24}$,而只存储了 8192 个向量。
|
||||
|
||||
- 每个后续层级捕捉更精细的细节:第一个码本捕捉粗略结构,第二个捕捉中频修正,依此类推。这类似于 JPEG 中的逐次逼近或网页图像中的渐进式渲染,先出现粗略版本,然后逐步填充细节。
|
||||
|
||||

|
||||
|
||||
- **多尺度码本**通过在不同空间分辨率上操作来扩展这一思想。不是重复量化同一个空间网格,而是在多个尺度上进行量化:粗粒度网格捕捉全局结构,细粒度网格捕捉局部细节。这与第 8 章目标检测部分中的特征金字塔思想相关,其中不同尺度的特征捕捉不同层次的细节。
|
||||
|
||||
- **乘积量化**是一种相关技术,将 $d$ 维潜在向量拆分为 $M$ 个维度为 $d/M$ 的子向量,每个子向量使用自己的码本独立量化。这使得有效词汇表达到 $K^M$,同时只存储 $M \times K$ 个向量。乘积量化广泛应用于近似最近邻搜索(第 13 章),并已被适配用于图像词元化。
|
||||
|
||||
- **有限标量量化**(FSQ),由 Mentzer 等人(2023)提出,采取了一种完全不同的方法:不是学习一个码本,而是简单地将潜在向量的每个维度四舍五入到一组固定整数级别中的一个(例如 $\{-2, -1, 0, 1, 2\}$)。每维 $L$ 个级别,$d$ 个维度,隐含码本大小为 $L^d$。FSQ 完全避免了码本坍塌,因为没有可学习的码本向量,只有被确定性四舍五入的可学习编码器输出。直通估计器处理了四舍五入的不可微性。
|
||||
|
||||
## 实践中的图像词元化器
|
||||
|
||||
- 从 VQ-VAE 到 VQ-GAN 再到残差量化的演进,催生了一系列实际图像词元化器,用于最先进的生成模型。
|
||||
|
||||
### DALL-E 词元化器(dVAE)
|
||||
|
||||
- 最初的 **DALL-E**(Ramesh 等人,2021)使用离散 VAE(dVAE)将 256×256 图像词元化为 32×32 的词元网格,码本大小为 8192。dVAE 将硬 $\arg\min$ 量化替换为 Gumbel-Softmax 松弛,使前向传播在训练过程中可微。在推理时,使用 $\arg\max$ 生成硬词元分配。dVAE 使用重建损失、与均匀先验的 KL 散度以及 Gumbel-Softmax 的学习温度调度组合进行训练。然后 DALL-E 训练了一个 120 亿参数的自回归 Transformer 来建模 256 个文本词元和 1024 个图像词元(32×32)的联合分布。
|
||||
|
||||
### LlamaGen
|
||||
|
||||
- **LlamaGen**(Sun 等人,2024)表明,只要你有一个好的图像词元化器,就可以将标准的 Llama 风格语言模型架构(第 7 章)重新用于自回归图像生成。LlamaGen 使用改进的 VQ-GAN 词元化器,具有大型码本(16,384 个条目),并训练了一个普通的自回归 Transformer(除了词元化器外没有特殊的图像特定修改)以光栅扫描顺序从左到右预测图像词元。关键的见解是,一旦图像被词元化为离散序列,适用于语言的相同下一个词元预测范式也同样适用于图像,这验证了词元化确实弥合了模态鸿沟的观点。
|
||||
|
||||
### Cosmos 词元化器
|
||||
|
||||
- **Cosmos 词元化器**(NVIDIA,2024)设计用于在统一框架中处理图像和视频。它使用因果 3D 架构,将图像视为单帧视频,使得同一个词元化器可以处理两种模态。Cosmos 支持连续和离散两种词元化模式:连续模式输出实值潜在向量(用于扩散模型后端),而离散模式应用有限标量量化产生整数词元(用于自回归模型后端)。编码器使用因果 3D 卷积,使得每帧的词元仅依赖于当前帧和之前的帧,从而支持流式视频词元化。
|
||||
|
||||

|
||||
|
||||
## 视频词元化
|
||||
|
||||
- 视频在图像的二维空间维度上增加了第三个轴——时间。视频是一系列帧,通常为每秒 24–30 帧,相邻帧之间高度冗余,因为在 33 毫秒内视觉世界不会发生剧烈变化。视频词元化利用这种时间冗余来实现比独立词元化每帧高得多的压缩率。
|
||||
|
||||
- 把视频压缩想象成一幅翻页书。如果每一页都从头画起,你需要数千张精细的绘图。但大多数页面与相邻页面几乎相同,所以你可以每 10 页画一个完整的"关键帧",只记录中间页面上的微小变化。视频词元化器自动学会了这个技巧。
|
||||
|
||||
### 3D VQ-VAE
|
||||
|
||||
- 将 VQ-VAE 扩展到视频的最直接方式是 **3D VQ-VAE**,它将编码器和解码器中的 2D 卷积替换为同时在空间和时间维度上操作的 3D 卷积。如果编码器在空间上降采样 $f_s$ 倍,在时间上降采样 $f_t$ 倍,则 $T \times H \times W$ 的视频片段变为 $(T/f_t) \times (H/f_s) \times (W/f_s)$ 的词元网格。
|
||||
|
||||
- 例如,$f_s = 16$ 且 $f_t = 4$ 时,一个 16 帧的 256×256 视频片段变为 $4 \times 16 \times 16 = 1024$ 的词元序列。这对 Transformer 进行自回归建模来说已经足够紧凑,而原始像素数将是 $16 \times 256 \times 256 \times 3 \approx 310$ 万个数值。
|
||||
|
||||
- 3D 卷积联合学习空间和时间特征。早期层捕捉局部运动(帧间移动的边缘),而更深层捕捉高层动态(物体的出现、消失或形状变化)。这与第 8 章卷积网络中的层次化特征提取原理相同,只是沿时间轴进行了扩展。
|
||||
|
||||

|
||||
|
||||
### 因果视频词元化器
|
||||
|
||||
- 标准 3D 卷积会同时查看过去、当前和未来的帧,这意味着在词元化任何帧之前需要整个视频片段。**因果视频词元化器**约束时间卷积,使每个输出仅依赖于当前帧和之前的帧,从不依赖于未来的帧。这类似于自回归 Transformer(第 7 章)中的因果掩码:信息在时间上向前流动,但绝不向后。
|
||||
|
||||
- 因果词元化对于两种使用场景至关重要。首先,**流式处理**:你可以在帧到达时实时词元化视频,而无需缓冲未来的帧。其次,**自回归生成**:当 Transformer 逐帧生成视频时,第 $t$ 帧的词元必须在不知道第 $t+1$ 帧的情况下可计算,因为第 $t+1$ 帧尚未生成。
|
||||
|
||||
- 因果约束通过非对称填充时间卷积来实现:时间大小为 $k$ 的核在过去一侧填充 $k-1$ 个零,未来一侧填充零个零,确保时间 $t$ 的输出仅依赖于时间 $t-k+1, \ldots, t$ 的输入。
|
||||
|
||||
- 因果视频词元化器的一个优雅特性是它们可以词元化单张图像("视频"只有一帧)而无需特殊处理。第一帧没有历史上下文,因此其词元仅从该帧本身计算。这种**图像-视频统一**意味着单个词元化器可以服务于两种模态,简化了架构,并使模型能够使用同一个解码器生成图像和视频。
|
||||
|
||||
### 时间压缩策略
|
||||
|
||||
- 不同的应用需要不同的时间压缩比。对于动作识别(其中细微运动很重要),温和压缩($f_t = 2$)可以保留时间细节。对于长视频生成(存储数千帧是不可行的),需要激进压缩($f_t = 8$ 或更高)。
|
||||
|
||||
- 某些词元化器使用**分解压缩**:空间和时间压缩在不同的阶段进行。首先,2D 编码器独立压缩每帧,产生每帧的潜在网格。然后,1D 时间编码器跨时间维度进行压缩。这种分解在计算上比完整的 3D 卷积更便宜,并允许空间和时间采用不同的压缩比。其代价是它不能像联合 3D 编码那样高效地捕捉时空模式(如对角线运动的球)。
|
||||
|
||||
- **时间插值词元**是一项最近的创新,词元化器仅完整编码关键帧,并将中间帧表示为轻量级的插值编码,描述如何在关键帧之间变形。这类似于经典视频压缩(H.264/HEVC 中的 I 帧和 P 帧),但在学习到的潜在空间中进行。
|
||||
|
||||

|
||||
|
||||
## 连续词元与离散词元
|
||||
|
||||
- 并非每个下游模型都需要离散词元。**扩散模型**(第 10 章,文件 04)原生使用连续值——它们迭代地去噪高斯样本,其损失函数(去噪得分匹配)定义在连续空间上。对于扩散后端,词元化器编码器产生连续潜在向量,从不进行量化。**潜在扩散模型**(Stable Diffusion、DALL-E 3、Flux)使用类似 VQ-GAN 的编码器-解码器,但完全跳过了码本,在连续潜在空间中操作。
|
||||
|
||||
- 而**自回归模型**(GPT 风格)则使用 $K$ 类上的 softmax 从有限词汇表中预测下一个词元。它们从根本上需要离散词元。每个使用自回归 Transformer 的图像生成系统(DALL-E、Parti、LlamaGen、Chameleon)都依赖离散词元化器。
|
||||
|
||||
- 因此,连续词元和离散词元之间的选择由生成后端决定:
|
||||
|
||||
- 在以下情况下使用**离散词元**:模型是自回归的(使用交叉熵损失的下一个词元预测),你想与文本词元共享词汇表以实现统一的多模态模型,或者你需要精确的词元级控制(例如,通过词元替换进行检索或编辑)。
|
||||
|
||||
- 在以下情况下使用**连续词元**:模型是扩散模型或流匹配模型,任务需要非常高的保真度重建(连续潜在表示完全避免了量化误差),或者你想使用作用于实值向量的回归损失。
|
||||
|
||||
- 一些最近的架构支持两种模式。例如,Cosmos 词元化器可以从同一个编码器输出连续潜在表示(用于其扩散模式)或 FSQ 离散化词元(用于其自回归模式),只需一个可以打开或关闭的轻量级量化头。
|
||||
|
||||
- **软量化**是一个中间地带:不是硬 $\arg\min$ 分配,而是计算 top-$k$ 最近码本条目的加权平均,权重由负距离上的 softmax 给出。这比硬量化保留了更多信息,同时仍然近似离散。有些系统在训练时使用软量化,在推理时使用硬量化。
|
||||
|
||||

|
||||
|
||||
## 应用
|
||||
|
||||
### 自回归图像生成
|
||||
|
||||
- 一旦图像变成离散词元序列,你就可以训练标准的自回归 Transformer 来建模它们。图像词元被展平为一维序列(通常按光栅扫描顺序:从左到右、从上到下),Transformer 学习 $p(\text{词元}_i \mid \text{词元}_1, \ldots, \text{词元}_{i-1})$,使用标准交叉熵损失。在生成时,词元被逐个采样,完整的网格通过词元化器的解码器转换为像素。
|
||||
|
||||
- 文本条件化很简单:在图像词元序列前添加文本词元,使模型学习 $p(\text{图像词元} \mid \text{文本词元})$。这正是 DALL-E、Parti 和 LlamaGen 执行文生图的方式。文本词元和图像词元共享同一个 Transformer、同一个注意力机制,并且通常共享同一个嵌入表(文本词元和图像词元占据不同的索引范围)。
|
||||
|
||||
- 光栅扫描顺序引入了一种人为的非对称性:图像的左上角是在没有任何关于右下角上下文的情况下首先生成的。一些工作解决了这个问题。**掩码图像建模**(MaskGIT)训练了一个双向 Transformer,同时生成所有词元但置信度不同,迭代地解开最自信的词元。**多尺度生成**首先生成粗粒度词元(捕捉全局构图),然后用残差词元进行细化。这些方法用纯从左到右生成的简单性换取了更好的全局连贯性。
|
||||
|
||||
### 统一的视觉-语言词元
|
||||
|
||||
- 图像词元化最深刻的动机是**统一**:将视觉和语言置于相同的表示格式中,使得单个模型架构可以同时处理两者。正如我们在第 7 章中讨论的,语言模型是极其强大的序列到序列机器。通过将图像表示为词元序列,我们免费继承了语言建模的所有基础设施——预训练配方、缩放定律、RLHF、上下文长度扩展。
|
||||
|
||||
- **Chameleon**(Meta,2024)是一个突出的例子:它使用具有 8192 个码本条目的 VQ-GAN 词元化器将图像转换为词元,这些词元与文本词元交织在一个约 65,000 个条目(文本 + 图像)的单一词汇表中。标准的 Transformer 在混合文本-图像序列上进行训练,使其能够根据图像生成文本、根据文本生成图像或生成交织的文本和图像内容,全部使用同一次前向传播。
|
||||
|
||||
- **Gemini**(Google,2024)在大规模上采取了类似的方法,原生地在单个 Transformer 中理解并生成图像、音频和文本,由特定模态的词元化器馈送到共享序列中。
|
||||
|
||||
- 统一模型中的关键工程挑战是**词汇表平衡**:如果 65,000 个词汇表条目中有 8192 个是图像词元,模型可能会分配不足的能力给视觉。解决方案包括为每种模态使用独立的嵌入层(仅在注意力层面共享)、特定模态的损失加权,以及预训练期间仔细的数据混合比例。
|
||||
|
||||

|
||||
|
||||
## 编程练习(在 Colab 或笔记本中运行)
|
||||
|
||||
1. 在 JAX 中实现一个最小 VQ 层:给定一批编码器输出向量,执行最近邻码本查找并计算 VQ-VAE 损失(重建 + 码本 + 承诺)。将码本利用率可视化为直方图。
|
||||
```python
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
# --- 最小 VQ 层 ---
|
||||
key = jax.random.PRNGKey(42)
|
||||
d = 8 # 嵌入维度
|
||||
K = 64 # 码本大小
|
||||
n_vectors = 256 # 一批编码器输出
|
||||
|
||||
# 随机编码器输出和码本
|
||||
k1, k2 = jax.random.split(key)
|
||||
z_e = jax.random.normal(k1, (n_vectors, d)) # 编码器输出
|
||||
codebook = jax.random.normal(k2, (K, d)) * 0.1 # 码本(小初始化)
|
||||
|
||||
# 最近邻查找:为每个 z_e 找到最近的码本条目
|
||||
# distances[i, k] = ||z_e[i] - codebook[k]||^2
|
||||
distances = (
|
||||
jnp.sum(z_e ** 2, axis=1, keepdims=True)
|
||||
- 2 * z_e @ codebook.T
|
||||
+ jnp.sum(codebook ** 2, axis=1, keepdims=True).T
|
||||
)
|
||||
indices = jnp.argmin(distances, axis=1) # 词元索引
|
||||
z_q = codebook[indices] # 量化向量
|
||||
|
||||
# VQ-VAE 损失项
|
||||
beta = 0.25
|
||||
loss_codebook = jnp.mean((jax.lax.stop_gradient(z_e) - z_q) ** 2)
|
||||
loss_commit = jnp.mean((z_e - jax.lax.stop_gradient(z_q)) ** 2)
|
||||
loss_total = loss_codebook + beta * loss_commit
|
||||
print(f"码本损失: {loss_codebook:.4f}, 承诺损失: {loss_commit:.4f}")
|
||||
|
||||
# 码本利用率
|
||||
unique, counts = jnp.unique(indices, return_counts=True, size=K, fill_value=-1)
|
||||
plt.figure(figsize=(10, 4))
|
||||
plt.bar(range(K), counts, color='#3498db', alpha=0.8)
|
||||
plt.xlabel('码本索引'); plt.ylabel('分配计数')
|
||||
plt.title(f'码本利用率(已使用 {jnp.sum(counts > 0)}/{K} 个条目)')
|
||||
plt.grid(True, alpha=0.3); plt.tight_layout(); plt.show()
|
||||
# 尝试:将 K 增加到 512 并观察坍塌。然后添加码本重置逻辑。
|
||||
```
|
||||
|
||||
2. 构建一个玩具 2D 向量量化器,学习对 2D 分布进行划分。生成随机 2D 点,通过 EMA 更新学习码本,并将 Voronoi 区域可视化。
|
||||
```python
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
# 从高斯混合生成 2D 数据
|
||||
key = jax.random.PRNGKey(0)
|
||||
n_points = 2000
|
||||
K = 16 # 码本条目数
|
||||
gamma = 0.99 # EMA 衰减
|
||||
|
||||
# 四个簇
|
||||
keys = jax.random.split(key, 5)
|
||||
centres = jnp.array([[2, 2], [-2, 2], [-2, -2], [2, -2]], dtype=jnp.float32)
|
||||
data = jnp.concatenate([
|
||||
jax.random.normal(keys[i], (n_points // 4, 2)) * 0.5 + centres[i]
|
||||
for i in range(4)
|
||||
])
|
||||
|
||||
# 从随机数据点初始化码本
|
||||
idx = jax.random.choice(keys[4], n_points, (K,), replace=False)
|
||||
codebook = data[idx]
|
||||
ema_count = jnp.ones(K)
|
||||
ema_sum = codebook.copy()
|
||||
|
||||
# 运行多个 epoch 的基于 EMA 的码本学习
|
||||
for epoch in range(30):
|
||||
# 将每个点分配给最近的码本条目
|
||||
dists = jnp.sum((data[:, None, :] - codebook[None, :, :]) ** 2, axis=2)
|
||||
assignments = jnp.argmin(dists, axis=1)
|
||||
# EMA 更新
|
||||
for k in range(K):
|
||||
mask = (assignments == k)
|
||||
count_k = jnp.sum(mask)
|
||||
ema_count = ema_count.at[k].set(gamma * ema_count[k] + (1 - gamma) * count_k)
|
||||
if count_k > 0:
|
||||
sum_k = jnp.sum(data[mask], axis=0)
|
||||
ema_sum = ema_sum.at[k].set(gamma * ema_sum[k] + (1 - gamma) * sum_k)
|
||||
codebook = ema_sum / ema_count[:, None]
|
||||
|
||||
# 可视化分配和码本
|
||||
fig, ax = plt.subplots(1, 1, figsize=(8, 8))
|
||||
colors = plt.cm.tab20(jnp.linspace(0, 1, K))
|
||||
for k in range(K):
|
||||
mask = assignments == k
|
||||
ax.scatter(data[mask, 0], data[mask, 1], c=[colors[k]], s=5, alpha=0.3)
|
||||
ax.scatter(codebook[:, 0], codebook[:, 1], c='black', s=120, marker='X',
|
||||
edgecolors='white', linewidths=1.5, zorder=10, label='码本')
|
||||
ax.set_title(f'在 2D 数据上学得的 VQ 码本({K} 个条目)')
|
||||
ax.legend(); ax.set_aspect('equal'); ax.grid(True, alpha=0.3)
|
||||
plt.tight_layout(); plt.show()
|
||||
# 尝试:将 K 增加到 64 并观察更精细的划分。减小 gamma 并观察不稳定性。
|
||||
```
|
||||
|
||||
3. 演示残差量化:用 $T$ 个连续的量化阶段对一批向量进行编码,并测量每个层级重建误差的下降。
|
||||
```python
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
key = jax.random.PRNGKey(7)
|
||||
d = 16 # 嵌入维度
|
||||
K = 32 # 每个层级的码本大小
|
||||
T = 8 # 残差层级数
|
||||
n_vectors = 512
|
||||
|
||||
# 待量化的随机数据
|
||||
k1, *cb_keys = jax.random.split(key, T + 1)
|
||||
z = jax.random.normal(k1, (n_vectors, d))
|
||||
|
||||
# 每个层级的独立随机码本
|
||||
codebooks = [jax.random.normal(cb_keys[t], (K, d)) * (0.5 ** t)
|
||||
for t in range(T)]
|
||||
|
||||
# 残差量化循环
|
||||
residual = z.copy()
|
||||
z_hat = jnp.zeros_like(z)
|
||||
errors = []
|
||||
|
||||
for t in range(T):
|
||||
cb = codebooks[t]
|
||||
dists = (jnp.sum(residual ** 2, axis=1, keepdims=True)
|
||||
- 2 * residual @ cb.T
|
||||
+ jnp.sum(cb ** 2, axis=1, keepdims=True).T)
|
||||
indices = jnp.argmin(dists, axis=1)
|
||||
z_q_t = cb[indices]
|
||||
z_hat = z_hat + z_q_t
|
||||
residual = residual - z_q_t
|
||||
mse = jnp.mean(jnp.sum((z - z_hat) ** 2, axis=1))
|
||||
errors.append(float(mse))
|
||||
print(f"层级 {t+1}: MSE = {mse:.4f}")
|
||||
|
||||
plt.figure(figsize=(8, 5))
|
||||
plt.plot(range(1, T + 1), errors, 'o-', color='#e74c3c', linewidth=2, markersize=8)
|
||||
plt.xlabel('残差量化层级')
|
||||
plt.ylabel('重建 MSE')
|
||||
plt.title('残差量化的误差降低')
|
||||
plt.xticks(range(1, T + 1)); plt.grid(True, alpha=0.3)
|
||||
plt.tight_layout(); plt.show()
|
||||
# 尝试:使用大小为 K*T 的单个码本并与 RQ 比较。哪个更好?
|
||||
```
|
||||
|
||||
4. 模拟一个简单的 1D"视频词元化器":生成一系列 1D 信号(模拟视频帧),应用因果时间压缩,并与无因果压缩在重建质量方面进行比较。
|
||||
```python
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
key = jax.random.PRNGKey(99)
|
||||
n_frames = 16
|
||||
frame_len = 64
|
||||
|
||||
# 生成一个"视频":在帧间缓慢移动的高斯凸起
|
||||
x_axis = jnp.linspace(-3, 3, frame_len)
|
||||
frames = jnp.stack([
|
||||
jnp.exp(-0.5 * (x_axis - (-2 + 4 * t / n_frames)) ** 2)
|
||||
for t in range(n_frames)
|
||||
]) # 形状: (n_frames, frame_len)
|
||||
|
||||
# 因果时间压缩:每帧的编码仅依赖于过去的帧
|
||||
# 简单方法:使用过去帧的指数衰减对当前帧进行平均
|
||||
alpha_causal = 0.6
|
||||
causal_codes = jnp.zeros_like(frames)
|
||||
causal_codes = causal_codes.at[0].set(frames[0])
|
||||
for t in range(1, n_frames):
|
||||
causal_codes = causal_codes.at[t].set(
|
||||
alpha_causal * frames[t] + (1 - alpha_causal) * causal_codes[t - 1]
|
||||
)
|
||||
|
||||
# 无因果:同时平均过去和未来(双边平滑)
|
||||
kernel = jnp.array([0.2, 0.6, 0.2]) # 过去, 当前, 未来
|
||||
padded = jnp.concatenate([frames[:1], frames, frames[-1:]], axis=0)
|
||||
noncausal_codes = jnp.stack([
|
||||
kernel[0] * padded[t] + kernel[1] * padded[t+1] + kernel[2] * padded[t+2]
|
||||
for t in range(n_frames)
|
||||
])
|
||||
|
||||
# 重建误差
|
||||
mse_causal = jnp.mean((frames - causal_codes) ** 2)
|
||||
mse_noncausal = jnp.mean((frames - noncausal_codes) ** 2)
|
||||
print(f"因果 MSE: {mse_causal:.6f}, 无因果 MSE: {mse_noncausal:.6f}")
|
||||
|
||||
fig, axes = plt.subplots(1, 3, figsize=(15, 5))
|
||||
for ax, data, title in zip(axes,
|
||||
[frames, causal_codes, noncausal_codes],
|
||||
['原始帧', f'因果 (MSE={mse_causal:.5f})',
|
||||
f'无因果 (MSE={mse_noncausal:.5f})']):
|
||||
ax.imshow(data, aspect='auto', cmap='viridis', origin='lower')
|
||||
ax.set_xlabel('空间位置'); ax.set_ylabel('帧索引')
|
||||
ax.set_title(title)
|
||||
plt.tight_layout(); plt.show()
|
||||
# 尝试:改变 alpha_causal 和核权重。alpha=1.0 时会发生什么?
|
||||
```
|
||||
@@ -0,0 +1,405 @@
|
||||
# 跨模态生成 (Cross-Modal Generation)
|
||||
|
||||
*跨模态生成(cross-modal generation)是指以某一模态的输入为条件,生成另一模态的输出——从文生图、图生文、文生音频,乃至更多。本章涵盖 DALL·E、Stable Diffusion、无分类器引导、ControlNet、图像描述、文生视频(Sora)以及文生音频生成。*
|
||||
|
||||
- 在本章的文件 01-03 中,你已经学习了如何表示、对齐和分词不同模态。现在轮到创造性的环节了:从一个模态生成另一个模态。跨模态生成是文生图工具、视频合成系统、音乐创作模型和图像描述背后的引擎。可以将其理解为教会机器成为多媒体艺术家——你用文字描述你想要的内容,机器则负责绘画、动画或作曲。
|
||||
|
||||
- 核心思想是**条件生成(conditional generation)**:给定来自模态 $A$(例如文本)的输入,生成模态 $B$(例如图像)的输出。形式上,我们学习模型 $p_\theta(y \mid x)$,其中 $x$ 是条件信号,$y$ 是生成的输出。挑战在于这个条件分布极其复杂且维度极高——一张 512x512 的图像存在于 $\mathbb{R}^{786432}$ 中,而对于同一个文本提示,可能有无数张合理的图像。
|
||||
|
||||

|
||||
|
||||
## 文生图生成 (Text-to-Image Generation)
|
||||
|
||||
- 想象你向法庭素描师描述一个场景。素描师必须理解你的话,回忆物体长什么样,在空间上排布它们,最后画出最终的图画。文生图模型正是做这件事,但它们必须从数据中学习所有这些技能,而不是经过多年的艺术院校训练。
|
||||
|
||||
### DALL·E:自回归图像生成
|
||||
|
||||
- **DALL·E**(Ramesh 等人,2021)将图像生成视为一个序列预测问题——这正是语言模型所采用的范式(见第 07 章)。其关键洞察是:如果你能将图像表示为离散 token(回顾文件 03 中的 VQ-VAE),那么生成图像就只是逐个生成 token 序列的过程。
|
||||
|
||||
- 其流程分为两个阶段。首先,一个**离散 VAE(dVAE)**将 256x256 的图像压缩成 32x32 的离散 token 网格,码本大小为 8192,将图像简化为 1024 个 token 的序列。其次,一个**Transformer 解码器**被训练来建模 256 个文本 token(BPE 编码)与 1024 个图像 token 拼接后的联合分布,总计 1280 个 token:
|
||||
|
||||
$$p(x_{\text{text}}, x_{\text{img}}) = \prod_{i=1}^{1280} p(x_i \mid x_1, \ldots, x_{i-1})$$
|
||||
|
||||
- 在生成时,输入文本 token,模型自回归地逐个采样图像 token。这种方法优雅之处在于它复用了语言建模的完整机制——注意力、因果掩码、top-k 采样——来完成图像合成。
|
||||
|
||||
- 缺点是自回归生成本质上是串行的:逐个生成 1024 个 token 速度很慢,而且序列早期的任何错误都会被放大。DALL·E 通过生成大量候选图像并用 CLIP(来自文件 01)进行重排序来缓解这一问题,以找到与文本提示最匹配的结果。
|
||||
|
||||

|
||||
|
||||
### Stable Diffusion:带文本条件的隐空间扩散
|
||||
|
||||
- **Stable Diffusion**(Rombach 等人,2022)采用了一种根本不同的方法。它不是逐个预测 token,而是从纯噪声开始,在文本提示的引导下逐步将噪声去噪成图像。回顾第 8 章中的扩散模型——Stable Diffusion 在压缩后的隐空间(latent space)而非像素空间中运行,因此效率大幅提升。
|
||||
|
||||
- 其架构由三个组件协同工作。**VAE 编码器**将图像从像素空间($512 \times 512 \times 3$)压缩为隐空间表示($64 \times 64 \times 4$),将维度降低了 48 倍。**文本编码器**(通常为 CLIP 或 OpenCLIP)将文本提示转换为嵌入向量序列。**U-Net 去噪器**接收含噪隐变量、时间步和文本嵌入,并预测每一步需要减去的噪声。文本条件通过**交叉注意力(cross-attention)**层进入 U-Net:
|
||||
|
||||
$$\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d}}\right)V$$
|
||||
|
||||
- 其中 $Q$ 来自含噪图像特征,$K$ 和 $V$ 来自文本嵌入。这使得模型能够在每个空间位置上关注相关的词语——当去噪"红球"应该出现的区域时,模型会关注"红"和"球"这两个 token。
|
||||
|
||||
- 在推理时,你在隐空间中采样 $z_T \sim \mathcal{N}(0, I)$,利用 U-Net 迭代去噪 $T$ 步(通常使用 DDIM 调度为 20-50 步),然后用 VAE 解码器将干净的隐变量 $z_0$ 解码回像素空间。整个前向过程在消费级 GPU 上仅需数秒即可生成一张 512x512 的图像。
|
||||
|
||||

|
||||
|
||||
### 无分类器引导的实践应用
|
||||
|
||||
- **无分类器引导(Classifier-Free Guidance,CFG)**是让文生图模型能够生成与提示真正匹配的图像的关键要素。回顾第 8 章,CFG 同时训练条件模型和无条件模型,然后在采样时放大条件信号:
|
||||
|
||||
$$\hat{\epsilon} = \epsilon_\theta(x_t, \varnothing) + s \cdot (\epsilon_\theta(x_t, c) - \epsilon_\theta(x_t, \varnothing))$$
|
||||
|
||||
- 其中 $s$ 是引导尺度。可以将 $(\epsilon_\theta(x_t, c) - \epsilon_\theta(x_t, \varnothing))$ 理解为"朝向提示的方向"——它捕捉了有条件预测与无条件预测之间的差异。乘以 $s > 1$ 会放大这个方向,将图像推近文本描述,但代价是多样性降低。
|
||||
|
||||
- 在实践中,Stable Diffusion 的常用默认值为 $s = 7.5$。当 $s = 1.0$ 时得到模型的原始输出(多样但仅松散匹配提示)。当 $s \geq 20$ 时图像变得过饱和且重复,但与文本高度一致。最优 $s$ 值取决于应用场景:创意探索倾向于较低的引导值,而精确遵循提示则需要更高的引导值。
|
||||
|
||||
### Imagen:基于语言理解的级联扩散
|
||||
|
||||
- **Imagen**(Saharia 等人,2022)证明了强大的文本编码器比更大的图像模型更重要。Imagen 没有使用 CLIP,而是采用一个冻结的 **T5-XXL** 语言模型(来自第 07 章)作为文本编码器,该模型对语言语义、组合性和空间关系(如"红色球体上的蓝色方块")有着更丰富的理解。
|
||||
|
||||
- Imagen 使用了**级联扩散(cascaded diffusion)**方法:基础扩散模型生成 64x64 的图像,第一个超分辨率模型放大到 256x256,第二个超分辨率模型达到 1024x1024。每个阶段都是独立的扩散模型,以文本和(对于上采样器)低分辨率图像为条件。这种级联方式避免了在基础分辨率上建模精细细节,使基础模型能够专注于构图和语义,而上采样器则负责处理纹理和清晰度。
|
||||
|
||||
- Imagen 还引入了**动态阈值(dynamic thresholding)**:在每个去噪步骤中,预测的像素值被裁剪到基于百分位数的范围,而不是固定的 $[-1, 1]$ 范围。这可以防止在高引导尺度下出现饱和伪影,这是扩散模型中的常见问题。
|
||||
|
||||
### Parti:大规模自回归
|
||||
|
||||
- **Parti**(Pathways Autoregressive Text-to-Image,Yu 等人,2022)以超大尺度复兴了自回归方法。与 DALL·E 类似,它将图像转换为离散 token(使用 ViT-VQGAN),并用 Transformer 顺序生成。但 Parti 使用了 200 亿参数的编码器-解码器 Transformer(基于 Pathways 架构),并证明了自回归模型在充分扩展后可以达到扩散模型的质量。
|
||||
|
||||
- Parti 的编码器-解码器架构是与 DALL·E 纯解码器设计的关键区别。文本通过编码器处理;解码器在生成图像 token 时,通过交叉注意力关注编码后的文本。这类似于机器翻译(第 07 章)——你从"文本语言"翻译到"图像语言"。
|
||||
|
||||
### DiT 与基于流匹配的生成
|
||||
|
||||
- **扩散 Transformer(DiT)**(Peebles 和 Xie,2023)用纯 Transformer 替换了扩散模型中的 U-Net 主干网络。每个含噪隐空间块被当作一个 token(类似于第 8 章中的 ViT),Transformer 通过自注意力和对文本条件的交叉注意力来处理这些 token。DiT 表明,在扩散任务中,Transformer 的可扩展性比 U-Net 更具可预测性——计算量每翻一倍,FID 分数就会可靠地减半。
|
||||
|
||||
- **流匹配(flow matching)**(回顾第 8 章)已成为扩散噪声预测范式之外的一种替代方案。模型不再预测需要减去的噪声 $\epsilon$,而是预测一个速度场 $v_\theta(x_t, t)$,该速度场沿直线路径将样本从噪声传输到数据。**Stable Diffusion 3** 和 **Flux** 采用流匹配和**多模态 DiT(MM-DiT)**架构,其中文本和图像 token 由 Transformer 块通过双向注意力联合处理——两种模态互相关注,而不是文本仅通过交叉注意力作为图像特征的条件。
|
||||
|
||||

|
||||
|
||||
## 文生视频生成 (Text-to-Video Generation)
|
||||
|
||||
- 文生视频相当于文生图再加上一个严苛的额外约束:**时间连贯性(temporal coherence)**。每一帧必须在内部保持一致(是一张合理的图像),但连续帧之间也必须平滑连接——物体应该自然运动,光照应连续变化,"镜头"应遵循物理上合理的轨迹。可以想象一下绘制一幅风景画和导演一部电影之间的区别。
|
||||
|
||||
### 时间维度的挑战
|
||||
|
||||
- 视频引入了图像生成之外的三个挑战。**时间一致性(temporal consistency)**要求物体在各帧之间保持身份不变——第 1 帧中的狗在第 100 帧中应该还是同一条狗。**运动建模(motion modeling)**需要学习物理动态:物体如何运动、重力如何作用、流体如何流动。**计算成本**非常高昂:一段 24 fps、512x512 分辨率的 10 秒视频包含 $10 \times 24 \times 512 \times 512 \times 3 \approx 1.88$ 亿个值,大约是单张图像数据量的 240 倍。
|
||||
|
||||
### Make-A-Video 与延展至视频方法
|
||||
|
||||
- **Make-A-Video**(Singer 等人,2022)采用了一种务实的方法:从预训练的文生图模型开始,添加时间层。关键洞察是,你已经拥有了基于数十亿图文对训练的强大文生图模型,你只需要从(未标注的)视频数据中学习运动。
|
||||
|
||||
- Make-A-Video 在预训练的空间 U-Net 中插入了**时间注意力(temporal attention)**和**时间卷积(temporal convolution)**层。空间层(在图像上预训练)负责外观,而新的时间层(在视频上训练)负责运动。空间自注意力在每帧内部操作;时间注意力在每个空间位置上跨帧操作。这种分解是高效的,因为时间和空间模式在很大程度上是可分离的。
|
||||
|
||||
- 生成流程与 Imagen 的级联方式类似:基础模型生成 64x64 的 16 帧,然后空间和时间超分辨率模型将分辨率升级到最终大小和帧率。帧插值网络用于提高时间平滑性。
|
||||
|
||||
### VideoPoet 与基于 Token 的视频模型
|
||||
|
||||
- **VideoPoet**(Kondratyuk 等人,2024)将视频生成统一到语言建模范式之下。所有模态——文本、图像、视频、音频——都被 token 化为离散序列,一个单一的大语言模型(LLM)被训练来跨所有模态自回归地预测 token。这使得零样本能力成为可能:文生视频、图生视频、视频生音频、视频编辑和视频修补都可以从同一个模型中涌现。
|
||||
|
||||
- VideoPoet 使用 MAGVIT-v2 编码器(一个来自文件 03 的 3D VQ-VAE)对视频进行 token 化,该编码器联合压缩空间和时间维度。音频使用 SoundStream 进行 token 化。LLM 主干在文本上预训练,然后在多模态 token 序列上微调,学习跨模态的联合分布。
|
||||
|
||||
### Sora 风格的时间扩散
|
||||
|
||||
- **Sora**(OpenAI,2024)凭借其生成长时间、连贯、物理合理的视频的能力,将时间扩散带入了主流视野。虽然完整的架构细节尚未公开,但其关键思想是将 DiT 扩展到时空领域:视频帧被分解为**时空块(spacetime patches)**(跨越高度、宽度和时间的三维块),这些块被当作大型 Transformer 的 token 来处理。
|
||||
|
||||
- 时空块方法意味着模型将视频作为原生的 3D 信号来处理,而不是一系列 2D 帧。这使得模型能够捕获长程的时间依赖关系——模型可以"提前规划"整个视频时长,而不是逐帧生成。
|
||||
|
||||
- Sora 可以通过调整时空块的数量来处理可变的时长、分辨率和宽高比。以数据原生分辨率进行训练(而不是将所有图像裁剪为正方形)可以提高构图和取景质量。
|
||||
|
||||
### Wan:开源视频生成
|
||||
|
||||
- **Wan**(Wan 等人,2025)是一个开源视频生成模型系列(1.3B 和 14B 参数),基于 DiT 主干和 3D VAE 时间压缩。Wan 采用**流匹配**而不是传统的 DDPM 风格扩散,学习从噪声到视频隐空间的直线传输路径。3D VAE 在空间和时间上压缩视频(4 倍时间压缩),DiT 以全 3D 注意力处理生成的时空隐空间 token。
|
||||
|
||||
- Wan 支持文生视频、图生视频(将静态图像动画化)和视频编辑。14B 模型可以生成长达 5 秒、720p 分辨率的连贯视频,表明当架构和训练方案选择恰当时,开源模型可以接近专有系统的质量。
|
||||
|
||||

|
||||
|
||||
## 文生音频生成 (Text-to-Audio Generation)
|
||||
|
||||
- 想象一位电影配乐师阅读剧本并为电影配乐。文生音频模型做着类似的事情:给定一段文本描述("伴有大雨和远处雷声的雷暴"),它们生成相应的音频波形。挑战在于弥合文本的离散、符号化本质与声音的连续、时间性本质之间的差距。
|
||||
|
||||
### AudioLM:音频的语言建模
|
||||
|
||||
- **AudioLM**(Borsos 等人,2023)通过自回归预测离散音频 token 来生成音频,采用了与 DALL·E 为图像所用的相同语言建模范式。它使用分层 token 结构:**语义 token**(来自自监督模型如 w2v-BERT,回顾第 9 章)捕获高层次内容(说了什么或演奏了什么),而**声学 token**(来自 SoundStream,一种神经音频编解码器)捕获细粒度的声学细节(听起来如何——音色、录音质量)。
|
||||
|
||||
- 生成分两个阶段进行。首先,一个 Transformer 在给定可选音频提示的情况下预测语义 token,建立高层次的"内容规划"。其次,另一个 Transformer 以语义 token 为条件预测声学 token,填充声学细节。这种层次结构类似于文生语音流程(第 9 章)——语义 token 扮演音素的角色,声学 token 扮演梅尔频谱图帧的角色。
|
||||
|
||||
- AudioLM 可以生成语音接续(给定 3 秒语音,生成接下来的 10 秒)、音乐接续和音效,所有这些都来自一个仅在音频数据上训练的模型(预训练不需要文本标签)。
|
||||
|
||||
### MusicLM:文本条件音乐生成
|
||||
|
||||
- **MusicLM**(Agostinelli 等人,2023)将 AudioLM 扩展到文本条件下的音乐生成。它添加了一个文本-音频联合嵌入(来自 **MuLan**,一个在音乐-文本对上训练的类 CLIP 模型)来条件化生成。MuLan 嵌入捕获文本描述的语义含义("带有萨克斯独奏的欢快爵士乐")并指导分层 token 生成。
|
||||
|
||||
- MusicLM 以 24 kHz 的频率生成任意时长的音乐,在数分钟长的作品中保持旋律和节奏的连贯性。它还可以用哼唱的旋律(由音高追踪器提取的旋律 token)加上文本描述作为条件,生成完整的编曲,既遵循哼唱的曲调,又符合文本描述的风格。
|
||||
|
||||
### MusicGen:高效单阶段生成
|
||||
|
||||
- **MusicGen**(Copet 等人,2023)简化了多阶段方法。MusicGen 不使用独立的语义和声学模型,而是使用一个单一的自回归 Transformer,直接生成来自音频编解码器的多个码本层级。关键创新是**交织码本模式(interleaved codebook pattern)**:MusicGen 并非在进入下一个时间步之前生成该时间步的所有码本层级,而是以某种模式跨码本和时间步交织 token,从而允许对某些码本层级进行并行解码。
|
||||
|
||||
- 条件化直接明了:文本由 T5 编码器编码,文本嵌入被前置到音频 token 序列之前(像语言模型中的前缀提示)或通过交叉注意力注入。MusicGen 还支持旋律条件化:参考旋律的色谱图(chromagram,来自第 9 章中讨论的频谱图特征)被编码后与文本条件一起使用。
|
||||
|
||||
$$p(a_1, \ldots, a_T) = \prod_{t=1}^{T} \prod_{k=1}^{K} p(a_{t,k} \mid a_{<t}, c_{\text{text}})$$
|
||||
|
||||
- 其中 $a_{t,k}$ 是时间步 $t$、码本层级 $k$ 处的音频 token,$c_{\text{text}}$ 是文本条件。对 $k$ 的求积根据码本模式进行因式分解——某些层级是并行预测的。
|
||||
|
||||

|
||||
|
||||
## 图生文生成 (Image-to-Text Generation)
|
||||
|
||||
- 现在翻转方向:给定一张图像,生成自然语言描述。这就是**图像描述(image captioning)**,这是一种以图像为条件的条件文本生成形式。可以想象一位博物馆导览员描述一幅画作——他们必须感知视觉内容,理解物体之间的关系,并用流畅的语言表达观察结果。
|
||||
|
||||
### 作为条件生成的图像描述
|
||||
|
||||
- 经典方法使用**编码器-解码器**架构(第 07 章)。预训练的 CNN 或 ViT(第 8 章)将图像编码为一组特征向量。语言模型解码器逐词生成描述,每一步都关注图像特征:
|
||||
|
||||
$$p(w_1, \ldots, w_L \mid I) = \prod_{l=1}^{L} p(w_l \mid w_1, \ldots, w_{l-1}, I)$$
|
||||
|
||||
- 其中 $w_l$ 是描述中的词语,$I$ 是图像表示。交叉注意力将文本解码器与图像特征连接起来,使模型在生成不同词语时能够"查看"图像的不同区域——生成"狗"时关注狗的区域,生成"公园"时关注公园的区域。
|
||||
|
||||
- **CoCa**(Contrastive Captioners,Yu 等人,2022)在一个单一模型中统一了对比学习(文件 01 中的 CLIP 风格目标)和图像描述。图像编码器生成的特征既用于与文本进行对比对齐,也用于描述解码器中的交叉注意力。这种多任务训练使 CoCa 同时具有强大的零样本识别能力(来自对比学习)和强大的生成能力(来自图像描述)。
|
||||
|
||||
### 现代视觉语言描述
|
||||
|
||||
- 现代方法通常使用**大型多模态模型**(文件 02)来进行图像描述。LLaVA、Qwen-VL 和 GPT-4V 等模型将图像描述视为视觉问答的一种特殊情况——"问题"隐式地就是"描述这张图像"。视觉编码器(CLIP ViT 或 SigLIP)生成块 token,这些 token 被投影到 LLM 的嵌入空间中,然后 LLM 生成自由形式的描述。
|
||||
|
||||
- 基于 LLM 的描述相较于专用编码器-解码器模型的优势在于**指令遵循(instruction following)**:你可以要求不同详细程度("用一句话描述"对比"提供详细段落"),关注特定方面("描述颜色"),或生成结构化输出("列出所有物体及其位置")。这种灵活性来源于 LLM 的指令微调(第 07 章)。
|
||||
|
||||
## 视频-音频联合生成 (Video-Audio Co-Generation)
|
||||
|
||||
- 想象一下关掉声音看电影——体验是空洞的。视觉内容和音频是深度耦合的:弹跳的球有节奏的撞击声,雨水发出啪嗒声,人群爆发出欢呼声。**视频-音频联合生成(video-audio co-generation)**旨在同时生成两种模态,保持所看与所听之间的时间对齐。
|
||||
|
||||
### 联合时间建模
|
||||
|
||||
- 核心挑战是**时间同步(temporal synchronisation)**:击鼓的音频必须与鼓槌击鼓的视觉帧精确重合。这需要一个两种模态都能引用的共享时间表示。
|
||||
|
||||
- 一种方法是从共享的潜在时间线生成视频和音频。像 **CoDi**(Composable Diffusion,Tang 等人,2023)这样的模型对每种模态使用独立的扩散模型,但通过共享的隐空间进行对齐。在训练过程中,跨模态注意力层学习在每个时间步同步视觉和音频特征。在生成过程中,两种扩散过程同时运行,通过共享对齐相互条件化。
|
||||
|
||||
- 前面讨论的 VideoPoet 采用了一种更统一的方法:由于所有模态都被 token 化为单一序列,LLM 自然地学习了视频 token 和音频 token 之间的时间对应关系。一段狗叫的视频片段后面跟随着相应的音频 token,教会模型将视觉上的狗叫动作与狗叫声关联起来。
|
||||
|
||||
- **时间对齐损失(temporal alignment loss)**函数显式地强制同步。一种形式是在帧级别使用对比学习:时间 $t$ 的音频段应该与时间 $t$ 的视频帧比其他时刻的帧更相似:
|
||||
|
||||
$$\mathcal{L}_{\text{sync}} = -\mathbb{E}_t \left[\log \frac{\exp(\text{sim}(v_t, a_t) / \tau)}{\sum_{t'} \exp(\text{sim}(v_t, a_{t'}) / \tau)}\right]$$
|
||||
|
||||
- 其中 $v_t$ 和 $a_t$ 是时间 $t$ 的视频和音频表示,$\tau$ 是温度参数。这与文件 01 中的 InfoNCE 损失在结构上相同,但应用于时间帧级别而非片段级别。
|
||||
|
||||
## 指令遵循式生成 (Instruction-Following Generation)
|
||||
|
||||
- 想象你告诉一位艺术家"让天空更有戏剧性"或"把帽子换成王冠"。**指令遵循式生成(instruction-following generation)**允许你使用自然语言命令编辑图像,而不需要精确的空间遮罩或笔触。
|
||||
|
||||
### InstructPix2Pix:通过描述进行编辑
|
||||
|
||||
- **InstructPix2Pix**(Brooks 等人,2023)训练了一个条件扩散模型,该模型接收输入图像和文本指令,然后生成编辑后的图像。巧妙之处在于训练数据的创建方式:GPT-3 生成编辑指令("变成冬天"、"把猫变成狗")以及输入-输出文本描述对,然后文生图模型(Stable Diffusion)生成相应的图像对。
|
||||
|
||||
- 模型是一个修改后的 Stable Diffusion U-Net,同时接收文本指令(通过交叉注意力)和输入图像的隐表示(与含噪隐变量按通道拼接)。它使用**双无分类器引导(dual classifier-free guidance)**,包含两个引导尺度——一个用于文本指令($s_T$),一个用于输入图像($s_I$):
|
||||
|
||||
$$\hat{\epsilon} = \epsilon_\theta(x_t, \varnothing, \varnothing) + s_I \cdot (\epsilon_\theta(x_t, c_I, \varnothing) - \epsilon_\theta(x_t, \varnothing, \varnothing)) + s_T \cdot (\epsilon_\theta(x_t, c_I, c_T) - \epsilon_\theta(x_t, c_I, \varnothing))$$
|
||||
|
||||
- 其中 $c_I$ 是输入图像条件,$c_T$ 是文本指令。第一个引导项控制保留输入图像的程度;第二个控制遵循指令的强度。这为用户提供了一个二维旋钮:高 $s_I$ 更紧密地保留原图,而高 $s_T$ 则进行更大幅度的编辑。
|
||||
|
||||

|
||||
|
||||
### SDEdit 与基于噪声的编辑
|
||||
|
||||
- **SDEdit**(Meng 等人,2022)提供了一种更简单的编辑方法,不需要特殊训练。你对输入图像添加噪声(运行前向扩散过程到中间时间步 $t_0$),然后用描述所需输出的文本提示进行去噪。噪声量控制编辑强度:低噪声保留结构(颜色变化、风格迁移),而高噪声允许大幅重构(物体替换、布局改变)。
|
||||
|
||||
- 这是一个精确的权衡:在时间步 $t_0$,含噪图像保留了原始信号的 $\bar{\alpha}_{t_0}$ 比例。去噪过程根据新的文本提示填充被破坏的细节。这在数学上是严谨的:扩散模型从后验分布 $p(x_0 \mid x_{t_0}, c)$ 中采样,其中 $x_{t_0}$ 将生成结果约束为"接近"原始图像。
|
||||
|
||||
### ControlNet:空间条件控制
|
||||
|
||||
- **ControlNet**(Zhang 等人,2023)为文生图扩散增加了细粒度的空间控制。预训练 U-Net 编码器的副本被训练来接受额外的输入条件——边缘图(Canny 边缘)、深度图、姿态骨架、分割图——而原始 U-Net 权重被冻结。ControlNet 编码器的输出通过**零卷积(zero convolutions)**(初始化为零的 1x1 卷积)添加到冻结的 U-Net 的跳跃连接中,确保训练从预训练模型的行为开始,逐步学习新的条件。
|
||||
|
||||
- 这种架构让你可以提供草图、深度图或人体姿态作为结构指导,文本提示则负责填充外观。预训练权重处理逼真度和文本理解;ControlNet 层处理对条件空间保真度的保持。
|
||||
|
||||
## 一致性与对齐指标 (Consistency and Alignment Metrics)
|
||||
|
||||
- 如何衡量生成的图像是否良好?"良好"至少有两个维度:**质量(quality)**(看起来像真实图像吗?)和**对齐度(alignment)**(与文本提示匹配吗?)。若干指标已被开发出来量化这些方面。
|
||||
|
||||
### Frechet Inception Distance (FID)
|
||||
|
||||
- **Frechet Inception Distance(FID)**(Heusel 等人,2017)衡量生成图像分布与真实图像分布之间在预训练 Inception 网络特征空间中的距离。可以将其理解为比较两个图像集合的"指纹",而不是比较单个图像。
|
||||
|
||||
- 真实图像集和生成图像集都通过 Inception-v3 处理,收集倒数第二层的激活值。这些激活值被建模为多元高斯分布 $\mathcal{N}(\mu_r, \Sigma_r)$ 和 $\mathcal{N}(\mu_g, \Sigma_g)$。FID 就是这些高斯分布之间的 Frechet 距离(Wasserstein-2 距离):
|
||||
|
||||
$$\text{FID} = \|\mu_r - \mu_g\|^2 + \text{Tr}\left(\Sigma_r + \Sigma_g - 2(\Sigma_r \Sigma_g)^{1/2}\right)$$
|
||||
|
||||
- FID 越低越好。FID = 0 意味着分布完全相同。FID 同时捕捉质量(如果生成的图像模糊,其特征将与真实图像不同)和多样性(如果模型遭受模式坍塌,$\Sigma_g$ 将小于 $\Sigma_r$)。在 ImageNet 256x256 上,当前的先进水平为 FID < 2.0。
|
||||
|
||||
- FID 存在已知局限性:它假设特征分布是高斯分布(这只是一个近似),需要数千个样本才能获得稳定估计,并且使用 Inception 特征(可能无法捕捉所有感知上相关的差异)。
|
||||
|
||||
### Inception Score (IS)
|
||||
|
||||
- **Inception Score(IS)**(Salimans 等人,2016)衡量两个特性:每张生成的图像应该能被自信地分类(条件类别分布 $p(y \mid x)$ 应该是尖峰状的),并且生成的图像集合应该覆盖多个类别(边缘分布 $p(y) = \mathbb{E}_x[p(y \mid x)]$ 应该是均匀的)。IS 通过 KL 散度将两者结合起来:
|
||||
|
||||
$$\text{IS} = \exp\left(\mathbb{E}_x \left[D_{\text{KL}}(p(y \mid x) \| p(y))\right]\right)$$
|
||||
|
||||
- IS 越高越好。最大 IS 等于类别数(对于 ImageNet 为 1000)。IS 奖励质量(清晰、可识别的图像)和多样性(类别覆盖),但它有显著的局限性:它完全忽略真实数据分布,无法检测类别内的模式遗漏,并且由于使用 Inception 的类别预测,它偏向于类似 ImageNet 的图像。
|
||||
|
||||
### CLIPScore:衡量文本-图像对齐度
|
||||
|
||||
- **CLIPScore**(Hessel 等人,2021)使用预训练的 CLIP 模型(文件 01)直接衡量生成的图像与其文本提示的匹配程度。这个分数就是 CLIP 图像嵌入与 CLIP 文本嵌入之间的余弦相似度:
|
||||
|
||||
$$\text{CLIPScore}(I, T) = \max(0, \cos(E_I(I), E_T(T)))$$
|
||||
|
||||
- 其中 $E_I$ 和 $E_T$ 是 CLIP 的图像和文本编码器。CLIPScore 无需参考——它不需要真实图像,只需要文本提示。它与人类对文本-图像对齐的判断高度相关,已成为评估文生图模型提示保真度的标准指标。
|
||||
|
||||
- 如果需要与参考描述进行比较,**RefCLIPScore** 会纳入参考图像:
|
||||
|
||||
$$\text{RefCLIPScore} = \text{HarmonicMean}(\text{CLIPScore}(I, T), \max(0, \cos(E_I(I), E_I(I_{\text{ref}}))))$$
|
||||
|
||||
- 这平衡了文本对齐度与参考图像的视觉相似性。
|
||||
|
||||

|
||||
|
||||
### 人工评估
|
||||
|
||||
- 自动化指标只是代理指标;人工判断仍然是黄金标准。常见方案包括**成对比较(pairwise comparisons)**(两张图像中哪张更匹配提示?)、**Likert 量表(Likert scales)**(从 1-5 分评价质量和对齐度)以及 **Elo 评分(Elo ratings)**(跨模型的锦标赛式排名)。DrawBench 和 PartiPrompts 基准测试提供了用于系统化人工评估的标准化提示集。
|
||||
|
||||
## 伦理考量 (Ethical Considerations)
|
||||
|
||||
- 跨模态生成是人工智能领域伦理后果最严重的领域之一。能够根据文本描述创建逼真的图像、视频和音频,这引发了从业者必须严肃对待的深刻担忧。
|
||||
|
||||
### 深度伪造与虚假信息
|
||||
|
||||
- **深度伪造(Deepfakes)**是指旨在描绘从未发生事件的生成或操纵媒体。文生图和文生视频模型可以创建令人信服的公众人物假照片、捏造的证据和误导性的新闻图像。危险不仅在于伪造的存在,还在于它们的存在削弱了对所有媒体的信任——如果任何图像都可能是假的,那么就没有图像是值得完全信任的。
|
||||
|
||||
- 检测方法包括训练分类器区分真实和生成的图像、分析统计伪影(GAN 生成的图像具有微妙的频谱特征)以及嵌入不可见水印(Stable Diffusion 的不可见水印、Google 的 SynthID)。然而,检测是一场军备竞赛:随着生成器的改进,检测器必须不断更新。
|
||||
|
||||
### 生成中的偏差
|
||||
|
||||
- 在互联网规模数据上训练的模型会继承并放大社会偏见。文生图模型会不成比例地生成肤色较浅的面孔,将某些职业与特定性别关联起来,并在提示不够明确时默认采用西方文化规范。这些偏见根植于训练数据分布以及 CLIP/T5 文本编码器中,后者从其自身的训练语料库中编码了偏见。
|
||||
|
||||
- 缓解策略包括:策划更具代表性的训练数据、对文本编码器应用去偏技术、使用安全分类器过滤有问题的输出,以及让用户能够控制人口统计属性。这些都不是完整的解决方案,持续的审核至关重要。
|
||||
|
||||
### 内容过滤与安全性
|
||||
|
||||
- 负责任的部署需要多层保护。**输入过滤**在生成之前阻止有害提示。**输出过滤**对生成内容进行分类并拒绝有害材料。**NSFW 分类器**检测露骨色情、暴力或其他有害内容。例如,Stable Diffusion 的安全检查器计算生成图像的 CLIP 嵌入与一组预定义的有害概念嵌入之间的余弦相似度,标记超过阈值的图像。
|
||||
|
||||
- 许多生成模型(Stable Diffusion、Wan)的开源性质在普及访问和防止滥用之间形成了张力。一旦模型权重发布,内容过滤就可以被绕过。这引发了关于适当的开放程度以及模型开发者责任的讨论。
|
||||
|
||||
### 知识产权与知情同意
|
||||
|
||||
- 在互联网数据上训练的生成模型可能会在未经同意的情况下复制受版权保护的风格、商标或真实人物的肖像。法律和伦理框架仍在演变中,但负责任的实践包括尊重选择退出机制、承认训练数据中蕴含的创造性贡献,以及开发防止记忆和复述训练例子的技术保障措施。
|
||||
|
||||
## 编程练习(使用 CoLab 或 notebook)
|
||||
|
||||
1. 为一个玩具 2D 扩散模型实现无分类器引导。在 2D 数据集(例如标注的聚类)上训练一个条件扩散模型,然后使用不同的引导尺度进行采样,观察质量与多样性的权衡。
|
||||
```python
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
# Toy 2D conditional diffusion with classifier-free guidance
|
||||
def noise_schedule(T):
|
||||
betas = jnp.linspace(1e-4, 0.02, T)
|
||||
alphas = 1.0 - betas
|
||||
return jnp.cumprod(alphas)
|
||||
|
||||
def forward_diffuse(x0, t, alpha_bars, key):
|
||||
noise = jax.random.normal(key, x0.shape)
|
||||
return jnp.sqrt(alpha_bars[t]) * x0 + jnp.sqrt(1 - alpha_bars[t]) * noise, noise
|
||||
|
||||
# Generate labelled 2D data: class 0 = ring, class 1 = cluster
|
||||
key = jax.random.PRNGKey(42)
|
||||
k1, k2, k3 = jax.random.split(key, 3)
|
||||
theta = jax.random.uniform(k1, (200,)) * 2 * jnp.pi
|
||||
ring = jnp.stack([jnp.cos(theta), jnp.sin(theta)], axis=1) * 2
|
||||
ring += jax.random.normal(k2, ring.shape) * 0.1
|
||||
cluster = jax.random.normal(k3, (200, 2)) * 0.3
|
||||
|
||||
data = jnp.concatenate([ring, cluster])
|
||||
labels = jnp.concatenate([jnp.zeros(200), jnp.ones(200)])
|
||||
|
||||
# Simulate CFG: show how guidance pushes samples toward class-conditional modes
|
||||
# Try varying guidance_scale from 0.0 to 5.0 and observe results
|
||||
guidance_scales = [0.0, 1.0, 3.0, 7.0]
|
||||
fig, axes = plt.subplots(1, 4, figsize=(16, 4))
|
||||
for ax, s in zip(axes, guidance_scales):
|
||||
ax.scatter(ring[:, 0], ring[:, 1], s=8, alpha=0.4, label='Ring (c=0)')
|
||||
ax.scatter(cluster[:, 0], cluster[:, 1], s=8, alpha=0.4, label='Cluster (c=1)')
|
||||
ax.set_title(f'Guidance scale s={s}')
|
||||
ax.set_xlim(-4, 4); ax.set_ylim(-4, 4)
|
||||
ax.set_aspect('equal'); ax.legend(fontsize=7)
|
||||
plt.suptitle('Experiment: vary guidance scale and observe quality vs diversity')
|
||||
plt.tight_layout(); plt.show()
|
||||
# Exercise: train a small MLP denoiser with class conditioning,
|
||||
# then implement the CFG formula to sample with different s values.
|
||||
```
|
||||
|
||||
2. 使用完整的 Frechet 距离公式计算两组 2D 样本之间的 FID。改变生成分布,观察 FID 如何变化。
|
||||
```python
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
def compute_fid(real, generated):
|
||||
"""Compute Frechet distance between two 2D sample sets."""
|
||||
mu_r, mu_g = jnp.mean(real, axis=0), jnp.mean(generated, axis=0)
|
||||
sigma_r = jnp.cov(real.T)
|
||||
sigma_g = jnp.cov(generated.T)
|
||||
diff = mu_r - mu_g
|
||||
# Matrix square root via eigendecomposition
|
||||
product = sigma_r @ sigma_g
|
||||
eigvals, eigvecs = jnp.linalg.eigh(product)
|
||||
sqrt_product = eigvecs @ jnp.diag(jnp.sqrt(jnp.maximum(eigvals, 0))) @ eigvecs.T
|
||||
fid = jnp.sum(diff ** 2) + jnp.trace(sigma_r + sigma_g - 2 * sqrt_product)
|
||||
return fid
|
||||
|
||||
key = jax.random.PRNGKey(0)
|
||||
k1, k2, k3, k4 = jax.random.split(key, 4)
|
||||
|
||||
# Real distribution: standard 2D Gaussian
|
||||
real = jax.random.normal(k1, (1000, 2))
|
||||
|
||||
# Generated distributions with increasing divergence
|
||||
shifts = [0.0, 0.5, 1.0, 2.0, 4.0]
|
||||
fig, axes = plt.subplots(1, len(shifts), figsize=(18, 3.5))
|
||||
for ax, shift in zip(axes, shifts):
|
||||
gen = jax.random.normal(k2, (1000, 2)) * (1 + shift * 0.2) + shift
|
||||
fid = compute_fid(real, gen)
|
||||
ax.scatter(real[:, 0], real[:, 1], s=3, alpha=0.3, label='Real')
|
||||
ax.scatter(gen[:, 0], gen[:, 1], s=3, alpha=0.3, label='Generated')
|
||||
ax.set_title(f'Shift={shift}\nFID={fid:.2f}')
|
||||
ax.set_xlim(-5, 8); ax.set_ylim(-5, 8)
|
||||
ax.set_aspect('equal'); ax.legend(fontsize=7)
|
||||
plt.suptitle('FID increases as generated distribution diverges from real')
|
||||
plt.tight_layout(); plt.show()
|
||||
# Try: change the variance of generated samples without shifting the mean.
|
||||
# How does FID respond to a diversity mismatch vs a location mismatch?
|
||||
```
|
||||
|
||||
3. 使用随机投影作为 CLIP 的替代,实现文本和图像嵌入之间的 CLIPScore 计算。观察当你改变模态之间的"对齐度"时,余弦相似度如何变化。
|
||||
```python
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
def cosine_similarity(a, b):
|
||||
return jnp.dot(a, b) / (jnp.linalg.norm(a) * jnp.linalg.norm(b))
|
||||
|
||||
def clip_score(img_emb, txt_emb):
|
||||
"""CLIPScore: clamped cosine similarity."""
|
||||
return jnp.maximum(0.0, cosine_similarity(img_emb, txt_emb))
|
||||
|
||||
key = jax.random.PRNGKey(42)
|
||||
dim = 512 # CLIP embedding dimension
|
||||
|
||||
# Simulate aligned and misaligned pairs
|
||||
# Aligned: image and text embeddings share a component
|
||||
k1, k2, k3 = jax.random.split(key, 3)
|
||||
shared = jax.random.normal(k1, (dim,))
|
||||
shared = shared / jnp.linalg.norm(shared)
|
||||
|
||||
noise_levels = jnp.linspace(0, 5, 20)
|
||||
scores = []
|
||||
for noise in noise_levels:
|
||||
noise_vec = jax.random.normal(k2, (dim,)) * noise
|
||||
img_emb = shared + noise_vec * 0.3
|
||||
txt_emb = shared + jax.random.normal(k3, (dim,)) * noise * 0.3
|
||||
scores.append(float(clip_score(img_emb, txt_emb)))
|
||||
|
||||
plt.figure(figsize=(8, 4))
|
||||
plt.plot(noise_levels, scores, 'o-', color='#2c3e50')
|
||||
plt.xlabel('Noise level (misalignment)')
|
||||
plt.ylabel('CLIPScore')
|
||||
plt.title('CLIPScore decreases as text-image alignment degrades')
|
||||
plt.grid(True, alpha=0.3)
|
||||
plt.tight_layout(); plt.show()
|
||||
# Experiment: what happens if you normalise embeddings before adding noise?
|
||||
# How does dimensionality affect the score distribution?
|
||||
```
|
||||
@@ -0,0 +1,322 @@
|
||||
# 统一多模态架构
|
||||
|
||||
*统一多模态架构用单一系统取代了各自为政的专家模型,这个系统能够跨越文本、图像、音频和视频进行读取、推理和生成。本文涵盖了任意到任意模型(CoDi、NExT-GPT)、原生多模态大语言模型(Gemini、GPT-4o)、多模态分词策略,以及统一化所带来的架构权衡。*
|
||||
|
||||
## 统一化的理由
|
||||
|
||||
- 想象一位会说五种语言、能在句子中间无停顿地切换语种的翻译。早期的多模态系统更像是五个坐在不同房间的翻译,每人处理一种语言,通过墙上的小缝隙传递纸条。而**统一多模态架构**就是那一位多语言者:一个共享权重的模型,在单次前向传播中即可跨文本、图像、音频、视频甚至动作进行读取、书写和推理。
|
||||
|
||||
- 其动机既有实用层面的也有理论层面的。在实用层面,维护每对模态的专用专家模型(文本到图像、图像到文本、音频到文本等)会导致组合爆炸:$k$ 种模态需要最多 $k(k-1)$ 个有向流水线。一个统一模型将所有这类流水线坍缩为单一系统。在理论层面,人类认知并非在隔离的模块中处理视觉和语言;跨模态绑定发生得早且深,统一化尝试模仿这一点。
|
||||
|
||||
- 共享权重鼓励**跨模态迁移**。一个已在文本中学到时间模式(主语在动词前、原因在结果前)的 Transformer,可以将同样的注意力电路重新用于视频中的时间模式(对象出现在移动之前)或音频中的时间模式(起音在持续之前)。这是迁移学习的多模态类比——你曾在第 7 章的语言模型微调中和第 8 章的 ImageNet 预训练中见到过。
|
||||
|
||||
- 形式上,令 $\mathcal{M} = \{m_1, m_2, \ldots, m_k\}$ 为一组模态。统一模型定义了一个单一参数化函数 $f_\theta$,它将任意输入模态子集映射到任意输出模态子集:
|
||||
|
||||
$$f_\theta : \mathcal{P}(\mathcal{M}) \rightarrow \mathcal{P}(\mathcal{M})$$
|
||||
|
||||
- 其中 $\mathcal{P}(\mathcal{M})$ 是模态的幂集(所有子集)。关键约束是 $\theta$ 大部分是共享的;只有薄薄的模态特定适配器层有所不同。
|
||||
|
||||

|
||||
|
||||
- 统一化的前景伴随着一个基本张力:模态在结构上是不同的。文本是离散 token 的一维序列。图像是连续像素值的二维网格。音频是一维连续波形,时间尺度与文本截然不同。视频为图像添加了时间轴。将这些迥异的结构调和成单一的、Transformer 能够消化的序列,是该领域核心的工程挑战。
|
||||
|
||||
## 任意到任意模型
|
||||
|
||||
- 想象一个通用遥控器,可以通过同一个界面操作你的电视、空调和音响系统。**任意到任意模型**就是 AI 中的等价物:它们接收任意模态组合作为输入,并产生任意组合作为输出。
|
||||
|
||||
- **CoDi**(Composable Diffusion,可组合扩散)通过训练模态特定的扩散模型,然后通过共享条件机制对齐它们的潜在空间来实现任意到任意生成。每种模态都有其自身的扩散过程(回顾本章文件 04 中的扩散模型),但噪声预测网络被条件化在一个联合交叉注意力层上,该层同时看到所有输入模态的嵌入。这让 CoDi 能够在单次前向传播中,例如从一个文本提示生成图像和匹配的音频。
|
||||
|
||||
- **NExT-GPT** 采用了不同的架构方法。它将 LLM 主干("大脑")通过轻量级的**投影层**连接到输入侧的模态特定编码器和输出侧的模态特定解码器。输入编码器(例如来自 CLIP 的图像编码器、来自 CLAP 的音频编码器)将每种模态翻译成 LLM 的嵌入空间。LLM 对组合后的 token 序列进行推理,并发出特殊的"模态信号 token"来将信息路由到适当的解码器(例如用于图像的 Stable Diffusion、用于音频的 AudioLDM)。只有投影层被训练;LLM 和专家编解码器保持冻结。
|
||||
|
||||
- **Gemini**(Google DeepMind)从预训练阶段起就是原生多模态的。与 NExT-GPT 的即插即用方法不同,Gemini 的 Transformer 从头开始就在文本、图像、音频和视频 token 的交错序列上进行训练。这意味着跨模态注意力模式在预训练期间有机地发展,而不是事后才拼接上去。该模型对文本使用 SentencePiece tokenizer,并学习了一种类似于本章文件 03 中讨论的 VQ 方法的视觉 tokenizer。
|
||||
|
||||
- **GPT-4o**("o"代表"omni",全模态)代表了另一种模式:一个端到端模型,其中所有模态共享同一个 Transformer 和同一个下一 token 预测目标。音频输入作为频谱 token 处理,图像作为块 token,文本作为子词 token,全部送入单一序列。模型生成的输出 token 由模态特定的头部解码。关键创新在于低延迟——通过消除早期系统(如 GPT-4V)所依赖的独立 ASR、LLM 和 TTS 级联而实现。
|
||||
|
||||

|
||||
|
||||
- 这些模型处于集成深度谱系的不同位置:
|
||||
|
||||
- **浅层集成**(NExT-GPT):冻结专家,通过训练适配器连接。构建快速,跨模态推理能力有限。
|
||||
- **中层集成**(CoDi):跨模态特定生成器的共享条件化。对齐更好,仍然模块化。
|
||||
- **深层集成**(Gemini、GPT-4o):在所有模态上端到端训练的单一模型。跨模态推理最丰富,训练成本最高。
|
||||
|
||||
## 共享主干上的模态特定编码器和解码器
|
||||
|
||||
- 想象一家工厂有一条总装线(共享主干),但有不同的原料装卸码头(编码器)和不同的成品发运部门(解码器)。每个码头专精于其货物,但一旦进入工厂内部,所有东西都在同一条传送带上移动。
|
||||
|
||||
- 统一模型的主导架构模式采用这种三部分结构:
|
||||
|
||||
- **模态编码器** $E_m$:将来自模态 $m$ 的原始输入转换为嵌入向量序列 $\mathbf{h}_1^m, \mathbf{h}_2^m, \ldots, \mathbf{h}_{n_m}^m$,每个向量的维度为 $d$。
|
||||
- **共享 Transformer 主干** $T_\theta$:使用自注意力处理来自所有输入模态的拼接或交错嵌入。
|
||||
- **模态解码器** $D_m$:将主干的输出嵌入转换回模态 $m$ 的原生格式(文本 token、图像像素、音频波形)。
|
||||
|
||||
- 对于文本,编码器通常是一个嵌入查找表 $E_\text{text}(w) = \mathbf{W}_e[w]$,其中 $w$ 是 token 索引,与你在第 7 章 Transformer 中看到的相同。对于图像,编码器通常是**视觉 Transformer**(ViT),它将图像分割成块并将每个块线性投影,如第 8 章所述。对于音频,编码器计算梅尔频谱图,然后用卷积前端或音频频谱图 Transformer(AST)处理,如第 9 章所述。
|
||||
|
||||
- 共享主干是一个标准 Transformer,对所有模态 token 进行自注意力。给定一个拼接输入序列 $\mathbf{H} = [\mathbf{h}_1^{m_1}, \ldots, \mathbf{h}_{n_1}^{m_1}, \mathbf{h}_1^{m_2}, \ldots, \mathbf{h}_{n_2}^{m_2}]$,自注意力允许每个 token 关注所有其他 token,无论其模态如何:
|
||||
|
||||
$$\text{Attention}(\mathbf{Q}, \mathbf{K}, \mathbf{V}) = \text{softmax}\left(\frac{\mathbf{Q}\mathbf{K}^\top}{\sqrt{d_k}}\right)\mathbf{V}$$
|
||||
|
||||
- 这与第 7 章中的注意力公式相同,但现在 $\mathbf{Q}$、$\mathbf{K}$ 和 $\mathbf{V}$ 包含来自多种模态的 token。图像块 token 可以关注文本 token,从而无需单独的交叉注意力模块即可实现跨模态推理。
|
||||
|
||||
- **模态嵌入**被添加到每个 token 上,以便主干知道 token 来自哪种模态。这类似于位置嵌入,但编码的是模态身份而非序列位置。一个可学习的向量 $\mathbf{e}_m \in \mathbb{R}^d$ 被添加到每个来自模态 $m$ 的 token 上:
|
||||
|
||||
$$\tilde{\mathbf{h}}_i^m = \mathbf{h}_i^m + \mathbf{e}_m + \mathbf{p}_i$$
|
||||
|
||||
- 其中 $\mathbf{p}_i$ 是位置 $i$ 的位置嵌入。
|
||||
|
||||

|
||||
|
||||
## 多模态分词
|
||||
|
||||
- 想象你在写一封信,信中既有英文文本又有手绘草图。你可能写一个句子,画一个图表,再写一个引用该图表的句子,然后贴上一段乐谱。这封信就是一个线性流,交错着不同的"模态"。多模态分词做的正是这件事:它将文本、图像、音频和视频转换成单一的扁平 token 序列,由 Transformer 从左到右处理。
|
||||
|
||||
- 对于文本,分词技术已经很成熟:**字节对编码**(BPE)或 SentencePiece 产生子词 token 的词汇表,如第 7 章所述。挑战在于将这一思想扩展到连续模态。
|
||||
|
||||
- 对于图像,有两种主要方法。**离散**方法使用 VQ-VAE 或 VQ-GAN(详见本章文件 03)将每幅图像映射为码本索引序列。如果码本有 $|\mathcal{C}|$ 个条目且一幅图像编码为 $n$ 个码字,则该图像变为 $n$ 个离散 token,取自大小为 $|\mathcal{C}|$ 的词汇表,直接与文本词汇表兼容。**连续**方法使用 ViT 或 CNN 编码器产生 $n$ 个连续嵌入向量,然后线性投影到 Transformer 的嵌入维度中。Gemini 和 GPT-4o 使用连续方法的变体;自回归图像生成器如 Parti 和 LlamaGen 则偏好离散路线。
|
||||
|
||||
- 对于音频,信号通常被转换为梅尔频谱图,然后要么通过神经音频编解码器(例如 EnCodec、SoundStream,它们产生层次化的离散 token)进行离散化,要么通过学习的编码器进行连续投影。例如,AudioLM 将音频表示为来自多个码本层次的离散 token 序列,然后以自回归方式对其进行建模。
|
||||
|
||||
- 对于视频,分词建立在图像分词的基础上,但还必须压缩时间维度。一种常见策略使用**3D VQ-VAE**(如文件 03 中的 VideoGPT 或 Cosmos Tokenizer)将时空块量化为离散 token。时间压缩因子至关重要:未经激进的时间下采样,24 fps 的原始视频每秒产生的 token 数量太多。
|
||||
|
||||
- 一旦所有模态都被分词化,它们就被**交错**成单一序列,并带有标记模态边界的特殊分隔 token。一个典型格式如下:
|
||||
|
||||
```
|
||||
[TEXT] 猫坐在垫子上 [/TEXT] [IMAGE] <img_tok_1> <img_tok_2> ... <img_tok_n> [/IMAGE] [AUDIO] <aud_tok_1> ... <aud_tok_m> [/AUDIO]
|
||||
```
|
||||
|
||||
- Transformer 然后使用其标准因果(或双向)注意力机制处理整个混合序列。模态分隔 token 起到双重作用:它们向模型告知模态边界,并充当"汇聚点",其表示概括了每个模态段。
|
||||
|
||||

|
||||
|
||||
- 一个关键的设计选择是**token 预算**。一张被分词为 256 个 token 的图像加上 50 个 token 的文本描述,意味着图像消耗的上下文窗口是文本的 5 倍。模型必须在分辨率(更多 token = 更多细节)和上下文长度(更多 token = 更高的内存和计算成本)之间取得平衡。**token 合并**(逐渐合并相似 token)和**自适应分词**(对简单区域使用较少的 token,对复杂区域使用更多 token)等技术有助于管理这种权衡。
|
||||
|
||||
## 训练配方:分阶段预训练与联合微调
|
||||
|
||||
- 你不会在教孩子算术之前就教他微积分。同样,你不能从随机初始化开始,在所有模态上同时训练一个统一多模态模型,并期望它能很好地收敛。主导方法是**分阶段训练**,其中模型在精心排序的阶段中逐步学习越来越复杂的跨模态能力。
|
||||
|
||||
- **阶段 1:单模态预训练。** 每个模态编码器在大型单模态数据集上独立训练。文本主干使用标准语言建模目标(下一 token 预测)在数万亿文本 token 上进行预训练,正如第 7 章一样。视觉编码器在图像分类或自监督目标(MAE、DINO)上预训练,如第 8 章所述。音频编码器在语音识别或音频分类数据上预训练,如第 9 章所述。这一阶段产生了强大的单模态特征提取器。
|
||||
|
||||
- **阶段 2:跨模态对齐。** 预训练的编码器连接到共享主干,模型在成对的多模态数据(图像-描述对、音频-文本对)上使用对比或生成目标进行训练。在此阶段,编码器权重可能被冻结(以保留单模态知识),仅更新投影层和主干。这是来自本章文件 01 的 CLIP 风格对齐被纳入统一模型的阶段。
|
||||
|
||||
- **阶段 3:联合多模态预训练。** 所有参数(或大部分)被解冻,模型在单模态和多模态数据的混合上训练,使用对所有模态 token 的单一下一 token 预测目标。损失函数为:
|
||||
|
||||
$$\mathcal{L} = -\sum_{t=1}^{T} \log p_\theta(x_t \mid x_{<t})$$
|
||||
|
||||
- 其中 $x_t$ 可以是文本 token、图像 token 或音频 token。模型必须学会预测下一个 token,无论其模态如何,这迫使它发展真正的跨模态理解。
|
||||
|
||||
- **阶段 4:指令微调与对齐。** 预训练模型在精心策划的指令遵循数据集上进行微调,这些数据集包括多模态指令(例如,"详细描述这幅图像"、"这段视频发出什么声音?"、"生成一张 X 的图像")。这一阶段通常使用**基于人类反馈的强化学习**(RLHF)或直接偏好优化(DPO)来使模型的输出与人类偏好对齐。
|
||||
|
||||
- **模态特定热身**是一种在阶段内部使用的技术,用于防止模态坍缩。如果一种模态(通常是文本,因为它拥有最多的训练数据)主导了梯度信号,模型可能会"遗忘"较弱的模态。热身策略包括:
|
||||
|
||||
- **梯度平衡**:缩放来自每种模态的梯度,使其对参数更新有均等贡献。
|
||||
- **数据比例调度**:逐步增加多模态数据相对于单模态数据的比例。
|
||||
- **损失加权**:分配模态特定的权重 $\lambda_m$,使总损失为 $\mathcal{L} = \sum_m \lambda_m \mathcal{L}_m$,其中 $\lambda_m$ 经过调整以平衡各模态的学习率。
|
||||
|
||||

|
||||
|
||||
- **为什么不跳过阶段?** 从头开始联合训练所有内容很诱人,但在实践中由于几个原因而失败。首先,模型必须同时学习低级特征(边缘检测、音素识别)和高级跨模态推理,两者具有非常不同的学习动态。其次,跨模态的数据分布极不平衡(数万亿文本 token 对比数十亿图像 token 对比数亿音频片段)。第三,优化景观高度非凸,分阶段训练提供了一个课程表,引导模型走向更好的盆地,类似于第 6 章讨论的课程学习理念。
|
||||
|
||||
## 多模态思维链推理
|
||||
|
||||
- 当你解决一个几何问题时,你可能会画一个示意图,标注角度,写出方程,然后逐步求解。你不会直接从问题陈述跳到答案。**多模态思维链**(CoT)推理使模型能够做同样的事情:在得出最终答案之前生成可能涉及文本、视觉注释甚至生成图表的中间推理步骤。
|
||||
|
||||
- 在纯文本 CoT 中(如第 7 章提示策略的讨论中所探讨的),模型以自然语言生成推理步骤序列。多模态 CoT 扩展了这一能力,允许中间步骤引用或生成视觉内容。例如,给定一张图表图像和问题"哪一年销售额最高?",多模态 CoT 模型可能首先描述图表("该图表显示 2018 年至 2023 年的销售额……"),然后识别相关的视觉特征("最高的条形出现在 2021 年……"),最后输出答案("2021 年")。
|
||||
|
||||
- 形式上,令 $\mathbf{x}$ 为多模态输入,$y$ 为目标答案。标准预测模型直接建模 $p(y \mid \mathbf{x})$。思维链引入了中间推理 $\mathbf{r} = (r_1, r_2, \ldots, r_L)$ 并将预测分解为:
|
||||
|
||||
$$p(y \mid \mathbf{x}) = \sum_{\mathbf{r}} p(y \mid \mathbf{r}, \mathbf{x}) \cdot p(\mathbf{r} \mid \mathbf{x})$$
|
||||
|
||||
- 在实践中,求和通过贪心或束搜索解码在推理链上近似。推理步骤 $r_i$ 可以是文本 token、对图像区域的引用,甚至是生成的视觉 token(例如,叠加在输入图像上的边界框注释)。
|
||||
|
||||
- **训练多模态 CoT** 通常涉及策划数据集,其中人类标注者提供逐步的多模态推理轨迹,然后在此类轨迹上微调模型。一些方法从更大的教师模型中蒸馏 CoT 能力:教师为大型数据集生成推理轨迹,较小的学生模型则在输入和教师的轨迹上进行训练。
|
||||
|
||||
- 多模态 CoT 对于需要**空间推理**(例如,"红色球在蓝色立方体的左边吗?")、**对图表的数学推理**(例如,几何问题)和**多步视觉问答**(答案依赖于组合图像多个区域的信息)的任务尤其强大。
|
||||
|
||||
## 多模态智能体
|
||||
|
||||
- 想象厨房里的一个机器人厨师。它查看台面上的食材(视觉),阅读平板上的食谱(文本),听计时器的哔哔声(音频),然后物理上拿起刀并切洋葱(动作)。**多模态智能体**就是数字版:一个通过多种模态感知世界、推理该做什么、并执行基于其感知的动作的模型。
|
||||
|
||||
- 智能体循环遵循经典的**观察-推理-行动**周期:
|
||||
|
||||
1. **观察**:智能体从其环境接收多模态输入(截图、用户的口头指令、视频流)。
|
||||
2. **推理**:统一模型处理多模态输入,可能使用思维链来规划步骤序列。
|
||||
3. **行动**:模型输出一个动作(文本回复、工具调用、坐标为 $(x, y)$ 的鼠标点击、机器人电机指令)。
|
||||
|
||||
- **工具使用**是多模态智能体的一个关键能力。模型被训练识别何时无法直接回答问题,而必须调用外部工具:计算器、代码解释器、网页浏览器或搜索引擎。模型在其输出 token 序列中生成结构化的工具调用(例如,`search("伦敦当前天气")`),系统执行调用,并将结果作为额外的输入 token 反馈给模型处理。
|
||||
|
||||
- **视觉接地**将语言连接到图像或视频中的特定区域。当智能体说"点击右上角的蓝色按钮"时,它必须将短语"右上角的蓝色按钮"接地到像素坐标。在架构上,这是通过训练模型将边界框坐标作为特殊 token 输出,或让模型在图像上生成指示所指区域的热图来实现的。这将本章文件 02(视觉语言模型)中讨论的接地和指代工作扩展到了动作领域。
|
||||
|
||||
- **Web 智能体**如 WebVoyager 和 SeeAct 展示了多模态智能体在网站上导航。智能体接收网页截图,识别交互元素(按钮、文本字段、链接),并输出动作(点击、打字、滚动)以完成用户指定的目标。关键挑战在于巨大的动作空间:一个典型网页可能有数百个可点击目标。
|
||||
|
||||

|
||||
|
||||
- **具身智能体**将其扩展到物理环境。带有摄像头和麦克风的机器人接收视觉和音频输入,通过统一模型处理,并输出电机指令。像 PaLM-E(Google)这样的项目将机器人传感器数据直接嵌入语言模型的 token 序列中,使机器人能够通过将指令接地到其视觉观察中并生成一系列电机动作,来遵循诸如"拿起碗附近的绿色方块"之类的指令。
|
||||
|
||||
- 智能体的训练配方在标准分阶段预训练之上添加了一个**强化学习**(RL)阶段。智能体与环境(模拟桌面、网页浏览器、机器人模拟器)交互,因完成任务而获得奖励,并使用 PPO 或 REINFORCE 等算法更新其策略。奖励信号通常是稀疏的(任务成功为 1,否则为 0),使得这一优化具有挑战性,并且高度依赖于多模态预训练的强先验。
|
||||
|
||||
## 基准测试与评估
|
||||
|
||||
- 评估一个能看见、听见、阅读和行动的模型需要一套多样化的基准测试。没有单一指标能够捕捉多模态能力,因此该领域依赖于一组专门评估的集合。
|
||||
|
||||
- **MMLU**(大规模多任务语言理解)测试 57 个学术科目的知识。虽然最初是纯文本的,但它作为基线:一个统一多模态模型在获得视觉能力时不应丢失纯文本性能。多模态训练后 MMLU 的下降标志着灾难性遗忘。
|
||||
|
||||
- **MMBench** 评估跨 20 个细粒度能力维度的视觉语言理解,包括属性识别、空间关系理解和 OCR。每个问题呈现一幅图像和一个多项选择题。该基准系统地测试模型是否真正理解图像,还是依赖于纯文本的捷径。
|
||||
|
||||
- **SEED-Bench** 提供 19,000 个多项选择题,跨越图像和视频理解的 12 个评估维度。它特别测试时间理解(给定帧之前/之后发生了什么)和组合推理(组合多个视觉属性)。
|
||||
|
||||
- **MM-Vet** 通过要求模型同时使用多种技能来评估集成的多模态能力:识别、OCR、空间意识、语言生成和知识检索,全部在单一问题中。
|
||||
|
||||
- **MathVista** 测试对视觉输入的数学推理:几何图、统计图表、函数图和科学图形。该基准专门针对多模态思维链能力。
|
||||
|
||||
- **音视频基准**如 AVQA(音视频问答)测试模型是否能推理它们所看到和所听到之间的关系。例如:"说话的人是左边的还是右边的?"
|
||||
|
||||
- **智能体基准**如 WebArena、OSWorld 和 SWE-bench 评估在交互式环境中的任务完成情况。指标通常是成功率:智能体正确完成任务的占比是多少?这些基准特别具有挑战性,因为它们需要长视野规划和错误恢复。
|
||||
|
||||
- **全面评估**框架如 LMSYS Chatbot Arena 使用人在头对头格式中的偏好判断。两个模型被展示相同的多模态输入,人类评委选择哪个响应更好。Elo 评分从数千次这样的比较中计算得出,提供了一个与整体模型质量高度相关的单一标量。
|
||||
|
||||
- 多模态评估中的一个持续挑战是**数据污染**:因为这些模型是在互联网规模的数据上训练的,基准图像和问题可能出现在训练集中。仔细的去重和创建保留测试集是必要但不完美的保障措施。
|
||||
|
||||
## 世界模型
|
||||
|
||||
- 想象闭上眼睛,想象如果你把一个玻璃杯推下桌子边缘会发生什么。你"看到"它落下,"听到"破碎声,并"感觉"到那将是个坏主意。你的大脑正在运行一个**世界模型**:对环境的物理和因果结构的内部模拟,能够跨多种模态预测未来状态。
|
||||
|
||||
- 在 AI 语境中,世界模型是一个学习到的函数,根据当前状态和动作预测世界的下一个状态:
|
||||
|
||||
$$\hat{s}_{t+1} = g_\phi(s_t, a_t)$$
|
||||
|
||||
- 其中 $s_t$ 是当前状态表示(可能包含视觉、听觉和本体感觉信息),$a_t$ 是一个动作,$\hat{s}_{t+1}$ 是预测的下一个状态。状态 $s_t$ 存在于学习到的潜在空间中,而非原始像素空间,使得预测问题可解。
|
||||
|
||||
- **视频预测模型**如 Sora(OpenAI)和 Genie(Google DeepMind)代表了迈向世界模型的重要一步。它们学习生成以文本提示和/或动作序列为条件的、时间上连贯的视频帧。虽然它们通常被作为视频生成器讨论,但底层的技术能力更接近于世界模拟:模型已经内化了足够的物理知识(重力、碰撞、遮挡、流体动力学)来渲染合理的未来场景。
|
||||
|
||||
- 与多模态架构的联系很深。一个只预测像素的世界模型是有限的;一个真正有用的世界模型应该跨模态预测。如果你推玻璃杯,世界模型应该预测视觉轨迹(玻璃杯落下)、听觉事件(玻璃杯破碎)和语义后果(现在地板上有碎玻璃)。统一多模态架构是世界模型的天然后选者,因为它们已经在共享空间中表示所有模态。
|
||||
|
||||
- 形式上,多模态世界模型优化:
|
||||
|
||||
$$\mathcal{L}_\text{world} = \mathbb{E}\left[\sum_{m \in \mathcal{M}} \lambda_m \| s_{t+1}^m - g_\phi^m(s_t, a_t) \|^2 \right]$$
|
||||
|
||||
- 其中 $s_{t+1}^m$ 是模态 $m$ 中的真实下一状态表示,$g_\phi^m$ 是世界模型的模态特定预测头。共享的潜在动态 $g_\phi$ 在联合多模态空间中运行,而模态特定的头则将预测解码为每种模态的原生格式。
|
||||
|
||||

|
||||
|
||||
- **JEPA**(联合嵌入预测架构),由 Yann LeCun 提出,提供了一个避免像素级预测陷阱的世界模型框架。JEPA 不是在原始像素层面预测(这会将容量浪费在无关细节如精确纹理上),而是在嵌入空间中进行预测。模型学习一个将观测映射到嵌入的编码器,以及一个预测未来嵌入的预测器:
|
||||
|
||||
$$\hat{\mathbf{z}}_{t+1} = h_\psi(\mathbf{z}_t, a_t), \quad \mathbf{z}_t = \text{Enc}(s_t)$$
|
||||
|
||||
- 损失函数比较的是嵌入而非原始观测,这对感知混叠(许多不同的像素配置可能代表相同的语义状态)更加鲁棒。这种方法对多模态世界模型尤其有前景,因为它自然地运行在统一架构已经提供的共享嵌入空间中。
|
||||
|
||||
- 世界模型有超越学术兴趣的实际应用。在**基于模型的强化学习**中,智能体在采取行动之前使用其世界模型来"想象"行动的后果,大大减少了所需的真实世界交互次数(回顾第 11 章对基于模型 RL 的讨论)。在**自动驾驶**中,世界模型预测在给定不同转向决策后场景在未来几秒内将如何演变。在**机器人学**中,世界模型允许机器人在执行操作序列之前在头脑中进行排练。
|
||||
|
||||
- 世界模型研究的前沿正朝着**交互式世界模型**发展,这些模型实时运行,响应任意用户动作,本质上成为完全从数据中学习得到的通用模拟器。Genie 2(Google DeepMind)为 3D 环境演示了这一点:给定一张图像,它生成一个交互式的、可控的 3D 世界,用户可以探索。世界模型与统一多模态架构的融合表明,未来一个单一模型能够跨所有模态进行感知、预测、模拟和行动。
|
||||
|
||||
## 编程任务(使用 CoLab 或 notebook)
|
||||
|
||||
**任务 1:构建一个最小化的多模态 token 交错器**
|
||||
|
||||
- 编写一个函数,接收一个文本字符串和一个虚拟的"图像"(一个小型 2D 数组),并将它们的 token 化表示交错成一个带有模态嵌入的单一扁平序列。
|
||||
|
||||
```python
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
|
||||
# 模拟多模态分词:文本 token + "图像块" token
|
||||
def interleave_modalities(text_tokens, image_patches, embed_dim=32, key=jax.random.PRNGKey(0)):
|
||||
"""将文本和图像 token 与学习到的模态嵌入交错。"""
|
||||
k1, k2, k3 = jax.random.split(key, 3)
|
||||
n_text = text_tokens.shape[0]
|
||||
n_img = image_patches.shape[0]
|
||||
# 随机投影矩阵(替代真实编码器)
|
||||
W_text = jax.random.normal(k1, (text_tokens.shape[-1], embed_dim)) * 0.02
|
||||
W_img = jax.random.normal(k2, (image_patches.shape[-1], embed_dim)) * 0.02
|
||||
# 模态嵌入:一个用于文本,一个用于图像
|
||||
mod_emb = jax.random.normal(k3, (2, embed_dim)) * 0.02
|
||||
text_embs = text_tokens @ W_text + mod_emb[0] # (n_text, embed_dim)
|
||||
img_embs = image_patches @ W_img + mod_emb[1] # (n_img, embed_dim)
|
||||
# 交错:[IMG] token 在前,然后是 [TEXT] token(像 LLaVA)
|
||||
combined = jnp.concatenate([img_embs, text_embs], axis=0)
|
||||
print(f"组合序列: {n_img} 图像 + {n_text} 文本 = {combined.shape[0]} tokens")
|
||||
return combined
|
||||
|
||||
# 尝试:5 个文本 token(dim 16)和 4 个图像块(dim 64)
|
||||
text = jax.random.normal(jax.random.PRNGKey(1), (5, 16))
|
||||
image = jax.random.normal(jax.random.PRNGKey(2), (4, 64))
|
||||
seq = interleave_modalities(text, image)
|
||||
# 实验:改变 embed_dim,交换交错顺序,添加第三个模态
|
||||
```
|
||||
|
||||
**任务 2:可视化跨模态注意力模式**
|
||||
|
||||
- 创建一个合成的多模态序列,计算自注意力分数,观察图像 token 如何关注文本 token,反之亦然。
|
||||
|
||||
```python
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
def cross_modal_attention(n_text=6, n_img=4, d=32, key=jax.random.PRNGKey(42)):
|
||||
"""计算并可视化文本和图像 token 之间的注意力。"""
|
||||
k1, k2, k3 = jax.random.split(key, 3)
|
||||
# 模拟两种模态的 token 嵌入
|
||||
text_embs = jax.random.normal(k1, (n_text, d))
|
||||
img_embs = jax.random.normal(k2, (n_img, d))
|
||||
seq = jnp.concatenate([img_embs, text_embs], axis=0) # (n_img+n_text, d)
|
||||
# 学习到的 Q, K 投影
|
||||
Wq = jax.random.normal(k3, (d, d)) * 0.1
|
||||
Wk = jax.random.normal(jax.random.PRNGKey(99), (d, d)) * 0.1
|
||||
Q, K = seq @ Wq, seq @ Wk
|
||||
scores = Q @ K.T / jnp.sqrt(d)
|
||||
attn = jax.nn.softmax(scores, axis=-1)
|
||||
# 绘图
|
||||
labels = [f"img_{i}" for i in range(n_img)] + [f"txt_{i}" for i in range(n_text)]
|
||||
fig, ax = plt.subplots(figsize=(7, 6))
|
||||
ax.imshow(attn, cmap="viridis")
|
||||
ax.set_xticks(range(len(labels))); ax.set_xticklabels(labels, rotation=45, fontsize=8)
|
||||
ax.set_yticks(range(len(labels))); ax.set_yticklabels(labels, fontsize=8)
|
||||
ax.set_xlabel("Key(被关注的)"); ax.set_ylabel("Query(发起的)")
|
||||
ax.set_title("跨模态自注意力图")
|
||||
plt.colorbar(ax.images[0], ax=ax, shrink=0.8)
|
||||
plt.tight_layout(); plt.show()
|
||||
|
||||
cross_modal_attention()
|
||||
# 实验:增大 d,添加因果掩码,观察注意力模式如何变化
|
||||
```
|
||||
|
||||
**任务 3:模拟带有模态特定损失权重的分阶段训练**
|
||||
|
||||
- 演示模态特定的损失权重如何影响玩具多模态训练循环。观察平衡损失如何防止一种模态主导训练。
|
||||
|
||||
```python
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
def staged_training_sim(steps=200, key=jax.random.PRNGKey(7)):
|
||||
"""模拟具有可调节模态损失权重的多模态训练。"""
|
||||
# 两种"模态",损失尺度不同(文本损失比图像损失大约 10 倍)
|
||||
losses_text, losses_img = [], []
|
||||
param = jnp.array([0.0, 0.0]) # 两种模态损失共同更新的共享参数
|
||||
lr = 0.05
|
||||
# 尝试更改这些权重以观察对收敛平衡的影响
|
||||
lambda_text, lambda_img = 1.0, 5.0 # 对较弱模态加大权重
|
||||
|
||||
for step in range(steps):
|
||||
k1, k2, key = jax.random.split(key, 3)
|
||||
noise_t = jax.random.normal(k1, ()) * 0.3
|
||||
noise_i = jax.random.normal(k2, ()) * 0.1
|
||||
loss_t = (param[0] - 3.0) ** 2 + noise_t # 文本目标 = 3.0
|
||||
loss_i = 0.1 * (param[1] - 1.0) ** 2 + noise_i # 图像目标 = 1.0(尺度更小)
|
||||
# 加权组合梯度
|
||||
grad_t = lambda_text * 2 * (param[0] - 3.0)
|
||||
grad_i = lambda_img * 0.2 * (param[1] - 1.0)
|
||||
param = param - lr * jnp.array([grad_t, grad_i])
|
||||
losses_text.append(float(loss_t)); losses_img.append(float(loss_i))
|
||||
|
||||
fig, ax = plt.subplots(figsize=(8, 4))
|
||||
ax.plot(losses_text, label=f"文本损失 (权重={lambda_text})", alpha=0.7)
|
||||
ax.plot(losses_img, label=f"图像损失 (权重={lambda_img})", alpha=0.7)
|
||||
ax.set_xlabel("训练步数"); ax.set_ylabel("损失"); ax.legend()
|
||||
ax.set_title("分阶段训练中的模态损失平衡")
|
||||
plt.tight_layout(); plt.show()
|
||||
|
||||
staged_training_sim()
|
||||
# 实验:设置 lambda_img=1.0,观察图像损失收敛慢得多
|
||||
```
|
||||
Reference in New Issue
Block a user