Files
maths-cs-ai-compendium-zh/chapter 12: graph neural networks/04. graph attention networks.md
T
flykhan 2536c937e3 feat: 完整中文翻译 maths-cs-ai-compendium(数学·计算机科学·AI 知识大全)
翻译自英文原版 maths-cs-ai-compendium,共 20 章全部完成。

第01章 向量 | 第02章 矩阵 | 第03章 微积分
第04章 统计学 | 第05章 概率论 | 第06章 机器学习
第07章 计算语言学 | 第08章 计算机视觉 | 第09章 音频与语音
第10章 多模态学习 | 第11章 自主系统 | 第12章 图神经网络
第13章 计算与操作系统 | 第14章 数据结构与算法
第15章 生产级软件工程 | 第16章 SIMD与GPU编程
第17章 AI推理 | 第18章 ML系统设计
第19章 应用人工智能 | 第20章 前沿人工智能

翻译说明:
- 所有数学公式 $...$ / $$...$$、代码块、图片引用完整保留
- mkdocs.yml 配置中文导航 + language: zh
- README.md 已翻译为中文(兼 docs/index.md)
- docs/ 目录包含指向各章文件的 symlink
- 约 29,000 行中文内容,排除 .cache/ 构建缓存
2026-05-03 10:23:20 +08:00

