Press "Enter" to skip to content

使用可视化实现的图注意力网络(GAT)的详细解析

顶层可视化的步行材料

理解图神经网络(GNNs)在变压器继续处理Open Graph Benchmark等图问题时变得越来越重要。即使自然语言是图所需的一切,但GNNs仍然是未来方法的丰富灵感来源。

在本文中,我将介绍一个基本的GNN层的实现,然后展示一个基本的图注意力层的修改,如ICLR论文中所描述的图注意网络。

文档图的5 x 5邻接矩阵

首先,想象一下我们有一个以有向无环图(DAG)表示的文本文档图。文档0与文档1、2和3之间存在边,因此在这些列中的0行中有1。

为了进行可视化的实现,我将使用Graphbook,这是一个视觉AI建模工具。有关如何理解Graphbook中的可视化表示的更多信息,请参阅我的其他文章。

使用可视化实现的图注意力网络(GAT)的详细解析 四海 第3张

我们还有每个文档的一些节点特征。我将每个文档作为一个单独的[5] 1D数组输入BERT,以生成一个[5, 768]形状的池化器输出中的嵌入。

使用可视化实现的图注意力网络(GAT)的详细解析 四海 第4张

出于教学目的,我将仅使用BERT输出的前8个维度作为节点特征,以便我们可以更容易地跟踪数据形状。现在,我们有了邻接矩阵和节点特征。

使用可视化实现的图注意力网络(GAT)的详细解析 四海 第5张

GNN层

GNN层的一般公式是,对于每个节点,我们将每个节点的所有邻居的特征乘以一个权重矩阵相加,然后通过激活函数。我创建了一个空白块,标题是这个公式,并将其传递给邻接矩阵和节点特征,然后我将在块内实现这个公式。

使用可视化实现的图注意力网络(GAT)的详细解析 四海 第6张

当我们实现这个公式时,我们不想运行循环。如果我们可以完全向量化,那么使用GPU进行任何训练和推理将会更快,因为乘法可以成为单个计算步骤。因此,我们将节点特征广播到3D形状,即我们从[5, 8]的节点特征形状变为[5, 5, 8]的形状,其中第0维中的每个单元格是节点特征的重复。现在我们可以将最后一个维度视为“邻居”特征。每个节点有5个可能的邻居。

使用可视化实现的图注意力网络(GAT)的详细解析 四海 第7张

我们不能直接将节点特征从[5, 8]广播到[5, 5, 8]的形状。相反,我们必须先广播到[25, 8],因为在广播时,形状中的每个维度都必须大于或等于原始维度。所以这就是为什么我们得到了形状的5和8部分(get_sub_arrays),然后将第一个乘以以获取25,然后将它们全部连接在一起。最后,我们将得到的[25, 8]重新形状为[5, 5, 8],我们确实可以在Graphbook中验证最后2个维度中每组节点特征是相同的。

接下来,我们还要将邻接矩阵广播到相同的形状。这意味着对于邻接矩阵中第 i 行和第 j 列的每个 1,在维度 [i, j] 上都有一个 num_feats 的 1.0 行。因此,在这个邻接矩阵中,第 0 行在第 1、2、3 列都有一个 1,所以在第 0 个单元格(即 [0, 1:3, :])中的第 1、2、3 行有 num_feats 个 1.0。

使用可视化实现的图注意力网络(GAT)的详细解析 四海 第8张

这里的实现非常简单,只需将邻接矩阵解析为小数,并从 [5, 5] 形状广播到 [5, 5, 8]。现在,我们可以逐元素地将这个邻接掩码与我们平铺的节点邻居特征相乘。

使用可视化实现的图注意力网络(GAT)的详细解析 四海 第9张

我们还想在邻接矩阵中包含一个自环,这样当我们对邻居特征求和时,也会包括该节点自身的节点特征。

使用可视化实现的图注意力网络(GAT)的详细解析 四海 第10张

