Press "Enter" to skip to content

使用Transformer进行图分类

在之前的博客中,我们探讨了关于图机器学习的一些理论方面。这篇博客将介绍如何使用Transformers库进行图分类(您也可以通过下载演示笔记本来跟随这个过程!)

目前,在Transformers中唯一可用的图转换模型是微软的Graphormer,所以我们将在这里使用它。我们期待看到其他人将会使用和整合哪些模型 🤗

要求

要按照本教程操作,您需要安装datasets和transformers(版本>=4.27.2),您可以使用pip install -U datasets transformers来安装。

数据

要使用图数据,您可以从自己的数据集开始,或者使用Hub上提供的数据集。我们将重点介绍如何使用已有的数据集,但是您也可以随意添加您自己的数据集!

加载

从Hub加载图数据集非常简单。让我们加载”ogbg-mohiv”数据集(Stanford的Open Graph Benchmark中的一个基准数据集),该数据集存储在OGB仓库中:

from datasets import load_dataset

# Hub上只有一个分割
dataset = load_dataset("OGB/ogbg-molhiv")

dataset = dataset.shuffle(seed=0)

这个数据集已经包含了三个分割,即train,validation和test,并且所有这些分割都包含了我们感兴趣的5个列(edge_index,edge_attr,y,num_nodes,node_feat),您可以通过print(dataset)来查看。

如果您有其他的图库,您可以使用它们来绘制图形和进一步检查数据集。例如,使用PyGeometric和matplotlib:

import networkx as nx
import matplotlib.pyplot as plt

# 我们想要绘制第一个train图形
graph = dataset["train"][0]

edges = graph["edge_index"]
num_edges = len(edges[0])
num_nodes = graph["num_nodes"]

# 转换为networkx格式
G = nx.Graph()
G.add_nodes_from(range(num_nodes))
G.add_edges_from([(edges[0][i], edges[1][i]) for i in range(num_edges)])

# 绘制图形
nx.draw(G)

格式

在Hub上,图数据集主要以图的列表形式存储(使用jsonl格式)。

一个单独的图是一个字典,以下是我们图分类数据集的期望格式:

  • edge_index包含边中节点的索引,存储为包含两个平行列表的列表。
    • 类型:2个整数列表的列表。
    • 示例:一个包含四个节点(0、1、2和3)并且连接关系为1->2、1->3和3->1的图将具有edge_index = [[1, 1, 3], [2, 3, 1]]。您可能会注意到这里没有出现节点0,因为它本身并不是一个边的一部分。这就是下一个属性的重要性所在。
  • num_nodes指示图中可用的节点总数(默认情况下,假设节点按顺序编号)。
    • 类型:整数
    • 示例:在上面的例子中,num_nodes = 4
  • y将每个图映射到我们想要从中预测的内容(可以是类别、属性值或多个二进制标签用于不同的任务)。
    • 类型:整数列表(用于多类分类)、浮点数列表(用于回归)或由1和0组成的列表(用于二元多任务分类)
    • 示例:我们可以预测图的大小(小 = 0,VoAGI = 1,大 = 2)。在这里,y = [0]
  • node_feat包含图中每个节点的可用特征(如果有的话),按节点索引排序。
    • 类型:整数列表的列表(可选)
    • 示例:我们上面的节点可以有不同的类型(如分子中的不同原子)。这可能会给出node_feat = [[1], [0], [1], [1]]
  • edge_attr包含图中每个边的可用属性(如果有的话),按照edge_index的顺序排列。
    • 类型:整数列表的列表(可选)
    • 示例:我们上面的边可以有不同的类型(如分子键)。这可能会给出edge_attr = [[0], [1], [1]]

预处理

图变换框架通常会对数据集进行特定的预处理,以生成增加的特征和属性,这些特征和属性有助于底层的学习任务(在我们的案例中是分类任务)。在这里,我们使用Graphormer的默认预处理,该预处理生成节点的入度/出度信息、节点之间的最短路径矩阵以及模型所需的其他属性。

from transformers.models.graphormer.collating_graphormer import preprocess_item, GraphormerDataCollator

dataset_processed = dataset.map(preprocess_item, batched=False)

也可以在DataCollator的参数中实时应用这种预处理(将on_the_fly_processing设置为True):并非所有的数据集都像ogbg-molhiv这样小,对于大型图,提前存储所有预处理数据可能会成本过高。

模型

加载

在这里,我们加载一个现有的预训练模型/检查点,并在我们的下游任务上进行微调,这是一个二分类任务(因此num_classes = 2)。我们也可以在回归任务(num_classes = 1)或多任务分类上进行微调。

from transformers import GraphormerForGraphClassification

model = GraphormerForGraphClassification.from_pretrained(
    "clefourrier/pcqm4mv2_graphormer_base",
    num_classes=2, # 下游任务的类别数
    ignore_mismatched_sizes=True,
)

让我们更详细地看看这个过程。

在我们的模型上调用from_pretrained方法会为我们下载并缓存权重。由于预测的类别数是依赖于数据集的,我们同时传递了新的num_classes以及ignore_mismatched_sizesmodel_checkpoint中。这确保了一个自定义的分类头会被创建,专门用于我们的任务,因此可能与原始解码头不同。

也可以创建一个新的随机初始化模型,从头开始进行训练,可以选择遵循给定检查点的已知参数或手动选择参数。

训练或微调

为了简单地训练我们的模型,我们将使用一个Trainer。为了实例化它,我们需要定义训练配置和评估指标。其中最重要的是TrainingArguments,它是一个包含所有属性的类,用于自定义训练。它需要一个文件夹名称,用于保存模型的检查点。

from transformers import TrainingArguments, Trainer

training_args = TrainingArguments(
    "graph-classification",
    logging_dir="graph-classification",
    per_device_train_batch_size=64,
    per_device_eval_batch_size=64,
    auto_find_batch_size=True, # 批大小可以自动更改以防止内存溢出
    gradient_accumulation_steps=10,
    dataloader_num_workers=4, #1, 
    num_train_epochs=20,
    evaluation_strategy="epoch",
    logging_strategy="epoch",
    push_to_hub=False,
)

对于图数据集,调整批大小和梯度累积步数以在足够样本上进行训练同时避免内存溢出是特别重要的。

最后一个参数push_to_hub允许Trainer在训练过程中定期将模型推送到Hub。

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=dataset_processed["train"],
    eval_dataset=dataset_processed["validation"],
    data_collator=GraphormerDataCollator(),
)

对于图分类的Trainer,重要的是传递给定图数据集的特定数据整合器,它将将单个图转换为训练用的批次。

train_results = trainer.train()
trainer.push_to_hub()

当模型训练完成后,可以使用push_to_hub将模型与所有相关的训练结果保存到Hub。

由于这个模型非常庞大,在CPU(IntelCore i7)上训练/微调20个时期大约需要一天的时间。为了加快速度,可以使用强大的GPU和并行化,在Colab笔记本或直接在所选集群上运行代码。

结束语

现在您已经了解如何使用transformers来训练图分类模型,我们希望您尝试在Hub上分享您喜欢的图转换器检查点、模型和数据集,供社区的其他人使用!

Leave a Reply

Your email address will not be published. Required fields are marked *