Press "Enter" to skip to content

在HuggingFace中的数据整理者

它们是什么以及它们的作用

来自unsplash.com的图像

当我开始学习HuggingFace时,数据收集器是我最不直观的组件之一。我很难理解它们,也没有找到足够好的资源来直观地解释它们。

在本文中,我们将看一下数据收集器是什么,它们之间的区别以及如何编写自定义数据收集器。

数据收集器:高层次

数据收集器是HuggingFace中数据处理的重要组成部分。在对数据进行标记化后,我们都会在将数据传递给训练器对象以训练模型之前使用它们。

简而言之,它们将一系列样本组合成一个小型训练批次。它们的操作取决于它们所定义的任务,但至少会对输入样本进行填充或截断,以确保在一个小批次中的所有样本具有相同的长度。典型的小批次大小范围从16到256个样本,这取决于模型的大小、数据和硬件限制。

数据收集器是特定于任务的。以下每个任务都有一个数据收集器:

  • 因果语言建模(CLM)
  • 遮蔽语言建模(MLM)
  • 序列分类
  • Seq2Seq
  • 标记分类

有些数据收集器很简单。例如,“序列分类”任务的数据收集器只需要将一个小批次中的所有序列进行填充,以确保它们具有相同的长度。然后将它们连接到一个张量中。

有些数据收集器非常复杂,因为它们需要处理该任务的数据处理。

基本数据收集器

最基本的两个数据收集器如下:

1) DefaultDataCollator它不进行任何填充或截断操作。它假设所有输入样本具有相同的长度。如果您的输入样本长度不一致,将导致错误。

import torchfrom transformers import DefaultDataCollatortexts = ["Hello world", "How are you?"]# Tokenizefrom transformers import AutoTokenizertokenizer =…
Leave a Reply

Your email address will not be published. Required fields are marked *