经过逐元素相乘(并包括自环)后,我们得到了每个节点的邻居特征以及那些没有通过边连接(不是邻居)的节点的零特征。对于第 0 个节点,这包括节点 0 到 3 的特征。对于第 3 个节点,这包括第 3 和第 4 个节点。

使用可视化实现的图注意力网络(GAT)的详细解析 四海 第11张

接下来,我们将重塑为 [25, 8],使每个邻居特征成为自己的行,并将其传递给一个具有所需隐藏大小的参数化线性层。这里我选择了 32,并保存为全局常量,以便重复使用。线性层的输出将为 [25, hidden_size]。只需重塑该输出,创建形状为 [5, 5, hidden_size],现在我们终于准备好进行公式的求和部分了!

使用可视化实现的图注意力网络(GAT)的详细解析 四海 第12张

我们在中间维度(维度索引 1)上求和,这样我们就可以对每个节点的邻居特征进行求和。结果是一个 [5, hidden_size] 的节点嵌入集,经过了 1 层。只需将这些层连接在一起,就可以得到一个 GNN 网络,并参考 https://www.youtube.com/@Graphbook 上的指南进行训练。

图注意力层

使用可视化实现的图注意力网络(GAT)的详细解析 四海 第13张

从论文中可以看出,图注意力层的秘诀在于注意系数,其在上述公式中给出。本质上,我们将处于一条边上的节点嵌入进行拼接,并通过另一个线性层运行,然后应用 softmax 函数。

使用可视化实现的图注意力网络(GAT)的详细解析 四海 第14张

然后,这些注意系数用于计算与原始节点特征相对应的特征的线性组合。

使用可视化实现的图注意力网络(GAT)的详细解析 四海 第15张

我们需要做的是将每个节点的特征复制给每个邻居,并将其与节点的邻居特征进行拼接。

使用可视化实现的图注意力网络(GAT)的详细解析 四海 第16张

秘诀是为每个邻居获取节点的特征平铺。为此,在遮罩之前,我们交换平铺节点特征的0和1维度。

使用可视化实现的图注意力网络(GAT)的详细解析 四海 第17张

结果仍然是一个[5, 5, 8]形状的数组,但现在[i, :, :]中的每一行都是相同的,并且对应于节点i的特征。现在,我们可以使用逐元素乘法来创建只在它们包含邻居时重复的节点特征。最后,我们将其与为GNN创建的邻居特征进行连接,并生成连接后的特征。

使用可视化实现的图注意力网络(GAT)的详细解析 四海 第18张

我们快要完成了!现在我们有了连接后的特征,我们可以将其通过线性层。我们需要重新调整形状为[5, 5, hidden_size],以便我们可以在中间维度上进行softmax并产生我们的注意力系数。

使用可视化实现的图注意力网络(GAT)的详细解析 四海 第19张

现在,我们有了形状为[5, 5, hidden_size]的注意力系数,这在我们的n节点图中实际上是每个图边的一个嵌入。论文中提到这些应该被转置(交换维度),所以我在ReLU之前进行了转置,并现在在最后一个维度上进行softmax,以使其在隐藏尺寸维度上的每个维度索引上归一化。我们将用这些系数乘以原始节点嵌入。回想一下,原始节点嵌入的形状是[5, 5, 8],其中8是随意从BERT对我们的文本文档进行编码时切片的前8个特征。

乘以[5, hidden_size, 5]形状和[5, 5, 8]形状得到[5, hidden_size, 8]形状。然后,我们在hidden_size维度上求和,最终输出匹配我们的输入形状的[5, 8]。我们现在也可以将其通过非线性函数(如另一个ReLU)并多次链接该层。

结论

到目前为止,我们已经介绍了单个GNN层和GAT层的可视化实现。您可以在此GitHub存储库中找到该项目。在论文中,他们还解释了如何将该方法扩展为多头注意力。如果您还希望我介绍这部分内容,或者如果您还希望我使用Graphbook介绍其他内容,请在评论中让我知道。

Leave a Reply

Your email address will not be published. Required fields are marked *