Files
maths-cs-ai-compendium-zh/chapter 12: graph neural networks/04. graph attention networks.md
T
flykhan 2536c937e3 feat: 完整中文翻译 maths-cs-ai-compendium(数学·计算机科学·AI 知识大全)
翻译自英文原版 maths-cs-ai-compendium,共 20 章全部完成。

第01章 向量 | 第02章 矩阵 | 第03章 微积分
第04章 统计学 | 第05章 概率论 | 第06章 机器学习
第07章 计算语言学 | 第08章 计算机视觉 | 第09章 音频与语音
第10章 多模态学习 | 第11章 自主系统 | 第12章 图神经网络
第13章 计算与操作系统 | 第14章 数据结构与算法
第15章 生产级软件工程 | 第16章 SIMD与GPU编程
第17章 AI推理 | 第18章 ML系统设计
第19章 应用人工智能 | 第20章 前沿人工智能

翻译说明:
- 所有数学公式 $...$ / $$...$$、代码块、图片引用完整保留
- mkdocs.yml 配置中文导航 + language: zh
- README.md 已翻译为中文(兼 docs/index.md)
- docs/ 目录包含指向各章文件的 symlink
- 约 29,000 行中文内容,排除 .cache/ 构建缓存
2026-05-03 10:23:20 +08:00

15 KiB
Raw Blame History

图注意力网络

图注意力网络将均匀的邻居聚合替换为学习到的、依赖数据的加权。本章涵盖GAT、多头图注意力、GATv2、图Transformer、位置和结构编码以及可扩展性

  • 在GCN(文件3)中,每个节点使用由图结构确定的固定权重(归一化邻接矩阵)聚合其邻居特征。一个有三个邻居的节点会给每个邻居大致相等的权重($\approx 1/3$)。但并非所有邻居都同等重要:来自密切合作者的消息应比来自远方熟人的消息更重要。

  • 图注意力网络通过使用与Transformer(第7章)相同的注意力机制来学习关注哪些邻居,从而解决了这一问题。与固定的、基于结构的权重不同,每个节点在其邻居上计算动态的、基于内容的注意力分数。

GAT:图注意力网络

  • GATVeličković等,2018)计算每个节点与其邻居之间的注意力系数。对于节点 i 和邻居 $j$
e_{ij} = \text{LeakyReLU}\left(\mathbf{a}^T \left[W\mathbf{h}_i \| W\mathbf{h}_j\right]\right)
  • 其中 W \in \mathbb{R}^{d' \times d} 是共享的线性变换,\| 表示拼接,\mathbf{a} \in \mathbb{R}^{2d'} 是可学习的注意力向量。分数 e_{ij} 衡量节点 j 的特征对节点 i 的重要程度。

  • 原始分数使用softmax在所有邻居之间进行归一化:

\alpha_{ij} = \text{softmax}_j(e_{ij}) = \frac{\exp(e_{ij})}{\sum_{k \in \mathcal{N}(i)} \exp(e_{ik})}
  • 这确保了每个节点邻域上的注意力权重之和为1,就像Transformer注意力一样(第7章)。节点更新后的特征为:
\mathbf{h}_i' = \sigma\left(\sum_{j \in \mathcal{N}(i)} \alpha_{ij} W\mathbf{h}_j\right)

GCN为所有邻居分配固定的等权重;GAT学习依赖数据的注意力权重

  • 与GCN的关键区别:权重 \alpha_{ij}从数据中学习的,而非由图结构固定。节点可以学会关注信息量最大的邻居,同时忽略噪声或无关的邻居。

  • 注意,注意力仅在边上计算(节点 i 只关注其邻居 $\mathcal{N}(i)$),而不是在所有节点对之间。这使得计算量与边的数量成正比,而不是节点数的平方。

多头图注意力

  • 正如在Transformer中(第7章),多头注意力并行运行 K 个独立的注意力机制,每个都有自己的参数 W^k 和 $\mathbf{a}^k$。结果在中间层进行拼接,在最终层取平均:
\mathbf{h}_i' = \Big\|_{k=1}^{K} \sigma\left(\sum_{j \in \mathcal{N}(i)} \alpha_{ij}^k W^k \mathbf{h}_j\right)
  • 每个头可以关注邻域的不同方面:一个头可能关注结构特征,另一个关注语义相似性。这与Transformer中多头注意力的动机相同:不同的头捕获不同类型的关系。

  • 使用 K 个头和每个头输出维度 $d'$,拼接后的输出维度为 $K \times d'$。最后一层通常使用平均而不是拼接来产生固定大小的输出。

