它们是什么以及它们的作用
当我开始学习HuggingFace时,数据收集器是我最不直观的组件之一。我很难理解它们,也没有找到足够好的资源来直观地解释它们。
在本文中,我们将看一下数据收集器是什么,它们之间的区别以及如何编写自定义数据收集器。
数据收集器:高层次
数据收集器是HuggingFace中数据处理的重要组成部分。在对数据进行标记化后,我们都会在将数据传递给训练器对象以训练模型之前使用它们。
简而言之,它们将一系列样本组合成一个小型训练批次。它们的操作取决于它们所定义的任务,但至少会对输入样本进行填充或截断,以确保在一个小批次中的所有样本具有相同的长度。典型的小批次大小范围从16到256个样本,这取决于模型的大小、数据和硬件限制。
数据收集器是特定于任务的。以下每个任务都有一个数据收集器:
- 因果语言建模(CLM)
- 遮蔽语言建模(MLM)
- 序列分类
- Seq2Seq
- 标记分类
有些数据收集器很简单。例如,“序列分类”任务的数据收集器只需要将一个小批次中的所有序列进行填充,以确保它们具有相同的长度。然后将它们连接到一个张量中。
有些数据收集器非常复杂,因为它们需要处理该任务的数据处理。
基本数据收集器
最基本的两个数据收集器如下:
1) DefaultDataCollator:它不进行任何填充或截断操作。它假设所有输入样本具有相同的长度。如果您的输入样本长度不一致,将导致错误。
import torchfrom transformers import DefaultDataCollatortexts = ["Hello world", "How are you?"]# Tokenizefrom transformers import AutoTokenizertokenizer =…