翻译自英文原版 maths-cs-ai-compendium,共 20 章全部完成。 第01章 向量 | 第02章 矩阵 | 第03章 微积分 第04章 统计学 | 第05章 概率论 | 第06章 机器学习 第07章 计算语言学 | 第08章 计算机视觉 | 第09章 音频与语音 第10章 多模态学习 | 第11章 自主系统 | 第12章 图神经网络 第13章 计算与操作系统 | 第14章 数据结构与算法 第15章 生产级软件工程 | 第16章 SIMD与GPU编程 第17章 AI推理 | 第18章 ML系统设计 第19章 应用人工智能 | 第20章 前沿人工智能 翻译说明: - 所有数学公式 $...$ / $$...$$、代码块、图片引用完整保留 - mkdocs.yml 配置中文导航 + language: zh - README.md 已翻译为中文(兼 docs/index.md) - docs/ 目录包含指向各章文件的 symlink - 约 29,000 行中文内容,排除 .cache/ 构建缓存
15 KiB
图注意力网络
图注意力网络将均匀的邻居聚合替换为学习到的、依赖数据的加权。本章涵盖GAT、多头图注意力、GATv2、图Transformer、位置和结构编码以及可扩展性
-
在GCN(文件3)中,每个节点使用由图结构确定的固定权重(归一化邻接矩阵)聚合其邻居特征。一个有三个邻居的节点会给每个邻居大致相等的权重($\approx 1/3$)。但并非所有邻居都同等重要:来自密切合作者的消息应比来自远方熟人的消息更重要。
-
图注意力网络通过使用与Transformer(第7章)相同的注意力机制来学习关注哪些邻居,从而解决了这一问题。与固定的、基于结构的权重不同,每个节点在其邻居上计算动态的、基于内容的注意力分数。
GAT:图注意力网络
- GAT(Veličković等,2018)计算每个节点与其邻居之间的注意力系数。对于节点
i和邻居 $j$:
e_{ij} = \text{LeakyReLU}\left(\mathbf{a}^T \left[W\mathbf{h}_i \| W\mathbf{h}_j\right]\right)
-
其中
W \in \mathbb{R}^{d' \times d}是共享的线性变换,\|表示拼接,\mathbf{a} \in \mathbb{R}^{2d'}是可学习的注意力向量。分数e_{ij}衡量节点j的特征对节点i的重要程度。 -
原始分数使用softmax在所有邻居之间进行归一化:
\alpha_{ij} = \text{softmax}_j(e_{ij}) = \frac{\exp(e_{ij})}{\sum_{k \in \mathcal{N}(i)} \exp(e_{ik})}
- 这确保了每个节点邻域上的注意力权重之和为1,就像Transformer注意力一样(第7章)。节点更新后的特征为:
\mathbf{h}_i' = \sigma\left(\sum_{j \in \mathcal{N}(i)} \alpha_{ij} W\mathbf{h}_j\right)
-
与GCN的关键区别:权重
\alpha_{ij}是从数据中学习的,而非由图结构固定。节点可以学会关注信息量最大的邻居,同时忽略噪声或无关的邻居。 -
注意,注意力仅在边上计算(节点
i只关注其邻居 $\mathcal{N}(i)$),而不是在所有节点对之间。这使得计算量与边的数量成正比,而不是节点数的平方。
多头图注意力
- 正如在Transformer中(第7章),多头注意力并行运行
K个独立的注意力机制,每个都有自己的参数W^k和 $\mathbf{a}^k$。结果在中间层进行拼接,在最终层取平均:
\mathbf{h}_i' = \Big\|_{k=1}^{K} \sigma\left(\sum_{j \in \mathcal{N}(i)} \alpha_{ij}^k W^k \mathbf{h}_j\right)
-
每个头可以关注邻域的不同方面:一个头可能关注结构特征,另一个关注语义相似性。这与Transformer中多头注意力的动机相同:不同的头捕获不同类型的关系。
-
使用
K个头和每个头输出维度 $d'$,拼接后的输出维度为 $K \times d'$。最后一层通常使用平均而不是拼接来产生固定大小的输出。
GATv2:修复静态注意力
-
原始GAT有一个微妙的限制:其注意力函数是静态的(也称为基于排序的)。注意力分数取决于拼接 $[W\mathbf{h}_i | W\mathbf{h}_j]$,但由于注意力向量
\mathbf{a}在拼接之后应用,它可以分解为两个独立的分量:$\mathbf{a}^T [W\mathbf{h}_i | W\mathbf{h}_j] = \mathbf{a}_1^T W\mathbf{h}_i + \mathbf{a}_2^T W\mathbf{h}_j$。 -
这意味着对于给定节点 $i$,邻居的排序完全由邻居的特征
\mathbf{h}_j决定(项\mathbf{a}_1^T W\mathbf{h}_i在i的所有邻居中是常数)。注意力排名并不真正依赖于查询节点的特征。节点i和节点k将以完全相同的方式对同一组邻居进行排序,这限制了表达能力。 -
GATv2(Brody等,2022)通过在注意力向量之前应用非线性函数来修复这个问题:
e_{ij} = \mathbf{a}^T \text{LeakyReLU}\left(W \left[\mathbf{h}_i \| \mathbf{h}_j\right]\right)
- 将LeakyReLU移到计算内部意味着注意力分数是联合特征的非线性函数,不能分解为独立项。这使得注意力变为动态:邻居的排序现在依赖于特定的查询节点。GATv2严格比GAT更具表达能力,且没有额外的计算成本。
图Transformer
-
标准消息传递GNN受到图拓扑的限制:一个节点只能关注其直接邻居。经过
k层后,来自k跳邻居的信息已通过多个聚合步骤混合,失去了保真度。这种局部瓶颈(再加上文件3中的过平滑)限制了捕获长距离依赖关系的能力。 -
图Transformer通过将全局自注意力应用于所有节点对(无论它们之间是否有边)来突破这个瓶颈。每个节点可以在单层中关注每个其他节点,就像标准Transformer一样(第7章)。
-
基本思想:将所有节点视为标记(token),应用Transformer自注意力:
\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V
-
其中 $Q = XW_Q$,$K = XW_K$,
V = XW_V是节点特征X的查询、键和值投影(与第7章完全相同)。这是完全连接图(完全图 $K_n$,文件2)上的GNN。 -
问题:完全连接图忽略了实际的图结构。边信息(谁实际连接到谁)丢失了。两种方法恢复了这一点:
-
Graphormer(Ying等,2021)通过注意力分数中的偏置项将图结构注入Transformer:
A_{ij} = \frac{(\mathbf{h}_i W_Q)(W_K^T \mathbf{h}_j^T)}{\sqrt{d_k}} + b_{\text{spatial}}(i, j) + b_{\text{edge}}(i, j)
-
空间偏置
b_{\text{spatial}}编码节点i和j之间的最短路径距离。边偏置b_{\text{edge}}编码沿最短路径的边特征。此外,Graphormer使用中心性编码,将节点的度数添加到其输入嵌入中,为模型提供关于每个节点结构角色的信息。 -
GPS(通用、强大、可扩展的图Transformer,Rampášek等,2022)在每一层中结合了局部消息传递和全局注意力:
\mathbf{h}_i' = \text{MLP}\left(\mathbf{h}_i^{\text{MPNN}} + \mathbf{h}_i^{\text{Attention}}\right)
- 每一层同时应用标准GNN(用于局部结构)和Transformer(用于全局上下文),然后组合结果。这获得了两个世界的优点:来自消息传递的局部结构和来自注意力的长距离依赖关系。
位置编码与结构编码
-
序列上的Transformer使用位置编码(第7章)来注入顺序信息。图没有规范的顺序,因此需要特定于图的编码。
-
拉普拉斯特征向量编码使用图拉普拉斯算子(文件2)的特征向量作为位置特征。
k个最小的非平凡特征向量提供了图的谱嵌入:在图中"附近"的节点具有相似的特征向量值。这些被拼接到节点特征中。 -
一个微妙之处:拉普拉斯特征向量有符号模糊性(如果
\mathbf{u}是特征向量,-\mathbf{u}也是)。模型必须对这些符号翻转保持不变。解决方案包括在训练期间使用随机符号翻转作为数据增强,或学习符号不变的变换。 -
随机游走编码计算从节点
i开始的随机游走经过k步后返回节点i的概率,对于 $k = 1, 2, \ldots, K$。这些概率编码了局部结构信息:密集簇中的节点具有高的返回概率,而稀疏区域中的节点返回概率低。着陆概率 $p_{ii}^{(k)} = (A_{\text{rw}}^k)_{ii}$,其中A_{\text{rw}} = D^{-1}A是随机游走转移矩阵。 -
度数编码简单地将节点度数作为一个特征添加。这出奇地有效,因为度数是一个强大的结构信号:叶节点(度数为1)、桥接节点和枢纽节点的行为不同。
-
这些编码提供了普通Transformer所缺乏的结构信息,使图Transformer在需要长距离推理的任务上能够超越标准消息传递GNN。
可扩展性
-
GNN的基本可扩展性挑战在于图可能拥有数百万个节点和数十亿条边。在完整图上训练GNN需要将所有节点特征和整个邻接矩阵存储在内存中,这通常是不可行的。
-
GNN的小批量训练比图像或序列更复杂,因为节点之间是相互连接的。朴素地采样一批节点需要它们的邻居(第1层)、邻居的邻居(第2层),依此类推。这种邻域爆炸意味着一个包含1000个目标节点的小批量可能需要计算图中数百万个节点。
-
邻域采样(GraphSAGE风格,文件3)通过每层每个节点采样固定数量的邻居来限制爆炸。使用2层和每层15个样本,每个目标节点的子图最多有
15^2 = 225个节点,与完整图的大小无关。 -
Cluster-GCN(Chiang等,2019)使用图聚类算法(例如METIS)将图划分为簇,然后一次在一个簇上训练。簇内边是密集的(大多数邻居在同一个簇内),因此子图捕获了相关结构。跨簇边通过偶尔包含簇之间的边来处理。
-
图Transformer的可扩展性更困难,因为全局注意力是
O(n^2)的。对于具有数百万个节点的图,完整的注意力是不可行的。解决方案包括:- 稀疏注意力模式(只关注图中距离最近的
k个节点) - 线性注意力近似
- 将局部消息传递(廉价,$O(|E|)$)与粗化图上的全局注意力(更少的节点)相结合
- 稀疏注意力模式(只关注图中距离最近的
时序图与动态图
-
我们迄今为止研究的图是静态的:节点、边和特征都是固定的。但许多现实世界的图会随时间演化:新用户加入社交网络、金融交易创建边、交通模式全天变化、分子相互作用发生波动。
-
时序图为每条边增加一个时间戳:
(i, j, t)表示节点i在时间t与节点j发生了交互。挑战在于学习同时捕获图结构和时序动态的表示。 -
存在两种范式:
-
离散时间动态图(DTDG):图被表示为一系列快照 $G_1, G_2, \ldots, G_T$,每个时间步一个。GNN处理每个快照,RNN或时序注意力机制捕获快照间的演化。这很简单,但丢失了精细的时间信息(快照之间的事件丢失了),并且需要选择快照频率。
-
连续时间动态图(CTDG):事件被建模为带时间戳的交互流。每个事件
(i, j, t)在其发生的准确时间更新节点i和j的表示。这保留了所有时序信息。 -
时序图网络(TGN)(Rossi等,2020)是领先的CTDG架构。每个节点维护一个记忆状态 $\mathbf{s}_i(t)$,每当节点参与交互时更新:
\mathbf{s}_i(t^+) = \text{GRU}\left(\mathbf{s}_i(t^-), \; \mathbf{m}_i(t)\right)
-
其中
\mathbf{m}_i(t)是从交互中计算出的消息(结合了两个节点的特征、边特征和时间编码)。GRU(第6章)选择性地保留和遗忘过去的信息,使记忆能够捕获长期模式,同时适应近期事件。 -
时间编码表示自上次交互以来经过的时间,类似于Transformer中的位置编码(第7章)。常用方法使用可学习的傅里叶特征:
\Phi(t) = \left[\cos(\omega_1 t), \sin(\omega_1 t), \ldots, \cos(\omega_d t), \sin(\omega_d t)\right]
-
这为模型提供了时间间隔的丰富表示:"该用户上次活跃是5分钟前"与"3个月前"以不同的方式嵌入。
-
**时序图注意力(TGAT)**在节点的时间邻域上应用自注意力:一组最近的交互,每个交互同时按特征相关性(如GAT)和时间近度加权。来自遥远过去的交互自然地被降低权重。
-
应用包括欺诈检测(金融图中的异常交易模式)、交通预测(从历史流量模式预测拥堵)、社交网络动态(预测病毒内容传播)以及随时间推移的药物相互作用预测。
编程任务(使用CoLab或notebook)
- 从头实现一个单头GAT注意力。计算节点与其邻居之间的注意力权重,并验证权重之和为1。
import jax
import jax.numpy as jnp
rng = jax.random.PRNGKey(0)
k1, k2, k3 = jax.random.split(rng, 3)
n_nodes, d_in, d_out = 5, 4, 3
# 随机节点特征
H = jax.random.normal(k1, (n_nodes, d_in))
# 可学习参数
W = jax.random.normal(k2, (d_in, d_out)) * 0.5
a = jax.random.normal(k3, (2 * d_out,)) * 0.5
# 邻接(节点0连接到1, 2, 3)
neighbours_of_0 = [1, 2, 3]
# 变换特征
Wh = H @ W # (n_nodes, d_out)
# 计算节点0的注意力分数
h_i = Wh[0]
scores = []
for j in neighbours_of_0:
h_j = Wh[j]
e_ij = jnp.dot(a, jnp.concatenate([h_i, h_j]))
e_ij = jax.nn.leaky_relu(e_ij, negative_slope=0.2)
scores.append(float(e_ij))
scores = jnp.array(scores)
alpha = jax.nn.softmax(scores)
print(f"原始分数: {scores}")
print(f"注意力权重: {alpha}")
print(f"权重之和: {alpha.sum():.4f}")
# 加权聚合
h_new = sum(alpha[k] * Wh[neighbours_of_0[k]] for k in range(len(neighbours_of_0)))
print(f"更新后的节点0特征: {h_new}")
- 比较GCN(固定权重)和GAT(学习权重)的聚合。展示GAT可以为邻居分配不同的权重,而GCN统一对待它们。
import jax
import jax.numpy as jnp
# 4个节点:节点0连接到1, 2, 3
A = jnp.array([[0,1,1,1],
[1,0,0,0],
[1,0,0,0],
[1,0,0,0]], dtype=float)
# 特征:节点1非常相关,节点2是噪声,节点3中等
H = jnp.array([[0.0, 0.0], # 节点0
[1.0, 0.0], # 节点1(信号)
[0.0, 0.0], # 节点2(噪声)
[0.5, 0.0]]) # 节点3(中等)
# GCN:归一化邻接权重
A_hat = A + jnp.eye(4)
D_inv = jnp.diag(1.0 / A_hat.sum(axis=1))
gcn_weights = (D_inv @ A_hat)[0] # 节点0的权重
print(f"GCN中节点0的权重: {gcn_weights}")
print(" → 所有邻居获得大致相等的权重")
# GAT:学习到的注意力(模拟)
# 假设注意力机制学会关注节点1
gat_weights = jnp.array([0.1, 0.7, 0.05, 0.15]) # 学习到的
print(f"\nGAT中节点0的权重: {gat_weights}")
print(" → 最具信息量的节点1获得最多关注")
gcn_output = gcn_weights @ H
gat_output = gat_weights @ H
print(f"\nGCN输出: {gcn_output} (被噪声稀释)")
print(f"GAT输出: {gat_output} (聚焦于信号)")
- 演示位置编码的益处。计算图的拉普拉斯特征向量编码,展示结构相似的节点获得相似的编码。
import jax.numpy as jnp
import matplotlib.pyplot as plt
# 杠铃图:两个团由一条桥连接
n = 10
A = jnp.zeros((n, n))
# 团1:节点0-4
for i in range(5):
for j in range(i+1, 5):
A = A.at[i,j].set(1).at[j,i].set(1)
# 团2:节点5-9
for i in range(5, 10):
for j in range(i+1, 10):
A = A.at[i,j].set(1).at[j,i].set(1)
# 桥
A = A.at[4,5].set(1).at[5,4].set(1)
D = jnp.diag(A.sum(axis=1))
L = D - A
eigenvalues, eigenvectors = jnp.linalg.eigh(L)
# 使用前3个非平凡特征向量作为位置编码
pe = eigenvectors[:, 1:4]
print("拉普拉斯位置编码:")
for i in range(n):
group = "团1" if i < 5 else "团2"
bridge = " (桥)" if i in [4, 5] else ""
print(f" 节点 {i} ({group}{bridge}): {pe[i]}")
plt.scatter(pe[:5, 0], pe[:5, 1], c="#3498db", s=80, label="团1")
plt.scatter(pe[5:, 0], pe[5:, 1], c="#e74c3c", s=80, label="团2")
plt.scatter(pe[[4,5], 0], pe[[4,5], 1], c="black", s=120, marker="*",
label="桥节点", zorder=5)
plt.legend(); plt.grid(True)
plt.title("拉普拉斯特征向量位置编码")
plt.xlabel("特征向量 1"); plt.ylabel("特征向量 2")
plt.show()