图神经网络(GNN)
核心定位: 处理非欧几里得结构数据的深度学习架构,从 2009 年起源到 2024 年广泛应用的演进史。 关联笔记: 神经网络、神经网络层类型详解、卷积神经网络、循环神经网络、Transformer
1 为什么需要 GNN
1.1 图数据无处不在
真实世界的图结构:
- 社交网络: 用户(节点)+ 关系(边)
- 分子结构: 原子(节点)+ 化学键(边)
- 知识图谱: 实体(节点)+ 关系(边)
- 推荐系统: 用户 - 商品二部图
- 交通网络: 路口(节点)+ 道路(边)
- 蛋白质结构: 氨基酸(节点)+ 空间关系(边)
- 代码分析: 函数/变量(节点)+ 调用/依赖(边)
1.2 传统方法的局限
问题 1:CNN 无法应用
- CNN 假设:网格结构(图像)、局部性、平移不变性
- 图特点:不规则结构、节点度数不同、无固定邻居顺序
问题 2:传统图算法无法泛化
- PageRank、最短路径:任务特定
- 手工特征工程:费时费力
- 无法端到端学习
问题 3:表示学习困难
- 如何编码图结构?
- 如何保留拓扑信息?
- 如何处理动态图?
1.3 GNN 的核心优势
✅ 结构感知: 利用图的连接信息 ✅ 端到端学习: 自动学习节点/图表示 ✅ 归纳能力: 可泛化到未见过的图 ✅ 任务多样: 节点/边/图级预测
2 图的基础知识
2.1 图的数学定义
无向图:
- :节点集合,
- :边集合,
- 邻接矩阵: , 表示 和 相连
有向图: 表示从 到 的边
加权图: (边权重)
属性图(Attributed Graph):
- 节点特征:(每个节点 维特征)
- 边特征:(可选)
2.2 图的关键概念
| 概念 | 定义 | 示例 |
|---|---|---|
| 度(Degree) | 节点连接的边数 | |
| 邻居 | 与节点直接相连的节点集 | |
| 路径 | 节点序列,相邻节点间有边 | |
| 连通性 | 任意两节点间存在路径 | 社交网络的连通分量 |
| 度矩阵 | (对角) | 用于归一化 |
| 拉普拉斯矩阵 | 谱图理论基础 |
2.3 图的表示方式
2.3.1 邻接矩阵(Adjacency Matrix)
优点: 直观,便于矩阵运算 缺点: 稀疏图浪费空间()
# 示例:4 个节点的图
A = [[0, 1, 1, 0],
[1, 0, 0, 1],
[1, 0, 0, 1],
[0, 1, 1, 0]]2.3.2 边列表(Edge List)
格式: 边的起点 - 终点对 优点: 空间高效()
edge_list = [(0, 1), (0, 2), (1, 3), (2, 3)]2.3.3 邻接表(Adjacency List)
格式: 每个节点的邻居列表 优点: 适合稀疏图
adj_list = {
0: [1, 2],
1: [0, 3],
2: [0, 3],
3: [1, 2]
}2.4 常见图类型
| 图类型 | 特点 | 示例 |
|---|---|---|
| 稀疏图 | 社交网络(平均度 < 100) | |
| 稠密图 | 全连接层(完全图) | |
| 同质图 | 单一节点/边类型 | 论文引用网络 |
| 异质图 | 多种节点/边类型 | 知识图谱(人、地点、事件) |
| 静态图 | 结构固定 | 分子结构 |
| 动态图 | 时变结构 | 社交网络演化 |
| 二部图 | 两类节点,类内无边 | 用户 - 商品推荐 |
3 GNN 核心思想:消息传递
3.1 消息传递神经网络(MPNN)框架
核心理念: 节点通过迭代聚合邻居信息来更新自己的表示
通用公式(第 层):
三个关键步骤:
-
消息生成(Message): 邻居发送消息
-
消息聚合(Aggregate): 汇总邻居消息
-
节点更新(Update): 结合自身信息更新
3.2 聚合函数的选择
| 聚合方式 | 公式 | 特点 | 代表模型 |
|---|---|---|---|
| 求和(Sum) | 大小不变性 | GCN, GraphSAGE | |
| 平均(Mean) | $\frac{1}{ | \mathcal{N}(v) | } \sum_{u} h_u$ |
| 最大(Max) | 提取显著特征 | GraphSAGE | |
| 注意力 | 自适应加权 | GAT | |
| LSTM | LSTM over neighbors | 考虑顺序(需排序) | GraphSAGE |
置换不变性(Permutation Invariance):
- 邻居顺序不应影响结果
- Sum、Mean、Max 满足
- LSTM 不满足(需要固定顺序)
3.3 消息传递的直观理解
以社交网络为例:
初始状态 (k=0):
你 (h_v^0): [兴趣: 音乐, 年龄: 25]
朋友A (h_A^0): [兴趣: 音乐, 年龄: 24]
朋友B (h_B^0): [兴趣: 运动, 年龄: 26]
第1层 (k=1):
1. 朋友发送信息 (Message)
2. 你接收并聚合 (Aggregate): Mean([h_A^0, h_B^0])
3. 更新你的表示 (Update): h_v^1 = NN([h_v^0, mean_msg])
→ h_v^1 包含了"我和我朋友圈"的信息
第2层 (k=2):
h_v^2 包含了"我、我朋友、朋友的朋友"的信息
→ 感受野扩大到 2-hop 邻居
感受野(Receptive Field):
- 层 GNN:感受野 = -hop 邻居
- 类似 CNN 的感受野,但在图上
4 经典 GNN 架构演进
4.1 谱图卷积(Spectral GCN)
4.1.1 理论基础:图拉普拉斯
拉普拉斯矩阵:
归一化拉普拉斯:
谱分解:
- :特征向量矩阵(图傅里叶基)
- :特征值矩阵
图傅里叶变换:
- 前向:
- 逆向:
4.1.2 谱卷积(Spectral Convolution)
定义: 在频域进行卷积
其中 是逐元素乘法。
问题:
- 计算复杂度:(特征分解)
- 不能泛化到不同大小的图
4.1.3 ChebNet (2016)
创新: 使用切比雪夫多项式近似
其中 是切比雪夫多项式,。
优势:
- -localized:只需 -hop 邻居
- 复杂度(线性于边数)
4.2 GCN:图卷积网络 (2017) ⭐
作者: Thomas Kipf & Max Welling 论文: “Semi-Supervised Classification with Graph Convolutional Networks”
4.2.1 简化推导
从 ChebNet 取 (一阶近似):
其中:
- (加自环)
- :可学习权重
- :激活函数(ReLU)
4.2.2 逐层传播规则
单节点视角:
归一化因子 :
- 防止度数大的节点主导
- 对称归一化
4.2.3 GCN 特点
优势:
- ✅ 简单高效
- ✅ 半监督学习(少量标签)
- ✅ 可叠加多层
- ✅ 可泛化到不同图
局限:
- ❌ 过平滑(Over-smoothing):深层后节点表示趋同
- ❌ 所有邻居等权重(无差异化)
- ❌ 无法处理边特征
4.3 GraphSAGE (2017)
全称: Graph Sample and Aggregate 作者: Hamilton et al. (Stanford) 创新: 采样 + 多种聚合器 + 归纳学习
4.3.1 核心算法
对于每个节点 v:
1. 采样固定数量邻居 (Sample)
2. 聚合邻居特征 (Aggregate)
h_N^(k) = AGG({h_u^(k-1), ∀u ∈ Sample(N(v))})
3. 结合自身特征 (Concatenate)
h_v^(k) = σ(W · [h_v^(k-1) || h_N^(k)])
4. 归一化 (L2 normalization)
h_v^(k) = h_v^(k) / ||h_v^(k)||_2
4.3.2 聚合器对比
| 聚合器 | 公式 | 复杂度 | 特点 |
|---|---|---|---|
| Mean | 最稳定 | ||
| LSTM | LSTM over random permutation | $O(d \cdot | \mathcal{N} |
| Pooling | 元素级最大池化 |
4.3.3 采样策略
固定大小采样:
- 1-hop:采样 25 个邻居
- 2-hop:采样 10 个邻居
- 总复杂度:(可控)
优势:
- 计算可预测(Mini-batch 训练)
- 适合大图(不需要整个图)
- 归纳学习(新节点无需重训练)
4.4 GAT:图注意力网络 (2018) ⭐
作者: Veličković et al. 创新: 自注意力机制加权邻居
4.4.1 注意力系数计算
步骤 1:计算注意力分数
其中:
- :特征变换矩阵
- :注意力参数向量
- :拼接
步骤 2:Softmax 归一化
步骤 3:加权聚合
4.4.2 多头注意力(Multi-Head Attention)
并行计算 个注意力头:
或在最后一层使用平均:
4.4.3 GAT 优势
✅ 自适应权重: 重要邻居权重大 ✅ 可解释性: 注意力分数可视化 ✅ 处理异质图: 不同类型边可用不同注意力 ✅ 并行化: 所有节点同时计算
应用示例:
- 分子性质预测:化学键重要性不同
- 推荐系统:用户兴趣权重动态调整
4.5 GIN:图同构网络 (2019)
全称: Graph Isomorphism Network 作者: Xu et al. 理论贡献: 证明了 GNN 的表达能力上界
4.5.1 WL-Test(Weisfeiler-Lehman)
图同构问题: 判断两个图是否结构相同
WL-Test 算法:
- 初始化节点标签
- 迭代更新:
- 如果两图标签序列相同 → 可能同构
4.5.2 GIN 公式
最大表达力的 GNN(等价于 WL-Test):
其中 可学习或固定。
关键设计:
- 单射聚合: Sum(而非 Mean/Max)
- MLP 更新: 比单层线性更强表达力
- 包含自身:
4.5.3 理论结果
定理: GIN 可以区分任何 WL-Test 可区分的图。
含义:
- GCN/GraphSAGE(Mean/Max):表达力弱于 WL-Test
- GAT:理论上不强于 GCN(注意力不改变表达力上界)
- GIN:达到理论上界
4.6 其他重要变体
4.6.1 消息传递变体
| 模型 | 年份 | 核心创新 | 适用场景 |
|---|---|---|---|
| GatedGCN | 2018 | 门控机制(类似 LSTM) | 过平滑缓解 |
| PNA | 2020 | 多尺度聚合器组合 | 通用任务 |
| DeeperGCN | 2020 | 残差连接 + 归一化 | 深层网络(50+ 层) |
| GRAND | 2021 | 图随机扩散 | 鲁棒性增强 |
4.6.2 位置编码
问题: GNN 无法区分对称节点
解决方案:
- 拉普拉斯位置编码(LapPE): 使用拉普拉斯矩阵特征向量
- 随机游走编码(RWSE): 基于随机游走统计量
- 距离编码: 到关键节点的最短路径
5 图级任务与池化
5.1 图分类问题
任务: 给整个图分配标签
示例:
- 分子性质预测(有毒/无毒)
- 代码漏洞检测
- 蛋白质功能预测
5.2 图池化方法
5.2.1 全局池化
简单聚合:
| 方法 | 公式 | 特点 |
|---|---|---|
| Sum | 保留总量 | |
| Mean | $\frac{1}{ | V |
| Max | 显著特征 | |
| Attention | 加权聚合 |
Set2Set(2018):
- 使用 LSTM 读取节点集合
- 顺序不变但表达力强
5.2.2 层次化池化
DiffPool (2018):
- 学习软分配矩阵
- 逐层粗化图结构
Top-K Pooling:
- 根据节点得分保留 Top-K 个节点
- 类似 CNN 的最大池化
SAGPool (2019):
- 自注意力图池化
- 选择得分最高的 个节点
6 GNN 的挑战与解决方案
6.1 过平滑(Over-Smoothing)
问题: 多层 GNN 后,所有节点表示趋于相同
数学解释:
- 每层聚合相当于拉普拉斯平滑
- 层后,节点表示混合了 -hop 邻域
- 小世界网络中,几层后覆盖整个图
实验现象:
2层 GCN: 准确率 80%
4层 GCN: 准确率 75%
8层 GCN: 准确率 50%(随机猜测)
解决方案:
-
残差连接(Skip Connections):
-
Jumping Knowledge(2018):
- 聚合所有层的表示:
-
DropEdge(2020):
- 训练时随机丢弃边
- 减少过度聚合
-
PairNorm(2020):
- 保持节点表示的方差
6.2 过拟合(Over-Fitting)
挑战: 图数据集通常很小(<1000 个图)
解决方案:
| 方法 | 原理 | 效果 |
|---|---|---|
| Dropout | 随机丢弃节点特征 | 标准正则 |
| DropEdge | 随机丢弃边 | 结构正则 |
| 图数据增强 | 添加/删除边,特征扰动 | 扩充数据 |
| 对比学习 | 自监督预训练 | 利用无标签数据 |
6.3 异质图建模
异质图(Heterogeneous Graph): 多种节点/边类型
示例:学术网络
- 节点:论文、作者、会议
- 边:写作、发表、引用
解决方案:
HAN(Heterogeneous Attention Network, 2019):
- 节点级注意力: 对不同类型邻居加权
- 语义级注意力: 对不同元路径加权
RGCN(Relational GCN, 2018):
每种关系 有独立权重 。
6.4 大图扩展性
挑战: 完整图无法放入 GPU 内存
解决方案:
-
采样方法:
- GraphSAGE: 邻居采样
- FastGCN: 层级采样
- ClusterGCN: 图聚类后 Mini-batch
-
简化模型:
- SGC(Simple GCN): 移除非线性,预计算
- SIGN: 预计算多跳邻域,并行训练
-
分布式训练:
- 图划分
- 跨设备通信
7 应用场景
7.1 节点级任务
7.1.1 节点分类
示例:
- 社交网络: 用户兴趣分类
- 学术网络: 论文主题分类
- 生物网络: 蛋白质功能预测
典型数据集:
| 数据集 | 节点数 | 边数 | 类别数 | 特征维度 |
|---|---|---|---|---|
| Cora | 2,708 | 5,429 | 7 | 1,433 |
| Citeseer | 3,327 | 4,732 | 6 | 3,703 |
| PubMed | 19,717 | 44,338 | 3 | 500 |
2024 SOTA(Cora):
- GCN: 81.5%
- GAT: 83.0%
- GraphTransformer: 85.2% 🔥
7.1.2 链接预测
任务: 预测两节点间是否存在边
应用:
- 推荐系统(用户 - 商品)
- 知识图谱补全
- 药物 - 靶点相互作用
方法:
- 学习节点嵌入:
- 计算相似度:
7.2 边级任务
边分类/回归:
- 社交网络:关系类型预测
- 交通网络:道路拥堵预测
7.3 图级任务
7.3.1 分子性质预测 ⭐
应用: 新药研发、材料设计
数据集:
- QM9: 130K 小分子,量子化学性质
- ZINC: 250K 药物样分子
- OGB(Open Graph Benchmark): 百万级分子
模型:
- SchNet(2018):连续滤波器卷积
- DimeNet(2020):方向性消息传递
- Graphormer(2021):Transformer on Graphs 🔥
2024 进展:
- UniMol: 3D 分子预训练模型
- MolGPT: 生成式分子设计
7.3.2 推荐系统
图构建:
- 节点:用户 + 商品
- 边:交互历史(点击、购买)
模型:
- PinSage(Pinterest, 2018): 随机游走 + GraphSAGE
- NGCF(2019): 协同过滤 + GCN
- LightGCN(2020): 简化 GCN(去激活函数)
7.3.3 知识图谱
任务:
- 实体链接
- 关系推理
- 问答系统
模型:
- R-GCN(2018)
- CompGCN(2020)
- NBFNet(2021):神经贝尔曼 - 福特网络
7.4 时空图(Spatio-Temporal Graphs)
应用:
- 交通流量预测
- 流行病传播建模
- 视频分析
模型:
- STGCN(2018): GCN + 时间卷积
- Graph WaveNet(2019): 自适应邻接矩阵
- MTGNN(2020): 多任务时空图
8 2024-2025 前沿进展
8.1 Graph Transformers 🔥
动机: 结合 Transformer 的全局建模能力
挑战:
- Transformer 是全连接(),大图无法承受
- 如何注入图结构偏置?
解决方案:
8.1.1 Graphormer (2021, Microsoft)
创新:
- 中心性编码(Centrality Encoding): 节点度数
- 空间编码(Spatial Encoding): 最短路径距离
- 边编码(Edge Encoding): 路径上的边特征
性能: OGB 排行榜多项第一
8.1.2 GraphGPS (2022)
设计: 消息传递 + Transformer 混合
Layer = MPNN(局部) + Transformer(全局) + FFN
优势:
- 局部归纳偏置(MPNN)
- 全局信息流(Transformer)
8.1.3 Exphormer (2023)
创新: 基于扩展图(Expander Graph)的稀疏注意力
复杂度: (线性!)
8.2 图基础模型(Graph Foundation Models)🔥
目标: 类似 GPT/BERT 的预训练 - 微调范式
8.2.1 GraphMAE (2022)
方法: 掩码自编码(Masked Autoencoder)
- 随机掩盖节点特征
- 重建被掩盖的特征
8.2.2 GraphGPT (2023)
创新: 将图结构转换为序列,用 LLM 处理
流程:
- 图 → 序列化(BFS/DFS)
- 输入 GPT 风格模型
- 生成图级预测
8.2.3 OFA-GNN (2024) 🔥
全称: One-For-All Graph Neural Network
目标: 单一模型处理所有图任务
技术:
- 统一任务表示
- 多任务训练
- 提示学习(Prompt-based)
8.3 几何深度学习 🔥
理论框架: 统一 CNN、RNN、Transformer、GNN
E(n)-Equivariant GNN:
- 保持旋转/平移不变性
- 应用:3D 分子、点云、物理模拟
代表模型:
- EGNN(2021): E(n)- 等变图神经网络
- GemNet(2021): 几何消息传递
- Equiformer(2023): 等变 Transformer
8.4 图生成模型 🔥
任务: 生成新的图结构
应用:
- 分子设计(新药发现)
- 社交网络模拟
- 电路设计
方法:
8.4.1 GraphRNN (2018)
思想: 用 RNN 顺序生成节点和边
8.4.2 GraphVAE (2018)
框架: 变分自编码器(VAE)
8.4.3 Diffusion Models for Graphs (2023) 🔥
方法: 扩散模型(类似 DALL-E)
- 逐步加噪声
- 学习去噪过程
代表:
- DiGress: 离散扩散生成
- GraphGDP: 几何扩散过程
8.5 因果 GNN 🔥
问题: GNN 学到相关性而非因果关系
解决:
- CIGA(2022): 因果不变图学习
- DIR(2023): 去相关信息路由
应用: 提升鲁棒性和泛化能力
9 代码实现
9.1 从零实现 GCN 层
import torch
import torch.nn as nn
import torch.nn.functional as F
class GCNLayer(nn.Module):
def __init__(self, in_features, out_features):
super().__init__()
self.linear = nn.Linear(in_features, out_features)
def forward(self, X, A):
"""
X: 节点特征矩阵 (N, in_features)
A: 邻接矩阵 (N, N)
"""
# 添加自环
A_hat = A + torch.eye(A.size(0), device=A.device)
# 度矩阵
D = torch.diag(A_hat.sum(dim=1))
# 对称归一化: D^(-1/2) * A * D^(-1/2)
D_inv_sqrt = torch.diag(1.0 / torch.sqrt(D.diagonal() + 1e-6))
A_norm = D_inv_sqrt @ A_hat @ D_inv_sqrt
# 传播: A_norm * X * W
out = A_norm @ X
out = self.linear(out)
return out
class GCN(nn.Module):
def __init__(self, in_features, hidden_features, out_features, num_layers=2, dropout=0.5):
super().__init__()
self.layers = nn.ModuleList()
# 第一层
self.layers.append(GCNLayer(in_features, hidden_features))
# 中间层
for _ in range(num_layers - 2):
self.layers.append(GCNLayer(hidden_features, hidden_features))
# 输出层
self.layers.append(GCNLayer(hidden_features, out_features))
self.dropout = dropout
def forward(self, X, A):
for i, layer in enumerate(self.layers):
X = layer(X, A)
if i < len(self.layers) - 1: # 最后一层不用激活
X = F.relu(X)
X = F.dropout(X, p=self.dropout, training=self.training)
return F.log_softmax(X, dim=1)
# 测试
N, D, C = 100, 16, 7 # 100节点, 16特征, 7类别
X = torch.randn(N, D)
A = torch.randint(0, 2, (N, N)).float() # 随机邻接矩阵
A = (A + A.T) / 2 # 对称化
model = GCN(in_features=D, hidden_features=32, out_features=C)
output = model(X, A)
print(f"Output shape: {output.shape}") # (100, 7)9.2 使用 PyTorch Geometric
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, GATConv, global_mean_pool
from torch_geometric.datasets import Planetoid
from torch_geometric.loader import DataLoader
# ===== 1. 节点分类 =====
class NodeGCN(torch.nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels):
super().__init__()
self.conv1 = GCNConv(in_channels, hidden_channels)
self.conv2 = GCNConv(hidden_channels, out_channels)
def forward(self, x, edge_index):
x = self.conv1(x, edge_index)
x = F.relu(x)
x = F.dropout(x, p=0.5, training=self.training)
x = self.conv2(x, edge_index)
return F.log_softmax(x, dim=1)
# 加载 Cora 数据集
dataset = Planetoid(root='/tmp/Cora', name='Cora')
data = dataset[0]
model = NodeGCN(
in_channels=dataset.num_features,
hidden_channels=16,
out_channels=dataset.num_classes
)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
# 训练
model.train()
for epoch in range(200):
optimizer.zero_grad()
out = model(data.x, data.edge_index)
loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask])
loss.backward()
optimizer.step()
if epoch % 20 == 0:
print(f'Epoch {epoch}, Loss: {loss.item():.4f}')
# 测试
model.eval()
with torch.no_grad():
pred = model(data.x, data.edge_index).argmax(dim=1)
correct = (pred[data.test_mask] == data.y[data.test_mask]).sum()
acc = int(correct) / int(data.test_mask.sum())
print(f'Test Accuracy: {acc:.4f}')
# ===== 2. GAT 实现 =====
class GAT(torch.nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels, heads=8):
super().__init__()
self.conv1 = GATConv(in_channels, hidden_channels, heads=heads, dropout=0.6)
# 输出层只用1个头
self.conv2 = GATConv(hidden_channels * heads, out_channels, heads=1, concat=False, dropout=0.6)
def forward(self, x, edge_index):
x = F.dropout(x, p=0.6, training=self.training)
x = F.elu(self.conv1(x, edge_index))
x = F.dropout(x, p=0.6, training=self.training)
x = self.conv2(x, edge_index)
return F.log_softmax(x, dim=1)
# ===== 3. 图分类 =====
class GraphClassifier(torch.nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels):
super().__init__()
self.conv1 = GCNConv(in_channels, hidden_channels)
self.conv2 = GCNConv(hidden_channels, hidden_channels)
self.conv3 = GCNConv(hidden_channels, hidden_channels)
self.lin = torch.nn.Linear(hidden_channels, out_channels)
def forward(self, x, edge_index, batch):
# 节点嵌入
x = F.relu(self.conv1(x, edge_index))
x = F.relu(self.conv2(x, edge_index))
x = F.relu(self.conv3(x, edge_index))
# 图级池化(同一图的节点聚合)
x = global_mean_pool(x, batch)
# 分类
x = F.dropout(x, p=0.5, training=self.training)
x = self.lin(x)
return F.log_softmax(x, dim=1)
# 使用示例(需要图分类数据集)
from torch_geometric.datasets import TUDataset
dataset = TUDataset(root='/tmp/ENZYMES', name='ENZYMES')
loader = DataLoader(dataset, batch_size=32, shuffle=True)
model = GraphClassifier(
in_channels=dataset.num_features,
hidden_channels=64,
out_channels=dataset.num_classes
)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
for epoch in range(100):
model.train()
total_loss = 0
for data in loader:
optimizer.zero_grad()
out = model(data.x, data.edge_index, data.batch)
loss = F.nll_loss(out, data.y)
loss.backward()
optimizer.step()
total_loss += loss.item() * data.num_graphs
print(f'Epoch {epoch}, Loss: {total_loss / len(dataset):.4f}')9.3 GraphSAGE 实现
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import SAGEConv
class GraphSAGE(nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels, num_layers=2):
super().__init__()
self.num_layers = num_layers
self.convs = nn.ModuleList()
# 第一层
self.convs.append(SAGEConv(in_channels, hidden_channels))
# 中间层
for _ in range(num_layers - 2):
self.convs.append(SAGEConv(hidden_channels, hidden_channels))
# 输出层
self.convs.append(SAGEConv(hidden_channels, out_channels))
def forward(self, x, edge_index):
for i, conv in enumerate(self.convs):
x = conv(x, edge_index)
if i < self.num_layers - 1:
x = F.relu(x)
x = F.dropout(x, p=0.5, training=self.training)
return x
@torch.no_grad()
def inference(self, x_all, subgraph_loader):
"""大图推理(逐层采样)"""
for i, conv in enumerate(self.convs):
xs = []
for batch in subgraph_loader:
x = x_all[batch.n_id].to(batch.x.device)
x = conv(x, batch.edge_index)
if i < self.num_layers - 1:
x = F.relu(x)
xs.append(x[:batch.batch_size].cpu())
x_all = torch.cat(xs, dim=0)
return x_all9.4 链接预测实现
import torch
from torch_geometric.utils import negative_sampling
from sklearn.metrics import roc_auc_score
class LinkPredictorGCN(torch.nn.Module):
def __init__(self, in_channels, hidden_channels):
super().__init__()
self.conv1 = GCNConv(in_channels, hidden_channels)
self.conv2 = GCNConv(hidden_channels, hidden_channels)
def encode(self, x, edge_index):
"""编码节点"""
x = self.conv1(x, edge_index).relu()
x = self.conv2(x, edge_index)
return x
def decode(self, z, edge_label_index):
"""预测边存在概率"""
src = z[edge_label_index[0]]
dst = z[edge_label_index[1]]
return (src * dst).sum(dim=-1) # 内积
def decode_all(self, z):
"""预测所有可能的边"""
prob_adj = z @ z.t()
return (prob_adj > 0).nonzero(as_tuple=False).t()
def train_link_prediction(model, data, optimizer):
model.train()
optimizer.zero_grad()
# 编码
z = model.encode(data.x, data.train_pos_edge_index)
# 正样本
pos_pred = model.decode(z, data.train_pos_edge_index)
# 负采样
neg_edge_index = negative_sampling(
edge_index=data.train_pos_edge_index,
num_nodes=data.num_nodes,
num_neg_samples=data.train_pos_edge_index.size(1)
)
neg_pred = model.decode(z, neg_edge_index)
# 损失(二分类)
loss = -torch.log(torch.sigmoid(pos_pred) + 1e-15).mean()
loss -= torch.log(1 - torch.sigmoid(neg_pred) + 1e-15).mean()
loss.backward()
optimizer.step()
return loss.item()
@torch.no_grad()
def test_link_prediction(model, data):
model.eval()
z = model.encode(data.x, data.train_pos_edge_index)
# 测试集预测
pos_pred = model.decode(z, data.test_pos_edge_index).cpu()
neg_pred = model.decode(z, data.test_neg_edge_index).cpu()
# AUC
pred = torch.cat([pos_pred, neg_pred]).numpy()
label = torch.cat([
torch.ones(pos_pred.size(0)),
torch.zeros(neg_pred.size(0))
]).numpy()
return roc_auc_score(label, pred)10 总结与展望
10.1 核心要点回顾
GNN 演进史:
谱方法 (ChebNet)
↓
空间方法 (GCN, GraphSAGE)
↓
注意力机制 (GAT)
↓
理论分析 (GIN)
↓
深层网络 (DeeperGCN)
↓
Transformer 融合 (Graphormer, GraphGPS)
↓
基础模型 (GraphMAE, GraphGPT) 🔥
关键概念:
- 消息传递: 节点通过邻居更新表示
- 置换不变性: 聚合函数顺序无关
- 表达能力: GIN 达到 WL-Test 上界
- 过平滑: 深层网络的主要挑战
- 归纳学习: 泛化到新图
10.2 实践建议
2025 年选择指南:
| 场景 | 推荐模型 | 理由 |
|---|---|---|
| 入门学习 | GCN | 简单直观 |
| 节点分类(同质图) | GAT / GCN | 性能稳定 |
| 节点分类(异质图) | HAN / RGCN | 支持多类型 |
| 大图(百万节点) | GraphSAGE / SIGN | 采样友好 |
| 图分类 | GIN / DeeperGCN | 表达力强 |
| 分子性质预测 | Graphormer / DimeNet | SOTA 性能 |
| 推荐系统 | LightGCN / PinSage | 工业验证 |
| 时空预测 | Graph WaveNet | 时空建模 |
| 科研前沿 | Graph Transformers | 最新方向 |
10.3 未来趋势
2025-2026 展望:
-
图基础模型(Graph FM): 🔥🔥🔥
- 预训练 - 微调范式
- 跨域迁移学习
- 零样本/少样本学习
-
大规模图学习:
- 十亿节点图
- 分布式 GNN
- 硬件加速(GPU/TPU)
-
多模态图学习:
- 文本 + 图结构
- 图像 + 图结构
- 统一表示
-
可解释性:
- GNNExplainer
- 因果分析
- 对抗鲁棒性
-
与 LLM 结合:
- LLM as Graph Reasoner
- Graph-enhanced LLM
- 知识图谱 + 生成模型
11 扩展阅读
11.1 必读论文
经典:
- GCN: “Semi-Supervised Classification with Graph Convolutional Networks” (Kipf, 2017)
- GraphSAGE: “Inductive Representation Learning on Large Graphs” (Hamilton, 2017)
- GAT: “Graph Attention Networks” (Veličković, 2018)
- GIN: “How Powerful are Graph Neural Networks?” (Xu, 2019)
现代:
- Graphormer: “Do Transformers Really Perform Bad for Graph Representation?” (Ying, 2021) 🔥
- GraphGPS: “Recipe for a General, Powerful, Scalable Graph Transformer” (Rampášek, 2022) 🔥
- GraphMAE: “GraphMAE: Self-Supervised Masked Graph Autoencoders” (Hou, 2022) 🔥
- GraphGPT: “GraphGPT: Graph Instruction Tuning for Large Language Models” (Tang, 2023) 🔥
11.2 开源资源
框架:
- PyTorch Geometric - 最流行的 GNN 库
- DGL (Deep Graph Library) - 支持多种后端
- GraphGym - GNN 实验平台
数据集:
- OGB (Open Graph Benchmark) - 标准评测基准
- TU Datasets - 图分类数据集
- PyG Datasets - 内置数据集
教程:
- Stanford CS224W - 机器学习与图
- 几何深度学习书 - 理论基础
代码示例:
12 相关笔记
- 神经网络 - 总览
- 神经网络层类型详解 - GNN 层详解
- 卷积神经网络 - 欧几里得空间卷积
- 循环神经网络 - 序列建模
- Transformer - 注意力机制
- 图嵌入 - Node2Vec, DeepWalk
- 知识图谱 - 结构化知识表示
- 分子性质预测 - GNN 在化学中的应用
最后更新: 2025-01-07 维护者: sean2077