feat: 完整中文翻译 maths-cs-ai-compendium(数学·计算机科学·AI 知识大全)

翻译自英文原版 maths-cs-ai-compendium,共 20 章全部完成。

第01章 向量 | 第02章 矩阵 | 第03章 微积分
第04章 统计学 | 第05章 概率论 | 第06章 机器学习
第07章 计算语言学 | 第08章 计算机视觉 | 第09章 音频与语音
第10章 多模态学习 | 第11章 自主系统 | 第12章 图神经网络
第13章 计算与操作系统 | 第14章 数据结构与算法
第15章 生产级软件工程 | 第16章 SIMD与GPU编程
第17章 AI推理 | 第18章 ML系统设计
第19章 应用人工智能 | 第20章 前沿人工智能

翻译说明:
- 所有数学公式 $...$ / $$...$$、代码块、图片引用完整保留
- mkdocs.yml 配置中文导航 + language: zh
- README.md 已翻译为中文(兼 docs/index.md)
- docs/ 目录包含指向各章文件的 symlink
- 约 29,000 行中文内容,排除 .cache/ 构建缓存
This commit is contained in:
2026-05-03 10:23:20 +08:00
commit 2536c937e3
400 changed files with 49040 additions and 0 deletions
@@ -0,0 +1,170 @@
# 几何深度学习
*几何深度学习是揭示CNN、Transformer和GNN皆遵循同一原理——利用对称性——的统一框架。本章涵盖对称群、群作用、不变性、等变性、五个几何域以及尺度分离*
- 在本书中,我们已经学习了多种架构:图像的CNN(第8章)、语言的Transformer(第7章)以及序列决策的RL策略(第6章)。它们看上去像是为完全不同的问题设计的完全不同的模型。但背后存在一个更深层的模式。
- **几何深度学习**揭示出所有这些架构都是同一个思想的实例:构建尊重数据**对称性**的网络。CNN利用图像中的平移对称性。Transformer利用序列中的置换对称性(注意力不依赖于绝对位置)。GNN利用图中的置换对称性。一旦看清这一点,众多架构就变成了一个统一的连贯框架。
## 对称性与群
- 一个对象的**对称性**是使其保持不变的变换。正方形有8种对称性:4种旋转(0°、90°、180°、270°)和4种反射。圆有无限多种:任何绕其中心的旋转。关键洞察在于,对称性告诉你什么是不重要的,而知道什么不重要的对于学习来说极为强大。
- 用机器学习的术语来说:如果一个任务具有对称性,那么无论看到输入的哪种"版本",模型都应给出相同的答案。猫检测器无论猫在图像的左上角还是右下角都应能工作。这就是平移对称性。
- 对称性通过**群**来形式化。一个群 $G$ 是一个具有四个性质的变换集合:
- **封闭性**:两个变换的组合产生集合中的另一个变换。先旋转90°再旋转90°得到180°,也属于该集合。
- **结合律**$(g_1 \circ g_2) \circ g_3 = g_1 \circ (g_2 \circ g_3)$。分组的顺序无关紧要(回顾第2章中矩阵乘法的结合律)。
- **单位元**:存在一个"什么也不做"的变换 $e$,使得 $e \circ g = g \circ e = g$。
- **逆元**:每个变换都有撤销操作:$g \circ g^{-1} = e$。
- 这些公理与向量空间(第1章)的公理相同,但应用于变换而非向量。其联系十分深刻:群作用于向量空间,而神经网络必须尊重这种作用。
- 深度学习中出现的关键群:
- **平移群** $(\mathbb{R}^n, +)$:平移图像或信号。这是CNN利用的对称性。
- **对称群** $S_n$:$n$ 个元素的所有置换。这是GNN和Transformer利用的对称性(重新排序节点或标记不应改变结果)。
- **旋转群** $SO(n)$$n$ 维空间中的所有旋转。$SO(2)$ 是平面旋转,$SO(3)$ 是三维旋转(对分子和3D视觉任务至关重要)。
- **欧几里得群** $E(n)$:所有旋转、反射和平移。物理空间的对称性。
- **特殊欧几里得群** $SE(n)$:旋转和平移(不含反射)。刚体运动的对称性。
- **群作用**描述了群如何变换数据。如果 $G$ 是一个群,$X$ 是数据空间,则作用 $\rho: G \times X \to X$ 将每个群元素 $g$ 和数据点 $x$ 映射到一个变换后的点 $\rho(g, x)$。对于图像,平移群通过平移像素坐标来作用。对于图,对称群通过重新标记节点来作用。
## 不变性与等变性
- 给定一个对称群,函数可以通过两种重要方式与之关联:
- 函数 $f$ 对群 $G$ 是**不变**的,如果输入变换后输出不变:
$$f(\rho(g, x)) = f(x) \quad \text{对于所有 } g \in G$$
- 示例:图像的总体亮度不因平移而改变。图像分类应是平移不变的:"猫"的类别无论猫在何处都是一样的。
- 函数 $f$ 对群 $G$ 是**等变**的,如果变换输入会对等地变换输出:
$$f(\rho_{\text{in}}(g, x)) = \rho_{\text{out}}(g, f(x)) \quad \text{对于所有 } g \in G$$
- 示例:如果将图像向右平移5个像素,CNN中的特征图也会向右平移5个像素。卷积操作是平移等变的:它保留了空间关系。目标检测应该是等变的:如果猫移动了,边界框也应随之移动。
![不变性:输出不随变换而改变。等变性:输出随之对应变换](../images/invariance_vs_equivariance.svg)
- 区分两者的重要性在于:**中间层**通常应是等变的(为下游层保留结构),而**最终输出**应是不变的(答案不应依赖于变换)。CNN通过堆叠等变卷积层,然后在末尾应用全局池化(它是不变的)来实现这一点。
- 将等变性构建到架构中比从数据中学习它要高效得多。一个具有权重共享的平移等变CNN所需的参数远少于一个必须独立学习"位置(10,10)处的猫"和"位置(200,150)处的猫"的全连接网络。对称性约束指数级地缩小了假设空间。
## 五个几何域
- 几何深度学习识别出数据的**五个基本域**,每个域都有其自己的对称群。每一个神经网络架构都可以被理解为利用其中某个域的对称性。
![五个几何域:网格、集合、序列、图和流形,各有其对称性和架构](../images/five_geometric_domains.svg)
- **1. 网格(欧几里得数据)**:图像、音频频谱图、体数据。底层结构是具有平移对称性的规则网格。群是平移群(可能再加上旋转和反射)。利用这种对称性的架构是**CNN**:卷积正是平移等变的操作。空间位置上的权重共享就是平移等变性的具体实现。
- **2. 集合(无序集合)**:点云、粒子系统。对称性是置换不变性:元素的顺序无关紧要。架构是**DeepSets**(以及第8章的PointNet):对每个元素应用共享函数,然后用置换不变操作(求和、均值或取最大值)进行聚合。形式上,$f(\{x_1, \ldots, x_n\}) = \phi\left(\sum_i \psi(x_i)\right)$。
- **3. 序列(有序数据)**:文本、时间序列。序列是一维网格,但有一个微妙之处:对称性更加细致。绝对位置可能重要也可能不重要。RNN以自回归方式处理序列。带位置编码的Transformer可以关注任何位置,其自注意力在加入位置编码之前是置换等变的。这就是Transformer泛化能力如此之强的原因:它们从置换等变开始,然后仅添加必要的位置结构。
- **4. 图(关系数据)**:社交网络、分子、知识图谱。对称性是节点的置换:重新标记节点不应改变图的性质。架构是**GNN**:连接节点之间传递消息,使用不依赖于节点顺序的共享函数。这是本章剩余部分的重点。
- **5. 流形和网格**:曲面、3D形状。对称性包括微分同胚(光滑变形)。架构使用内在算子(例如拉普拉斯-贝尔特拉米算子),这些算子由曲面几何本身定义,与曲面在空间中的嵌入方式无关。这联系到微分几何,并适用于形状分析、球面上的气候建模和蛋白质表面分析。
- 这个框架的强大之处在于其统一性。CNN是网格图上的GNN。Transformer是完全连接图上的GNN。DeepSets是没有边的GNN。将这些视为同一原理的实例,指导着新架构的设计:识别数据的对称性,然后构建一个尊重它的网络。
## 尺度分离与粗化
- 真实世界的数据具有多尺度结构。一幅图像有细粒度纹理(像素级)、局部模式(边缘、角点)、物体部件(车轮、窗户)和全局结构(整个场景)。一个分子有原子级特征、官能团和整体分子形状。
- **尺度分离**是这样一个原理:这些细节层次可以分层处理——先捕获局部结构,然后逐步聚合成更粗粒度的表示。这就是**粗化**或**池化**。
- 在CNN中,池化层(最大池化、平均池化)对空间分辨率进行下采样,迫使高层捕获更大尺度的模式。在感受野视角(第8章)中,更深层能"看到"更多的图像。这就是尺度分离的实际应用。
- 在图(graph)中,粗化意味着将节点群聚为"超节点",生成一个保留基本结构的更小图。这就是图池化,我们将在文件3中详细讨论。它与图像池化直接类似:降低分辨率的同时保留重要特征。
- 在序列中,分层处理(例如句子→段落→文档)在不同时间或语义尺度捕获结构。Swin Transformer(第8章)通过其移位窗口层次结构将这一思想应用于图像。
- 数学上,粗化定义了一个**逐渐抽象的表示层次**:
$$x \xrightarrow{\text{局部特征}} h^{(1)} \xrightarrow{\text{粗化}} h^{(2)} \xrightarrow{\text{粗化}} \cdots \xrightarrow{\text{全局}} y$$
- 在每个层次,表示相对于该层次的对称群是等变的。最后的全局表示是不变的,捕获了输入的本质而不受无关变换的影响。
- 这就是为什么对于结构化数据,深层网络比浅层网络效果更好:每一层增加一个抽象层次,多个等变层的组合从简单的局部特征构建出复杂的不变特征。
## 编程任务(使用CoLab或notebook
1. 验证卷积的平移等变性。对图像应用卷积,然后平移图像再次卷积。检查输出是否互为平移版本。
```python
import jax
import jax.numpy as jnp
# 一维信号和一个简单滤波器
signal = jnp.array([0, 0, 0, 1, 2, 3, 2, 1, 0, 0, 0], dtype=float)
kernel = jnp.array([1, 0, -1], dtype=float)
# 先卷积再平移
conv_result = jnp.convolve(signal, kernel, mode="same")
shifted_signal = jnp.roll(signal, 3)
conv_shifted = jnp.convolve(shifted_signal, kernel, mode="same")
shifted_conv = jnp.roll(conv_result, 3)
print(f"先卷积再平移: {shifted_conv}")
print(f"先平移再卷积: {conv_shifted}")
print(f"等变性: {jnp.allclose(shifted_conv, conv_shifted, atol=1e-5)}")
```
2. 验证DeepSets风格聚合的置换不变性。对集合中的每个元素应用共享函数,求和结果,并检查输出是否不依赖于元素顺序。
```python
import jax
import jax.numpy as jnp
# 4个向量的"集合"(顺序应无关紧要)
x = jnp.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]])
# 简单的共享函数:逐元素平方
psi = lambda v: v ** 2
# 通过求和聚合
def deepsets(points):
return jnp.sum(jax.vmap(psi)(points), axis=0)
# 原始顺序
result1 = deepsets(x)
# 置换后的顺序
perm = jnp.array([2, 0, 3, 1])
result2 = deepsets(x[perm])
print(f"原始顺序: {result1}")
print(f"置换顺序: {result2}")
print(f"不变性: {jnp.allclose(result1, result2)}")
```
3. 探索群结构。通过检查封闭性、结合律、单位元和逆元,验证二维旋转矩阵构成群。
```python
import jax.numpy as jnp
def rot2d(theta):
return jnp.array([[jnp.cos(theta), -jnp.sin(theta)],
[jnp.sin(theta), jnp.cos(theta)]])
R1 = rot2d(jnp.pi / 6)
R2 = rot2d(jnp.pi / 4)
R3 = rot2d(jnp.pi / 3)
# 封闭性:两个旋转的乘积还是一个旋转
R12 = R1 @ R2
print(f"封闭性 (行列式=1, 正交): det={jnp.linalg.det(R12):.4f}, "
f"R^T R = I: {jnp.allclose(R12.T @ R12, jnp.eye(2), atol=1e-5)}")
# 结合律
print(f"结合律: {jnp.allclose((R1 @ R2) @ R3, R1 @ (R2 @ R3), atol=1e-5)}")
# 单位元
I = rot2d(0.0)
print(f"单位元: {jnp.allclose(R1 @ I, R1, atol=1e-5)}")
# 逆元
R1_inv = rot2d(-jnp.pi / 6)
print(f"逆元: {jnp.allclose(R1 @ R1_inv, jnp.eye(2), atol=1e-5)}")
```
@@ -0,0 +1,236 @@
# 图论
*图论为描述实体间关系提供了数学语言。本章涵盖节点、边、邻接矩阵、图类型、度和连通性、图拉普拉斯算子、谱图理论以及现实世界的图应用。我们将在纯计算机科学章节中更深入地讨论图*
- 到目前为止,本书中的数据都存在于规则结构上:$\mathbb{R}^n$ 中的向量(第1章)、数字网格形式的矩阵(第2章)、像素网格形式的图像(第8章)、有序列表形式的序列(第7章)。但许多现实世界的系统是**不规则**的:社交网络没有网格结构,分子没有从左到右的顺序,道路网络也不能整齐地平铺成行和列。
- **图(Graph)** 是表示这些不规则关系结构的数学工具。图捕获了**实体**(节点)及它们之间的**关系**(边)。一旦数据被表示为图,我们就可以应用文件1中的几何深度学习原理来从中学习。
## 节点、边和邻接
- 一个**图** $G = (V, E)$ 由一组**节点**(或顶点)$V = \{v_1, v_2, \ldots, v_n\}$ 和一组连接节点对的**边** $E \subseteq V \times V$ 组成。
- 节点代表实体:人、原子、城市、网页、神经元。边代表关系:友谊、化学键、道路、超链接、突触。
- **邻接矩阵** $A$ 是图的矩阵表示。对于一个有 $n$ 个节点的图,$A$ 是一个 $n \times n$ 矩阵,其中如果存在从节点 $i$ 到节点 $j$ 的边,则 $A_{ij} = 1$,否则 $A_{ij} = 0$。
- 例如,一个三角形图(3个节点,全部相连):
```math
A = \begin{bmatrix} 0 & 1 & 1 \\ 1 & 0 & 1 \\ 1 & 1 & 0 \end{bmatrix}
```
![一个三角形图及其邻接矩阵:边存在处为1,否则为0](../images/graph_adjacency_matrix.svg)
- 对角线为零,因为节点默认不与自身相连(无自环)。邻接矩阵是我们在第2章中研究的布尔矩阵的直接应用:每个条目都是一个二元关系。
- 邻接矩阵完整地编码了图的结构。对 $A$ 的矩阵运算揭示了图的性质:$A^2_{ij}$ 计算节点 $i$ 和 $j$ 之间长度为2的路径数量(回顾第2章中的矩阵乘法:每个条目是经过中间节点的乘积之和)。更一般地,$A^k_{ij}$ 计算长度为 $k$ 的路径数量。
- 每个节点可以携带一个**特征向量** $\mathbf{x}_i \in \mathbb{R}^d$。对于社交网络,这可能是用户的个人信息。对于分子,它编码原子类型、电荷和其他属性。全部节点特征的集合是一个矩阵 $X \in \mathbb{R}^{n \times d}$,其中每一行是一个节点的特征。
- 边也可以携带特征:分子中的键类型、空间图中的距离、知识图谱中的关系类型。边 $(i, j)$ 的**边特征**是一个向量 $\mathbf{e}_{ij} \in \mathbb{R}^{d_e}$。
## 图类型
- **无向图**具有对称的边:如果 $i$ 连接到 $j$,则 $j$ 也连接到 $i$。邻接矩阵是对称的:$A = A^T$(一个对称矩阵,见第2章)。友谊和化学键是无向的。
- **有向图**(digraph)具有带方向的边:从 $i$ 到 $j$ 的边不意味着从 $j$ 到 $i$ 的边。邻接矩阵是非对称的。Twitter关注、网页超链接和引文网络是有向的。
- **加权图**为每条边分配一个数值权重。邻接矩阵具有实数值而非二进制值:$A_{ij} = w_{ij}$。道路网络中的距离、大脑连通性中的相关强度以及社交网络中的交互频率是加权的。
- **二分图**具有两个不相交的节点集合,边只存在于集合之间(集合内部没有边)。用户和产品构成一个二分图:用户评价产品,但用户之间不相互评价。二分图的邻接矩阵具有块结构:
```math
A = \begin{bmatrix} 0 & B \\ B^T & 0 \end{bmatrix}
```
- 其中 $B$ 是两个节点集之间的二分邻接矩阵。
- **多重图**允许同一对节点之间存在多条边和/或自环。知识图谱通常是多重图:两个实体之间可以有多种关系(例如"出生于"、"居住于"、"工作于")。
- **超图**将边推广为一次连接两个以上节点。一条**超边**连接一组节点,表示高阶关系。一篇由五人合著的研究论文是一条连接五个作者节点的超边。
- **完全图** $K_n$ 在每一对节点之间都有边。这是全连接层的图类比,也是Transformer操作的结构(每个标记关注每个其他标记)。
## 度、路径和连通性
- 一个**节点**的**度**是与它相连的边的数量。在无向图中,节点 $i$ 的度为 $d_i = \sum_j A_{ij}$。高度节点是拥有大量连接的"枢纽"。
- **度矩阵** $D$ 是一个对角线元素为度的对角矩阵:$D_{ii} = d_i$。这个矩阵出现在整个图论和GNN公式中。
- 两个节点之间的**路径**是连接它们的边序列。$i$ 和 $j$ 之间的**最短路径**(或测地线)是边数最少(或在加权图中总权重最小)的路径。**迪杰斯特拉算法**Dijkstra's algorithm)在 $O((|V| + |E|) \log |V|)$ 时间内找到最短路径。
- 如果每对节点之间都存在路径,则图是**连通的**。否则,图有多个**连通分量**:相互之间没有边的孤立子图。
- 图的**直径**是任意一对节点之间最长最短路径的长度。它衡量图"分散"的程度。社交网络以直径小而闻名("六度分隔")。
- **环**是起点和终点在同一节点的路径。没有环的图是**树**。树是最简单的连通图:$n$ 个节点和恰好 $n-1$ 条边。
- **中心性**衡量节点的重要性。**度中心性**就是度数。**介数中心性**计算通过一个节点的最短路径数量。**特征向量中心性**根据节点邻居的重要性分配重要性,得到特征向量方程 $A\mathbf{x} = \lambda \mathbf{x}$(第2章)。谷歌的PageRank是特征向量中心性在有向图上的变体。
## 图拉普拉斯算子
- **图拉普拉斯算子**也许是图论中最重要的矩阵。定义如下:
$$L = D - A$$
- 其中 $D$ 是度矩阵,$A$ 是邻接矩阵。对于我们的三角形示例:
```math
L = \begin{bmatrix} 2 & 0 & 0 \\ 0 & 2 & 0 \\ 0 & 0 & 2 \end{bmatrix} - \begin{bmatrix} 0 & 1 & 1 \\ 1 & 0 & 1 \\ 1 & 1 & 0 \end{bmatrix} = \begin{bmatrix} 2 & -1 & -1 \\ -1 & 2 & -1 \\ -1 & -1 & 2 \end{bmatrix}
```
- 拉普拉斯算子具有显著的性质:
- 它始终是**对称的**且**半正定的**(回顾第2章:所有特征值 $\geq 0$)。对于任意向量 $\mathbf{x}$
$$\mathbf{x}^T L \mathbf{x} = \sum_{(i,j) \in E} (x_i - x_j)^2$$
![图拉普拉斯算子度量信号平滑度:平滑信号在连接节点上具有相似值,非平滑信号变化剧烈](../images/graph_laplacian_smoothness.svg)
- 这个二次形式度量图上的信号 $\mathbf{x}$ 在边上的变化程度。如果相邻节点值相近,则 $\mathbf{x}^T L \mathbf{x}$ 较小。如果它们差异很大,则较大。拉普拉斯算子度量图上信号的**平滑度**。
- 最小特征值始终为0,特征向量为 $\mathbf{1} = [1, 1, \ldots, 1]^T$(常数信号没有变化)。零特征值的数量等于连通分量的数量。
- 第二小特征值 $\lambda_2$ 是**代数连通度**Fiedler值)。它衡量图的连通程度:$\lambda_2 = 0$ 表示图不连通,大的 $\lambda_2$ 表示图紧密连通。
- **归一化拉普拉斯算子**通过度进行缩放:
$$\hat{L} = D^{-1/2} L D^{-1/2} = I - D^{-1/2} A D^{-1/2}$$
- 这种归一化确保拉普拉斯算子的性质不依赖于节点度的绝对尺度。项 $D^{-1/2} A D^{-1/2}$ 是**对称归一化邻接矩阵**,它直接出现在GCN公式中(文件3)。
## 谱图理论
- 图拉普拉斯算子的特征值和特征向量定义了图的**谱**,它们充当图上的傅里叶变换的类似物。
- 在经典信号处理中,傅里叶变换将信号分解为频率分量(正弦和余弦)。在图上,拉普拉斯算子的特征向量扮演这些频率基的角色。小特征值的特征向量在图上变化缓慢(低频、平滑),而大特征值的特征向量变化迅速(高频、振荡)。
- 信号 $\mathbf{x}$ 在图上的**图傅里叶变换(GFT)** 为:
$$\hat{\mathbf{x}} = U^T \mathbf{x}$$
- 其中 $U$ 是拉普拉斯算子特征向量的矩阵(回顾第2章中的特征分解:$L = U \Lambda U^T$)。逆变换为 $\mathbf{x} = U \hat{\mathbf{x}}$。
- 谱域中的**图卷积**是频域中的逐点乘法,正如空间域中的卷积对应于傅里叶域中的乘法(卷积定理,见第8章):
$$g_\theta \star \mathbf{x} = U \left( (U^T g_\theta) \odot (U^T \mathbf{x}) \right) = U \, \text{diag}(\hat{g}_\theta) \, U^T \mathbf{x}$$
- 滤波器 $\hat{g}_\theta$ 是特征值的可学习函数。这是谱域GNN的基础,我们将在文件3中将其简化为实用的GCN。
- 计算瓶颈是对 $L$ 进行特征分解,对于有 $n$ 个节点的图需要 $O(n^3)$ 时间。这对于大型图(数百万节点)是不可行的。多项式近似(切比雪夫多项式)完全避免了特征分解,而这种近似直接导致了GCN。
## 社区检测
- 许多现实世界的图具有**社区结构**:紧密连接的节点簇,簇之间连接稀疏。社交网络有好友群组,生物网络有功能模块,引文网络有研究领域。
- **谱聚类**使用拉普拉斯算子特征向量来寻找社区。思路:使用 $L$ 的 $k$ 个最小的非平凡特征向量对每个节点进行嵌入,然后在这个嵌入空间中应用k-means(第6章)。同一社区中的节点在谱嵌入中最终彼此靠近。
- 这是可行的,因为Fiedler向量($\lambda_2$ 的特征向量)自然地将图分成两组:正值的节点和负值的节点,沿着最稀疏的连接切开。更高的特征向量进一步细分为更多组。
- **模块度** $Q$ 衡量社区划分的质量。它将社区内边的数量与随机图中的期望数量进行比较:
$$Q = \frac{1}{2|E|} \sum_{ij} \left( A_{ij} - \frac{d_i d_j}{2|E|} \right) \delta(c_i, c_j)$$
- 其中 $c_i$ 是节点 $i$ 的社区分配,如果节点在同一个社区则 $\delta$ 为1。$Q$ 的范围从 $-0.5$ 到 $1$,值越高表示社区结构越强。
## 现实世界中的图
- **社交网络**:节点是人,边是友谊或互动。Facebook有数十亿节点和数千亿条边。这些图通常是稀疏的(每个人有几百个朋友,而不是几十亿),具有小世界性质(短的平均路径长度),以及重尾度分布(少数拥有数百万连接的枢纽节点)。
- **分子图**:节点是原子,边是化学键。每个原子有特征(元素类型、电荷、杂化方式),每条键有特征(单键、双键、三键、芳香键)。分子图很小(数十到数百个节点)但高度结构化。从图结构预测分子性质是GNN的一个重要应用。
- **知识图谱**:节点是实体(人、地点、概念),边是类型化的关系("出生于"、"首都是"、"是……的实例")。知识图谱为搜索引擎、推荐系统和问答系统提供支持。它们通常是具有数百万实体和数十亿关系的有多重图。
- **引文网络**:节点是论文,边是引用(有向的)。聚类揭示研究社区。节点特征包括标题、摘要和出版年份。
- **蛋白质相互作用网络**:节点是蛋白质,边表示物理相互作用或功能关联。理解这些图有助于识别药物靶点和疾病机制。
- **道路网络与交通**:节点是交叉路口,边是具有距离/时间权重的道路段。这些图上的最短路径算法为导航系统提供动力。自动驾驶运动预测(第11章)将智能体交互表示为图。
## 编程任务(使用CoLab或notebook
1. 构建一个小型图的邻接矩阵,计算基本性质:每个节点的度、长度为2的路径数量以及图是否连通。
```python
import jax.numpy as jnp
# 一个简单图:5个节点
# 0-1, 0-2, 1-2, 2-3, 3-4
A = jnp.array([[0, 1, 1, 0, 0],
[1, 0, 1, 0, 0],
[1, 1, 0, 1, 0],
[0, 0, 1, 0, 1],
[0, 0, 0, 1, 0]], dtype=float)
# 度
degrees = A.sum(axis=1)
print(f"度数: {degrees}")
# 长度为2的路径
A2 = A @ A
print(f"长度为2的路径(节点0到3: {int(A2[0, 3])}")
# 是否连通?检查 A^(n-1) 是否所有条目非零
An = jnp.linalg.matrix_power(A + jnp.eye(5), 4) # (A+I)^4 用于可达性
connected = jnp.all(An > 0)
print(f"连通: {connected}")
```
2. 计算图拉普拉斯算子及其特征值。验证最小特征值为0且对应的特征向量为常数。
```python
import jax.numpy as jnp
A = jnp.array([[0, 1, 1, 0, 0],
[1, 0, 1, 0, 0],
[1, 1, 0, 1, 0],
[0, 0, 1, 0, 1],
[0, 0, 0, 1, 0]], dtype=float)
D = jnp.diag(A.sum(axis=1))
L = D - A
eigenvalues, eigenvectors = jnp.linalg.eigh(L)
print(f"特征值: {eigenvalues}")
print(f"最小特征向量: {eigenvectors[:, 0]}")
print(f"Fiedler值(代数连通度): {eigenvalues[1]:.4f}")
# 验证: x^T L x 度量平滑度
x = jnp.array([1.0, 1.0, 1.0, -1.0, -1.0]) # 两个组
smoothness = x @ L @ x
print(f"两组信号的平滑度: {smoothness:.2f}")
```
3. 对具有两个社区的图执行谱聚类。使用Fiedler向量嵌入节点,并按符号分离。
```python
import jax.numpy as jnp
import matplotlib.pyplot as plt
# 两个社区,各5个节点,弱连接
A = jnp.zeros((10, 10))
# 社区1:节点0-4(密集)
for i in range(5):
for j in range(i+1, 5):
A = A.at[i, j].set(1).at[j, i].set(1)
# 社区2:节点5-9(密集)
for i in range(5, 10):
for j in range(i+1, 10):
A = A.at[i, j].set(1).at[j, i].set(1)
# 一条桥接边
A = A.at[2, 7].set(1).at[7, 2].set(1)
D = jnp.diag(A.sum(axis=1))
L = D - A
eigenvalues, eigenvectors = jnp.linalg.eigh(L)
# Fiedler向量(第二小特征值)
fiedler = eigenvectors[:, 1]
communities = (fiedler > 0).astype(int)
print(f"Fiedler向量: {fiedler}")
print(f"聚类: {communities}")
plt.bar(range(10), fiedler, color=["#3498db" if c == 0 else "#e74c3c" for c in communities])
plt.xlabel("节点"); plt.ylabel("Fiedler向量值")
plt.title("通过Fiedler向量进行谱聚类")
plt.show()
```
@@ -0,0 +1,271 @@
# 图神经网络
*图神经网络通过在连接节点之间传递消息来学习图结构数据。本章涵盖消息传递框架、GCN、GraphSAGE、GIN、过平滑、图池化以及节点/边/图级别的任务;支撑分子性质预测、社交网络分析和推荐系统的核心架构。*
- 在前面的文件中,我们建立了数学基础:几何深度学习(文件1)告诉我们利用对称性,图论(文件2)提供了节点、边和邻接的语言。现在我们构建直接在图(graph)上操作的神经网络。
- 核心挑战:图数据是**不规则**的。与图像(固定网格)或序列(固定顺序)不同,图具有可变数量的节点、可变的连通性,并且没有规范的节点顺序。用于图的神经网络必须处理所有这些情况,同时保持置换等变性(重新标记节点不应改变输出)。
## 消息传递框架
- 几乎所有的GNN都遵循同样的模式,称为**消息传递**(也称为邻域聚合)。这个想法简单而优雅:每个节点通过从邻居收集信息来更新其表示。
- 在每个层 $l$,每个节点 $i$ 做三件事:
1. **消息**:节点 $i$ 的每个邻居 $j$ 基于其当前特征计算一条消息 $\mathbf{m}_{j \to i}$。
2. **聚合**:节点 $i$ 收集所有传入消息,并使用置换不变函数(求和、均值或取最大值)将它们组合。
3. **更新**:节点 $i$ 将聚合的消息与其自身特征结合,产生一个新的表示。
- 形式上:
$$\mathbf{m}_i^{(l)} = \bigoplus_{j \in \mathcal{N}(i)} \phi^{(l)}\left(\mathbf{h}_i^{(l)}, \mathbf{h}_j^{(l)}, \mathbf{e}_{ij}\right)$$
$$\mathbf{h}_i^{(l+1)} = \psi^{(l)}\left(\mathbf{h}_i^{(l)}, \mathbf{m}_i^{(l)}\right)$$
- 其中 $\mathcal{N}(i)$ 是节点 $i$ 的邻居集合,$\bigoplus$ 是一个置换不变的聚合操作(求和、均值、取最大值),$\phi$ 是消息函数,$\psi$ 是更新函数,$\mathbf{e}_{ij}$ 是可选的边特征。
![消息传递:邻居发送消息,置换不变函数聚合它们,然后节点更新其特征](../images/message_passing_gnn.svg)
- 聚合操作 $\bigoplus$ 必须是置换不变的(邻居处理的顺序无关紧要),以确保整个函数是置换等变的。这直接实现了文件1中的对称性原理。
- 经过 $k$ 层消息传递后,每个节点的表示编码了其 **$k$ 跳邻域**的信息:所有在 $k$ 条边内可达的节点。第1层看到直接邻居,第2层看到邻居的邻居,依此类推。这就是局部信息传播以建立全局理解的方式。
- GNN的感受野随深度增长,就像CNN的感受野随层数增长一样(第8章)。但与规则网格上的CNN不同,感受野的形状根据图拓扑结构在每个节点上有所不同。
## 图卷积网络(GCN
- **GCN**Kipf & Welling2017)是基础性的GNN架构。它将谱域图卷积(来自文件2)简化为一个优雅、高效的公式。
- 从谱域卷积 $g_\theta \star \mathbf{x} = U \, \text{diag}(\hat{g}_\theta) \, U^T \mathbf{x}$ 出发,Kipf和Welling用一阶切比雪夫多项式近似谱域滤波器,这完全避免了计算特征分解。简化后,逐层更新变为:
$$H^{(l+1)} = \sigma\left(\hat{A} H^{(l)} W^{(l)}\right)$$
- 其中:
- $H^{(l)} \in \mathbb{R}^{n \times d}$ 是第 $l$ 层的节点特征矩阵
- $W^{(l)} \in \mathbb{R}^{d \times d'}$ 是可学习的权重矩阵
- $\hat{A} = \tilde{D}^{-1/2} \tilde{A} \tilde{D}^{-1/2}$ 是带自环的对称归一化邻接矩阵
- $\tilde{A} = A + I$ 添加了自环(因此每个节点也接收自己的消息)
- $\tilde{D}$ 是 $\tilde{A}$ 的度矩阵
- $\sigma$ 是一个非线性激活函数(ReLU,如第6章所述)
- 矩阵乘法 $\hat{A} H^{(l)}$ 是聚合步骤:对于每个节点,它计算其邻居特征(加上自身特征,通过自环)的加权平均。权重矩阵 $W^{(l)}$ 是可学习的变换,在所有节点间共享。激活函数增加了非线性。
- 这非常简单:它只是矩阵乘法后接一个学习到的线性映射和激活函数。整个GCN层可以用一行代码实现。通过 $\tilde{D}^{-1/2}$ 的归一化防止具有许多邻居的节点占主导地位:高度节点的消息被按比例缩小。
- 在消息传递框架中,GCN使用:
- 消息:$\phi(\mathbf{h}_j) = \mathbf{h}_j$(只发送你的特征)
- 聚合:归一化和(按度加权)
- 更新:线性变换 + 激活函数
## GraphSAGE
- GCN是**直推式**的:它在训练时需要完整的图,无法处理新出现的未知节点。如果新用户加入社交网络,GCN必须对整个图重新训练。**GraphSAGE**Hamilton等,2017)通过**归纳式**方法解决了这个问题。
- 关键思想是**邻域采样**:不是使用所有邻居,而是采样一个固定大小的子集。这使得计算独立于完整的图结构,并允许推广到未见过的节点和图。
- 节点 $i$ 的GraphSAGE更新:
$$\mathbf{h}_i^{(l+1)} = \sigma\left(W^{(l)} \cdot \text{CONCAT}\left(\mathbf{h}_i^{(l)}, \text{AGG}\left(\{\mathbf{h}_j^{(l)} : j \in \mathcal{S}(i)\}\right)\right)\right)$$
- 其中 $\mathcal{S}(i)$ 是一个**采样**的邻居子集(例如,从500个邻居中随机采样10个)。CONCAT操作显式地将节点自身的特征与聚合后的邻居特征分开,让网络学习"自身"和"邻域"的不同变换。
- GraphSAGE支持多种聚合函数:
- **均值(Mean**$\text{AGG} = \frac{1}{|\mathcal{S}|} \sum_{j \in \mathcal{S}} \mathbf{h}_j$(简单,有效)
- **LSTM**:将采样的邻居通过LSTM(但这引入了顺序依赖,一定程度上违反了置换不变性)
- **池化(Pool**$\text{AGG} = \max(\{\sigma(W_{\text{pool}} \mathbf{h}_j + \mathbf{b})\})$(非线性变换后取最大值)
- 采样策略使GraphSAGE可扩展到非常大的图。训练使用节点的小批量:对于每个目标节点,在第1层采样 $k_1$ 个邻居,然后对于其中每个邻居在第2层采样 $k_2$ 个邻居。使用 $k_1 = k_2 = 10$ 和2层,每个节点的计算树最多有 $10 \times 10 = 100$ 个节点,与图的大小无关。
## 图同构网络(GIN
- 不同的GNN架构具有不同的**表达能力**:它们区分结构不同之图的能力。GCN和GraphSAGE虽然在实践中有效,但理论上在能区分哪些图结构方面是受限的。
- 衡量GNN表达能力的理论工具是**Weisfeiler-LehmanWL)测试**,这是一个用于测试图同构(两个图是否结构相同)的经典算法。WL测试通过将每个节点的标签与其邻居标签的多重集一起哈希,迭代地精炼节点标签。
- **GIN**(Xu等,2019)被设计为具有与WL测试同等的表达能力,使其成为最强大的消息传递GNN(在消息传递的理论限制内)。关键洞察:聚合函数必须在多重集上是**单射**的(不同的邻居特征多重集必须产生不同的聚合值)。
- 求和聚合在多重集上是单射的(求和 $\{1, 1, 2\}$ 得到4,而 $\{1, 3\}$ 也得到4,但在具有足够维度的特征向量上,不同多重集的和一般而言是不同的)。均值和取最大值不是单射的:均值无法区分 $\{1, 1\}$ 和 $\{2, 2\}$,取最大值无法区分 $\{1, 2, 3\}$ 和 $\{1, 1, 3\}$。
- GIN更新:
$$\mathbf{h}_i^{(l+1)} = \text{MLP}^{(l)}\left((1 + \epsilon^{(l)}) \cdot \mathbf{h}_i^{(l)} + \sum_{j \in \mathcal{N}(i)} \mathbf{h}_j^{(l)}\right)$$
- 其中 $\epsilon$ 是一个可学习的标量(或固定为0),MLP提供非线性、单射的映射。求和聚合保留了多重集结构,MLP可以学会区分任意两个不同的聚合值。
## 过平滑
- GNN的一个主要挑战是**过平滑**:随着层数增加,所有节点表示收敛到相同的值,失去区分不同节点的能力。
![过平滑:在第1层各不相同的节点特征在更深层逐渐融合为统一特征](../images/over_smoothing_gnn.svg)
- 其机制是直观的。每个消息传递层将节点的特征与其邻居的特征进行平均。经过多轮平均后,每个节点已经"看到"(并混合了)其连通分量中的每个其他节点。这些特征变成了统一的平均值,相当于将图像模糊太多次直到变成纯色的图类比。
- 形式上,重复应用归一化邻接矩阵 $\hat{A}$ 收敛到一个秩为1的矩阵(每一行都变得与图上随机游走的平稳分布成正比)。这与幂迭代收敛到主特征向量的过程相同(第2章)。
- 过平滑将GNN限制在很浅的深度(通常2-4层),而CNN和Transformer可以从几十或数百层中受益。这意味着每个节点只能看到有限的邻域,这对于需要长距离信息的任务来说是有问题的。
- 缓解方法包括:
- **残差连接**(来自ResNet,第8章):$\mathbf{h}_i^{(l+1)} = \mathbf{h}_i^{(l+1)} + \mathbf{h}_i^{(l)}$,保留来自较早层的信息。
- **跳跃知识(Jumping Knowledge**:拼接或注意力池化来自所有层的表示,而不仅仅是最后一层。
- **DropEdge**:训练期间随机移除边,减缓信息传播。
- **图TransformerGraph Transformer**(文件4):用全局注意力绕过局部消息传递的瓶颈。
## 图池化
- 对于**图级别任务**(预测整个图的属性,如分子的毒性),我们需要将所有节点表示折叠成一个单一的图级别向量。这就是**图池化**,是CNN中全局平均池化的图类比(第8章)。
- 最简单的方法是**读出(readout)**:对所有节点特征应用一个置换不变函数:
$$\mathbf{h}_G = \text{READOUT}(\{\mathbf{h}_i^{(L)} : i \in V\}) = \sum_i \mathbf{h}_i^{(L)} \quad \text{或} \quad \frac{1}{|V|} \sum_i \mathbf{h}_i^{(L)} \quad \text{或} \quad \max_i \mathbf{h}_i^{(L)}$$
- 这就是文件1中的DeepSets聚合,应用于最终的GNN层之后。求和保留了大小信息(一个有100个节点的图会比只有10个节点的图具有更大的和),而均值对大小进行了归一化。
- **分层池化**逐步粗化图,模仿CNN逐步下采样图像的方式。在每个层级,节点组被合并为"超节点":
- **DiffPool**(可微分池化)学习一个软分配矩阵 $S^{(l)} \in \mathbb{R}^{n_l \times n_{l+1}}$,将每个节点分配到一个簇:
$$X^{(l+1)} = S^{(l)T} H^{(l)}, \quad A^{(l+1)} = S^{(l)T} A^{(l)} S^{(l)}$$
- 分配矩阵由一个单独的GNN预测,使聚类变得端到端可微分。这创建了一个层次结构:原始图 → 具有较少节点的粗化图 → 更粗的图 → 单个节点(图表示)。
- **TopKPool**采用更简单的方法:为每个节点学习一个标量分数,保留得分最高的 top-$k$ 个节点,丢弃其余节点。这是一种硬选择(而非软分配),计算上比DiffPool更廉价。
## 异构图
- 截至目前的所有GNN都假设一个**同构图**:一种节点类型,一种边类型。但大多数现实世界的图是**异构**的:多种节点类型和多种边类型。知识图谱有人物节点、组织节点和位置节点,由"工作于"、"出生于"和"位于"边连接。推荐系统有用户节点和物品节点,由"已购买"、"已浏览"和"已评价"边连接。
- 异构图有一个**模式**(也称为元图),定义了允许的节点类型和边类型。每个边类型连接特定的源类型到特定的目标类型。例如,"工作于"连接 Person → Organisation。
- **关系GCNR-GCN**Schlichtkrull等,2018)通过为每种边类型使用单独的权重矩阵来处理异构边:
$$\mathbf{h}_i^{(l+1)} = \sigma\left(\sum_{r \in \mathcal{R}} \sum_{j \in \mathcal{N}_r(i)} \frac{1}{|\mathcal{N}_r(i)|} W_r^{(l)} \mathbf{h}_j^{(l)} + W_0^{(l)} \mathbf{h}_i^{(l)}\right)$$
- 其中 $\mathcal{R}$ 是边类型的集合,$\mathcal{N}_r(i)$ 是通过关系 $r$ 连接到节点 $i$ 的邻居集合,$W_r$ 是关系 $r$ 特有的权重矩阵。自连接 $W_0$ 单独处理节点自身的特征。
- 问题:当关系类型很多时,参数数量爆炸(每种关系一个 $d \times d$ 矩阵)。R-GCN通过**基分解**缓解这一问题:$W_r = \sum_{b=1}^{B} a_{rb} V_b$,其中 $V_b$ 是共享的基矩阵,$a_{rb}$ 是每个关系的标量系数。这类似于低秩分解(第2章):关系特定的矩阵生活在一个低维子空间中。
- **异构图表TransformerHGT**(Hu等,2020)将注意力机制应用于异构图。关键洞察:注意力应同时依赖于节点类型和连接它们的边类型。HGT为查询、键和值使用类型特定的投影矩阵:
$$\text{Attention}(i, j) = \left(W_{\tau(i)}^Q \mathbf{h}_i\right)^T \cdot \frac{W_{\phi(i,j)}^{\text{ATT}}}{\sqrt{d}} \cdot \left(W_{\tau(j)}^K \mathbf{h}_j\right)$$
- 其中 $\tau(i)$ 是节点 $i$ 的类型,$\phi(i,j)$ 是它们之间的边类型。这确保了模型对不同的关系类型使用不同的注意力权重:一篇论文关注其作者时,应使用与关注其参考文献时不同的注意力权重。
- **基于元路径的方法**定义通过模式的含义路径(例如,作者 → 论文 → 作者表示合著关系),并沿着这些路径聚合信息。**HAN**(异构图注意力网络)在两个层次应用注意力:在每个元路径内(沿此路径哪些邻居重要?)和跨元路径(哪些关系模式重要?)。
## 链接预测与知识图谱补全
- **链接预测**提出的问题是:给定现有边,哪些缺失的边可能存在?这是知识图谱补全(预测缺失的事实)、推荐(预测用户会喜欢哪些物品)和社交网络分析(预测未来的友谊)的核心任务。
- **基于嵌入的方法**为每个实体学习一个向量,为每个关系学习一个变换,然后通过实体和关系的匹配程度对潜在边进行评分:
- **TransE**将关系建模为嵌入空间中的平移:如果 $(h, r, t)$ 是一个有效的三元组(头实体,关系,尾实体),那么 $\mathbf{h} + \mathbf{r} \approx \mathbf{t}$。评分函数为 $f(h, r, t) = -\|\mathbf{h} + \mathbf{r} - \mathbf{t}\|$。直观地说,关系向量在嵌入空间中将头实体"移动"到尾实体。
- **RotatE**将关系建模为复空间中的旋转:$\mathbf{t} = \mathbf{h} \circ \mathbf{r}$,其中 $\circ$ 是逐元素复数乘法,$|\mathbf{r}_i| = 1$(单位复数就是旋转)。这可以建模TransE无法处理的对称性、反对称性、反转和复合模式。
- **ComplEx**使用复数值嵌入和埃尔米特点积,使其能够建模非对称关系(如果A是B的老板,B不是A的老板)。
- 基于GNN的链接预测通过消息传递计算节点嵌入,然后使用端点嵌入对边进行评分。这结合了GNN的结构推理能力和嵌入方法的关系建模能力。GNN编码器捕获了单嵌入方法所遗漏的多跳邻域结构。
## 任务类型
- GNN解决三类任务:
- **节点级别任务**:为每个节点预测一个属性。示例:对社交网络中的用户进行分类(机器人还是人类),预测相互作用网络中每个蛋白质的功能,半监督节点分类(标记少数节点,预测其余节点)。输出是节点嵌入 $\mathbf{h}_i^{(L)}$ 经过一个分类器。
- **边级别任务**:为每条边预测一个属性或预测边是否存在。示例:链接预测(这两个用户会成为朋友吗?),知识图谱补全(这个关系在这些实体间成立吗?),药物-药物相互作用预测。输出通常使用两个端点节点的嵌入:$\hat{y}_{ij} = f(\mathbf{h}_i, \mathbf{h}_j)$,其中 $f$ 是点积、拼接+MLP或其他组合。
- **图级别任务**:为整个图预测一个属性。示例:分子性质预测(这个分子有毒吗?),图分类(这个社交网络是机器人网络吗?),图生成(设计一个具有期望性质的分子)。输出使用图池化产生 $\mathbf{h}_G$,然后进行分类或回归。
## 编程任务(使用CoLab或notebook
1. 使用归一化邻接矩阵从头实现一个单层GCN。应用于一个小型图,观察节点特征如何被平滑。
```python
import jax
import jax.numpy as jnp
# 图:5个节点,简单链带分支
A = jnp.array([[0, 1, 0, 0, 0],
[1, 0, 1, 0, 0],
[0, 1, 0, 1, 1],
[0, 0, 1, 0, 0],
[0, 0, 1, 0, 0]], dtype=float)
# 添加自环
A_hat = A + jnp.eye(5)
D_hat = jnp.diag(A_hat.sum(axis=1))
D_inv_sqrt = jnp.diag(1.0 / jnp.sqrt(A_hat.sum(axis=1)))
A_norm = D_inv_sqrt @ A_hat @ D_inv_sqrt
# 节点特征:one-hot 单位阵
H = jnp.eye(5)
# 权重矩阵(随机初始化)
rng = jax.random.PRNGKey(0)
W = jax.random.normal(rng, (5, 3)) * 0.5
# GCN层:H' = ReLU(A_norm @ H @ W)
H_new = jax.nn.relu(A_norm @ H @ W)
print("原始特征(one-hot:")
print(H)
print("\n经过GCN层后:")
print(jnp.round(H_new, 3))
print("\n注意:连接的节点现在具有相似的表示")
```
2. 实现具有求和聚合(GIN风格)和均值聚合(GCN风格)的消息传递。展示求和能区分均值无法区分的多重集。
```python
import jax.numpy as jnp
# 两个具有相同均值的不同邻居多重集
# 节点A:邻居特征为 [1, 1, 1, 1] (四个邻居,都是1)
# 节点B:邻居特征为 [2, 2] (两个邻居,都是2)
neighbours_A = jnp.array([[1.0], [1.0], [1.0], [1.0]])
neighbours_B = jnp.array([[2.0], [2.0]])
# 均值聚合
mean_A = neighbours_A.mean(axis=0)
mean_B = neighbours_B.mean(axis=0)
print(f"均值 A: {mean_A}, 均值 B: {mean_B}, 相同: {jnp.allclose(mean_A, mean_B)}")
# 求和聚合
sum_A = neighbours_A.sum(axis=0)
sum_B = neighbours_B.sum(axis=0)
print(f"求和 A: {sum_A}, 求和 B: {sum_B}, 相同: {jnp.allclose(sum_A, sum_B)}")
print("\n求和能区分这些多重集;均值不能!")
```
3. 演示过平滑。重复应用归一化邻接矩阵,观察节点特征收敛。
```python
import jax.numpy as jnp
import matplotlib.pyplot as plt
# 随机图
A = jnp.array([[0,1,1,0,0,0],
[1,0,1,0,0,0],
[1,1,0,1,0,0],
[0,0,1,0,1,1],
[0,0,0,1,0,1],
[0,0,0,1,1,0]], dtype=float)
A_hat = A + jnp.eye(6)
D_inv_sqrt = jnp.diag(1.0 / jnp.sqrt(A_hat.sum(axis=1)))
A_norm = D_inv_sqrt @ A_hat @ D_inv_sqrt
# 初始特征:每个节点各不相同
H = jnp.array([[1,0], [0,1], [1,1], [-1,0], [0,-1], [-1,-1]], dtype=float)
distances = []
for k in range(20):
H = A_norm @ H
# 衡量特征的区别程度(节点间的标准差)
spread = jnp.std(H, axis=0).mean()
distances.append(float(spread))
plt.plot(distances, "o-")
plt.xlabel("消息传递轮数")
plt.ylabel("特征分散度(节点间标准差)")
plt.title("过平滑:特征随深度增加而收敛")
plt.show()
```
@@ -0,0 +1,258 @@
# 图注意力网络
*图注意力网络将均匀的邻居聚合替换为学习到的、依赖数据的加权。本章涵盖GAT、多头图注意力、GATv2、图Transformer、位置和结构编码以及可扩展性*
- 在GCN(文件3)中,每个节点使用由图结构确定的固定权重(归一化邻接矩阵)聚合其邻居特征。一个有三个邻居的节点会给每个邻居大致相等的权重($\approx 1/3$)。但并非所有邻居都同等重要:来自密切合作者的消息应比来自远方熟人的消息更重要。
- **图注意力网络**通过使用与Transformer(第7章)相同的注意力机制来学习**关注哪些邻居**,从而解决了这一问题。与固定的、基于结构的权重不同,每个节点在其邻居上计算动态的、基于内容的注意力分数。
## GAT:图注意力网络
- **GAT**Veličković等,2018)计算每个节点与其邻居之间的注意力系数。对于节点 $i$ 和邻居 $j$:
$$e_{ij} = \text{LeakyReLU}\left(\mathbf{a}^T \left[W\mathbf{h}_i \| W\mathbf{h}_j\right]\right)$$
- 其中 $W \in \mathbb{R}^{d' \times d}$ 是共享的线性变换,$\|$ 表示拼接,$\mathbf{a} \in \mathbb{R}^{2d'}$ 是可学习的注意力向量。分数 $e_{ij}$ 衡量节点 $j$ 的特征对节点 $i$ 的重要程度。
- 原始分数使用softmax在所有邻居之间进行归一化:
$$\alpha_{ij} = \text{softmax}_j(e_{ij}) = \frac{\exp(e_{ij})}{\sum_{k \in \mathcal{N}(i)} \exp(e_{ik})}$$
- 这确保了每个节点邻域上的注意力权重之和为1,就像Transformer注意力一样(第7章)。节点更新后的特征为:
$$\mathbf{h}_i' = \sigma\left(\sum_{j \in \mathcal{N}(i)} \alpha_{ij} W\mathbf{h}_j\right)$$
![GCN为所有邻居分配固定的等权重;GAT学习依赖数据的注意力权重](../images/gat_attention_weights.svg)
- 与GCN的关键区别:权重 $\alpha_{ij}$ 是**从数据中学习**的,而非由图结构固定。节点可以学会关注信息量最大的邻居,同时忽略噪声或无关的邻居。
- 注意,注意力仅在边上计算(节点 $i$ 只关注其邻居 $\mathcal{N}(i)$),而不是在所有节点对之间。这使得计算量与边的数量成正比,而不是节点数的平方。
## 多头图注意力
- 正如在Transformer中(第7章),**多头注意力**并行运行 $K$ 个独立的注意力机制,每个都有自己的参数 $W^k$ 和 $\mathbf{a}^k$。结果在中间层进行拼接,在最终层取平均:
$$\mathbf{h}_i' = \Big\|_{k=1}^{K} \sigma\left(\sum_{j \in \mathcal{N}(i)} \alpha_{ij}^k W^k \mathbf{h}_j\right)$$
- 每个头可以关注邻域的不同方面:一个头可能关注结构特征,另一个关注语义相似性。这与Transformer中多头注意力的动机相同:不同的头捕获不同类型的关系。
- 使用 $K$ 个头和每个头输出维度 $d'$,拼接后的输出维度为 $K \times d'$。最后一层通常使用平均而不是拼接来产生固定大小的输出。
## GATv2:修复静态注意力
- 原始GAT有一个微妙的限制:其注意力函数是**静态的**(也称为基于排序的)。注意力分数取决于拼接 $[W\mathbf{h}_i \| W\mathbf{h}_j]$,但由于注意力向量 $\mathbf{a}$ 在拼接之后应用,它可以分解为两个独立的分量:$\mathbf{a}^T [W\mathbf{h}_i \| W\mathbf{h}_j] = \mathbf{a}_1^T W\mathbf{h}_i + \mathbf{a}_2^T W\mathbf{h}_j$。
- 这意味着对于给定节点 $i$,邻居的排序完全由邻居的特征 $\mathbf{h}_j$ 决定(项 $\mathbf{a}_1^T W\mathbf{h}_i$ 在 $i$ 的所有邻居中是常数)。注意力排名并不真正依赖于查询节点的特征。节点 $i$ 和节点 $k$ 将以完全相同的方式对同一组邻居进行排序,这限制了表达能力。
- **GATv2**(Brody等,2022)通过在注意力向量之前应用非线性函数来修复这个问题:
$$e_{ij} = \mathbf{a}^T \text{LeakyReLU}\left(W \left[\mathbf{h}_i \| \mathbf{h}_j\right]\right)$$
- 将LeakyReLU移到计算内部意味着注意力分数是联合特征的非线性函数,不能分解为独立项。这使得注意力变为**动态**:邻居的排序现在依赖于特定的查询节点。GATv2严格比GAT更具表达能力,且没有额外的计算成本。
## 图Transformer
- 标准消息传递GNN受到图拓扑的限制:一个节点只能关注其直接邻居。经过 $k$ 层后,来自 $k$ 跳邻居的信息已通过多个聚合步骤混合,失去了保真度。这种局部瓶颈(再加上文件3中的过平滑)限制了捕获长距离依赖关系的能力。
- **图Transformer**通过将**全局自注意力**应用于所有节点对(无论它们之间是否有边)来突破这个瓶颈。每个节点可以在单层中关注每个其他节点,就像标准Transformer一样(第7章)。
- 基本思想:将所有节点视为标记(token),应用Transformer自注意力:
$$\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V$$
- 其中 $Q = XW_Q$$K = XW_K$$V = XW_V$ 是节点特征 $X$ 的查询、键和值投影(与第7章完全相同)。这是完全连接图(完全图 $K_n$,文件2)上的GNN。
- 问题:完全连接图忽略了实际的图结构。边信息(谁实际连接到谁)丢失了。两种方法恢复了这一点:
- **Graphormer**(Ying等,2021)通过注意力分数中的**偏置项**将图结构注入Transformer
$$A_{ij} = \frac{(\mathbf{h}_i W_Q)(W_K^T \mathbf{h}_j^T)}{\sqrt{d_k}} + b_{\text{spatial}}(i, j) + b_{\text{edge}}(i, j)$$
- 空间偏置 $b_{\text{spatial}}$ 编码节点 $i$ 和 $j$ 之间的最短路径距离。边偏置 $b_{\text{edge}}$ 编码沿最短路径的边特征。此外,Graphormer使用**中心性编码**,将节点的度数添加到其输入嵌入中,为模型提供关于每个节点结构角色的信息。
- **GPS**(通用、强大、可扩展的图TransformerRampášek等,2022)在每一层中结合了局部消息传递和全局注意力:
$$\mathbf{h}_i' = \text{MLP}\left(\mathbf{h}_i^{\text{MPNN}} + \mathbf{h}_i^{\text{Attention}}\right)$$
- 每一层同时应用标准GNN(用于局部结构)和Transformer(用于全局上下文),然后组合结果。这获得了两个世界的优点:来自消息传递的局部结构和来自注意力的长距离依赖关系。
## 位置编码与结构编码
- 序列上的Transformer使用位置编码(第7章)来注入顺序信息。图没有规范的顺序,因此需要特定于图的编码。
- **拉普拉斯特征向量编码**使用图拉普拉斯算子(文件2)的特征向量作为位置特征。$k$ 个最小的非平凡特征向量提供了图的谱嵌入:在图中"附近"的节点具有相似的特征向量值。这些被拼接到节点特征中。
- 一个微妙之处:拉普拉斯特征向量有符号模糊性(如果 $\mathbf{u}$ 是特征向量,$-\mathbf{u}$ 也是)。模型必须对这些符号翻转保持不变。解决方案包括在训练期间使用随机符号翻转作为数据增强,或学习符号不变的变换。
- **随机游走编码**计算从节点 $i$ 开始的随机游走经过 $k$ 步后返回节点 $i$ 的概率,对于 $k = 1, 2, \ldots, K$。这些概率编码了局部结构信息:密集簇中的节点具有高的返回概率,而稀疏区域中的节点返回概率低。着陆概率 $p_{ii}^{(k)} = (A_{\text{rw}}^k)_{ii}$,其中 $A_{\text{rw}} = D^{-1}A$ 是随机游走转移矩阵。
- **度数编码**简单地将节点度数作为一个特征添加。这出奇地有效,因为度数是一个强大的结构信号:叶节点(度数为1)、桥接节点和枢纽节点的行为不同。
- 这些编码提供了普通Transformer所缺乏的结构信息,使图Transformer在需要长距离推理的任务上能够超越标准消息传递GNN。
## 可扩展性
- GNN的基本可扩展性挑战在于图可能拥有数百万个节点和数十亿条边。在完整图上训练GNN需要将所有节点特征和整个邻接矩阵存储在内存中,这通常是不可行的。
- GNN的**小批量训练**比图像或序列更复杂,因为节点之间是相互连接的。朴素地采样一批节点需要它们的邻居(第1层)、邻居的邻居(第2层),依此类推。这种**邻域爆炸**意味着一个包含1000个目标节点的小批量可能需要计算图中数百万个节点。
- **邻域采样**(GraphSAGE风格,文件3)通过每层每个节点采样固定数量的邻居来限制爆炸。使用2层和每层15个样本,每个目标节点的子图最多有 $15^2 = 225$ 个节点,与完整图的大小无关。
- **Cluster-GCN**(Chiang等,2019)使用图聚类算法(例如METIS)将图划分为簇,然后一次在一个簇上训练。簇内边是密集的(大多数邻居在同一个簇内),因此子图捕获了相关结构。跨簇边通过偶尔包含簇之间的边来处理。
- **图Transformer的可扩展性**更困难,因为全局注意力是 $O(n^2)$ 的。对于具有数百万个节点的图,完整的注意力是不可行的。解决方案包括:
- 稀疏注意力模式(只关注图中距离最近的 $k$ 个节点)
- 线性注意力近似
- 将局部消息传递(廉价,$O(|E|)$)与粗化图上的全局注意力(更少的节点)相结合
## 时序图与动态图
- 我们迄今为止研究的图是**静态**的:节点、边和特征都是固定的。但许多现实世界的图会**随时间演化**:新用户加入社交网络、金融交易创建边、交通模式全天变化、分子相互作用发生波动。
- **时序图**为每条边增加一个时间戳:$(i, j, t)$ 表示节点 $i$ 在时间 $t$ 与节点 $j$ 发生了交互。挑战在于学习同时捕获图结构和时序动态的表示。
- 存在两种范式:
- **离散时间动态图(DTDG**:图被表示为一系列快照 $G_1, G_2, \ldots, G_T$,每个时间步一个。GNN处理每个快照,RNN或时序注意力机制捕获快照间的演化。这很简单,但丢失了精细的时间信息(快照之间的事件丢失了),并且需要选择快照频率。
- **连续时间动态图(CTDG**:事件被建模为带时间戳的交互流。每个事件 $(i, j, t)$ 在其发生的准确时间更新节点 $i$ 和 $j$ 的表示。这保留了所有时序信息。
- **时序图网络(TGN**Rossi等,2020)是领先的CTDG架构。每个节点维护一个**记忆状态** $\mathbf{s}_i(t)$,每当节点参与交互时更新:
$$\mathbf{s}_i(t^+) = \text{GRU}\left(\mathbf{s}_i(t^-), \; \mathbf{m}_i(t)\right)$$
- 其中 $\mathbf{m}_i(t)$ 是从交互中计算出的消息(结合了两个节点的特征、边特征和时间编码)。GRU(第6章)选择性地保留和遗忘过去的信息,使记忆能够捕获长期模式,同时适应近期事件。
- **时间编码**表示自上次交互以来经过的时间,类似于Transformer中的位置编码(第7章)。常用方法使用可学习的傅里叶特征:
$$\Phi(t) = \left[\cos(\omega_1 t), \sin(\omega_1 t), \ldots, \cos(\omega_d t), \sin(\omega_d t)\right]$$
- 这为模型提供了时间间隔的丰富表示:"该用户上次活跃是5分钟前"与"3个月前"以不同的方式嵌入。
- **时序图注意力(TGAT)**在节点的时间邻域上应用自注意力:一组最近的交互,每个交互同时按特征相关性(如GAT)和时间近度加权。来自遥远过去的交互自然地被降低权重。
- 应用包括欺诈检测(金融图中的异常交易模式)、交通预测(从历史流量模式预测拥堵)、社交网络动态(预测病毒内容传播)以及随时间推移的药物相互作用预测。
## 编程任务(使用CoLab或notebook
1. 从头实现一个单头GAT注意力。计算节点与其邻居之间的注意力权重,并验证权重之和为1。
```python
import jax
import jax.numpy as jnp
rng = jax.random.PRNGKey(0)
k1, k2, k3 = jax.random.split(rng, 3)
n_nodes, d_in, d_out = 5, 4, 3
# 随机节点特征
H = jax.random.normal(k1, (n_nodes, d_in))
# 可学习参数
W = jax.random.normal(k2, (d_in, d_out)) * 0.5
a = jax.random.normal(k3, (2 * d_out,)) * 0.5
# 邻接(节点0连接到1, 2, 3
neighbours_of_0 = [1, 2, 3]
# 变换特征
Wh = H @ W # (n_nodes, d_out)
# 计算节点0的注意力分数
h_i = Wh[0]
scores = []
for j in neighbours_of_0:
h_j = Wh[j]
e_ij = jnp.dot(a, jnp.concatenate([h_i, h_j]))
e_ij = jax.nn.leaky_relu(e_ij, negative_slope=0.2)
scores.append(float(e_ij))
scores = jnp.array(scores)
alpha = jax.nn.softmax(scores)
print(f"原始分数: {scores}")
print(f"注意力权重: {alpha}")
print(f"权重之和: {alpha.sum():.4f}")
# 加权聚合
h_new = sum(alpha[k] * Wh[neighbours_of_0[k]] for k in range(len(neighbours_of_0)))
print(f"更新后的节点0特征: {h_new}")
```
2. 比较GCN(固定权重)和GAT(学习权重)的聚合。展示GAT可以为邻居分配不同的权重,而GCN统一对待它们。
```python
import jax
import jax.numpy as jnp
# 4个节点:节点0连接到1, 2, 3
A = jnp.array([[0,1,1,1],
[1,0,0,0],
[1,0,0,0],
[1,0,0,0]], dtype=float)
# 特征:节点1非常相关,节点2是噪声,节点3中等
H = jnp.array([[0.0, 0.0], # 节点0
[1.0, 0.0], # 节点1(信号)
[0.0, 0.0], # 节点2(噪声)
[0.5, 0.0]]) # 节点3(中等)
# GCN:归一化邻接权重
A_hat = A + jnp.eye(4)
D_inv = jnp.diag(1.0 / A_hat.sum(axis=1))
gcn_weights = (D_inv @ A_hat)[0] # 节点0的权重
print(f"GCN中节点0的权重: {gcn_weights}")
print(" → 所有邻居获得大致相等的权重")
# GAT:学习到的注意力(模拟)
# 假设注意力机制学会关注节点1
gat_weights = jnp.array([0.1, 0.7, 0.05, 0.15]) # 学习到的
print(f"\nGAT中节点0的权重: {gat_weights}")
print(" → 最具信息量的节点1获得最多关注")
gcn_output = gcn_weights @ H
gat_output = gat_weights @ H
print(f"\nGCN输出: {gcn_output} (被噪声稀释)")
print(f"GAT输出: {gat_output} (聚焦于信号)")
```
3. 演示位置编码的益处。计算图的拉普拉斯特征向量编码,展示结构相似的节点获得相似的编码。
```python
import jax.numpy as jnp
import matplotlib.pyplot as plt
# 杠铃图:两个团由一条桥连接
n = 10
A = jnp.zeros((n, n))
# 团1:节点0-4
for i in range(5):
for j in range(i+1, 5):
A = A.at[i,j].set(1).at[j,i].set(1)
# 团2:节点5-9
for i in range(5, 10):
for j in range(i+1, 10):
A = A.at[i,j].set(1).at[j,i].set(1)
# 桥
A = A.at[4,5].set(1).at[5,4].set(1)
D = jnp.diag(A.sum(axis=1))
L = D - A
eigenvalues, eigenvectors = jnp.linalg.eigh(L)
# 使用前3个非平凡特征向量作为位置编码
pe = eigenvectors[:, 1:4]
print("拉普拉斯位置编码:")
for i in range(n):
group = "团1" if i < 5 else "团2"
bridge = " (桥)" if i in [4, 5] else ""
print(f" 节点 {i} ({group}{bridge}): {pe[i]}")
plt.scatter(pe[:5, 0], pe[:5, 1], c="#3498db", s=80, label="团1")
plt.scatter(pe[5:, 0], pe[5:, 1], c="#e74c3c", s=80, label="团2")
plt.scatter(pe[[4,5], 0], pe[[4,5], 1], c="black", s=120, marker="*",
label="桥节点", zorder=5)
plt.legend(); plt.grid(True)
plt.title("拉普拉斯特征向量位置编码")
plt.xlabel("特征向量 1"); plt.ylabel("特征向量 2")
plt.show()
```
@@ -0,0 +1,274 @@
# 3D图网络
*3D图网络将GNN扩展到具有空间几何的数据,其中必须正确处理旋转和平移。本章涵盖几何图、SE(3)/E(n)等变性、SchNet、DimeNet、EGNN、张量场网络以及分子性质预测、蛋白质结构、材料科学和药物发现中的应用——从3D物理世界中学习的架构。*
- 文件3和4中的GNN操作于抽象图:节点有特征,边编码连接性,但没有3D空间的概念。社交网络图没有几何结构。但许多最具影响力的GNN应用涉及存在于**物理3D空间**中的数据:分子、蛋白质、晶体、点云。对于这些数据,节点的空间位置携带了抽象GNN所忽略的关键信息。
- 挑战在于3D数据具有**几何对称性**(文件1):旋转分子不会改变其性质,平移也是如此。3D GNN必须尊重这些对称性。一个会在旋转分子时改变的能量预测在物理上是错误的。
## 几何图
- **几何图**是嵌入在3D空间中的图。每个节点 $i$ 除了其特征向量 $\mathbf{h}_i$ 之外,还有一个位置 $\mathbf{r}_i \in \mathbb{R}^3$。边可以基于空间邻近性(连接距离在 $r_{\text{cut}}$ 内的节点)而不是基于显式的化学键来定义。
- 对于分子,几何图以原子为节点(特征包括:元素类型、电荷等),化学键为边。3D位置 $\mathbf{r}_i$ 是原子坐标,由量子力学或实验测量(X射线晶体学、冷冻电镜)确定。
- 对于点云(来自LiDAR或3D扫描仪,第8章和第11章),每个点是一个节点,具有位置和可选特征(颜色、强度)。边连接附近的点,形成**k最近邻(kNN)图**或半径图。
- 用于消息传递的关键几何量:
- **原子间距离**$d_{ij} = \|\mathbf{r}_i - \mathbf{r}_j\|$。距离对旋转和平移保持不变。具有相同原子间距离的两个分子具有相同的形状,无论朝向如何。
- **键角**:节点 $i$ 处向量 $\mathbf{r}_j - \mathbf{r}_i$ 和 $\mathbf{r}_k - \mathbf{r}_i$ 之间的角度 $\theta_{ijk}$。角度捕获了超越成对距离的局部几何结构。
- **二面角(扭转角)**:由 $(i, j, k)$ 和 $(j, k, l)$ 定义的平面之间的角度 $\phi_{ijkl}$。二面角捕获结构在3D中的扭转方式,对蛋白质主链几何结构至关重要。
- **相对位置向量**$\mathbf{r}_{ij} = \mathbf{r}_j - \mathbf{r}_i$。这些是平移不变的,但不是旋转不变的。使用它们需要等变(而不仅仅是不变)的架构。
## SE(3) 和 E(n) 等变性
- 3D物理数据的对称群是**欧几里得群** $E(3)$,由所有旋转、反射和平移组成。子群 **$SE(3)$**(特殊欧几里得群)包括旋转和平移,但不包括反射。
- 3D GNN应该是:
- 对标量输出(能量、结合亲和力)**平移不变**:将所有原子平移相同向量不应改变预测。
- 对标量输出**旋转不变**:旋转分子不应改变其能量。
- 对向量/张量输出(力、偶极矩)**旋转等变**:旋转分子应使预测的力向量按相同旋转旋转。
![SE(3)等变性:旋转分子使标量预测(能量)保持不变,但使向量预测(力)相应旋转](../images/se3_equivariance.svg)
- 形式上,对标量预测 $f$ 和旋转 $R \in SO(3)$
$$f(R\mathbf{r}_1, R\mathbf{r}_2, \ldots) = f(\mathbf{r}_1, \mathbf{r}_2, \ldots) \quad \text{(不变性)}$$
- 对向量预测 $\mathbf{F}$
$$\mathbf{F}(R\mathbf{r}_1, R\mathbf{r}_2, \ldots) = R \cdot \mathbf{F}(\mathbf{r}_1, \mathbf{r}_2, \ldots) \quad \text{(等变性)}$$
- 这些约束直接反映了文件1中的不变性/等变性框架,现在专门应用于3D旋转和平移群。
- 存在两种设计方法:
1. **不变架构**:只使用不变几何特征(距离、角度)作为消息传递的输入。内部表示是标量(不变的)。简单高效,但不能在不破坏对称性的情况下产生向量输出。
2. **等变架构**:在整个网络中维护向量(以及更高阶张量)表示,确保每一层是等变的。表达能力更强,可以自然地预测向量和张量,但更加复杂。
## SchNet:基于距离的消息传递
- **SchNet**(Schütt等,2017)是基础性的不变3D GNN。其关键创新是**连续滤波器卷积**:不是使用固定的边类型集合(如分子GNN中的键类型),SchNet直接从原子间距离生成消息滤波器。
- 距离 $d_{ij}$ 首先使用**径向基函数(RBF)**扩展为特征向量:
$$\text{RBF}(d_{ij}) = \left[\exp\left(-\gamma_1 (d_{ij} - \mu_1)^2\right), \ldots, \exp\left(-\gamma_K (d_{ij} - \mu_K)^2\right)\right]$$
- 每个基函数是一个以 $\mu_k$ 为中心、宽度为 $\gamma_k$ 的高斯函数。这类似于距离的可学习位置编码:连续距离被映射到一个高维特征空间,网络可以在其中学习距离相关的交互。中心 $\mu_k$ 通常从0到截止半径均匀分布。
- SchNet从节点 $j$ 到节点 $i$ 的消息为:
$$\mathbf{m}_{j \to i} = \mathbf{h}_j \odot W_{\text{filter}}(\text{RBF}(d_{ij}))$$
- 其中 $W_{\text{filter}}$ 是一个将RBF扩展映射到滤波器向量的MLP,$\odot$ 是逐元素乘法(Hadamard乘积,第2章)。滤波器依赖于距离,因此附近的原子与远处的原子产生不同的交互。逐元素乘法类似于门控机制(第6章):依赖于距离的滤波器控制每个特征维度有多少通过。
- 由于SchNet只使用距离(不变的),整个模型自动对旋转和平移保持不变。除了这个设计选择之外,不需要对对称性进行特殊处理。
## DimeNet和SphereNet:角度和二面角
- 仅凭距离不能完全确定3D结构。两个不同的分子构象可以具有相同的成对距离但不同的键角(这就是"距离几何歧义"问题)。**DimeNet**Gasteiger等,2020)将**键角**纳入消息传递。
- DimeNet使用**定向消息传递**:消息沿有向边流动,边 $(j \to i)$ 上的消息受边 $(k \to j)$ 和 $(j \to i)$ 之间的角度影响:
$$\mathbf{m}_{kj \to ji} = f\left(\mathbf{m}_{kj}, d_{ji}, \theta_{kji}\right)$$
- 角度 $\theta_{kji}$ 使用球贝塞尔函数和球谐函数(球面上角度信息的自然基,类似于距离的RBF)进行扩展。这使模型在保持不变性的同时能够访问方向信息。
- **SphereNet**(Liu等,2022)更进一步,包含**二面角** $\phi_{lkji}$,捕获完整的3D扭转结构。层次结构为:
- 距离 → 捕获成对邻近性
- 角度 → 捕获局部几何结构(弯曲 vs. 线性)
- 二面角 → 捕获3D扭转(对蛋白质主链、药物结合至关重要)
- 每个层次增加了几何分辨率,但计算复杂度也随之增加(距离为 $O(|E|)$,角度为 $O(|E| \cdot k)$,二面角为 $O(|E| \cdot k^2)$,其中 $k$ 是平均度数)。
## E(n)等变GNNEGNN
- **EGNN**Satorras等,2021)采用等变方法:它不只使用不变特征,而是在每一层同时更新节点特征**和**节点位置,在整个过程中保持等变性。
- 节点 $i$ 的EGNN更新:
$$\mathbf{m}_{ij} = \phi_e\left(\mathbf{h}_i, \mathbf{h}_j, d_{ij}^2, a_{ij}\right)$$
$$\mathbf{r}_i' = \mathbf{r}_i + C \sum_{j \neq i} (\mathbf{r}_i - \mathbf{r}_j) \cdot \phi_r(\mathbf{m}_{ij})$$
$$\mathbf{h}_i' = \phi_h\left(\mathbf{h}_i, \sum_j \mathbf{m}_{ij}\right)$$
- 关键在于位置更新:节点位置通过相对位置向量 $(\mathbf{r}_i - \mathbf{r}_j)$ 的加权和进行调整。权重来自消息函数 $\phi_r$,该函数仅依赖于不变的量(特征和距离)。这种构造是**可证明等变的**:如果所有输入位置被旋转 $R$,则所有输出位置被相同的 $R$ 旋转。
- EGNN的优雅之处在于它不显式使用球谐函数或不可约表示就实现了等变性。相对位置向量携带方向信息,不变的消息函数控制如何使用该方向信息。
- 这种简洁性是有代价的:EGNN只使用向量表示(1阶)。它无法在未经扩展的情况下表示更高阶的张量,如四极矩或应力张量。
## 张量场网络与高阶表示
- **张量场网络**(Thomas等,2018)及其后继者(**SE(3)-Transformers**、**MACE**、**Equiformer**)使用旋转群的**不可约表示**的完整机制来构建等变层。
- 在表示论中(联系到第2章的线性代数),3D中的旋转可以分解为以整数阶 $\ell$ 为特征的不可约分量:
- $\ell = 0$:标量(1个分量,不变)。能量、电荷。
- $\ell = 1$:向量(3个分量,像位置向量一样旋转)。力、偶极矩。
- $\ell = 2$:秩2对称无迹张量(5个分量)。四极矩、应力张量。
- 更高的 $\ell$:捕获越来越复杂的角结构。
- 这些被称为**球面张量**,它们通过**Wigner-D矩阵** $D^\ell(R)$ 在旋转 $R$ 下变换:标量不变,向量由 $R$ 旋转,秩2张量由更复杂的矩阵旋转。
- 使用球面张量的**等变消息传递**使用**Clebsch-Gordan张量积**来组合不同阶的特征:
$$(\mathbf{f}^{\ell_1} \otimes \mathbf{f}^{\ell_2})^{\ell_{\text{out}}} = \sum_{m_1, m_2} C^{\ell_{\text{out}}, m_{\text{out}}}_{\ell_1, m_1, \ell_2, m_2} \cdot f^{\ell_1}_{m_1} \cdot f^{\ell_2}_{m_2}$$
- Clebsch-Gordan系数 $C$ 是固定的数学常数,确保张量积是等变的。这是SO(3)等变版本的矩阵乘法。
- **MACE**Batatia等,2022)使用高阶消息(多个邻居特征的乘积)以更少的消息传递层达到高精度。通过构建体序相互作用(距离的2体、角度的3体、张量积的多体),MACE高效地捕获了复杂的原子间相互作用。
- **Equiformer**Liao & Smidt2023)将等变球面张量特征与Transformer注意力机制(文件4)相结合,创建了SE(3)等变的图Transformer。注意力分数从不变量特征计算,而值聚合在等变张量特征上进行。
## 应用
- **分子性质预测**:给定分子的3D结构,预测性质如能量、力、偶极矩、HOMO-LUMO能隙、毒性、溶解度。这是3D GNN最成熟的应用。在量子化学数据集(QM9、OC20)上训练的模型在许多性质上达到了化学精度,实现了对数百万候选分子的虚拟筛选。
- **分子动力学加速**:使用量子力学(密度泛函理论,DFT)计算原子间的力极其昂贵(对 $n$ 个电子为 $O(n^3)$)。训练用于预测力的3D GNN可以在分子动力学模拟期间替代DFT,实现 $10^3$–$10^6$ 的加速,同时保持接近DFT的精度。这使得能够模拟更大的系统和更长的时间尺度,揭示传统方法无法观测的现象。
- **蛋白质结构**:蛋白质是折叠成复杂3D结构的氨基酸链。蛋白质主链是一个几何图,其中节点是残基,边连接空间上邻近的残基。3D GNN用于蛋白质功能预测、结合位点识别和蛋白质设计(逆折叠:给定期望结构,预测氨基酸序列)。**AlphaFold**使用几何和基于图的推理从序列预测蛋白质结构。
- **材料科学与催化**:晶体材料具有周期性的3D结构。GNN对重复晶胞进行建模并预测材料性质:带隙、形成能、机械强度。开放催化剂项目(OC20/OC22)对GNN进行基准测试,预测催化表面上的吸附能,加速寻找用于可再生能源的新型催化剂。
- **药物发现**:3D GNN预测药物分子如何与靶蛋白结合。结合亲和力取决于药物与蛋白质结合口袋之间的3D形状互补性和化学相互作用。**DiffDock**等模型使用等变GNN与扩散模型(第8章)来预测结合姿态(药物在蛋白质口袋中的3D朝向)。
## 图生成
- 上述所有架构**分析**现有图。**图生成**创建新的图:设计具有期望性质的分子、生成用于测试的合成社交网络或提出新的蛋白质结构。这是图级别预测的生成对应任务。
- 挑战在于图是离散的、大小可变且组合的。生成图意味着决定要创建多少个节点、它们具有什么特征以及哪些对要连接。可能的图空间随节点数量超指数增长。
- **自回归生成**一次构建一个节点(或一条边)。**GraphRNN**(You等,2018)顺序地生成图:RNN维护一个状态,每一步生成一个新节点,并决定将其连接到哪些现有节点。生成顺序为本来无序的图施加了人工序列,但BFS排序通过保持最近生成的节点相关性来帮助解决问题。
- **基于VAE的生成**将图编码到连续潜在空间(使用GNN编码器),然后从采样的潜在向量解码新图。**GraphVAE**一次性生成一个概率邻接矩阵 $\hat{A} \in [0, 1]^{n \times n}$,但这需要 $O(n^2)$ 规模并产生需要阈值化的密集输出。潜在空间允许平滑插值:在两个分子嵌入之间移动会产生化学上有效的中间结构。
- **基于扩散的生成**将扩散框架(第8章)应用于图。前向过程逐渐向节点特征和边结构添加噪声。反向过程学习去噪,从噪声中生成有效的图。**DiGress**(Vignac等,2023)对节点类型和边类型应用离散扩散,自然地处理图数据的分类性质。
- 对于**分子生成**,关键约束是**化学有效性**:生成的分子必须遵守化合价规则(碳形成4个键,氧形成2个,等等)。**Junction Tree VAEJT-VAE)**等方将分子分解为有效子结构(环、链、官能团),并通过组装这些构建块来生成,通过构造保证有效性。
- **目标导向生成**优化特定性质:生成对靶蛋白具有高结合亲和力、低毒性和良好溶解度的分子。这在一个循环中结合了图生成与性质预测(使用3D GNN作为性质评估器):生成 → 评估 → 精炼。强化学习(第6章)或贝叶斯优化指导着化学空间的搜索。
- **DiffDock**Corso等,2023)使用SE(3)等变扩散来预测药物分子如何对接入蛋白质结合口袋。该模型通过从随机位置去噪来生成3D结合姿态(药物相对于蛋白质的位置和朝向),将本文件中的3D等变网络与第8章的扩散框架相结合。
## 编程任务(使用CoLab或notebook
1. 构建一个使用原子间距离的简单不变3D消息传递层。将其应用于一个小分子(水:H-O-H),并验证输出对旋转是不变的。
```python
import jax
import jax.numpy as jnp
# 水分子:O在原点,两个H原子
positions = jnp.array([[0.0, 0.0, 0.0], # O
[0.96, 0.0, 0.0], # H1
[-0.24, 0.93, 0.0]]) # H2
# 节点特征:[原子序数]
features = jnp.array([[8.0], [1.0], [1.0]])
# 计算成对距离(不变的)
def pairwise_distances(pos):
diff = pos[:, None, :] - pos[None, :, :]
return jnp.sqrt(jnp.sum(diff**2, axis=-1) + 1e-8)
# 简单的基于距离的消息传递
def invariant_message_pass(features, positions):
dists = pairwise_distances(positions)
# 具有4个中心的RBF扩展
centres = jnp.array([0.5, 1.0, 1.5, 2.0])
rbf = jnp.exp(-5.0 * (dists[:, :, None] - centres[None, None, :]) ** 2)
# 消息:由距离相关滤波器加权的特征
messages = jnp.einsum("ij,jd->id", rbf.sum(axis=-1), features)
return messages
output1 = invariant_message_pass(features, positions)
# 将分子绕z轴旋转90度
R = jnp.array([[0, -1, 0], [1, 0, 0], [0, 0, 1]], dtype=float)
rotated_positions = (R @ positions.T).T
output2 = invariant_message_pass(features, rotated_positions)
print(f"原始输出:\n{output1}")
print(f"\n旋转后输出:\n{output2}")
print(f"\n不变性: {jnp.allclose(output1, output2, atol=1e-5)}")
```
2. 计算三个原子之间的键角,并验证其对旋转不变。
```python
import jax.numpy as jnp
def bond_angle(r_i, r_j, r_k):
"""节点j处边j->i和j->k之间的角度。"""
v1 = r_i - r_j
v2 = r_k - r_j
cos_angle = jnp.dot(v1, v2) / (jnp.linalg.norm(v1) * jnp.linalg.norm(v2))
return jnp.arccos(jnp.clip(cos_angle, -1, 1))
# 三个原子
r1 = jnp.array([1.0, 0.0, 0.0])
r2 = jnp.array([0.0, 0.0, 0.0])
r3 = jnp.array([0.0, 1.0, 0.0])
angle_original = bond_angle(r1, r2, r3)
print(f"原始角度: {jnp.degrees(angle_original):.1f}°")
# 应用随机旋转
R = jnp.array([[0.36, 0.48, -0.80],
[-0.80, 0.60, 0.00],
[0.48, 0.64, 0.60]])
r1_rot, r2_rot, r3_rot = R @ r1, R @ r2, R @ r3
angle_rotated = bond_angle(r1_rot, r2_rot, r3_rot)
print(f"旋转后角度: {jnp.degrees(angle_rotated):.1f}°")
print(f"不变性: {jnp.allclose(angle_original, angle_rotated, atol=1e-4)}")
```
3. 演示等变位置更新(EGNN风格)。使用距离加权的相对向量更新节点位置,并验证等变性。
```python
import jax
import jax.numpy as jnp
def egnn_position_update(positions, features):
"""简单的EGNN风格等变位置更新。"""
n = positions.shape[0]
new_positions = jnp.zeros_like(positions)
for i in range(n):
shift = jnp.zeros(3)
for j in range(n):
if i != j:
r_ij = positions[i] - positions[j]
d_ij = jnp.linalg.norm(r_ij)
# 基于距离的权重(简单:反比距离)
weight = 1.0 / (d_ij + 1.0)
# 按特征相似度缩放
feat_sim = jnp.dot(features[i], features[j])
shift = shift + weight * feat_sim * r_ij
new_positions = new_positions.at[i].set(positions[i] + 0.1 * shift)
return new_positions
# 3个原子
pos = jnp.array([[0.0, 0.0, 0.0], [1.0, 0.0, 0.0], [0.0, 1.0, 0.0]])
feat = jnp.array([[1.0, 0.5], [0.5, 1.0], [0.8, 0.3]])
# 更新位置
pos_new = egnn_position_update(pos, feat)
# 现在旋转输入、更新,并检查输出是否一致地旋转
R = jnp.array([[0.0, -1.0, 0.0], [1.0, 0.0, 0.0], [0.0, 0.0, 1.0]])
pos_rot = (R @ pos.T).T
pos_new_from_rot = egnn_position_update(pos_rot, feat)
# 应与旋转原始输出相同
pos_new_then_rot = (R @ pos_new.T).T
print(f"先更新再旋转:\n{jnp.round(pos_new_then_rot, 4)}")
print(f"\n先旋转再更新:\n{jnp.round(pos_new_from_rot, 4)}")
print(f"\n等变性: {jnp.allclose(pos_new_then_rot, pos_new_from_rot, atol=1e-4)}")
```