Press "Enter" to skip to content

我如何在微调过程中创建了嵌入的动画

使用Cleanlab、PCA和Procrustes来可视化在CIFAR-10上对ViT进行微调

在机器学习领域,Vision Transformers(ViT)是一种用于图像分类的模型类型。与传统的卷积神经网络不同,ViT使用Transformer架构来处理图像,该架构最初是为自然语言处理任务设计的。对这些模型进行微调以实现最佳性能可能是一个复杂的过程。

在之前的一篇文章中,我使用动画来演示微调过程中嵌入的变化。这是通过对嵌入进行主成分分析(PCA)来实现的。这些嵌入是从微调过程中不同阶段的模型及其对应的检查点生成的。

Projection of embeddings with PCA during fine-tuning of a Vision Transformer (ViT) model [1] on CIFAR10 [3]; Source: created by the author — Published before in Changes of Embeddings during Fine-Tuning

这个动画获得了超过200,000次的印象。它受到了很好的反响,许多读者表达了对其创建方式的兴趣。本文旨在支持那些读者和其他对创建类似可视化的人感兴趣的人。

在本文中,我旨在提供一个全面的指南,介绍如何创建这样的动画,详细介绍涉及的步骤:微调、嵌入创建、异常值检测、PCA、Procrustes、回顾和动画创建。

动画的完整代码也可以在GitHub上的附带笔记本中找到。

准备工作:微调

第一步是微调google/vit-base-patch16–224-in21k Vision Transformer(ViT)模型[1],该模型是预训练的。我们使用CIFAR-10数据集[2]进行微调,该数据集包含60,000张图像,分为十个不同的类别:飞机、汽车、鸟类、猫、鹿、狗、青蛙、马、船和卡车。

您可以按照Hugging Face教程中关于使用transformers进行图像分类的步骤来执行CIFAR-10的微调过程。此外,我们还使用TrainerCallback将训练过程中的损失值存储到CSV文件中,以供后续在动画中使用。

from transformers import TrainerCallbackclass PrinterCallback(TrainerCallback):    def on_log(self, args, state, control, logs=None, **kwargs):        _ = logs.pop("total_flos", None)        if state.is_local_process_zero:            if len(logs) == 3:  # skip last row                with open("log.csv", "a") as f:                    f.write(",".join(map(str, logs.values())) + "\n")

重要的是通过将save_strategy="step"save_step的值设置为较低的值来增加检查点的保存间隔,以确保有足够的检查点用于动画。动画中的每一帧对应一个检查点。在训练过程中,为每个检查点和CSV文件创建一个文件夹,以备后续使用。

嵌入的创建

我们使用Transformers库中的AutoFeatureExtractorAutoModel从CIFAR-10数据集的测试集中使用不同的模型检查点生成嵌入。

每个嵌入是一个768维的向量,表示一个模型检查点下的10,000个测试图像中的一个。这些嵌入可以存储在与检查点相同的文件夹中,以保持良好的概览。

提取异常值

我们可以使用Cleanlab库提供的OutOfDistribution类根据每个检查点的嵌入来识别异常值。然后,生成的分数可以用于识别动画中的前10个异常值。

from cleanlab.outlier import OutOfDistributiondef get_ood(sorted_checkpoint_folder, df):  ...  ood = OutOfDistribution()  ood_train_feature_scores = ood.fit_score(features=embedding_np)  df["scores"] = ood_train_feature_scores

应用PCA和Procrustes分析

使用scikit-learn包的主成分分析(PCA),我们通过将768维向量降低到2维来在二维空间中可视化嵌入。当对每个时间步重新计算PCA时,由于轴翻转或旋转可能会导致动画中的大跳跃。为了解决这个问题,我们使用SciPy包中的附加Procrustes分析[3]对每一帧进行几何变换,将其几何变换到最后一帧,这只涉及平移、旋转和均匀缩放。这样可以实现动画中更平滑的过渡。

from sklearn.decomposition import PCAfrom scipy.spatial import procrustesdef make_pca(sorted_checkpoint_folder, pca_np):  ...  embedding_np_flat = embedding_np.reshape(-1, 768)  pca = PCA(n_components=2)  pca_np_new = pca.fit_transform(embedding_np_flat)  _, pca_np_new, disparity = procrustes(pca_np, pca_np_new)

在Spotlight中进行评估

在最终完成整个动画之前,我们在Spotlight中进行了评估。在这个过程中,我们使用第一个和最后一个检查点进行嵌入生成、PCA和异常值检测。我们在Spotlight中加载生成的DataFrame:

CIFAR-10的嵌入:通过短时间微调的第一个和最后一个检查点的PCA和8个最差异常值的可视化 - 使用github.com/renumics/spotlight进行可视化,来源:作者创建

Spotlight在左上角提供了一个综合表格,展示了数据集中的所有字段。右上角显示了两个PCA表示:一个是使用第一个检查点生成的嵌入,另一个是使用最后一个检查点生成的嵌入。最后,在底部部分展示了选定的图像。

免责声明:本文的作者也是Spotlight的开发人员之一。

创建动画

对于每个检查点,我们创建一个图像,并将其与相应的检查点一起存储。

这是通过使用make_pca(...)get_ood(...)函数实现的,这两个函数分别生成表示嵌入的2D点和提取前8个异常值。2D点以对应类别的颜色绘制。异常值根据其分数排序,并在高分榜中显示相应的图像。训练损失从CSV文件中加载,并以折线图方式绘制。

最后,所有图像可以使用imageio或类似的库编译成GIF。

我如何在微调过程中创建了嵌入的动画 四海 第3张

我如何在微调过程中创建了嵌入的动画 四海 第4张

微调过程的前三个写入检查点生成的三个帧显示出轻微聚类。预计在后续步骤中会出现更明显的聚类。来源:作者创建

结论

本文详细介绍了如何创建一种动画,以可视化Vision Transformer(ViT)模型的微调过程。我们已经逐步介绍了生成和分析嵌入、可视化结果以及创建将这些元素结合起来的动画的步骤。

创建这样的动画不仅有助于理解微调ViT模型的复杂过程,而且还是向他人传达这些概念的强大工具。

动画的完整代码可在 GitHub 上的附带笔记本中找到。

我是一名专业人士,擅长创建用于交互式探索非结构化数据的高级软件解决方案。我写作关于非结构化数据,并使用强大的可视化工具进行分析和做出明智的决策。

参考资料

[1] Alexey Dosovitskiy, Lucas Beyer, Alexander Kolesnikov, Dirk Weissenborn, Xiaohua Zhai, Thomas Unterthiner, Mostafa Dehghani, Matthias Minderer, Georg Heigold, Sylvain Gelly, Jakob Uszkoreit, Neil Houlsby, An Image is Worth 16×16 Words: Transformers for Image Recognition at Scale (2020), arXiv

[2] Alex Krizhevsky, Learning Multiple Layers of Features from Tiny Images (2009), University Toronto

[3] Gower, John C. Generalized procrustes analysis (1975), Psychometrika

Leave a Reply

Your email address will not be published. Required fields are marked *