GATv2:修复静态注意力

  • 原始GAT有一个微妙的限制:其注意力函数是静态的(也称为基于排序的)。注意力分数取决于拼接 $[W\mathbf{h}_i | W\mathbf{h}_j]$,但由于注意力向量 \mathbf{a} 在拼接之后应用,它可以分解为两个独立的分量:$\mathbf{a}^T [W\mathbf{h}_i | W\mathbf{h}_j] = \mathbf{a}_1^T W\mathbf{h}_i + \mathbf{a}_2^T W\mathbf{h}_j$。

  • 这意味着对于给定节点 $i$,邻居的排序完全由邻居的特征 \mathbf{h}_j 决定(项 \mathbf{a}_1^T W\mathbf{h}_ii 的所有邻居中是常数)。注意力排名并不真正依赖于查询节点的特征。节点 i 和节点 k 将以完全相同的方式对同一组邻居进行排序,这限制了表达能力。

  • GATv2(Brody等,2022)通过在注意力向量之前应用非线性函数来修复这个问题:

e_{ij} = \mathbf{a}^T \text{LeakyReLU}\left(W \left[\mathbf{h}_i \| \mathbf{h}_j\right]\right)
  • 将LeakyReLU移到计算内部意味着注意力分数是联合特征的非线性函数,不能分解为独立项。这使得注意力变为动态:邻居的排序现在依赖于特定的查询节点。GATv2严格比GAT更具表达能力,且没有额外的计算成本。

图Transformer

  • 标准消息传递GNN受到图拓扑的限制:一个节点只能关注其直接邻居。经过 k 层后,来自 k 跳邻居的信息已通过多个聚合步骤混合,失去了保真度。这种局部瓶颈(再加上文件3中的过平滑)限制了捕获长距离依赖关系的能力。

  • 图Transformer通过将全局自注意力应用于所有节点对(无论它们之间是否有边)来突破这个瓶颈。每个节点可以在单层中关注每个其他节点,就像标准Transformer一样(第7章)。

  • 基本思想:将所有节点视为标记(token),应用Transformer自注意力:

\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V
  • 其中 $Q = XW_Q$$K = XW_K$V = XW_V 是节点特征 X 的查询、键和值投影(与第7章完全相同)。这是完全连接图(完全图 $K_n$,文件2)上的GNN。

  • 问题:完全连接图忽略了实际的图结构。边信息(谁实际连接到谁)丢失了。两种方法恢复了这一点:

  • Graphormer(Ying等,2021)通过注意力分数中的偏置项将图结构注入Transformer

A_{ij} = \frac{(\mathbf{h}_i W_Q)(W_K^T \mathbf{h}_j^T)}{\sqrt{d_k}} + b_{\text{spatial}}(i, j) + b_{\text{edge}}(i, j)
  • 空间偏置 b_{\text{spatial}} 编码节点 ij 之间的最短路径距离。边偏置 b_{\text{edge}} 编码沿最短路径的边特征。此外,Graphormer使用中心性编码,将节点的度数添加到其输入嵌入中,为模型提供关于每个节点结构角色的信息。

  • GPS(通用、强大、可扩展的图TransformerRampášek等,2022)在每一层中结合了局部消息传递和全局注意力:

\mathbf{h}_i' = \text{MLP}\left(\mathbf{h}_i^{\text{MPNN}} + \mathbf{h}_i^{\text{Attention}}\right)
  • 每一层同时应用标准GNN(用于局部结构)和Transformer(用于全局上下文),然后组合结果。这获得了两个世界的优点:来自消息传递的局部结构和来自注意力的长距离依赖关系。

位置编码与结构编码

  • 序列上的Transformer使用位置编码(第7章)来注入顺序信息。图没有规范的顺序,因此需要特定于图的编码。

  • 拉普拉斯特征向量编码使用图拉普拉斯算子(文件2)的特征向量作为位置特征。k 个最小的非平凡特征向量提供了图的谱嵌入:在图中"附近"的节点具有相似的特征向量值。这些被拼接到节点特征中。

  • 一个微妙之处:拉普拉斯特征向量有符号模糊性(如果 \mathbf{u} 是特征向量,-\mathbf{u} 也是)。模型必须对这些符号翻转保持不变。解决方案包括在训练期间使用随机符号翻转作为数据增强,或学习符号不变的变换。

  • 随机游走编码计算从节点 i 开始的随机游走经过 k 步后返回节点 i 的概率,对于 $k = 1, 2, \ldots, K$。这些概率编码了局部结构信息:密集簇中的节点具有高的返回概率,而稀疏区域中的节点返回概率低。着陆概率 $p_{ii}^{(k)} = (A_{\text{rw}}^k)_{ii}$,其中 A_{\text{rw}} = D^{-1}A 是随机游走转移矩阵。

  • 度数编码简单地将节点度数作为一个特征添加。这出奇地有效,因为度数是一个强大的结构信号:叶节点(度数为1)、桥接节点和枢纽节点的行为不同。

  • 这些编码提供了普通Transformer所缺乏的结构信息,使图Transformer在需要长距离推理的任务上能够超越标准消息传递GNN。