259 lines
15 KiB
Markdown
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# 图注意力网络
*图注意力网络将均匀的邻居聚合替换为学习到的、依赖数据的加权。本章涵盖GAT、多头图注意力、GATv2、图Transformer、位置和结构编码以及可扩展性*
- 在GCN(文件3)中,每个节点使用由图结构确定的固定权重(归一化邻接矩阵)聚合其邻居特征。一个有三个邻居的节点会给每个邻居大致相等的权重($\approx 1/3$)。但并非所有邻居都同等重要:来自密切合作者的消息应比来自远方熟人的消息更重要。
- **图注意力网络**通过使用与Transformer(第7章)相同的注意力机制来学习**关注哪些邻居**,从而解决了这一问题。与固定的、基于结构的权重不同,每个节点在其邻居上计算动态的、基于内容的注意力分数。
## GAT:图注意力网络
- **GAT**Veličković等,2018)计算每个节点与其邻居之间的注意力系数。对于节点 $i$ 和邻居 $j$:
$$e_{ij} = \text{LeakyReLU}\left(\mathbf{a}^T \left[W\mathbf{h}_i \| W\mathbf{h}_j\right]\right)$$
- 其中 $W \in \mathbb{R}^{d' \times d}$ 是共享的线性变换,$\|$ 表示拼接,$\mathbf{a} \in \mathbb{R}^{2d'}$ 是可学习的注意力向量。分数 $e_{ij}$ 衡量节点 $j$ 的特征对节点 $i$ 的重要程度。
- 原始分数使用softmax在所有邻居之间进行归一化:
$$\alpha_{ij} = \text{softmax}_j(e_{ij}) = \frac{\exp(e_{ij})}{\sum_{k \in \mathcal{N}(i)} \exp(e_{ik})}$$
- 这确保了每个节点邻域上的注意力权重之和为1,就像Transformer注意力一样(第7章)。节点更新后的特征为:
$$\mathbf{h}_i' = \sigma\left(\sum_{j \in \mathcal{N}(i)} \alpha_{ij} W\mathbf{h}_j\right)$$
![GCN为所有邻居分配固定的等权重;GAT学习依赖数据的注意力权重](../images/gat_attention_weights.svg)
- 与GCN的关键区别:权重 $\alpha_{ij}$ 是**从数据中学习**的,而非由图结构固定。节点可以学会关注信息量最大的邻居,同时忽略噪声或无关的邻居。
- 注意,注意力仅在边上计算(节点 $i$ 只关注其邻居 $\mathcal{N}(i)$),而不是在所有节点对之间。这使得计算量与边的数量成正比,而不是节点数的平方。
## 多头图注意力
- 正如在Transformer中(第7章),**多头注意力**并行运行 $K$ 个独立的注意力机制,每个都有自己的参数 $W^k$ 和 $\mathbf{a}^k$。结果在中间层进行拼接,在最终层取平均:
$$\mathbf{h}_i' = \Big\|_{k=1}^{K} \sigma\left(\sum_{j \in \mathcal{N}(i)} \alpha_{ij}^k W^k \mathbf{h}_j\right)$$
- 每个头可以关注邻域的不同方面:一个头可能关注结构特征,另一个关注语义相似性。这与Transformer中多头注意力的动机相同:不同的头捕获不同类型的关系。
- 使用 $K$ 个头和每个头输出维度 $d'$,拼接后的输出维度为 $K \times d'$。最后一层通常使用平均而不是拼接来产生固定大小的输出。
## GATv2:修复静态注意力
- 原始GAT有一个微妙的限制:其注意力函数是**静态的**(也称为基于排序的)。注意力分数取决于拼接 $[W\mathbf{h}_i \| W\mathbf{h}_j]$,但由于注意力向量 $\mathbf{a}$ 在拼接之后应用,它可以分解为两个独立的分量:$\mathbf{a}^T [W\mathbf{h}_i \| W\mathbf{h}_j] = \mathbf{a}_1^T W\mathbf{h}_i + \mathbf{a}_2^T W\mathbf{h}_j$。
- 这意味着对于给定节点 $i$,邻居的排序完全由邻居的特征 $\mathbf{h}_j$ 决定(项 $\mathbf{a}_1^T W\mathbf{h}_i$ 在 $i$ 的所有邻居中是常数)。注意力排名并不真正依赖于查询节点的特征。节点 $i$ 和节点 $k$ 将以完全相同的方式对同一组邻居进行排序,这限制了表达能力。
- **GATv2**Brody等,2022)通过在注意力向量之前应用非线性函数来修复这个问题:
$$e_{ij} = \mathbf{a}^T \text{LeakyReLU}\left(W \left[\mathbf{h}_i \| \mathbf{h}_j\right]\right)$$
- 将LeakyReLU移到计算内部意味着注意力分数是联合特征的非线性函数,不能分解为独立项。这使得注意力变为**动态**:邻居的排序现在依赖于特定的查询节点。GATv2严格比GAT更具表达能力,且没有额外的计算成本。
## 图Transformer
- 标准消息传递GNN受到图拓扑的限制:一个节点只能关注其直接邻居。经过 $k$ 层后,来自 $k$ 跳邻居的信息已通过多个聚合步骤混合,失去了保真度。这种局部瓶颈(再加上文件3中的过平滑)限制了捕获长距离依赖关系的能力。
- **图Transformer**通过将**全局自注意力**应用于所有节点对(无论它们之间是否有边)来突破这个瓶颈。每个节点可以在单层中关注每个其他节点,就像标准Transformer一样(第7章)。
- 基本思想:将所有节点视为标记(token),应用Transformer自注意力:
$$\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V$$
- 其中 $Q = XW_Q$$K = XW_K$$V = XW_V$ 是节点特征 $X$ 的查询、键和值投影(与第7章完全相同)。这是完全连接图(完全图 $K_n$,文件2)上的GNN。
- 问题:完全连接图忽略了实际的图结构。边信息(谁实际连接到谁)丢失了。两种方法恢复了这一点:
- **Graphormer**Ying等,2021)通过注意力分数中的**偏置项**将图结构注入Transformer
$$A_{ij} = \frac{(\mathbf{h}_i W_Q)(W_K^T \mathbf{h}_j^T)}{\sqrt{d_k}} + b_{\text{spatial}}(i, j) + b_{\text{edge}}(i, j)$$
- 空间偏置 $b_{\text{spatial}}$ 编码节点 $i$ 和 $j$ 之间的最短路径距离。边偏置 $b_{\text{edge}}$ 编码沿最短路径的边特征。此外,Graphormer使用**中心性编码**,将节点的度数添加到其输入嵌入中,为模型提供关于每个节点结构角色的信息。
- **GPS**(通用、强大、可扩展的图TransformerRampášek等,2022)在每一层中结合了局部消息传递和全局注意力:
$$\mathbf{h}_i' = \text{MLP}\left(\mathbf{h}_i^{\text{MPNN}} + \mathbf{h}_i^{\text{Attention}}\right)$$
- 每一层同时应用标准GNN(用于局部结构)和Transformer(用于全局上下文),然后组合结果。这获得了两个世界的优点:来自消息传递的局部结构和来自注意力的长距离依赖关系。
## 位置编码与结构编码
- 序列上的Transformer使用位置编码(第7章)来注入顺序信息。图没有规范的顺序,因此需要特定于图的编码。
- **拉普拉斯特征向量编码**使用图拉普拉斯算子(文件2)的特征向量作为位置特征。$k$ 个最小的非平凡特征向量提供了图的谱嵌入:在图中"附近"的节点具有相似的特征向量值。这些被拼接到节点特征中。
- 一个微妙之处:拉普拉斯特征向量有符号模糊性(如果 $\mathbf{u}$ 是特征向量,$-\mathbf{u}$ 也是)。模型必须对这些符号翻转保持不变。解决方案包括在训练期间使用随机符号翻转作为数据增强,或学习符号不变的变换。
- **随机游走编码**计算从节点 $i$ 开始的随机游走经过 $k$ 步后返回节点 $i$ 的概率,对于 $k = 1, 2, \ldots, K$。这些概率编码了局部结构信息:密集簇中的节点具有高的返回概率,而稀疏区域中的节点返回概率低。着陆概率 $p_{ii}^{(k)} = (A_{\text{rw}}^k)_{ii}$,其中 $A_{\text{rw}} = D^{-1}A$ 是随机游走转移矩阵。
- **度数编码**简单地将节点度数作为一个特征添加。这出奇地有效,因为度数是一个强大的结构信号:叶节点(度数为1)、桥接节点和枢纽节点的行为不同。
- 这些编码提供了普通Transformer所缺乏的结构信息,使图Transformer在需要长距离推理的任务上能够超越标准消息传递GNN。
## 可扩展性
- GNN的基本可扩展性挑战在于图可能拥有数百万个节点和数十亿条边。在完整图上训练GNN需要将所有节点特征和整个邻接矩阵存储在内存中,这通常是不可行的。
- GNN的**小批量训练**比图像或序列更复杂,因为节点之间是相互连接的。朴素地采样一批节点需要它们的邻居(第1层)、邻居的邻居(第2层),依此类推。这种**邻域爆炸**意味着一个包含1000个目标节点的小批量可能需要计算图中数百万个节点。
- **邻域采样**(GraphSAGE风格,文件3)通过每层每个节点采样固定数量的邻居来限制爆炸。使用2层和每层15个样本,每个目标节点的子图最多有 $15^2 = 225$ 个节点,与完整图的大小无关。
- **Cluster-GCN**Chiang等,2019)使用图聚类算法(例如METIS)将图划分为簇,然后一次在一个簇上训练。簇内边是密集的(大多数邻居在同一个簇内),因此子图捕获了相关结构。跨簇边通过偶尔包含簇之间的边来处理。
- **图Transformer的可扩展性**更困难,因为全局注意力是 $O(n^2)$ 的。对于具有数百万个节点的图,完整的注意力是不可行的。解决方案包括:
- 稀疏注意力模式(只关注图中距离最近的 $k$ 个节点)
- 线性注意力近似
- 将局部消息传递(廉价,$O(|E|)$)与粗化图上的全局注意力(更少的节点)相结合
## 时序图与动态图
- 我们迄今为止研究的图是**静态**的:节点、边和特征都是固定的。但许多现实世界的图会**随时间演化**:新用户加入社交网络、金融交易创建边、交通模式全天变化、分子相互作用发生波动。
- **时序图**为每条边增加一个时间戳:$(i, j, t)$ 表示节点 $i$ 在时间 $t$ 与节点 $j$ 发生了交互。挑战在于学习同时捕获图结构和时序动态的表示。
- 存在两种范式:
- **离散时间动态图(DTDG)**:图被表示为一系列快照 $G_1, G_2, \ldots, G_T$,每个时间步一个。GNN处理每个快照,RNN或时序注意力机制捕获快照间的演化。这很简单,但丢失了精细的时间信息(快照之间的事件丢失了),并且需要选择快照频率。
- **连续时间动态图(CTDG)**:事件被建模为带时间戳的交互流。每个事件 $(i, j, t)$ 在其发生的准确时间更新节点 $i$ 和 $j$ 的表示。这保留了所有时序信息。
- **时序图网络(TGN**Rossi等,2020)是领先的CTDG架构。每个节点维护一个**记忆状态** $\mathbf{s}_i(t)$,每当节点参与交互时更新:
$$\mathbf{s}_i(t^+) = \text{GRU}\left(\mathbf{s}_i(t^-), \; \mathbf{m}_i(t)\right)$$
- 其中 $\mathbf{m}_i(t)$ 是从交互中计算出的消息(结合了两个节点的特征、边特征和时间编码)。GRU(第6章)选择性地保留和遗忘过去的信息,使记忆能够捕获长期模式,同时适应近期事件。
- **时间编码**表示自上次交互以来经过的时间,类似于Transformer中的位置编码(第7章)。常用方法使用可学习的傅里叶特征:
$$\Phi(t) = \left[\cos(\omega_1 t), \sin(\omega_1 t), \ldots, \cos(\omega_d t), \sin(\omega_d t)\right]$$
- 这为模型提供了时间间隔的丰富表示:"该用户上次活跃是5分钟前"与"3个月前"以不同的方式嵌入。
- **时序图注意力(TGAT)**在节点的时间邻域上应用自注意力:一组最近的交互,每个交互同时按特征相关性(如GAT)和时间近度加权。来自遥远过去的交互自然地被降低权重。
- 应用包括欺诈检测(金融图中的异常交易模式)、交通预测(从历史流量模式预测拥堵)、社交网络动态(预测病毒内容传播)以及随时间推移的药物相互作用预测。
## 编程任务(使用CoLab或notebook
1. 从头实现一个单头GAT注意力。计算节点与其邻居之间的注意力权重,并验证权重之和为1。
```python
import jax
import jax.numpy as jnp
rng = jax.random.PRNGKey(0)
k1, k2, k3 = jax.random.split(rng, 3)
n_nodes, d_in, d_out = 5, 4, 3
# 随机节点特征
H = jax.random.normal(k1, (n_nodes, d_in))
# 可学习参数
W = jax.random.normal(k2, (d_in, d_out)) * 0.5
a = jax.random.normal(k3, (2 * d_out,)) * 0.5
# 邻接(节点0连接到1, 2, 3
neighbours_of_0 = [1, 2, 3]
# 变换特征
Wh = H @ W # (n_nodes, d_out)
# 计算节点0的注意力分数
h_i = Wh[0]
scores = []
for j in neighbours_of_0:
h_j = Wh[j]
e_ij = jnp.dot(a, jnp.concatenate([h_i, h_j]))
e_ij = jax.nn.leaky_relu(e_ij, negative_slope=0.2)
scores.append(float(e_ij))
scores = jnp.array(scores)
alpha = jax.nn.softmax(scores)
print(f"原始分数: {scores}")
print(f"注意力权重: {alpha}")
print(f"权重之和: {alpha.sum():.4f}")
# 加权聚合
h_new = sum(alpha[k] * Wh[neighbours_of_0[k]] for k in range(len(neighbours_of_0)))
print(f"更新后的节点0特征: {h_new}")
```
2. 比较GCN(固定权重)和GAT(学习权重)的聚合。展示GAT可以为邻居分配不同的权重,而GCN统一对待它们。
```python
import jax
import jax.numpy as jnp
# 4个节点:节点0连接到1, 2, 3
A = jnp.array([[0,1,1,1],
[1,0,0,0],
[1,0,0,0],
[1,0,0,0]], dtype=float)
# 特征:节点1非常相关,节点2是噪声,节点3中等
H = jnp.array([[0.0, 0.0], # 节点0
[1.0, 0.0], # 节点1(信号)
[0.0, 0.0], # 节点2(噪声)
[0.5, 0.0]]) # 节点3(中等)
# GCN:归一化邻接权重
A_hat = A + jnp.eye(4)
D_inv = jnp.diag(1.0 / A_hat.sum(axis=1))
gcn_weights = (D_inv @ A_hat)[0] # 节点0的权重
print(f"GCN中节点0的权重: {gcn_weights}")
print(" → 所有邻居获得大致相等的权重")
# GAT:学习到的注意力(模拟)
# 假设注意力机制学会关注节点1
gat_weights = jnp.array([0.1, 0.7, 0.05, 0.15]) # 学习到的
print(f"\nGAT中节点0的权重: {gat_weights}")
print(" → 最具信息量的节点1获得最多关注")
gcn_output = gcn_weights @ H
gat_output = gat_weights @ H
print(f"\nGCN输出: {gcn_output} (被噪声稀释)")
print(f"GAT输出: {gat_output} (聚焦于信号)")
```
3. 演示位置编码的益处。计算图的拉普拉斯特征向量编码,展示结构相似的节点获得相似的编码。
```python
import jax.numpy as jnp
import matplotlib.pyplot as plt
# 杠铃图:两个团由一条桥连接
n = 10
A = jnp.zeros((n, n))
# 团1:节点0-4
for i in range(5):
for j in range(i+1, 5):
A = A.at[i,j].set(1).at[j,i].set(1)
# 团2:节点5-9
for i in range(5, 10):
for j in range(i+1, 10):
A = A.at[i,j].set(1).at[j,i].set(1)
# 桥
A = A.at[4,5].set(1).at[5,4].set(1)
D = jnp.diag(A.sum(axis=1))
L = D - A
eigenvalues, eigenvectors = jnp.linalg.eigh(L)
# 使用前3个非平凡特征向量作为位置编码
pe = eigenvectors[:, 1:4]
print("拉普拉斯位置编码:")
for i in range(n):
group = "团1" if i < 5 else "团2"
bridge = " (桥)" if i in [4, 5] else ""
print(f" 节点 {i} ({group}{bridge}): {pe[i]}")
plt.scatter(pe[:5, 0], pe[:5, 1], c="#3498db", s=80, label="团1")
plt.scatter(pe[5:, 0], pe[5:, 1], c="#e74c3c", s=80, label="团2")
plt.scatter(pe[[4,5], 0], pe[[4,5], 1], c="black", s=120, marker="*",
label="桥节点", zorder=5)
plt.legend(); plt.grid(True)
plt.title("拉普拉斯特征向量位置编码")
plt.xlabel("特征向量 1"); plt.ylabel("特征向量 2")
plt.show()
```