Press "Enter" to skip to content

使用🤗 Transformers对ViT进行微调,用于图像分类

使用🤗 Transformers对ViT进行微调,用于图像分类 四海 第1张

正如基于Transformer的模型改变了自然语言处理领域一样,我们现在看到了将其应用于各种其他领域的论文的爆炸式增长。其中最具革命性的是Vision Transformer(ViT),它是由Google Brain的研究人员于2021年6月推出的。

本论文探讨了如何对图像进行标记,就像对句子进行标记一样,以便可以将它们传递给Transformer模型进行训练。这其实是一个非常简单的概念…

  1. 将图像分割成子图像块的网格
  2. 使用线性投影对每个子图像块进行嵌入
  3. 每个嵌入的子图像块成为一个标记,嵌入的子图像块序列就是传递给模型的序列。

事实证明,一旦完成了上述步骤,你可以像处理自然语言处理任务一样预训练和微调Transformer模型。相当不错 😎。


在本博客文章中,我们将介绍如何利用🤗 datasets下载和处理图像分类数据集,然后使用它们来微调预训练的ViT模型,使用🤗 transformers

首先,让我们安装这两个包。

pip install datasets transformers

加载数据集

让我们首先加载一个小的图像分类数据集,并查看其结构。

我们将使用beans数据集,该数据集是一组健康和不健康的豆叶图片。🍃

from datasets import load_dataset

ds = load_dataset('beans')
ds

让我们看一下'train'数据集中第400个示例。你会注意到数据集中的每个示例都有3个特征:

  1. image:PIL图像
  2. image_file_path:作为image加载的图像文件的路径str
  3. labels:一个datasets.ClassLabel特征,它是标签的整数表示。(稍后你将看到如何获取字符串类名,不用担心!)
ex = ds['train'][400]
ex

{
  'image': <PIL.JpegImagePlugin ...>,
  'image_file_path': '/root/.cache/.../bean_rust_train.4.jpg',
  'labels': 1
}

让我们看一下这个图像👀

image = ex['image']
image

这明显是一片叶子!但是是什么种类的叶子呢?😅

由于该数据集的'labels'特征是datasets.features.ClassLabel,我们可以使用它来查找对应于该示例标签ID的名称。

首先,让我们访问'labels'的特征定义。

labels = ds['train'].features['labels']
labels

ClassLabel(num_classes=3, names=['angular_leaf_spot', 'bean_rust', 'healthy'], names_file=None, id=None)

现在,让我们打印出我们示例的类标签。你可以使用ClassLabelint2str函数来实现,该函数允许将类的整数表示传递给查找字符串标签。

labels.int2str(ex['labels'])

'bean_rust'

结果显示,上面显示的叶子受到了豆锈病的感染,这是豆类植物中的一种严重病害。😢

让我们编写一个函数,显示每个类别的示例网格,以便更好地了解你要处理的内容。

import random
from PIL import ImageDraw, ImageFont, Image

