图神经网络(GNN)

核心定位: 处理非欧几里得结构数据的深度学习架构,从 2009 年起源到 2024 年广泛应用的演进史。 关联笔记: 神经网络神经网络层类型详解卷积神经网络循环神经网络Transformer


1 为什么需要 GNN

1.1 图数据无处不在

真实世界的图结构:

  • 社交网络: 用户(节点)+ 关系(边)
  • 分子结构: 原子(节点)+ 化学键(边)
  • 知识图谱: 实体(节点)+ 关系(边)
  • 推荐系统: 用户 - 商品二部图
  • 交通网络: 路口(节点)+ 道路(边)
  • 蛋白质结构: 氨基酸(节点)+ 空间关系(边)
  • 代码分析: 函数/变量(节点)+ 调用/依赖(边)

1.2 传统方法的局限

问题 1:CNN 无法应用

  • CNN 假设:网格结构(图像)、局部性、平移不变性
  • 图特点:不规则结构、节点度数不同、无固定邻居顺序

问题 2:传统图算法无法泛化

  • PageRank、最短路径:任务特定
  • 手工特征工程:费时费力
  • 无法端到端学习

问题 3:表示学习困难

  • 如何编码图结构?
  • 如何保留拓扑信息?
  • 如何处理动态图?

1.3 GNN 的核心优势

结构感知: 利用图的连接信息 ✅ 端到端学习: 自动学习节点/图表示 ✅ 归纳能力: 可泛化到未见过的图 ✅ 任务多样: 节点/边/图级预测


2 图的基础知识

2.1 图的数学定义

无向图:

  • :节点集合,
  • :边集合,
  • 邻接矩阵: 表示 相连

有向图: 表示从 的边

加权图: (边权重)

属性图(Attributed Graph):

  • 节点特征:(每个节点 维特征)
  • 边特征:(可选)

2.2 图的关键概念

概念定义示例
度(Degree)节点连接的边数
邻居与节点直接相连的节点集
路径节点序列,相邻节点间有边
连通性任意两节点间存在路径社交网络的连通分量
度矩阵(对角)用于归一化
拉普拉斯矩阵谱图理论基础

2.3 图的表示方式

2.3.1 邻接矩阵(Adjacency Matrix)