可扩展性

  • GNN的基本可扩展性挑战在于图可能拥有数百万个节点和数十亿条边。在完整图上训练GNN需要将所有节点特征和整个邻接矩阵存储在内存中,这通常是不可行的。

  • GNN的小批量训练比图像或序列更复杂,因为节点之间是相互连接的。朴素地采样一批节点需要它们的邻居(第1层)、邻居的邻居(第2层),依此类推。这种邻域爆炸意味着一个包含1000个目标节点的小批量可能需要计算图中数百万个节点。

  • 邻域采样(GraphSAGE风格,文件3)通过每层每个节点采样固定数量的邻居来限制爆炸。使用2层和每层15个样本,每个目标节点的子图最多有 15^2 = 225 个节点,与完整图的大小无关。

  • Cluster-GCN(Chiang等,2019)使用图聚类算法(例如METIS)将图划分为簇,然后一次在一个簇上训练。簇内边是密集的(大多数邻居在同一个簇内),因此子图捕获了相关结构。跨簇边通过偶尔包含簇之间的边来处理。

  • 图Transformer的可扩展性更困难,因为全局注意力是 O(n^2) 的。对于具有数百万个节点的图,完整的注意力是不可行的。解决方案包括:

    • 稀疏注意力模式(只关注图中距离最近的 k 个节点)
    • 线性注意力近似
    • 将局部消息传递(廉价,$O(|E|)$)与粗化图上的全局注意力(更少的节点)相结合