def show_examples(ds, seed: int = 1234, examples_per_class: int = 3, size=(350, 350)):

    w, h = size
    labels = ds['train'].features['labels'].names
    grid = Image.new('RGB', size=(examples_per_class * w, len(labels) * h))
    draw = ImageDraw.Draw(grid)
    font = ImageFont.truetype("/usr/share/fonts/truetype/liberation/LiberationMono-Bold.ttf", 24)

    for label_id, label in enumerate(labels):

        # 通过单个标签对数据集进行过滤、打乱顺序并获取一些样本
        ds_slice = ds['train'].filter(lambda ex: ex['labels'] == label_id).shuffle(seed).select(range(examples_per_class))

        # 在一行中绘制该标签的示例
        for i, example in enumerate(ds_slice):
            image = example['image']
            idx = examples_per_class * label_id + i
            box = (idx % examples_per_class * w, idx // examples_per_class * h)
            grid.paste(image.resize(size), box=box)
            draw.text(box, label, (255, 255, 255), font=font)

    return grid

show_examples(ds, seed=random.randint(0, 1337), examples_per_class=3)

数据集中每个类别的几个示例的网格

从我所看到的,

  • 角叶斑病:具有不规则的棕色斑块
  • 豆锈病:具有被白黄色环围的圆形棕色斑点
  • 健康:…看起来健康。🤷‍♂️

加载 ViT 特征提取器

现在我们知道了我们的图像是什么样的,并更好地理解了我们试图解决的问题。让我们看看如何为我们的模型准备这些图像!

在训练 ViT 模型时,会对输入到模型中的图像应用特定的转换。如果在图像上应用错误的转换,模型将无法理解它所看到的内容!🖼 ➡️ 🔢

为确保我们应用正确的转换,我们将使用一个初始化的 ViTFeatureExtractor,其中包含与我们计划使用的预训练模型一起保存的配置。在我们的例子中,我们将使用 google/vit-base-patch16-224-in21k 模型,因此让我们从 Hugging Face Hub 加载它的特征提取器。

from transformers import ViTFeatureExtractor

model_name_or_path = 'google/vit-base-patch16-224-in21k'
feature_extractor = ViTFeatureExtractor.from_pretrained(model_name_or_path)

您可以通过打印特征提取器的配置来查看它。

ViTFeatureExtractor {
  "do_normalize": true,
  "do_resize": true,
  "feature_extractor_type": "ViTFeatureExtractor",
  "image_mean": [
    0.5,
    0.5,
    0.5
  ],
  "image_std": [
    0.5,
    0.5,
    0.5
  ],
  "resample": 2,
  "size": 224
}

要处理一张图像,只需将其传递给特征提取器的调用函数。这将返回一个包含像素值的字典,这是传递给模型的数值表示。

默认情况下,您会得到一个 NumPy 数组,但如果添加return_tensors='pt'参数,您将得到torch张量。

feature_extractor(image, return_tensors='pt')

应该会得到类似于以下的结果…

{
  'pixel_values': tensor([[[[ 0.2706,  0.3255,  0.3804,  ...]]]])
}

…其中张量的形状是(1, 3, 224, 224)

处理数据集

现在您知道如何读取图像并将其转换为输入,让我们编写一个函数,将这两个步骤组合起来处理数据集中的一个示例。

def process_example(example):
    inputs = feature_extractor(example['image'], return_tensors='pt')
    inputs['labels'] = example['labels']
    return inputs

process_example(ds['train'][0])

{
  'pixel_values': tensor([[[[-0.6157, -0.6000, -0.6078,  ...]]]]),
  'labels': 0
}

虽然您可以调用ds.map并将其应用于所有示例,但这可能非常慢,特别是如果您使用的是较大的数据集。相反,您可以将一个转换应用于数据集。转换仅在索引示例时应用。

不过,首先,您需要更新最后一个函数,以接受一批数据,因为这是ds.with_transform所期望的。

ds = load_dataset('beans')

def transform(example_batch):
    # 将 PIL 图像列表转换为像素值
    inputs = feature_extractor([x for x in example_batch['image']], return_tensors='pt')

    # 不要忘记包括标签!
    inputs['labels'] = example_batch['labels']
    return inputs

您可以使用ds.with_transform(transform)将其直接应用于数据集。

prepared_ds = ds.with_transform(transform)

现在,每当您从数据集中获取一个示例时,转换将实时应用于它(对样本和切片都适用,如下所示)

prepared_ds['train'][0:2]

这次,生成的pixel_values张量的形状将为(2, 3, 224, 224)

{
  'pixel_values': tensor([[[[-0.6157, -0.6000, -0.6078,  ..., ]]]]),
  'labels': [0, 0]
}

数据已经处理好,您可以开始设置训练流程了。本篇博客将使用🤗的Trainer,但在此之前需要完成以下几个步骤:

  • 定义一个collate函数。

  • 定义一个评估指标。在训练过程中,模型的预测准确率需要进行评估。您需要相应地定义一个compute_metrics函数。

  • 加载预训练的检查点。您需要加载一个预训练的检查点,并正确配置它以进行训练。

  • 定义训练配置。

在对模型进行微调后,您将在评估数据上正确评估模型,并验证它是否确实学会了正确分类图像。

定义我们的数据整合器

批次以字典的列表形式传入,因此您只需将其解包+堆叠为批次张量即可。

由于collate_fn将返回一个批次字典,因此稍后可以**解包模型的输入。✨

import torch

def collate_fn(batch):
    return {
        'pixel_values': torch.stack([x['pixel_values'] for x in batch]),
        'labels': torch.tensor([x['labels'] for x in batch])
    }

定义一个评估指标

可以使用datasets中的准确度指标来将预测结果与标签进行比较。下面是如何在compute_metrics函数中使用它,该函数将由Trainer使用。

import numpy as np
from datasets import load_metric

metric = load_metric("accuracy")
def compute_metrics(p):
    return metric.compute(predictions=np.argmax(p.predictions, axis=1), references=p.label_ids)

让我们加载预训练模型。我们将在初始化时添加num_labels,以便模型创建具有正确单位数的分类头。我们还将包含id2labellabel2id映射,以便在Hub小部件中显示可读的标签(如果选择push_to_hub)。

from transformers import ViTForImageClassification

labels = ds['train'].features['labels'].names

model = ViTForImageClassification.from_pretrained(
    model_name_or_path,
    num_labels=len(labels),
    id2label={str(i): c for i, c in enumerate(labels)},
    label2id={c: str(i) for i, c in enumerate(labels)}
)

即将开始训练!在此之前,还需要设置训练配置,即定义TrainingArguments

其中大部分都很好理解,但这里有一个非常重要的参数remove_unused_columns=False。这个参数将丢弃模型调用函数中未使用的任何特征。默认情况下,该参数为True,因为通常最好丢弃未使用的特征列,这样更容易将输入解包到模型的调用函数中。但是,在我们的情况下,我们需要未使用的特征(特别是’image’)来创建’pixel_values’。

我的意思是,如果您忘记设置remove_unused_columns=False,将会遇到麻烦。

from transformers import TrainingArguments

training_args = TrainingArguments(
  output_dir="./vit-base-beans",
  per_device_train_batch_size=16,
  evaluation_strategy="steps",
  num_train_epochs=4,
  fp16=True,
  save_steps=100,
  eval_steps=100,
  logging_steps=10,
  learning_rate=2e-4,
  save_total_limit=2,
  remove_unused_columns=False,
  push_to_hub=False,
  report_to='tensorboard',
  load_best_model_at_end=True,
)

现在,所有实例都可以传递给Trainer,我们准备好开始训练了!

from transformers import Trainer

trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=collate_fn,
    compute_metrics=compute_metrics,
    train_dataset=prepared_ds["train"],
    eval_dataset=prepared_ds["validation"],
    tokenizer=feature_extractor,
)

