保姆级教程:用torch_geometric和KarateClub数据集入门图神经网络GCN
从零开始用KarateClub数据集构建你的第一个GCN模型第一次接触图神经网络GNN时很多人会被复杂的数学公式和抽象的概念吓退。但事实上GNN的核心思想非常简单——就像人类社交网络中信息的传播一样自然。想象一下你加入了一个新的社交圈子通过朋友的朋友逐渐了解这个群体的结构和特点这就是GNN在做的事情。本文将用最直观的空手道俱乐部社交网络作为案例带你从零开始构建第一个图卷积网络GCN模型。1. 环境准备与数据理解在开始之前我们需要确保环境配置正确。推荐使用Python 3.8和PyTorch 1.10环境。安装torch_geometric最简单的方式是通过pippip install torch torch_geometricKarateClub数据集是一个经典的社交网络数据集记录了美国一所大学空手道俱乐部34名成员之间的社交关系。在1970-1972年期间由于教练节点0与管理员节点33之间的分歧这个俱乐部最终分裂成了两个群体。我们的目标是让GCN模型通过学习成员间的社交关系自动识别出这两个群体。让我们先看看这个数据集的基本结构from torch_geometric.datasets import KarateClub dataset KarateClub() data dataset[0] print(f节点数量: {data.num_nodes}) # 34 print(f边数量: {data.num_edges}) # 78 print(f节点特征维度: {data.num_features}) # 34 print(f类别数量: {data.num_classes}) # 2每个节点都有一个34维的特征向量初始为单位矩阵edge_index则存储了所有78条边的关系。值得注意的是这里的边是无向的意味着社交关系是双向的。提示在PyG中edge_index是一个形状为[2, num_edges]的张量第一行是源节点索引第二行是目标节点索引。这种存储方式比邻接矩阵更节省空间尤其对于稀疏图。2. 可视化原始社交网络理解数据结构最直观的方式就是可视化。我们将使用networkx库来绘制这个社交网络import matplotlib.pyplot as plt import networkx as nx from torch_geometric.utils import to_networkx G to_networkx(data, to_undirectedTrue) plt.figure(figsize(10, 8)) pos nx.spring_layout(G, seed42) nx.draw_networkx(G, pospos, with_labelsTrue, node_colordata.y.numpy(), cmapSet2) plt.title(Karate Club社交网络结构) plt.show()从可视化结果中我们可以清晰地看到两个主要群体一个围绕教练节点0另一个围绕管理员节点33。节点颜色代表了它们最终归属的群体这正是我们的GCN模型需要学习的分类目标。3. 构建GCN模型架构现在我们来构建一个简单的两层GCN模型。与传统的CNN不同GCN的关键在于消息传递机制——每个节点通过聚合邻居节点的信息来更新自己的表示。import torch from torch.nn import Linear from torch_geometric.nn import GCNConv class GCN(torch.nn.Module): def __init__(self): super().__init__() torch.manual_seed(1234) self.conv1 GCNConv(dataset.num_features, 4) # 第一层将34维特征压缩到4维 self.conv2 GCNConv(4, 2) # 第二层输出2维方便可视化 self.classifier Linear(2, dataset.num_classes) # 最后的分类层 def forward(self, x, edge_index): # 第一层GCN 激活函数 h self.conv1(x, edge_index).tanh() # 第二层GCN h self.conv2(h, edge_index).tanh() # 分类输出 out self.classifier(h) return out, h # 返回分类结果和中间表示(用于可视化)这个模型的关键组件是GCNConv层它实现了图卷积操作。与普通卷积不同图卷积需要考虑节点的邻居关系这正是通过edge_index参数传递的。注意我们选择tanh作为激活函数是因为它在(-1,1)区间有输出适合后续的可视化展示。在实际应用中ReLU可能更常见。4. 训练与可视化节点嵌入现在我们可以开始训练模型了。由于这是一个小型社交网络我们采用全监督学习方式model GCN() criterion torch.nn.CrossEntropyLoss() optimizer torch.optim.Adam(model.parameters(), lr0.01) def train(): model.train() optimizer.zero_grad() out, h model(data.x, data.edge_index) loss criterion(out, data.y) # 使用所有节点的标签 loss.backward() optimizer.step() return loss, h # 训练循环 for epoch in range(100): loss, h train() if epoch % 10 0: print(fEpoch: {epoch:03d}, Loss: {loss:.4f})随着训练的进行损失值应该逐渐下降。更有趣的是观察节点嵌入h在训练过程中的变化。我们可以绘制这些二维嵌入的散点图def visualize_embedding(h, color, epochNone, lossNone): plt.figure(figsize(7, 7)) plt.scatter(h[:, 0], h[:, 1], s140, ccolor, cmapSet2) if epoch is not None and loss is not None: plt.xlabel(fEpoch: {epoch}, Loss: {loss:.4f}, fontsize16) plt.show() # 初始状态的可视化 model.eval() _, h model(data.x, data.edge_index) visualize_embedding(h.detach().numpy(), colordata.y, epoch0, lossloss) # 训练后的可视化 visualize_embedding(h.detach().numpy(), colordata.y, epoch100, lossloss)你会看到随着训练进行两个群体的节点在嵌入空间中逐渐分离。这正是GCN的魅力所在——它通过学习图结构信息将拓扑关系编码到了节点的特征表示中。5. 模型评估与实战技巧虽然我们的模型在训练集上表现良好但在实际应用中还需要考虑以下关键点数据划分真实场景中通常只有部分节点有标签超参数调优学习率、隐藏层维度等模型深度GCN不宜过深通常2-3层效果最佳让我们修改数据加载方式模拟半监督学习场景# 随机选择4个节点作为训练集 data.train_mask torch.zeros(data.num_nodes, dtypetorch.bool) data.train_mask[[0, 33, 5, 10]] True # 教练、管理员和两个普通成员 # 修改训练函数 def train_semi_supervised(): model.train() optimizer.zero_grad() out, _ model(data.x, data.edge_index) loss criterion(out[data.train_mask], data.y[data.train_mask]) loss.backward() optimizer.step() return loss # 测试函数 def test(): model.eval() out, _ model(data.x, data.edge_index) pred out.argmax(dim1) correct (pred data.y).sum() acc int(correct) / data.num_nodes return acc在这种半监督设置下模型需要从极少的标签信息中推断整个图的结构更能体现GCN的优势。6. GCN与MLP的对比实验为了展示GCN的真正价值我们可以将其与普通的MLP进行对比。MLP只能看到节点特征而无法利用图结构信息class MLP(torch.nn.Module): def __init__(self): super().__init__() self.lin1 Linear(dataset.num_features, 16) self.lin2 Linear(16, dataset.num_classes) def forward(self, x): x self.lin1(x).relu() x self.lin2(x) return x mlp_model MLP() mlp_optimizer torch.optim.Adam(mlp_model.parameters(), lr0.01) # 训练MLP for epoch in range(100): mlp_model.train() mlp_optimizer.zero_grad() out mlp_model(data.x) loss criterion(out[data.train_mask], data.y[data.train_mask]) loss.backward() mlp_optimizer.step() # 测试 mlp_model.eval() with torch.no_grad(): out mlp_model(data.x) pred out.argmax(dim1) correct (pred data.y).sum() acc int(correct) / data.num_nodes print(fEpoch: {epoch:03d}, Acc: {acc:.4f})实验结果表明GCN在半监督设置下的准确率通常能达到90%以上而MLP往往只有50-60%这清晰地展示了图结构信息的重要性。7. 进阶思考与应用扩展掌握了基本GCN后你可以进一步探索其他GNN架构GraphSAGE、GAT、GIN等各有特点更复杂的数据集如Cora、PubMed等学术引用网络工业级应用社交推荐、分子属性预测、知识图谱等例如在推荐系统中用户和商品可以构成二分图GNN能够同时利用用户-商品交互和用户-用户相似性# 伪代码示例推荐系统中的GNN应用 class RecommenderGNN(torch.nn.Module): def __init__(self): super().__init__() self.user_emb Embedding(num_users, emb_dim) self.item_emb Embedding(num_items, emb_dim) self.conv1 GCNConv(emb_dim, emb_dim) def forward(self, user, item, edge_index): x torch.cat([self.user_emb.weight, self.item_emb.weight]) h self.conv1(x, edge_index).relu() user_emb, item_emb h[:num_users], h[num_users:] return (user_emb[user] * item_emb[item]).sum(1) # 内积作为预测分数在实际项目中我发现GNN对特征工程的要求相对较低因为它能自动学习图结构中的有用信息。但对于边权重、节点重要性等先验知识适当融入模型往往能获得额外提升。