跳转至

图注意力网络

图注意力网络将均匀的邻居聚合替换为学习到的、依赖数据的加权。本章涵盖GAT、多头图注意力、GATv2、图Transformer、位置和结构编码以及可扩展性

  • 在GCN(文件3)中,每个节点使用由图结构确定的固定权重(归一化邻接矩阵)聚合其邻居特征。一个有三个邻居的节点会给每个邻居大致相等的权重(\(\approx 1/3\))。但并非所有邻居都同等重要:来自密切合作者的消息应比来自远方熟人的消息更重要。

  • 图注意力网络通过使用与Transformer(第7章)相同的注意力机制来学习关注哪些邻居,从而解决了这一问题。与固定的、基于结构的权重不同,每个节点在其邻居上计算动态的、基于内容的注意力分数。

GAT:图注意力网络

  • GAT(Veličković等,2018)计算每个节点与其邻居之间的注意力系数。对于节点 \(i\) 和邻居 \(j\)
\[e_{ij} = \text{LeakyReLU}\left(\mathbf{a}^T \left[W\mathbf{h}_i \| W\mathbf{h}_j\right]\right)\]
  • 其中 \(W \in \mathbb{R}^{d' \times d}\) 是共享的线性变换,\(\|\) 表示拼接,\(\mathbf{a} \in \mathbb{R}^{2d'}\) 是可学习的注意力向量。分数 \(e_{ij}\) 衡量节点 \(j\) 的特征对节点 \(i\) 的重要程度。

  • 原始分数使用softmax在所有邻居之间进行归一化:

\[\alpha_{ij} = \text{softmax}_j(e_{ij}) = \frac{\exp(e_{ij})}{\sum_{k \in \mathcal{N}(i)} \exp(e_{ik})}\]
  • 这确保了每个节点邻域上的注意力权重之和为1,就像Transformer注意力一样(第7章)。节点更新后的特征为:
\[\mathbf{h}_i' = \sigma\left(\sum_{j \in \mathcal{N}(i)} \alpha_{ij} W\mathbf{h}_j\right)\]

GCN为所有邻居分配固定的等权重;GAT学习依赖数据的注意力权重

  • 与GCN的关键区别:权重 \(\alpha_{ij}\)从数据中学习的,而非由图结构固定。节点可以学会关注信息量最大的邻居,同时忽略噪声或无关的邻居。

  • 注意,注意力仅在边上计算(节点 \(i\) 只关注其邻居 \(\mathcal{N}(i)\)),而不是在所有节点对之间。这使得计算量与边的数量成正比,而不是节点数的平方。

多头图注意力

  • 正如在Transformer中(第7章),多头注意力并行运行 \(K\) 个独立的注意力机制,每个都有自己的参数 \(W^k\)\(\mathbf{a}^k\)。结果在中间层进行拼接,在最终层取平均:
\[\mathbf{h}_i' = \Big\|_{k=1}^{K} \sigma\left(\sum_{j \in \mathcal{N}(i)} \alpha_{ij}^k W^k \mathbf{h}_j\right)\]
  • 每个头可以关注邻域的不同方面:一个头可能关注结构特征,另一个关注语义相似性。这与Transformer中多头注意力的动机相同:不同的头捕获不同类型的关系。

  • 使用 \(K\) 个头和每个头输出维度 \(d'\),拼接后的输出维度为 \(K \times d'\)。最后一层通常使用平均而不是拼接来产生固定大小的输出。

GATv2:修复静态注意力

  • 原始GAT有一个微妙的限制:其注意力函数是静态的(也称为基于排序的)。注意力分数取决于拼接 \([W\mathbf{h}_i \| W\mathbf{h}_j]\),但由于注意力向量 \(\mathbf{a}\) 在拼接之后应用,它可以分解为两个独立的分量:\(\mathbf{a}^T [W\mathbf{h}_i \| W\mathbf{h}_j] = \mathbf{a}_1^T W\mathbf{h}_i + \mathbf{a}_2^T W\mathbf{h}_j\)

  • 这意味着对于给定节点 \(i\),邻居的排序完全由邻居的特征 \(\mathbf{h}_j\) 决定(项 \(\mathbf{a}_1^T W\mathbf{h}_i\)\(i\) 的所有邻居中是常数)。注意力排名并不真正依赖于查询节点的特征。节点 \(i\) 和节点 \(k\) 将以完全相同的方式对同一组邻居进行排序,这限制了表达能力。

  • GATv2(Brody等,2022)通过在注意力向量之前应用非线性函数来修复这个问题:

\[e_{ij} = \mathbf{a}^T \text{LeakyReLU}\left(W \left[\mathbf{h}_i \| \mathbf{h}_j\right]\right)\]
  • 将LeakyReLU移到计算内部意味着注意力分数是联合特征的非线性函数,不能分解为独立项。这使得注意力变为动态:邻居的排序现在依赖于特定的查询节点。GATv2严格比GAT更具表达能力,且没有额外的计算成本。

图Transformer

  • 标准消息传递GNN受到图拓扑的限制:一个节点只能关注其直接邻居。经过 \(k\) 层后,来自 \(k\) 跳邻居的信息已通过多个聚合步骤混合,失去了保真度。这种局部瓶颈(再加上文件3中的过平滑)限制了捕获长距离依赖关系的能力。

  • 图Transformer通过将全局自注意力应用于所有节点对(无论它们之间是否有边)来突破这个瓶颈。每个节点可以在单层中关注每个其他节点,就像标准Transformer一样(第7章)。

  • 基本思想:将所有节点视为标记(token),应用Transformer自注意力:

\[\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V\]
  • 其中 \(Q = XW_Q\)\(K = XW_K\)\(V = XW_V\) 是节点特征 \(X\) 的查询、键和值投影(与第7章完全相同)。这是完全连接图(完全图 \(K_n\),文件2)上的GNN。

  • 问题:完全连接图忽略了实际的图结构。边信息(谁实际连接到谁)丢失了。两种方法恢复了这一点:

  • Graphormer(Ying等,2021)通过注意力分数中的偏置项将图结构注入Transformer:

\[A_{ij} = \frac{(\mathbf{h}_i W_Q)(W_K^T \mathbf{h}_j^T)}{\sqrt{d_k}} + b_{\text{spatial}}(i, j) + b_{\text{edge}}(i, j)\]
  • 空间偏置 \(b_{\text{spatial}}\) 编码节点 \(i\)\(j\) 之间的最短路径距离。边偏置 \(b_{\text{edge}}\) 编码沿最短路径的边特征。此外,Graphormer使用中心性编码,将节点的度数添加到其输入嵌入中,为模型提供关于每个节点结构角色的信息。

  • GPS(通用、强大、可扩展的图Transformer,Rampášek等,2022)在每一层中结合了局部消息传递和全局注意力:

\[\mathbf{h}_i' = \text{MLP}\left(\mathbf{h}_i^{\text{MPNN}} + \mathbf{h}_i^{\text{Attention}}\right)\]
  • 每一层同时应用标准GNN(用于局部结构)和Transformer(用于全局上下文),然后组合结果。这获得了两个世界的优点:来自消息传递的局部结构和来自注意力的长距离依赖关系。

位置编码与结构编码

  • 序列上的Transformer使用位置编码(第7章)来注入顺序信息。图没有规范的顺序,因此需要特定于图的编码。

  • 拉普拉斯特征向量编码使用图拉普拉斯算子(文件2)的特征向量作为位置特征。\(k\) 个最小的非平凡特征向量提供了图的谱嵌入:在图中"附近"的节点具有相似的特征向量值。这些被拼接到节点特征中。

  • 一个微妙之处:拉普拉斯特征向量有符号模糊性(如果 \(\mathbf{u}\) 是特征向量,\(-\mathbf{u}\) 也是)。模型必须对这些符号翻转保持不变。解决方案包括在训练期间使用随机符号翻转作为数据增强,或学习符号不变的变换。

  • 随机游走编码计算从节点 \(i\) 开始的随机游走经过 \(k\) 步后返回节点 \(i\) 的概率,对于 \(k = 1, 2, \ldots, K\)。这些概率编码了局部结构信息:密集簇中的节点具有高的返回概率,而稀疏区域中的节点返回概率低。着陆概率 \(p_{ii}^{(k)} = (A_{\text{rw}}^k)_{ii}\),其中 \(A_{\text{rw}} = D^{-1}A\) 是随机游走转移矩阵。

  • 度数编码简单地将节点度数作为一个特征添加。这出奇地有效,因为度数是一个强大的结构信号:叶节点(度数为1)、桥接节点和枢纽节点的行为不同。

  • 这些编码提供了普通Transformer所缺乏的结构信息,使图Transformer在需要长距离推理的任务上能够超越标准消息传递GNN。

可扩展性

  • GNN的基本可扩展性挑战在于图可能拥有数百万个节点和数十亿条边。在完整图上训练GNN需要将所有节点特征和整个邻接矩阵存储在内存中,这通常是不可行的。

  • GNN的小批量训练比图像或序列更复杂,因为节点之间是相互连接的。朴素地采样一批节点需要它们的邻居(第1层)、邻居的邻居(第2层),依此类推。这种邻域爆炸意味着一个包含1000个目标节点的小批量可能需要计算图中数百万个节点。

  • 邻域采样(GraphSAGE风格,文件3)通过每层每个节点采样固定数量的邻居来限制爆炸。使用2层和每层15个样本,每个目标节点的子图最多有 \(15^2 = 225\) 个节点,与完整图的大小无关。

  • Cluster-GCN(Chiang等,2019)使用图聚类算法(例如METIS)将图划分为簇,然后一次在一个簇上训练。簇内边是密集的(大多数邻居在同一个簇内),因此子图捕获了相关结构。跨簇边通过偶尔包含簇之间的边来处理。

  • 图Transformer的可扩展性更困难,因为全局注意力是 \(O(n^2)\) 的。对于具有数百万个节点的图,完整的注意力是不可行的。解决方案包括:

    • 稀疏注意力模式(只关注图中距离最近的 \(k\) 个节点)
    • 线性注意力近似
    • 将局部消息传递(廉价,\(O(|E|)\))与粗化图上的全局注意力(更少的节点)相结合

时序图与动态图

  • 我们迄今为止研究的图是静态的:节点、边和特征都是固定的。但许多现实世界的图会随时间演化:新用户加入社交网络、金融交易创建边、交通模式全天变化、分子相互作用发生波动。

  • 时序图为每条边增加一个时间戳:\((i, j, t)\) 表示节点 \(i\) 在时间 \(t\) 与节点 \(j\) 发生了交互。挑战在于学习同时捕获图结构和时序动态的表示。

  • 存在两种范式:

  • 离散时间动态图(DTDG):图被表示为一系列快照 \(G_1, G_2, \ldots, G_T\),每个时间步一个。GNN处理每个快照,RNN或时序注意力机制捕获快照间的演化。这很简单,但丢失了精细的时间信息(快照之间的事件丢失了),并且需要选择快照频率。

  • 连续时间动态图(CTDG):事件被建模为带时间戳的交互流。每个事件 \((i, j, t)\) 在其发生的准确时间更新节点 \(i\)\(j\) 的表示。这保留了所有时序信息。

  • 时序图网络(TGN)(Rossi等,2020)是领先的CTDG架构。每个节点维护一个记忆状态 \(\mathbf{s}_i(t)\),每当节点参与交互时更新:

\[\mathbf{s}_i(t^+) = \text{GRU}\left(\mathbf{s}_i(t^-), \; \mathbf{m}_i(t)\right)\]
  • 其中 \(\mathbf{m}_i(t)\) 是从交互中计算出的消息(结合了两个节点的特征、边特征和时间编码)。GRU(第6章)选择性地保留和遗忘过去的信息,使记忆能够捕获长期模式,同时适应近期事件。

  • 时间编码表示自上次交互以来经过的时间,类似于Transformer中的位置编码(第7章)。常用方法使用可学习的傅里叶特征:

\[\Phi(t) = \left[\cos(\omega_1 t), \sin(\omega_1 t), \ldots, \cos(\omega_d t), \sin(\omega_d t)\right]\]
  • 这为模型提供了时间间隔的丰富表示:"该用户上次活跃是5分钟前"与"3个月前"以不同的方式嵌入。

  • 时序图注意力(TGAT)在节点的时间邻域上应用自注意力:一组最近的交互,每个交互同时按特征相关性(如GAT)和时间近度加权。来自遥远过去的交互自然地被降低权重。

  • 应用包括欺诈检测(金融图中的异常交易模式)、交通预测(从历史流量模式预测拥堵)、社交网络动态(预测病毒内容传播)以及随时间推移的药物相互作用预测。

编程任务(使用CoLab或notebook)

  1. 从头实现一个单头GAT注意力。计算节点与其邻居之间的注意力权重,并验证权重之和为1。

    import jax
    import jax.numpy as jnp
    
    rng = jax.random.PRNGKey(0)
    k1, k2, k3 = jax.random.split(rng, 3)
    
    n_nodes, d_in, d_out = 5, 4, 3
    
    # 随机节点特征
    H = jax.random.normal(k1, (n_nodes, d_in))
    
    # 可学习参数
    W = jax.random.normal(k2, (d_in, d_out)) * 0.5
    a = jax.random.normal(k3, (2 * d_out,)) * 0.5
    
    # 邻接(节点0连接到1, 2, 3)
    neighbours_of_0 = [1, 2, 3]
    
    # 变换特征
    Wh = H @ W  # (n_nodes, d_out)
    
    # 计算节点0的注意力分数
    h_i = Wh[0]
    scores = []
    for j in neighbours_of_0:
        h_j = Wh[j]
        e_ij = jnp.dot(a, jnp.concatenate([h_i, h_j]))
        e_ij = jax.nn.leaky_relu(e_ij, negative_slope=0.2)
        scores.append(float(e_ij))
    
    scores = jnp.array(scores)
    alpha = jax.nn.softmax(scores)
    
    print(f"原始分数: {scores}")
    print(f"注意力权重: {alpha}")
    print(f"权重之和: {alpha.sum():.4f}")
    
    # 加权聚合
    h_new = sum(alpha[k] * Wh[neighbours_of_0[k]] for k in range(len(neighbours_of_0)))
    print(f"更新后的节点0特征: {h_new}")
    

  2. 比较GCN(固定权重)和GAT(学习权重)的聚合。展示GAT可以为邻居分配不同的权重,而GCN统一对待它们。

    import jax
    import jax.numpy as jnp
    
    # 4个节点:节点0连接到1, 2, 3
    A = jnp.array([[0,1,1,1],
                   [1,0,0,0],
                   [1,0,0,0],
                   [1,0,0,0]], dtype=float)
    
    # 特征:节点1非常相关,节点2是噪声,节点3中等
    H = jnp.array([[0.0, 0.0],   # 节点0
                   [1.0, 0.0],   # 节点1(信号)
                   [0.0, 0.0],   # 节点2(噪声)
                   [0.5, 0.0]])  # 节点3(中等)
    
    # GCN:归一化邻接权重
    A_hat = A + jnp.eye(4)
    D_inv = jnp.diag(1.0 / A_hat.sum(axis=1))
    gcn_weights = (D_inv @ A_hat)[0]  # 节点0的权重
    print(f"GCN中节点0的权重: {gcn_weights}")
    print("  → 所有邻居获得大致相等的权重")
    
    # GAT:学习到的注意力(模拟)
    # 假设注意力机制学会关注节点1
    gat_weights = jnp.array([0.1, 0.7, 0.05, 0.15])  # 学习到的
    print(f"\nGAT中节点0的权重: {gat_weights}")
    print("  → 最具信息量的节点1获得最多关注")
    
    gcn_output = gcn_weights @ H
    gat_output = gat_weights @ H
    print(f"\nGCN输出: {gcn_output}  (被噪声稀释)")
    print(f"GAT输出: {gat_output}  (聚焦于信号)")
    

  3. 演示位置编码的益处。计算图的拉普拉斯特征向量编码,展示结构相似的节点获得相似的编码。

    import jax.numpy as jnp
    import matplotlib.pyplot as plt
    
    # 杠铃图:两个团由一条桥连接
    n = 10
    A = jnp.zeros((n, n))
    # 团1:节点0-4
    for i in range(5):
        for j in range(i+1, 5):
            A = A.at[i,j].set(1).at[j,i].set(1)
    # 团2:节点5-9
    for i in range(5, 10):
        for j in range(i+1, 10):
            A = A.at[i,j].set(1).at[j,i].set(1)
    # 桥
    A = A.at[4,5].set(1).at[5,4].set(1)
    
    D = jnp.diag(A.sum(axis=1))
    L = D - A
    eigenvalues, eigenvectors = jnp.linalg.eigh(L)
    
    # 使用前3个非平凡特征向量作为位置编码
    pe = eigenvectors[:, 1:4]
    
    print("拉普拉斯位置编码:")
    for i in range(n):
        group = "团1" if i < 5 else "团2"
        bridge = " (桥)" if i in [4, 5] else ""
        print(f"  节点 {i} ({group}{bridge}): {pe[i]}")
    
    plt.scatter(pe[:5, 0], pe[:5, 1], c="#3498db", s=80, label="团1")
    plt.scatter(pe[5:, 0], pe[5:, 1], c="#e74c3c", s=80, label="团2")
    plt.scatter(pe[[4,5], 0], pe[[4,5], 1], c="black", s=120, marker="*",
                label="桥节点", zorder=5)
    plt.legend(); plt.grid(True)
    plt.title("拉普拉斯特征向量位置编码")
    plt.xlabel("特征向量 1"); plt.ylabel("特征向量 2")
    plt.show()