优点: 直观,便于矩阵运算 缺点: 稀疏图浪费空间(

# 示例:4 个节点的图
A = [[0, 1, 1, 0],
     [1, 0, 0, 1],
     [1, 0, 0, 1],
     [0, 1, 1, 0]]

2.3.2 边列表(Edge List)

格式: 边的起点 - 终点对 优点: 空间高效(

edge_list = [(0, 1), (0, 2), (1, 3), (2, 3)]

2.3.3 邻接表(Adjacency List)

格式: 每个节点的邻居列表 优点: 适合稀疏图

adj_list = {
    0: [1, 2],
    1: [0, 3],
    2: [0, 3],
    3: [1, 2]
}

2.4 常见图类型

图类型特点示例
稀疏图社交网络(平均度 < 100)
稠密图全连接层(完全图)
同质图单一节点/边类型论文引用网络
异质图多种节点/边类型知识图谱(人、地点、事件)
静态图结构固定分子结构
动态图时变结构社交网络演化
二部图两类节点,类内无边用户 - 商品推荐

3 GNN 核心思想:消息传递

3.1 消息传递神经网络(MPNN)框架

核心理念: 节点通过迭代聚合邻居信息来更新自己的表示

通用公式(第 层):

三个关键步骤:

  1. 消息生成(Message): 邻居发送消息

  2. 消息聚合(Aggregate): 汇总邻居消息

  3. 节点更新(Update): 结合自身信息更新

3.2 聚合函数的选择

聚合方式公式特点代表模型
求和(Sum)大小不变性GCN, GraphSAGE
平均(Mean)$\frac{1}{\mathcal{N}(v)} \sum_{u} h_u$
最大(Max)提取显著特征GraphSAGE
注意力自适应加权GAT
LSTMLSTM over neighbors考虑顺序(需排序)GraphSAGE

置换不变性(Permutation Invariance):

  • 邻居顺序不应影响结果
  • Sum、Mean、Max 满足
  • LSTM 不满足(需要固定顺序)

3.3 消息传递的直观理解

以社交网络为例:

初始状态 (k=0):
你 (h_v^0): [兴趣: 音乐, 年龄: 25]
朋友A (h_A^0): [兴趣: 音乐, 年龄: 24]
朋友B (h_B^0): [兴趣: 运动, 年龄: 26]

第1层 (k=1):
1. 朋友发送信息 (Message)
2. 你接收并聚合 (Aggregate): Mean([h_A^0, h_B^0])
3. 更新你的表示 (Update): h_v^1 = NN([h_v^0, mean_msg])
   → h_v^1 包含了"我和我朋友圈"的信息

第2层 (k=2):
h_v^2 包含了"我、我朋友、朋友的朋友"的信息
→ 感受野扩大到 2-hop 邻居

感受野(Receptive Field):

  • 层 GNN:感受野 = -hop 邻居
  • 类似 CNN 的感受野,但在图上

4 经典 GNN 架构演进

4.1 谱图卷积(Spectral GCN)

4.1.1 理论基础:图拉普拉斯

拉普拉斯矩阵:

归一化拉普拉斯:

谱分解:

  • :特征向量矩阵(图傅里叶基)
  • :特征值矩阵

图傅里叶变换:

  • 前向:
  • 逆向:

4.1.2 谱卷积(Spectral Convolution)

定义: 在频域进行卷积

其中 是逐元素乘法。

问题:

  • 计算复杂度:(特征分解)
  • 不能泛化到不同大小的图

4.1.3 ChebNet (2016)

创新: 使用切比雪夫多项式近似

其中 是切比雪夫多项式,

优势:

  • -localized:只需 -hop 邻居
  • 复杂度(线性于边数)

4.2 GCN:图卷积网络 (2017) ⭐

作者: Thomas Kipf & Max Welling 论文: “Semi-Supervised Classification with Graph Convolutional Networks”

4.2.1 简化推导

从 ChebNet 取 (一阶近似):

其中:

  • (加自环)
  • :可学习权重
  • :激活函数(ReLU)

4.2.2 逐层传播规则

单节点视角:

归一化因子

  • 防止度数大的节点主导
  • 对称归一化

4.2.3 GCN 特点

优势:

  • ✅ 简单高效
  • ✅ 半监督学习(少量标签)
  • ✅ 可叠加多层
  • ✅ 可泛化到不同图

局限:

  • ❌ 过平滑(Over-smoothing):深层后节点表示趋同
  • ❌ 所有邻居等权重(无差异化)
  • ❌ 无法处理边特征

4.3 GraphSAGE (2017)

全称: Graph Sample and Aggregate 作者: Hamilton et al. (Stanford) 创新: 采样 + 多种聚合器 + 归纳学习

4.3.1 核心算法

对于每个节点 v:
  1. 采样固定数量邻居 (Sample)
  2. 聚合邻居特征 (Aggregate)
     h_N^(k) = AGG({h_u^(k-1), ∀u ∈ Sample(N(v))})
  3. 结合自身特征 (Concatenate)
     h_v^(k) = σ(W · [h_v^(k-1) || h_N^(k)])
  4. 归一化 (L2 normalization)
     h_v^(k) = h_v^(k) / ||h_v^(k)||_2

4.3.2 聚合器对比

聚合器公式复杂度特点
Mean最稳定
LSTMLSTM over random permutation$O(d \cdot\mathcal{N}
Pooling元素级最大池化

4.3.3 采样策略

固定大小采样:

  • 1-hop:采样 25 个邻居
  • 2-hop:采样 10 个邻居
  • 总复杂度:(可控)

优势:

  • 计算可预测(Mini-batch 训练)
  • 适合大图(不需要整个图)
  • 归纳学习(新节点无需重训练)

4.4 GAT:图注意力网络 (2018) ⭐

作者: Veličković et al. 创新: 自注意力机制加权邻居

4.4.1 注意力系数计算

步骤 1:计算注意力分数

其中:

  • :特征变换矩阵
  • :注意力参数向量
  • :拼接

步骤 2:Softmax 归一化

步骤 3:加权聚合

4.4.2 多头注意力(Multi-Head Attention)

并行计算 个注意力头:

或在最后一层使用平均:

4.4.3 GAT 优势

自适应权重: 重要邻居权重大 ✅ 可解释性: 注意力分数可视化 ✅ 处理异质图: 不同类型边可用不同注意力 ✅ 并行化: 所有节点同时计算

应用示例:

  • 分子性质预测:化学键重要性不同
  • 推荐系统:用户兴趣权重动态调整

4.5 GIN:图同构网络 (2019)

全称: Graph Isomorphism Network 作者: Xu et al. 理论贡献: 证明了 GNN 的表达能力上界

4.5.1 WL-Test(Weisfeiler-Lehman)

图同构问题: 判断两个图是否结构相同

WL-Test 算法:

  1. 初始化节点标签
  2. 迭代更新:
  3. 如果两图标签序列相同 → 可能同构

4.5.2 GIN 公式

最大表达力的 GNN(等价于 WL-Test):

其中 可学习或固定。

关键设计:

  • 单射聚合: Sum(而非 Mean/Max)
  • MLP 更新: 比单层线性更强表达力
  • 包含自身:

4.5.3 理论结果

定理: GIN 可以区分任何 WL-Test 可区分的图。

含义:

  • GCN/GraphSAGE(Mean/Max):表达力弱于 WL-Test
  • GAT:理论上不强于 GCN(注意力不改变表达力上界)
  • GIN:达到理论上界

4.6 其他重要变体

4.6.1 消息传递变体

模型年份核心创新适用场景
GatedGCN2018门控机制(类似 LSTM)过平滑缓解
PNA2020多尺度聚合器组合通用任务
DeeperGCN2020残差连接 + 归一化深层网络(50+ 层)
GRAND2021图随机扩散鲁棒性增强

4.6.2 位置编码

问题: GNN 无法区分对称节点

解决方案:

  • 拉普拉斯位置编码(LapPE): 使用拉普拉斯矩阵特征向量
  • 随机游走编码(RWSE): 基于随机游走统计量
  • 距离编码: 到关键节点的最短路径

5 图级任务与池化

5.1 图分类问题

任务: 给整个图分配标签

示例:

  • 分子性质预测(有毒/无毒)
  • 代码漏洞检测
  • 蛋白质功能预测

5.2 图池化方法

5.2.1 全局池化

简单聚合:

方法公式特点
Sum保留总量
Mean$\frac{1}{V
Max显著特征
Attention加权聚合

Set2Set(2018):

  • 使用 LSTM 读取节点集合
  • 顺序不变但表达力强

5.2.2 层次化池化

DiffPool (2018):

  • 学习软分配矩阵
  • 逐层粗化图结构

Top-K Pooling:

  • 根据节点得分保留 Top-K 个节点
  • 类似 CNN 的最大池化

SAGPool (2019):

  • 自注意力图池化
  • 选择得分最高的 个节点

6 GNN 的挑战与解决方案

6.1 过平滑(Over-Smoothing)

问题: 多层 GNN 后,所有节点表示趋于相同

数学解释:

  • 每层聚合相当于拉普拉斯平滑
  • 层后,节点表示混合了 -hop 邻域
  • 小世界网络中,几层后覆盖整个图

实验现象:

2层 GCN: 准确率 80%
4层 GCN: 准确率 75%
8层 GCN: 准确率 50%(随机猜测)

解决方案:

  1. 残差连接(Skip Connections):

  2. Jumping Knowledge(2018):

    • 聚合所有层的表示:
  3. DropEdge(2020):

    • 训练时随机丢弃边
    • 减少过度聚合
  4. PairNorm(2020):

    • 保持节点表示的方差

6.2 过拟合(Over-Fitting)

挑战: 图数据集通常很小(<1000 个图)

解决方案:

方法原理效果
Dropout随机丢弃节点特征标准正则
DropEdge随机丢弃边结构正则
图数据增强添加/删除边,特征扰动扩充数据
对比学习自监督预训练利用无标签数据

6.3 异质图建模

异质图(Heterogeneous Graph): 多种节点/边类型

示例:学术网络

  • 节点:论文、作者、会议
  • 边:写作、发表、引用

解决方案:

HAN(Heterogeneous Attention Network, 2019):

  1. 节点级注意力: 对不同类型邻居加权
  2. 语义级注意力: 对不同元路径加权

RGCN(Relational GCN, 2018):

每种关系 有独立权重

6.4 大图扩展性

挑战: 完整图无法放入 GPU 内存

解决方案:

  1. 采样方法:

    • GraphSAGE: 邻居采样
    • FastGCN: 层级采样
    • ClusterGCN: 图聚类后 Mini-batch
  2. 简化模型:

    • SGC(Simple GCN): 移除非线性,预计算
    • SIGN: 预计算多跳邻域,并行训练
  3. 分布式训练:

    • 图划分
    • 跨设备通信

7 应用场景

7.1 节点级任务

7.1.1 节点分类

示例:

  • 社交网络: 用户兴趣分类
  • 学术网络: 论文主题分类
  • 生物网络: 蛋白质功能预测

典型数据集:

数据集节点数边数类别数特征维度
Cora2,7085,42971,433
Citeseer3,3274,73263,703
PubMed19,71744,3383500

2024 SOTA(Cora):

  • GCN: 81.5%
  • GAT: 83.0%
  • GraphTransformer: 85.2% 🔥

7.1.2 链接预测

任务: 预测两节点间是否存在边

应用:

  • 推荐系统(用户 - 商品)
  • 知识图谱补全
  • 药物 - 靶点相互作用

方法:

  1. 学习节点嵌入:
  2. 计算相似度:

7.2 边级任务

边分类/回归:

  • 社交网络:关系类型预测
  • 交通网络:道路拥堵预测

7.3 图级任务

7.3.1 分子性质预测 ⭐

应用: 新药研发、材料设计

数据集:

  • QM9: 130K 小分子,量子化学性质
  • ZINC: 250K 药物样分子
  • OGB(Open Graph Benchmark): 百万级分子

模型:

  • SchNet(2018):连续滤波器卷积
  • DimeNet(2020):方向性消息传递
  • Graphormer(2021):Transformer on Graphs 🔥

2024 进展:

  • UniMol: 3D 分子预训练模型
  • MolGPT: 生成式分子设计

7.3.2 推荐系统

图构建:

  • 节点:用户 + 商品
  • 边:交互历史(点击、购买)

模型:

  • PinSage(Pinterest, 2018): 随机游走 + GraphSAGE
  • NGCF(2019): 协同过滤 + GCN
  • LightGCN(2020): 简化 GCN(去激活函数)

7.3.3 知识图谱

任务:

  • 实体链接
  • 关系推理
  • 问答系统

模型:

  • R-GCN(2018)
  • CompGCN(2020)
  • NBFNet(2021):神经贝尔曼 - 福特网络

7.4 时空图(Spatio-Temporal Graphs)

应用:

  • 交通流量预测
  • 流行病传播建模
  • 视频分析

模型:

  • STGCN(2018): GCN + 时间卷积
  • Graph WaveNet(2019): 自适应邻接矩阵
  • MTGNN(2020): 多任务时空图

8 2024-2025 前沿进展

8.1 Graph Transformers 🔥

动机: 结合 Transformer 的全局建模能力

挑战:

  • Transformer 是全连接(),大图无法承受
  • 如何注入图结构偏置?

解决方案:

8.1.1 Graphormer (2021, Microsoft)

创新:

  1. 中心性编码(Centrality Encoding): 节点度数
  2. 空间编码(Spatial Encoding): 最短路径距离
  3. 边编码(Edge Encoding): 路径上的边特征

性能: OGB 排行榜多项第一

8.1.2 GraphGPS (2022)

设计: 消息传递 + Transformer 混合

Layer = MPNN(局部) + Transformer(全局) + FFN

优势:

  • 局部归纳偏置(MPNN)
  • 全局信息流(Transformer)

8.1.3 Exphormer (2023)

创新: 基于扩展图(Expander Graph)的稀疏注意力

复杂度: (线性!)

8.2 图基础模型(Graph Foundation Models)🔥

目标: 类似 GPT/BERT 的预训练 - 微调范式

8.2.1 GraphMAE (2022)

方法: 掩码自编码(Masked Autoencoder)

  • 随机掩盖节点特征
  • 重建被掩盖的特征

8.2.2 GraphGPT (2023)

创新: 将图结构转换为序列,用 LLM 处理

流程:

  1. 图 → 序列化(BFS/DFS)
  2. 输入 GPT 风格模型
  3. 生成图级预测

8.2.3 OFA-GNN (2024) 🔥

全称: One-For-All Graph Neural Network

目标: 单一模型处理所有图任务

技术:

  • 统一任务表示
  • 多任务训练
  • 提示学习(Prompt-based)

8.3 几何深度学习 🔥

理论框架: 统一 CNN、RNN、Transformer、GNN

E(n)-Equivariant GNN:

  • 保持旋转/平移不变性
  • 应用:3D 分子、点云、物理模拟

代表模型:

  • EGNN(2021): E(n)- 等变图神经网络
  • GemNet(2021): 几何消息传递
  • Equiformer(2023): 等变 Transformer

8.4 图生成模型 🔥

任务: 生成新的图结构

应用:

  • 分子设计(新药发现)
  • 社交网络模拟
  • 电路设计

方法:

8.4.1 GraphRNN (2018)

思想: 用 RNN 顺序生成节点和边

8.4.2 GraphVAE (2018)

框架: 变分自编码器(VAE)

8.4.3 Diffusion Models for Graphs (2023) 🔥

方法: 扩散模型(类似 DALL-E)

  • 逐步加噪声
  • 学习去噪过程

代表:

  • DiGress: 离散扩散生成
  • GraphGDP: 几何扩散过程

8.5 因果 GNN 🔥

问题: GNN 学到相关性而非因果关系

解决:

  • CIGA(2022): 因果不变图学习
  • DIR(2023): 去相关信息路由

应用: 提升鲁棒性和泛化能力


9 代码实现

9.1 从零实现 GCN 层

import torch
import torch.nn as nn
import torch.nn.functional as F
 
class GCNLayer(nn.Module):
    def __init__(self, in_features, out_features):
        super().__init__()
        self.linear = nn.Linear(in_features, out_features)
        
    def forward(self, X, A):
        """
        X: 节点特征矩阵 (N, in_features)
        A: 邻接矩阵 (N, N)
        """
        # 添加自环
        A_hat = A + torch.eye(A.size(0), device=A.device)
        
        # 度矩阵
        D = torch.diag(A_hat.sum(dim=1))
        
        # 对称归一化: D^(-1/2) * A * D^(-1/2)
        D_inv_sqrt = torch.diag(1.0 / torch.sqrt(D.diagonal() + 1e-6))
        A_norm = D_inv_sqrt @ A_hat @ D_inv_sqrt
        
        # 传播: A_norm * X * W
        out = A_norm @ X
        out = self.linear(out)
        
        return out
 
 
class GCN(nn.Module):
    def __init__(self, in_features, hidden_features, out_features, num_layers=2, dropout=0.5):
        super().__init__()
        self.layers = nn.ModuleList()
        
        # 第一层
        self.layers.append(GCNLayer(in_features, hidden_features))
        
        # 中间层
        for _ in range(num_layers - 2):
            self.layers.append(GCNLayer(hidden_features, hidden_features))
        
        # 输出层
        self.layers.append(GCNLayer(hidden_features, out_features))
        
        self.dropout = dropout
        
    def forward(self, X, A):
        for i, layer in enumerate(self.layers):
            X = layer(X, A)
            if i < len(self.layers) - 1:  # 最后一层不用激活
                X = F.relu(X)
                X = F.dropout(X, p=self.dropout, training=self.training)
        return F.log_softmax(X, dim=1)
 
# 测试
N, D, C = 100, 16, 7  # 100节点, 16特征, 7类别
X = torch.randn(N, D)
A = torch.randint(0, 2, (N, N)).float()  # 随机邻接矩阵
A = (A + A.T) / 2  # 对称化
 
model = GCN(in_features=D, hidden_features=32, out_features=C)
output = model(X, A)
print(f"Output shape: {output.shape}")  # (100, 7)

9.2 使用 PyTorch Geometric

import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, GATConv, global_mean_pool
from torch_geometric.datasets import Planetoid
from torch_geometric.loader import DataLoader
 
# ===== 1. 节点分类 =====
class NodeGCN(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super().__init__()
        self.conv1 = GCNConv(in_channels, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, out_channels)
        
    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.conv2(x, edge_index)
        return F.log_softmax(x, dim=1)
 
# 加载 Cora 数据集
dataset = Planetoid(root='/tmp/Cora', name='Cora')
data = dataset[0]
 
model = NodeGCN(
    in_channels=dataset.num_features,
    hidden_channels=16,
    out_channels=dataset.num_classes
)
 
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
 
# 训练
model.train()
for epoch in range(200):
    optimizer.zero_grad()
    out = model(data.x, data.edge_index)
    loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask])
    loss.backward()
    optimizer.step()
    
    if epoch % 20 == 0:
        print(f'Epoch {epoch}, Loss: {loss.item():.4f}')
 
# 测试
model.eval()
with torch.no_grad():
    pred = model(data.x, data.edge_index).argmax(dim=1)
    correct = (pred[data.test_mask] == data.y[data.test_mask]).sum()
    acc = int(correct) / int(data.test_mask.sum())
    print(f'Test Accuracy: {acc:.4f}')
 
 
# ===== 2. GAT 实现 =====
class GAT(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, heads=8):
        super().__init__()
        self.conv1 = GATConv(in_channels, hidden_channels, heads=heads, dropout=0.6)
        # 输出层只用1个头
        self.conv2 = GATConv(hidden_channels * heads, out_channels, heads=1, concat=False, dropout=0.6)
        
    def forward(self, x, edge_index):
        x = F.dropout(x, p=0.6, training=self.training)
        x = F.elu(self.conv1(x, edge_index))
        x = F.dropout(x, p=0.6, training=self.training)
        x = self.conv2(x, edge_index)
        return F.log_softmax(x, dim=1)
 
 
# ===== 3. 图分类 =====
class GraphClassifier(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super().__init__()
        self.conv1 = GCNConv(in_channels, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, hidden_channels)
        self.conv3 = GCNConv(hidden_channels, hidden_channels)
        self.lin = torch.nn.Linear(hidden_channels, out_channels)
        
    def forward(self, x, edge_index, batch):
        # 节点嵌入
        x = F.relu(self.conv1(x, edge_index))
        x = F.relu(self.conv2(x, edge_index))
        x = F.relu(self.conv3(x, edge_index))
        
        # 图级池化(同一图的节点聚合)
        x = global_mean_pool(x, batch)
        
        # 分类
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.lin(x)
        
        return F.log_softmax(x, dim=1)
 
# 使用示例(需要图分类数据集)
from torch_geometric.datasets import TUDataset
 
dataset = TUDataset(root='/tmp/ENZYMES', name='ENZYMES')
loader = DataLoader(dataset, batch_size=32, shuffle=True)
 
model = GraphClassifier(
    in_channels=dataset.num_features,
    hidden_channels=64,
    out_channels=dataset.num_classes
)
 
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
 
for epoch in range(100):
    model.train()
    total_loss = 0
    for data in loader:
        optimizer.zero_grad()
        out = model(data.x, data.edge_index, data.batch)
        loss = F.nll_loss(out, data.y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * data.num_graphs
    
    print(f'Epoch {epoch}, Loss: {total_loss / len(dataset):.4f}')

9.3 GraphSAGE 实现

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import SAGEConv
 
class GraphSAGE(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, num_layers=2):
        super().__init__()
        self.num_layers = num_layers
        self.convs = nn.ModuleList()
        
        # 第一层
        self.convs.append(SAGEConv(in_channels, hidden_channels))
        
        # 中间层
        for _ in range(num_layers - 2):
            self.convs.append(SAGEConv(hidden_channels, hidden_channels))
        
        # 输出层
        self.convs.append(SAGEConv(hidden_channels, out_channels))
        
    def forward(self, x, edge_index):
        for i, conv in enumerate(self.convs):
            x = conv(x, edge_index)
            if i < self.num_layers - 1:
                x = F.relu(x)
                x = F.dropout(x, p=0.5, training=self.training)
        return x
    
    @torch.no_grad()
    def inference(self, x_all, subgraph_loader):
        """大图推理(逐层采样)"""
        for i, conv in enumerate(self.convs):
            xs = []
            for batch in subgraph_loader:
                x = x_all[batch.n_id].to(batch.x.device)
                x = conv(x, batch.edge_index)
                if i < self.num_layers - 1:
                    x = F.relu(x)
                xs.append(x[:batch.batch_size].cpu())
            
            x_all = torch.cat(xs, dim=0)
        
        return x_all

9.4 链接预测实现

import torch
from torch_geometric.utils import negative_sampling
from sklearn.metrics import roc_auc_score
 
class LinkPredictorGCN(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels):
        super().__init__()
        self.conv1 = GCNConv(in_channels, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, hidden_channels)
        
    def encode(self, x, edge_index):
        """编码节点"""
        x = self.conv1(x, edge_index).relu()
        x = self.conv2(x, edge_index)
        return x
    
    def decode(self, z, edge_label_index):
        """预测边存在概率"""
        src = z[edge_label_index[0]]
        dst = z[edge_label_index[1]]
        return (src * dst).sum(dim=-1)  # 内积
    
    def decode_all(self, z):
        """预测所有可能的边"""
        prob_adj = z @ z.t()
        return (prob_adj > 0).nonzero(as_tuple=False).t()
 
 
def train_link_prediction(model, data, optimizer):
    model.train()
    optimizer.zero_grad()
    
    # 编码
    z = model.encode(data.x, data.train_pos_edge_index)
    
    # 正样本
    pos_pred = model.decode(z, data.train_pos_edge_index)
    
    # 负采样
    neg_edge_index = negative_sampling(
        edge_index=data.train_pos_edge_index,
        num_nodes=data.num_nodes,
        num_neg_samples=data.train_pos_edge_index.size(1)
    )
    neg_pred = model.decode(z, neg_edge_index)
    
    # 损失(二分类)
    loss = -torch.log(torch.sigmoid(pos_pred) + 1e-15).mean()
    loss -= torch.log(1 - torch.sigmoid(neg_pred) + 1e-15).mean()
    
    loss.backward()
    optimizer.step()
    
    return loss.item()
 
 
@torch.no_grad()
def test_link_prediction(model, data):
    model.eval()
    z = model.encode(data.x, data.train_pos_edge_index)
    
    # 测试集预测
    pos_pred = model.decode(z, data.test_pos_edge_index).cpu()
    neg_pred = model.decode(z, data.test_neg_edge_index).cpu()
    
    # AUC
    pred = torch.cat([pos_pred, neg_pred]).numpy()
    label = torch.cat([
        torch.ones(pos_pred.size(0)),
        torch.zeros(neg_pred.size(0))
    ]).numpy()
    
    return roc_auc_score(label, pred)

10 总结与展望

10.1 核心要点回顾

GNN 演进史:

谱方法 (ChebNet)
    ↓
空间方法 (GCN, GraphSAGE)
    ↓
注意力机制 (GAT)
    ↓
理论分析 (GIN)
    ↓
深层网络 (DeeperGCN)
    ↓
Transformer 融合 (Graphormer, GraphGPS)
    ↓
基础模型 (GraphMAE, GraphGPT) 🔥

关键概念:

  1. 消息传递: 节点通过邻居更新表示
  2. 置换不变性: 聚合函数顺序无关
  3. 表达能力: GIN 达到 WL-Test 上界
  4. 过平滑: 深层网络的主要挑战
  5. 归纳学习: 泛化到新图

10.2 实践建议

2025 年选择指南:

场景推荐模型理由
入门学习GCN简单直观
节点分类(同质图)GAT / GCN性能稳定
节点分类(异质图)HAN / RGCN支持多类型
大图(百万节点)GraphSAGE / SIGN采样友好
图分类GIN / DeeperGCN表达力强
分子性质预测Graphormer / DimeNetSOTA 性能
推荐系统LightGCN / PinSage工业验证
时空预测Graph WaveNet时空建模
科研前沿Graph Transformers最新方向

10.3 未来趋势

2025-2026 展望:

  1. 图基础模型(Graph FM): 🔥🔥🔥

    • 预训练 - 微调范式
    • 跨域迁移学习
    • 零样本/少样本学习
  2. 大规模图学习:

    • 十亿节点图
    • 分布式 GNN
    • 硬件加速(GPU/TPU)
  3. 多模态图学习:

    • 文本 + 图结构
    • 图像 + 图结构
    • 统一表示
  4. 可解释性:

    • GNNExplainer
    • 因果分析
    • 对抗鲁棒性
  5. 与 LLM 结合:

    • LLM as Graph Reasoner
    • Graph-enhanced LLM
    • 知识图谱 + 生成模型

11 扩展阅读

11.1 必读论文

经典:

  • GCN: “Semi-Supervised Classification with Graph Convolutional Networks” (Kipf, 2017)
  • GraphSAGE: “Inductive Representation Learning on Large Graphs” (Hamilton, 2017)
  • GAT: “Graph Attention Networks” (Veličković, 2018)
  • GIN: “How Powerful are Graph Neural Networks?” (Xu, 2019)

现代:

  • Graphormer: “Do Transformers Really Perform Bad for Graph Representation?” (Ying, 2021) 🔥
  • GraphGPS: “Recipe for a General, Powerful, Scalable Graph Transformer” (Rampášek, 2022) 🔥
  • GraphMAE: “GraphMAE: Self-Supervised Masked Graph Autoencoders” (Hou, 2022) 🔥
  • GraphGPT: “GraphGPT: Graph Instruction Tuning for Large Language Models” (Tang, 2023) 🔥

11.2 开源资源

框架:

数据集:

教程:

代码示例:


12 相关笔记


最后更新: 2025-01-07 维护者: sean2077