训练 🚀

train_results = trainer.train()
trainer.save_model()
trainer.log_metrics("train", train_results.metrics)
trainer.save_metrics("train", train_results.metrics)
trainer.save_state()

评估 📊

metrics = trainer.evaluate(prepared_ds['validation'])
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)

这是我的评估结果 – 很酷的结果!抱歉,我不得不说出来。

***** 评估指标 *****
  epoch                   =        4.0
  eval_accuracy           =      0.985
  eval_loss               =     0.0637
  eval_runtime            = 0:00:02.13
  eval_samples_per_second =     62.356
  eval_steps_per_second   =       7.97

最后,如果你愿意,你可以将你的模型推送到hub上。在这里,如果你在训练配置中指定了push_to_hub=True,我们将把它推送到hub。请注意,为了推送到hub,你必须安装git-lfs并登录到你的Hugging Face账户(可以通过huggingface-cli login来完成)。

kwargs = {
    "finetuned_from": model.config._name_or_path,
    "tasks": "image-classification",
    "dataset": 'beans',
    "tags": ['image-classification'],
}

if training_args.push_to_hub:
    trainer.push_to_hub('🍻 干杯', **kwargs)
else:
    trainer.create_model_card(**kwargs)

生成的模型已经分享到nateraw/vit-base-beans。我猜你身边没有豆叶的图片,所以我为你添加了一些示例供你尝试!🚀

Leave a Reply

Your email address will not be published. Required fields are marked *