使用Cleanlab、PCA和Procrustes来可视化在CIFAR-10上对ViT进行微调
在机器学习领域,Vision Transformers(ViT)是一种用于图像分类的模型类型。与传统的卷积神经网络不同,ViT使用Transformer架构来处理图像,该架构最初是为自然语言处理任务设计的。对这些模型进行微调以实现最佳性能可能是一个复杂的过程。
在之前的一篇文章中,我使用动画来演示微调过程中嵌入的变化。这是通过对嵌入进行主成分分析(PCA)来实现的。这些嵌入是从微调过程中不同阶段的模型及其对应的检查点生成的。
![我如何在微调过程中创建了嵌入的动画 四海 第1张-四海吧 Projection of embeddings with PCA during fine-tuning of a Vision Transformer (ViT) model [1] on CIFAR10 [3]; Source: created by the author — Published before in Changes of Embeddings during Fine-Tuning](https://miro.medium.com/v2/resize:fit:640/1*jYdWl_8UM6ecV1ux8_qr1Q.gif)
这个动画获得了超过200,000次的印象。它受到了很好的反响,许多读者表达了对其创建方式的兴趣。本文旨在支持那些读者和其他对创建类似可视化的人感兴趣的人。
在本文中,我旨在提供一个全面的指南,介绍如何创建这样的动画,详细介绍涉及的步骤:微调、嵌入创建、异常值检测、PCA、Procrustes、回顾和动画创建。
动画的完整代码也可以在GitHub上的附带笔记本中找到。
准备工作:微调
第一步是微调google/vit-base-patch16–224-in21k Vision Transformer(ViT)模型[1],该模型是预训练的。我们使用CIFAR-10数据集[2]进行微调,该数据集包含60,000张图像,分为十个不同的类别:飞机、汽车、鸟类、猫、鹿、狗、青蛙、马、船和卡车。
您可以按照Hugging Face教程中关于使用transformers进行图像分类的步骤来执行CIFAR-10的微调过程。此外,我们还使用TrainerCallback
将训练过程中的损失值存储到CSV文件中,以供后续在动画中使用。
from transformers import TrainerCallbackclass PrinterCallback(TrainerCallback): def on_log(self, args, state, control, logs=None, **kwargs): _ = logs.pop("total_flos", None) if state.is_local_process_zero: if len(logs) == 3: # skip last row with open("log.csv", "a") as f: f.write(",".join(map(str, logs.values())) + "\n")
重要的是通过将save_strategy="step"
和save_step
的值设置为较低的值来增加检查点的保存间隔,以确保有足够的检查点用于动画。动画中的每一帧对应一个检查点。在训练过程中,为每个检查点和CSV文件创建一个文件夹,以备后续使用。
嵌入的创建
我们使用Transformers库中的AutoFeatureExtractor
和AutoModel
从CIFAR-10数据集的测试集中使用不同的模型检查点生成嵌入。
每个嵌入是一个768维的向量,表示一个模型检查点下的10,000个测试图像中的一个。这些嵌入可以存储在与检查点相同的文件夹中,以保持良好的概览。
提取异常值
我们可以使用Cleanlab库提供的OutOfDistribution
类根据每个检查点的嵌入来识别异常值。然后,生成的分数可以用于识别动画中的前10个异常值。
from cleanlab.outlier import OutOfDistributiondef get_ood(sorted_checkpoint_folder, df): ... ood = OutOfDistribution() ood_train_feature_scores = ood.fit_score(features=embedding_np) df["scores"] = ood_train_feature_scores
应用PCA和Procrustes分析
使用scikit-learn包的主成分分析(PCA),我们通过将768维向量降低到2维来在二维空间中可视化嵌入。当对每个时间步重新计算PCA时,由于轴翻转或旋转可能会导致动画中的大跳跃。为了解决这个问题,我们使用SciPy包中的附加Procrustes分析[3]对每一帧进行几何变换,将其几何变换到最后一帧,这只涉及平移、旋转和均匀缩放。这样可以实现动画中更平滑的过渡。
from sklearn.decomposition import PCAfrom scipy.spatial import procrustesdef make_pca(sorted_checkpoint_folder, pca_np): ... embedding_np_flat = embedding_np.reshape(-1, 768) pca = PCA(n_components=2) pca_np_new = pca.fit_transform(embedding_np_flat) _, pca_np_new, disparity = procrustes(pca_np, pca_np_new)
在Spotlight中进行评估
在最终完成整个动画之前,我们在Spotlight中进行了评估。在这个过程中,我们使用第一个和最后一个检查点进行嵌入生成、PCA和异常值检测。我们在Spotlight中加载生成的DataFrame:

Spotlight在左上角提供了一个综合表格,展示了数据集中的所有字段。右上角显示了两个PCA表示:一个是使用第一个检查点生成的嵌入,另一个是使用最后一个检查点生成的嵌入。最后,在底部部分展示了选定的图像。
免责声明:本文的作者也是Spotlight的开发人员之一。
创建动画
对于每个检查点,我们创建一个图像,并将其与相应的检查点一起存储。
这是通过使用make_pca(...)
和get_ood(...)
函数实现的,这两个函数分别生成表示嵌入的2D点和提取前8个异常值。2D点以对应类别的颜色绘制。异常值根据其分数排序,并在高分榜中显示相应的图像。训练损失从CSV文件中加载,并以折线图方式绘制。
最后,所有图像可以使用imageio或类似的库编译成GIF。

结论
本文详细介绍了如何创建一种动画,以可视化Vision Transformer(ViT)模型的微调过程。我们已经逐步介绍了生成和分析嵌入、可视化结果以及创建将这些元素结合起来的动画的步骤。
创建这样的动画不仅有助于理解微调ViT模型的复杂过程,而且还是向他人传达这些概念的强大工具。
动画的完整代码可在 GitHub 上的附带笔记本中找到。
我是一名专业人士,擅长创建用于交互式探索非结构化数据的高级软件解决方案。我写作关于非结构化数据,并使用强大的可视化工具进行分析和做出明智的决策。
参考资料
[1] Alexey Dosovitskiy, Lucas Beyer, Alexander Kolesnikov, Dirk Weissenborn, Xiaohua Zhai, Thomas Unterthiner, Mostafa Dehghani, Matthias Minderer, Georg Heigold, Sylvain Gelly, Jakob Uszkoreit, Neil Houlsby, An Image is Worth 16×16 Words: Transformers for Image Recognition at Scale (2020), arXiv
[2] Alex Krizhevsky, Learning Multiple Layers of Features from Tiny Images (2009), University Toronto
[3] Gower, John C. Generalized procrustes analysis (1975), Psychometrika