使用PyTorch Geometric的逐步指南
图神经网络(Graph Neural Networks,GNNs)是深度学习领域中最吸引人且发展迅速的架构之一。作为一种用于处理图结构数据的深度学习模型,GNNs具有出色的灵活性和强大的学习能力。
在各种类型的GNNs中,图卷积网络(Graph Convolutional Networks,GCNs)已成为最普遍和广泛应用的模型。GCNs之所以创新,是因为它们能够利用节点的特征和局部性来进行预测,为处理图结构数据提供了一种有效的方式。
在本文中,我们将深入探讨GCN层的机制并解释其内部工作原理。此外,我们将使用PyTorch Geometric作为我们的工具,探索其在节点分类任务中的实际应用。
PyTorch Geometric是PyTorch的专门扩展,专为开发和实现GNNs而创建。它是一个先进而用户友好的库,提供了一套全面的工具,用于简化基于图的机器学习。为了开始我们的旅程,我们需要安装PyTorch Geometric。如果您使用Google Colab,PyTorch应该已经安装好了,所以我们只需要执行几个额外的命令。
所有代码都可以在Google Colab和GitHub上找到。
!pip install torch_geometric
import torchimport numpy as npimport networkx as nximport matplotlib.pyplot as plt
现在PyTorch Geometric已经安装完成,让我们来探索本教程中将使用的数据集。
🌐 一、图数据
图是表示对象之间关系的基本结构。在许多现实世界的场景中,都可以遇到图数据,例如社交网络、计算机网络、分子的化学结构、自然语言处理和图像识别等。
在本文中,我们将研究著名且广泛使用的Zachary’s karate club数据集。
Zachary’s karate club数据集体现了Wayne W. Zachary在1970年代观察到的一个空手道俱乐部内形成的关系。它是一种社交网络,其中每个节点代表一个俱乐部成员,节点之间的边表示在俱乐部环境之外发生的互动。
在这个特定的场景中,俱乐部成员被分为四个不同的群体。我们的任务是根据他们的互动模式,为每个成员分配正确的群体(节点分类)。
让我们使用PyG的内置函数导入数据集,并试图了解它使用的Datasets
对象。
from torch_geometric.datasets import KarateClub
# 从PyTorch Geometric导入数据集dataset = KarateClub()# 打印信息print(dataset)print('------------')print(f'图的数量:{len(dataset)}')print(f'特征的数量:{dataset.num_features}')print(f'类别的数量:{dataset.num_classes}')
KarateClub()------------图的数量:1特征的数量:34类别的数量:4
该数据集仅包含1个图,每个节点具有34维的特征向量,并属于四个类别之一(我们的四个群体)。实际上,Datasets
对象可以被看作是一组Data
(图)对象。
我们可以进一步检查我们唯一的图,以了解更多信息。
# 打印第一个元素print(f'图:{dataset[0]}')
图:Data(x=[34, 34], edge_index=[2, 156], y=[34], train_mask=[34])
Data
对象特别有趣。打印它可以提供对我们研究的图的概要:
x=[34, 34]
是节点特征矩阵,形状为(节点数量,特征数量)。在我们的案例中,这意味着我们有34个节点(我们的34个成员),每个节点与一个34维特征向量相关联。edge_index=[2, 156]
表示图的连接性(节点如何连接),形状为(2,有向边的数量)。y=[34]
是节点的真实标签。在这个问题中,每个节点分配给一个类(组),因此我们对每个节点有一个值。train_mask=[34]
是一个可选属性,用于指示哪些节点应该用于训练,它是一个由True
或False
语句组成的列表。
让我们打印每个张量以了解它们存储了什么。让我们从节点特征开始。
data = dataset[0]
print(f'x = {data.x.shape}')print(data.x)
x = torch.Size([34, 34])tensor([[1., 0., 0., ..., 0., 0., 0.], [0., 1., 0., ..., 0., 0., 0.], [0., 0., 1., ..., 0., 0., 0.], ..., [0., 0., 0., ..., 1., 0., 0.], [0., 0., 0., ..., 0., 1., 0.], [0., 0., 0., ..., 0., 0., 1.]])
在这里,节点特征矩阵x
是一个单位矩阵:它不包含任何有关节点的相关信息。它可以包含年龄、技能水平等信息,但在这个数据集中不是这样。这意味着我们必须仅通过查看节点的连接来对节点进行分类。
现在,让我们打印边索引。
print(f'edge_index = {data.edge_index.shape}')print(data.edge_index)
edge_index = torch.Size([2, 156])tensor([[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 4, 4, 4, 5, 5, 5, 5, 6, 6, 6, 6, 7, 7, 7, 7, 8, 8, 8, 8, 8, 9, 9, 10, 10, 10, 11, 12, 12, 13, 13, 13, 13, 13, 14, 14, 15, 15, 16, 16, 17, 17, 18, 18, 19, 19, 19, 20, 20, 21, 21, 22, 22, 23, 23, 23, 23, 23, 24, 24, 24, 25, 25, 25, 26, 26, 27, 27, 27, 27, 28, 28, 28, 29, 29, 29, 29, 30, 30, 30, 30, 31, 31, 31, 31, 31, 31, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 33, 33, 33, 33, 33, 33, 33, 33, 33, 33, 33, 33, 33, 33, 33, 33, 33], [ 1, 2, 3, 4, 5, 6, 7, 8, 10, 11, 12, 13, 17, 19, 21, 31, 0, 2, 3, 7, 13, 17, 19, 21, 30, 0, 1, 3, 7, 8, 9, 13, 27, 28, 32, 0, 1, 2, 7, 12, 13, 0, 6, 10, 0, 6, 10, 16, 0, 4, 5, 16, 0, 1, 2, 3, 0, 2, 30, 32, 33, 2, 33, 0, 4, 5, 0, 0, 3, 0, 1, 2, 3, 33, 32, 33, 32, 33, 5, 6, 0, 1, 32, 33, 0, 1, 33, 32, 33, 0, 1, 32, 33, 25, 27, 29, 32, 33, 25, 27, 31, 23, 24, 31, 29, 33, 2, 23, 24, 33, 2, 31, 33, 23, 26, 32, 33, 1, 8, 32, 33, 0, 24, 25, 28, 32, 33, 2, 8, 14, 15, 18, 20, 22, 23, 29, 30, 31, 33, 8, 9, 13, 14, 15, 18, 19, 20, 22, 23, 26, 27, 28, 29, 30, 31, 32]])
在图论和网络分析中,节点之间的连通性使用各种数据结构进行存储。其中,edge_index
是一种数据结构,图的连接关系存储在两个列表中(156个有向边,相当于78个双向边)。之所以使用这两个列表,是因为一个列表存储源节点,而第二个列表标识目标节点。
这种方法被称为坐标列表(COO)格式,本质上是一种高效存储稀疏矩阵的方式。稀疏矩阵是一种高效存储大部分元素为零的矩阵的数据结构。在COO格式中,只存储非零元素,节省内存和计算资源。
相反,表示图连通性的更直观和简单的方式是使用邻接矩阵 A。这是一个方阵,其中每个元素 Aᵢⱼ 指定了图中从节点 i 到节点 j 的边的存在或不存在。换句话说,非零元素 Aᵢⱼ 表示从节点 i 到节点 j 的连接,而零表示没有直接连接。
然而,邻接矩阵对于稀疏矩阵或边少的图来说并不是很节省空间。然而,出于清晰和易于解释的考虑,邻接矩阵仍然是表示图连通性的常用选择。
可以使用实用函数to_dense_adj()
从edge_index
推断出邻接矩阵。
from torch_geometric.utils import to_dense_adj
A = to_dense_adj(data.edge_index)[0].numpy().astype(int)print(f'A = {A.shape}')print(A)
A = (34, 34)[[0 1 1 ... 1 0 0] [1 0 1 ... 0 0 0] [1 1 0 ... 0 1 0] ... [1 0 0 ... 0 1 1] [0 0 1 ... 1 0 1] [0 0 0 ... 1 1 0]]
对于图数据来说,节点之间的密集连接相对较少。正如您所见,我们的邻接矩阵 A 是稀疏的(填充了零)。
在许多现实世界的图中,大多数节点只与少数其他节点相连,导致邻接矩阵中有大量的零。存储这么多零并不高效,这就是为什么 PyG 采用 COO 格式的原因。
相反,真实标签易于理解。
print(f'y = {data.y.shape}')print(data.y)
y = torch.Size([34])tensor([1, 1, 1, 1, 3, 3, 3, 1, 0, 1, 3, 1, 1, 1, 0, 0, 3, 1, 0, 1, 0, 1, 0, 0, 2, 2, 0, 0, 2, 0, 0, 2, 0, 0])
我们在y
中存储的节点真实标签简单地编码了每个节点的组号(0、1、2、3),这就是为什么有34个值的原因。
最后,让我们打印训练掩码。
print(f'train_mask = {data.train_mask.shape}')print(data.train_mask)
train_mask = torch.Size([34])tensor([ True, False, False, False, True, False, False, False, True, False, False, False, False, False, False, False, False, False, False, False, False, False, False, False, True, False, False, False, False, False, False, False, False])
训练掩码显示了应该用于训练的节点,其中True
表示。这些节点表示训练集,而其他节点可以被视为测试集。这种划分有助于模型评估,通过提供未见过的数据进行测试。
但我们还没有完成!Data
对象还有更多功能。它提供了各种实用函数,可以用来探索图的几个属性。例如:
is_directed()
告诉你图是否为有向图。有向图表示邻接矩阵不对称,即边的方向在节点之间的连接中起作用。isolated_nodes()
检查是否有节点不连接到图的其他部分。这些节点可能会在分类等任务中造成挑战,因为它们缺乏连接。has_self_loops()
指示是否至少有一个节点与自身相连。这与循环的概念不同:循环意味着一条从同一节点开始并结束的路径,中间经过其他节点。
在Zachary的空手道俱乐部数据集的上下文中,所有这些属性返回False
。这意味着该图不是有向图,没有任何孤立的节点,并且没有节点与自身相连。
print(f'边是有向的:{data.is_directed()}')print(f'图中有孤立节点:{data.has_isolated_nodes()}')print(f'图中有循环:{data.has_self_loops()}')
最后,我们可以使用to_networkx
将PyTorch Geometric中的图转换为流行的图库NetworkX。这对于使用networkx
和matplotlib
可视化小型图形非常有用。
让我们用不同的颜色为每个组绘制数据集。
from torch_geometric.utils import to_networkx
G = to_networkx(data, to_undirected=True)plt.figure(figsize=(12,12))plt.axis('off')nx.draw_networkx(G, pos=nx.spring_layout(G, seed=0), with_labels=True, node_size=800, node_color=data.y, cmap="hsv", vmin=-2, vmax=3, width=0.8, edge_color="grey", font_size=14 )plt.show()
这个Zachary的空手道俱乐部的图显示了我们的34个节点,78条(双向)边和4个具有4种不同颜色的标签。现在我们已经了解了使用PyTorch Geometric加载和处理数据集的基本要点,我们可以介绍图卷积网络架构。
✉️ II. 图卷积网络
本节旨在从基础开始介绍和构建图卷积层。
在传统神经网络中,线性层对输入数据应用线性变换。这个变换通过使用权重矩阵𝐖将输入特征x转换为隐藏向量h。暂时忽略偏差,可以表示为:
在图数据中,通过节点之间的连接添加了额外的复杂性。这些连接很重要,因为通常在网络中,假设相似的节点更有可能相互链接,而不是不相似的节点,这种现象被称为网络同质性。
我们可以通过将节点的特征与其邻居的特征合并来丰富节点表示。这个操作称为卷积或邻域聚合。让我们将节点i及其自身包含在内的邻域表示为Ñ。
与卷积神经网络(CNNs)中的过滤器不同,我们的权重矩阵𝐖是唯一的,并且在每个节点之间共享。但还有一个问题:节点没有像像素那样的固定数量的邻居。
我们如何处理一个节点只有一个邻居,而另一个节点有500个邻居的情况?如果我们简单地对特征向量求和,那么具有500个邻居的节点的嵌入h将会更大。为了确保所有节点的数值范围相似并且可比性,我们可以根据节点的度对结果进行归一化,其中度指的是节点的连接数。
我们就快完成了!Kipf等人在2016年提出的图卷积层还有一个最终的改进。
作者们观察到,具有大量邻居的节点的特征比孤立节点的特征更容易传播。为了抵消这种影响,他们建议为具有较少邻居的节点的特征分配更大的权重,从而在所有节点之间平衡影响。这个操作可以写成:
请注意,当i和j具有相同数量的邻居时,等价于我们自己的层。现在,让我们看看如何使用PyTorch Geometric在Python中实现它。
🧠 III. 实现GCN
PyTorch Geometric提供了GCNConv
函数,直接实现了图卷积层。
在这个例子中,我们将创建一个基本的图卷积网络,其中包含一个GCN层、一个ReLU激活函数和一个线性输出层。这个输出层将生成与我们的四个类别对应的四个值,其中最高值确定每个节点的类别。
在下面的代码块中,我们使用一个3维隐藏层定义了GCN层。
from torch.nn import Linearfrom torch_geometric.nn import GCNConv
class GCN(torch.nn.Module): def __init__(self): super().__init__() self.gcn = GCNConv(dataset.num_features, 3) self.out = Linear(3, dataset.num_classes) def forward(self, x, edge_index): h = self.gcn(x, edge_index).relu() z = self.out(h) return h, zmodel = GCN()print(model)
GCN( (gcn): GCNConv(34, 3) (out): Linear(in_features=3, out_features=4, bias=True))
如果我们添加了第二个GCN层,我们的模型不仅会聚合每个节点的邻居的特征向量,还会聚合这些邻居的邻居的特征向量。
我们可以堆叠多个图层来聚合越来越远的值,但是有一个问题:如果我们添加太多的层,聚合变得非常强烈,所有的嵌入最终看起来都一样。这种现象称为过度平滑,当层数过多时可能会成为一个真正的问题。
现在,我们已经定义了我们的GNN,让我们用PyTorch写一个简单的训练循环。由于这是一个多类分类任务,我选择了常规的交叉熵损失,并使用Adam作为优化器。在本文中,我们不会实现训练/测试分割,以保持简单并专注于GNN的学习方式。
训练循环是标准的:我们尝试预测正确的标签,将GCN的结果与data.y
中存储的值进行比较。误差通过交叉熵损失计算,并使用Adam进行反向传播,以微调我们的GNN的权重和偏差。最后,我们每10个epoch打印一次指标。
criterion = torch.nn.CrossEntropyLoss()optimizer = torch.optim.Adam(model.parameters(), lr=0.02)
# 计算准确性def accuracy(pred_y, y): return (pred_y == y).sum() / len(y)# 用于动画的数据embeddings = []losses = []accuracies = []outputs = []# 训练循环for epoch in range(201): # 清除梯度 optimizer.zero_grad() # 前向传播 h, z = model(data.x, data.edge_index) # 计算损失函数 loss = criterion(z, data.y) # 计算准确性 acc = accuracy(z.argmax(dim=1), data.y) # 计算梯度 loss.backward() # 调整参数 optimizer.step() # 存储数据用于动画 embeddings.append(h) losses.append(loss) accuracies.append(acc) outputs.append(z.argmax(dim=1)) # 每10个epoch打印指标 if epoch % 10 == 0: print(f'Epoch {epoch:>3} | Loss: {loss:.2f} | Acc: {acc*100:.2f}%')
Epoch 0 | Loss: 1.40 | Acc: 41.18%Epoch 10 | Loss: 1.21 | Acc: 47.06%Epoch 20 | Loss: 1.02 | Acc: 67.65%Epoch 30 | Loss: 0.80 | Acc: 73.53%Epoch 40 | Loss: 0.59 | Acc: 73.53%Epoch 50 | Loss: 0.39 | Acc: 94.12%Epoch 60 | Loss: 0.23 | Acc: 97.06%Epoch 70 | Loss: 0.13 | Acc: 100.00%Epoch 80 | Loss: 0.07 | Acc: 100.00%Epoch 90 | Loss: 0.05 | Acc: 100.00%Epoch 100 | Loss: 0.03 | Acc: 100.00%Epoch 110 | Loss: 0.02 | Acc: 100.00%Epoch 120 | Loss: 0.02 | Acc: 100.00%Epoch 130 | Loss: 0.02 | Acc: 100.00%Epoch 140 | Loss: 0.01 | Acc: 100.00%Epoch 150 | Loss: 0.01 | Acc: 100.00%Epoch 160 | Loss: 0.01 | Acc: 100.00%Epoch 170 | Loss: 0.01 | Acc: 100.00%Epoch 180 | Loss: 0.01 | Acc: 100.00%Epoch 190 | Loss: 0.01 | Acc: 100.00%Epoch 200 | Loss: 0.01 | Acc: 100.00%
太好了!毫不意外地,我们在训练集(完整数据集)上达到了100%的准确率。这意味着我们的模型学会了正确地将每个空手道俱乐部成员分配到其正确的组中。
我们可以通过对图形进行动画处理并查看GNN在训练过程中的预测演变来生成一个整洁的可视化效果。
%%capturefrom IPython.display import HTMLfrom matplotlib import animationplt.rcParams["animation.bitrate"] = 3000
def animate(i): G = to_networkx(data, to_undirected=True) nx.draw_networkx(G, pos=nx.spring_layout(G, seed=0), with_labels=True, node_size=800, node_color=outputs[i], cmap="hsv", vmin=-2, vmax=3, width=0.8, edge_color="grey", font_size=14 ) plt.title(f'Epoch {i} | Loss: {losses[i]:.2f} | Acc: {accuracies[i]*100:.2f}%', fontsize=18, pad=20)fig = plt.figure(figsize=(12, 12))plt.axis('off')anim = animation.FuncAnimation(fig, animate, \ np.arange(0, 200, 10), interval=500, repeat=True)html = HTML(anim.to_html5_video())display(html)
初始预测是随机的,但是GCN在一段时间后完美地标记了每个节点。事实上,最终的图与我们在第一部分末尾绘制的图相同。但是,GCN真正学到了什么呢?
通过聚合相邻节点的特征,GNN学习了网络中每个节点的向量表示(或嵌入)。在我们的模型中,最后一层只是学习如何使用这些表示来产生最佳分类。然而,嵌入是GNN的真正产品。
让我们打印出我们的模型学到的嵌入。
# 打印嵌入print(f'Final embeddings = {h.shape}')print(h)
Final embeddings = torch.Size([34, 3])tensor([[1.9099e+00, 2.3584e+00, 7.4027e-01], [2.6203e+00, 2.7997e+00, 0.0000e+00], [2.2567e+00, 2.2962e+00, 6.4663e-01], [2.0802e+00, 2.8785e+00, 0.0000e+00], [0.0000e+00, 0.0000e+00, 2.9694e+00], [0.0000e+00, 0.0000e+00, 3.3817e+00], [0.0000e+00, 1.5008e-04, 3.4246e+00], [1.7593e+00, 2.4292e+00, 2.4551e-01], [1.9757e+00, 6.1032e-01, 1.8986e+00], [1.7770e+00, 1.9950e+00, 6.7018e-01], [0.0000e+00, 1.1683e-04, 2.9738e+00], [1.8988e+00, 2.0512e+00, 2.6225e-01], [1.7081e+00, 2.3618e+00, 1.9609e-01], [1.8303e+00, 2.1591e+00, 3.5906e-01], [2.0755e+00, 2.7468e-01, 1.9804e+00], [1.9676e+00, 3.7185e-01, 2.0011e+00], [0.0000e+00, 0.0000e+00, 3.4787e+00], [1.6945e+00, 2.0350e+00, 1.9789e-01], [1.9808e+00, 3.2633e-01, 2.1349e+00], [1.7846e+00, 1.9585e+00, 4.8021e-01], [2.0420e+00, 2.7512e-01, 1.9810e+00], [1.7665e+00, 2.1357e+00, 4.0325e-01], [1.9870e+00, 3.3886e-01, 2.0421e+00], [2.0614e+00, 5.1042e-01, 2.4872e+00],... [2.1778e+00, 4.4730e-01, 2.0077e+00], [3.8906e-02, 2.3443e+00, 1.9195e+00], [3.0748e+00, 0.0000e+00, 3.0789e+00], [3.4316e+00, 1.9716e-01, 2.5231e+00]], grad_fn=<ReluBackward0>)
如您所见,嵌入向量的维度不需要与特征向量相同。在这里,我选择将维度从34(dataset.num_features
)降低到三维,以便在3D中获得良好的可视化效果。
在任何训练发生之前,让我们先绘制这些嵌入向量,即在epoch 0时。
# 获取epoch=0的第一个嵌入向量
embed = h.detach().cpu().numpy()
fig = plt.figure(figsize=(12, 12))
ax = fig.add_subplot(projection='3d')
ax.patch.set_alpha(0)
plt.tick_params(left=False, bottom=False, labelleft=False, labelbottom=False)
ax.scatter(embed[:, 0], embed[:, 1], embed[:, 2], s=200, c=data.y, cmap="hsv", vmin=-2, vmax=3)
plt.show()
我们可以看到扎卡里的空手道俱乐部中的每个节点都有其真实标签(而不是模型的预测)。目前,它们分散在各处,因为GNN尚未训练。但是,如果我们在训练循环的每个步骤中绘制这些嵌入向量,我们将能够可视化GNN真正学习到的内容。
让我们看看随着GCN在分类节点方面变得越来越好,它们是如何随时间演变的。
%%capture
def animate(i):
embed = embeddings[i].detach().cpu().numpy()
ax.clear()
ax.scatter(embed[:, 0], embed[:, 1], embed[:, 2], s=200, c=data.y, cmap="hsv", vmin=-2, vmax=3)
plt.title(f'Epoch {i} | Loss: {losses[i]:.2f} | Acc: {accuracies[i]*100:.2f}%', fontsize=18, pad=40)
fig = plt.figure(figsize=(12, 12))
plt.axis('off')
ax = fig.add_subplot(projection='3d')
plt.tick_params(left=False, bottom=False, labelleft=False, labelbottom=False)
anim = animation.FuncAnimation(fig, animate, np.arange(0, 200, 10), interval=800, repeat=True)
html = HTML(anim.to_html5_video())
display(html)
我们的图卷积网络(GCN)已经有效地学习到了将相似节点分组成不同聚类的嵌入向量。这使得最后的线性层能够轻松将它们分成不同的类别。
嵌入向量不仅适用于GNN,它们在深度学习中无处不在。它们也不一定是三维的:实际上,它们很少是三维的。例如,像BERT这样的语言模型会产生768或1024维的嵌入向量。
额外的维度存储了关于节点、文本、图像等的更多信息,但也会创建更难以训练的更大模型。这就是为什么尽可能保持低维嵌入向量的优势。
结论
图卷积网络是一种非常多样化的架构,可以应用于许多上下文环境。在本文中,我们熟悉了PyTorch Geometric库和诸如Datasets
和Data
之类的对象。然后,我们成功地从头开始重建了一个图卷积层。接下来,我们通过实现一个GCN来将理论付诸实践,这使我们了解了实际方面以及各个组件之间的相互作用。最后,我们可视化了训练过程,并清楚地了解了这样一个网络所涉及的内容。
扎卡里的空手道俱乐部是一个简单的数据集,但足以理解图数据和GNN中最重要的概念。尽管本文只讨论了节点分类,但GNN还可以完成其他任务:如链接预测(例如,推荐朋友)、图分类(例如,标记分子)、图生成(例如,创建新分子)等。
除了GCN,研究人员还提出了许多GNN层和架构。在下一篇文章中,我们将介绍图注意力网络(GAT)架构,该架构使用注意力机制动态计算GCN的归一化因子和每个连接的重要性。
如果你想更多了解图神经网络,可以通过我的书《Hands-On Graph Neural Networks》深入探索GNN的世界。
下一篇文章
第二章:图注意力网络:自注意力解释
使用PyTorch Geometric的自注意力图神经网络指南
towardsdatascience.com
了解更多机器学习知识并支持我的工作,只需一键点击,成为VoAGI会员:
使用我的推荐链接加入VoAGI — Maxime Labonne
作为VoAGI会员,你的会费的一部分将用于支持你阅读的作者,并且你可以完全访问每个故事…
VoAGI.com
如果你已经是会员,你可以在VoAGI上关注我。