时序图与动态图

  • 我们迄今为止研究的图是静态的:节点、边和特征都是固定的。但许多现实世界的图会随时间演化:新用户加入社交网络、金融交易创建边、交通模式全天变化、分子相互作用发生波动。

  • 时序图为每条边增加一个时间戳:(i, j, t) 表示节点 i 在时间 t 与节点 j 发生了交互。挑战在于学习同时捕获图结构和时序动态的表示。

  • 存在两种范式:

  • 离散时间动态图(DTDG:图被表示为一系列快照 $G_1, G_2, \ldots, G_T$,每个时间步一个。GNN处理每个快照,RNN或时序注意力机制捕获快照间的演化。这很简单,但丢失了精细的时间信息(快照之间的事件丢失了),并且需要选择快照频率。

  • 连续时间动态图(CTDG:事件被建模为带时间戳的交互流。每个事件 (i, j, t) 在其发生的准确时间更新节点 ij 的表示。这保留了所有时序信息。

  • 时序图网络(TGNRossi等,2020)是领先的CTDG架构。每个节点维护一个记忆状态 $\mathbf{s}_i(t)$,每当节点参与交互时更新:

\mathbf{s}_i(t^+) = \text{GRU}\left(\mathbf{s}_i(t^-), \; \mathbf{m}_i(t)\right)
  • 其中 \mathbf{m}_i(t) 是从交互中计算出的消息(结合了两个节点的特征、边特征和时间编码)。GRU(第6章)选择性地保留和遗忘过去的信息,使记忆能够捕获长期模式,同时适应近期事件。

  • 时间编码表示自上次交互以来经过的时间,类似于Transformer中的位置编码(第7章)。常用方法使用可学习的傅里叶特征:

\Phi(t) = \left[\cos(\omega_1 t), \sin(\omega_1 t), \ldots, \cos(\omega_d t), \sin(\omega_d t)\right]
  • 这为模型提供了时间间隔的丰富表示:"该用户上次活跃是5分钟前"与"3个月前"以不同的方式嵌入。

  • **时序图注意力(TGAT)**在节点的时间邻域上应用自注意力:一组最近的交互,每个交互同时按特征相关性(如GAT)和时间近度加权。来自遥远过去的交互自然地被降低权重。

  • 应用包括欺诈检测(金融图中的异常交易模式)、交通预测(从历史流量模式预测拥堵)、社交网络动态(预测病毒内容传播)以及随时间推移的药物相互作用预测。

编程任务(使用CoLab或notebook

  1. 从头实现一个单头GAT注意力。计算节点与其邻居之间的注意力权重,并验证权重之和为1。
import jax
import jax.numpy as jnp

rng = jax.random.PRNGKey(0)
k1, k2, k3 = jax.random.split(rng, 3)

n_nodes, d_in, d_out = 5, 4, 3

# 随机节点特征
H = jax.random.normal(k1, (n_nodes, d_in))

# 可学习参数
W = jax.random.normal(k2, (d_in, d_out)) * 0.5
a = jax.random.normal(k3, (2 * d_out,)) * 0.5

# 邻接(节点0连接到1, 2, 3
neighbours_of_0 = [1, 2, 3]

# 变换特征
Wh = H @ W  # (n_nodes, d_out)

# 计算节点0的注意力分数
h_i = Wh[0]
scores = []
for j in neighbours_of_0:
    h_j = Wh[j]
    e_ij = jnp.dot(a, jnp.concatenate([h_i, h_j]))
    e_ij = jax.nn.leaky_relu(e_ij, negative_slope=0.2)
    scores.append(float(e_ij))

scores = jnp.array(scores)
alpha = jax.nn.softmax(scores)

print(f"原始分数: {scores}")
print(f"注意力权重: {alpha}")
print(f"权重之和: {alpha.sum():.4f}")

# 加权聚合
h_new = sum(alpha[k] * Wh[neighbours_of_0[k]] for k in range(len(neighbours_of_0)))
print(f"更新后的节点0特征: {h_new}")
  1. 比较GCN(固定权重)和GAT(学习权重)的聚合。展示GAT可以为邻居分配不同的权重,而GCN统一对待它们。
import jax
import jax.numpy as jnp

# 4个节点:节点0连接到1, 2, 3
A = jnp.array([[0,1,1,1],
               [1,0,0,0],
               [1,0,0,0],
               [1,0,0,0]], dtype=float)

# 特征:节点1非常相关,节点2是噪声,节点3中等
H = jnp.array([[0.0, 0.0],   # 节点0
               [1.0, 0.0],   # 节点1(信号)
               [0.0, 0.0],   # 节点2(噪声)
               [0.5, 0.0]])  # 节点3(中等)

# GCN:归一化邻接权重
A_hat = A + jnp.eye(4)
D_inv = jnp.diag(1.0 / A_hat.sum(axis=1))
gcn_weights = (D_inv @ A_hat)[0]  # 节点0的权重
print(f"GCN中节点0的权重: {gcn_weights}")
print("  → 所有邻居获得大致相等的权重")

# GAT:学习到的注意力(模拟)
# 假设注意力机制学会关注节点1
gat_weights = jnp.array([0.1, 0.7, 0.05, 0.15])  # 学习到的
print(f"\nGAT中节点0的权重: {gat_weights}")
print("  → 最具信息量的节点1获得最多关注")

gcn_output = gcn_weights @ H
gat_output = gat_weights @ H
print(f"\nGCN输出: {gcn_output}  (被噪声稀释)")
print(f"GAT输出: {gat_output}  (聚焦于信号)")
  1. 演示位置编码的益处。计算图的拉普拉斯特征向量编码,展示结构相似的节点获得相似的编码。
import jax.numpy as jnp
import matplotlib.pyplot as plt

# 杠铃图:两个团由一条桥连接
n = 10
A = jnp.zeros((n, n))
# 团1:节点0-4
for i in range(5):
    for j in range(i+1, 5):
        A = A.at[i,j].set(1).at[j,i].set(1)
# 团2:节点5-9
for i in range(5, 10):
    for j in range(i+1, 10):
        A = A.at[i,j].set(1).at[j,i].set(1)
# 桥
A = A.at[4,5].set(1).at[5,4].set(1)

D = jnp.diag(A.sum(axis=1))
L = D - A
eigenvalues, eigenvectors = jnp.linalg.eigh(L)

# 使用前3个非平凡特征向量作为位置编码
pe = eigenvectors[:, 1:4]

print("拉普拉斯位置编码:")
for i in range(n):
    group = "团1" if i < 5 else "团2"
    bridge = " (桥)" if i in [4, 5] else ""
    print(f"  节点 {i} ({group}{bridge}): {pe[i]}")

plt.scatter(pe[:5, 0], pe[:5, 1], c="#3498db", s=80, label="团1")
plt.scatter(pe[5:, 0], pe[5:, 1], c="#e74c3c", s=80, label="团2")
plt.scatter(pe[[4,5], 0], pe[[4,5], 1], c="black", s=120, marker="*",
            label="桥节点", zorder=5)
plt.legend(); plt.grid(True)
plt.title("拉普拉斯特征向量位置编码")
plt.xlabel("特征向量 1"); plt.ylabel("特征向量 2")
plt.show()