Press "Enter" to skip to content

使用🤗 Transformers对多语言ASR进行微调的Fine-Tune Whisper

使用🤗 Transformers对多语言ASR进行微调的Fine-Tune Whisper 四海 第1张

在本博客中,我们使用Hugging Face 🤗 Transformers为任何多语种ASR数据集提供了Whisper微调的逐步指南。本博客提供了对Whisper模型、Common Voice数据集以及微调背后原理的深入解释,并附带了执行数据准备和微调步骤的代码单元格。如需更简洁版本的笔记本,其中包含更少的解释但包含所有代码,请参阅附带的Google Colab。

目录

  1. 介绍
  2. 在Google Colab中微调Whisper
    1. 准备环境
    2. 加载数据集
    3. 准备特征提取器、标记器和数据
    4. 训练和评估
    5. 构建演示
  3. 结束语

介绍

Whisper是由Alec Radford等人于2022年9月在OpenAI发布的用于自动语音识别(ASR)的预训练模型。与其许多前辈模型(如Wav2Vec 2.0)不同,Whisper在大量的标记音频转录数据上进行了预训练,准确地说是680,000小时。这比用于训练Wav2Vec 2.0的无标记音频数据(60,000小时)多一个数量级。此外,这个预训练数据中的117,000小时是多语种ASR数据。这导致可以应用于96种以上语言的检查点,其中许多语言被认为是低资源语言。

这个大量的标记数据使得Whisper能够直接在监督任务(语音识别)上进行预训练,从标记的音频转录预训练数据中学习从语音到文本的映射。因此,Whisper只需要很少的额外微调就能够产生高性能的ASR模型。这与Wav2Vec 2.0形成对比,后者在无监督任务(遮蔽预测)上进行预训练。在这种情况下,模型被训练来学习从无标记音频数据到隐藏状态的中间映射。虽然无监督预训练可以生成高质量的语音表示,但它并不学习从语音到文本的映射。这个映射只有在微调过程中学习,因此需要更多的微调才能产生有竞争力的性能。

当扩展到680,000小时的标记预训练数据时,Whisper模型展示了很强的泛化能力,适用于许多数据集和领域。预训练检查点在LibriSpeech ASR的测试-清洁子集上实现了与最先进的ASR系统竞争的结果,字错误率(WER)接近3%,并在TED-LIUM上取得了4.7%的WER新记录(参见Whisper论文的表8)。Whisper在预训练过程中获得的广泛多语种ASR知识可以用于其他低资源语言;通过微调,预训练检查点可以针对特定数据集和语言进行调整,进一步改善这些结果。

Whisper是一种基于Transformer的编码器-解码器模型,也称为序列到序列模型。它将一系列音频频谱特征映射到一系列文本标记。首先,原始音频输入通过特征提取器转换为对数梅尔频谱图。然后,Transformer编码器将频谱图编码为编码器隐藏状态的序列。最后,解码器根据先前的标记和编码器隐藏状态联合条件地自回归预测文本标记。图1总结了Whisper模型。

Figure 1: Whisper model. The architecture follows the standard Transformer-based encoder-decoder model. A log-Mel spectrogram is input to the encoder. The last encoder hidden states are input to the decoder via cross-attention mechanisms. The decoder autoregressively predicts text tokens, jointly conditional on the encoder hidden states and previously predicted tokens. Figure source: OpenAI Whisper Blog .

在序列到序列模型中,编码器将音频输入转换为一组隐藏状态表示,从口语中提取重要特征。解码器扮演语言模型的角色,处理隐藏状态表示并生成相应的文本转录。在系统架构中“内部”地结合语言模型被称为深度融合。这与“外部”将语言模型与编码器组合(例如CTC + n n n-gram,参见内部语言模型估计)形成对比。通过深度融合,整个系统可以使用相同的训练数据和损失函数进行端到端的训练,具有更大的灵活性和通常更优越的性能(参见ESB基准)。

Whisper使用交叉熵目标函数进行预训练和微调,这是在分类任务上训练序列到序列系统的标准目标函数。在这里,系统被训练以正确分类来自预定义文本令牌词汇表的目标文本令牌。

Whisper的检查点有五种不同模型大小的配置。最小的四个是在仅英语或多语言数据上训练的。最大的检查点仅支持多语言。在Hugging Face Hub上提供了这九个预训练检查点。下表总结了这些检查点,并提供了指向Hub上模型的链接:

为了演示目的,我们将使用244M参数(约1GB)的“small”检查点的多语言版本进行微调。至于我们的数据,我们将在Common Voice数据集中选择一种低资源语言进行系统的训练和评估。我们将展示,即使只有8小时的微调数据,我们也可以在这种语言中获得强大的性能。


1 {}^1 1 Whisper这个名字来自于“WSPSR”这个首字母缩写,它代表“面向Web规模的有监督语音预训练”。

在Google Colab中微调Whisper

准备环境

我们将使用几个流行的Python包来微调Whisper模型。我们将使用datasets下载和准备训练数据,使用transformers加载和训练Whisper模型。我们还需要soundfile包来预处理音频文件,evaluatejiwer来评估模型的性能。最后,我们将使用gradio来构建我们微调模型的演示。

!pip install datasets>=2.6.1
!pip install git+https://github.com/huggingface/transformers
!pip install librosa
!pip install evaluate>=0.30
!pip install jiwer
!pip install gradio

我们强烈建议您在训练过程中直接从Hugging Face Hub上上传模型检查点。Hub提供以下功能:

  • 集成版本控制:您可以确保在训练过程中不会丢失任何模型检查点。
  • Tensorboard日志:跟踪训练过程中的重要指标。
  • 模型卡片:记录模型的功能和预期用途。
  • 社区:与社区分享和协作的简单方式!

将笔记本连接到Hub非常简单 – 只需要在提示时输入您的Hub身份验证令牌即可。在此处找到您的Hub身份验证令牌:

from huggingface_hub import notebook_login

notebook_login()

输出结果:

登录成功
您的令牌已保存到/root/.huggingface/token

加载数据集

Common Voice是一系列众包数据集,其中发言者以各种语言录制维基百科的文本。我们将使用Common Voice数据集的最新版本(版本11)。至于我们的语言,我们将在印地语上微调我们的模型,印地语是印度北部、中部、东部和西部使用的印度-雅利安语系的语言。Common Voice 11.0包含大约12小时的标记印地语数据,其中4小时用作测试数据。

让我们前往Hub并查看Common Voice的数据集页面:mozilla-foundation/common_voice_11_0。

我们第一次查看此页面时,将被要求接受使用条款。之后,我们将获得对数据集的完全访问权限。

一旦我们提供了使用数据集的身份验证,我们将会看到数据集预览。数据集预览显示了数据集的前100个样本。而且,它加载了音频样本,可以实时播放。我们可以通过将子集设置为hi来选择Common Voice的印地语子集,使用下拉菜单(hi是印地语的语言标识符代码):

使用🤗 Transformers对多语言ASR进行微调的Fine-Tune Whisper 四海 第3张

如果我们点击第一个样本的播放按钮,我们可以听到音频并查看相应的文本。浏览训练集和测试集的样本,以更好地了解我们处理的音频和文本数据。从语调和风格可以看出,这些录音是从叙述性的语音中提取的。您还可能注意到说话者和录音质量的巨大变化,这是众包数据的共同特征。

使用🤗数据集,下载和准备数据非常简单。我们只需一行代码即可下载和准备Common Voice的拆分数据。由于印地语的资源非常有限,我们将合并trainvalidation拆分,以获得约8小时的训练数据。我们将使用4小时的test数据作为我们的保留测试集:

from datasets import load_dataset, DatasetDict

common_voice = DatasetDict()

common_voice["train"] = load_dataset("mozilla-foundation/common_voice_11_0", "hi", split="train+validation", use_auth_token=True)
common_voice["test"] = load_dataset("mozilla-foundation/common_voice_11_0", "hi", split="test", use_auth_token=True)

print(common_voice)

输出结果:

DatasetDict({
    train: Dataset({
        features: ['client_id', 'path', 'audio', 'sentence', 'up_votes', 'down_votes', 'age', 'gender', 'accent', 'locale', 'segment'],
        num_rows: 6540
    })
    test: Dataset({
        features: ['client_id', 'path', 'audio', 'sentence', 'up_votes', 'down_votes', 'age', 'gender', 'accent', 'locale', 'segment'],
        num_rows: 2894
    })
})

大多数ASR数据集只提供输入音频样本(audio)和相应的转录文本(sentence)。Common Voice还包含其他元数据信息,如accentlocale,在ASR中我们可以忽略这些信息。为了使笔记本尽可能通用,我们只考虑输入音频和转录文本进行微调,忽略其他附加的元数据信息:

common_voice = common_voice.remove_columns(["accent", "age", "client_id", "down_votes", "gender", "locale", "path", "segment", "up_votes"])

Common Voice只是我们可以从Hub下载的多语言ASR数据集之一-还有很多其他可用的数据集!要查看可用于语音识别的数据集范围,请访问链接:Hub上的ASR数据集

准备特征提取器、分词器和数据

ASR流水线可以分解为三个组件:

  1. 特征提取器,用于预处理原始音频输入
  2. 执行序列到序列映射的模型
  3. 分词器,用于将模型输出后处理为文本格式

在🤗 Transformers中,Whisper模型有一个关联的特征提取器和分词器,分别称为WhisperFeatureExtractor和WhisperTokenizer。

我们将逐个详细介绍特征提取器和分词器的细节!

加载WhisperFeatureExtractor

语音由随时间变化的一维数组表示。数组在任何给定时间步的值是该点处信号的振幅。仅从振幅信息中,我们可以重建音频的频谱并恢复所有声学特征。

由于语音是连续的,它包含无限数量的振幅值。这对于预期获得有限数组的计算机设备带来问题。因此,我们通过在固定时间步长上从信号中采样值来离散化我们的语音信号。我们采样音频的间隔称为采样率,通常以样本/秒或赫兹(Hz)表示。使用更高的采样率进行采样可以更好地近似连续的语音信号,但也需要每秒存储更多的值。

我们必须确保音频输入的采样率与模型期望的采样率匹配,因为不同采样率的音频信号具有非常不同的分布。音频样本只能使用正确的采样率进行处理。否则会导致意想不到的结果!例如,使用采样率为16kHz的音频样本,以8kHz的采样率进行播放会使音频听起来像是半速。同样,传递具有错误采样率的音频可能会导致期望一种采样率的ASR模型失败并接收到另一种采样率。Whisper特征提取器期望采样率为16kHz的音频输入,因此我们需要将我们的输入与此值匹配。我们不希望无意中在慢动作训练语音识别系统!

Whisper特征提取器执行两个操作。首先,它填充/截断一批音频样本,使所有样本的输入长度为30秒。长度小于30秒的样本通过在序列末尾添加零(在音频信号中表示无信号或静音)进行填充。长度大于30秒的样本被截断为30秒。由于批处理中的所有元素在输入空间中都被填充/截断到最大长度,所以在将音频输入转发到Whisper模型时,我们不需要注意力掩码。在这方面,Whisper是独特的 – 对于大多数音频模型,您可以提供一个指示哪些序列已被填充,因此应在自注意机制中忽略的注意力掩码。Whisper经过训练,可以在没有注意力掩码的情况下进行操作,并直接从语音信号中推断出在哪些位置忽略输入。

Whisper特征提取器执行的第二个操作是将填充的音频数组转换为对数-Mel频谱图。这些频谱图是信号频率的可视化表示,类似于傅里叶变换。图2显示了一个示例频谱图。y轴是Mel通道,对应于特定的频率区间。x轴是时间。每个像素的颜色表示给定时间的该频率区间的对数强度。对数-Mel频谱图是Whisper模型期望的输入形式。

Mel通道(频率区间)在语音处理中是标准的,并且被选择为近似人类听觉范围。对于Whisper微调,我们只需要知道频谱图是语音信号中频率的可视化表示即可。有关Mel通道的更多详细信息,请参阅Mel频率倒谱。

图2:采样音频数组转换为对数-Mel频谱图。左侧:采样的一维音频信号。右侧:对应的对数-Mel频谱图。图源:Google SpecAugment博客。

幸运的是,🤗 Transformers的Whisper特征提取器只需要一行代码就能同时进行填充和频谱图转换!让我们继续从预训练的检查点中加载特征提取器,以备使用我们的音频数据:

from transformers import WhisperFeatureExtractor

feature_extractor = WhisperFeatureExtractor.from_pretrained("openai/whisper-small")

加载WhisperTokenizer

现在让我们看一下如何加载Whisper分词器。Whisper模型输出文本标记,指示预测文本在词汇字典中的索引。分词器将一系列文本标记映射到实际文本字符串(例如 [1169, 3797, 3332] -> “the cat sat”)。

在传统的ASR中,使用仅编码器模型时,我们使用连

from transformers import WhisperTokenizer

tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-small", language="Hindi", task="transcribe")

我们可以通过对Common Voice数据集中的第一个样本进行编码和解码来验证分词器是否正确编码了印地语字符。在对转录进行编码时,分词器在序列的开头和结尾附加了“特殊标记”,包括转录的开头/结尾标记、语言标记和任务标记(根据前一步中的参数指定)。在解码标签ID时,我们可以选择“跳过”这些特殊标记,从而返回原始输入形式的字符串:

input_str = common_voice["train"][0]["sentence"]
labels = tokenizer(input_str).input_ids
decoded_with_special = tokenizer.decode(labels, skip_special_tokens=False)
decoded_str = tokenizer.decode(labels, skip_special_tokens=True)

print(f"输入:                   {input_str}")
print(f"解码(包含特殊标记):   {decoded_with_special}")
print(f"解码(不包含特殊标记): {decoded_str}")
print(f"相等:                   {input_str == decoded_str}")

打印输出:

输入:                   खीर की मिठास पर गरमाई बिहार की सियासत, कुशवाहा ने दी सफाई
解码(包含特殊标记):   <|startoftranscript|><|hi|><|transcribe|><|notimestamps|>खीर की मिठास पर गरमाई बिहार की सियासत, कुशवाहा ने दी सफाई<|endoftext|>
解码(不包含特殊标记): खीर की मिठास पर गरमाई बिहार की सियासत, कुशवाहा ने दी सफाई
相等:                   True

合并创建WhisperProcessor

为了简化特征提取器和分词器的使用,我们可以将它们包装到一个单独的WhisperProcessor类中。该处理器对象继承自WhisperFeatureExtractorWhisperProcessor,可以根据需要用于音频输入和模型预测。这样,在训练期间我们只需要跟踪两个对象:处理器(processor)和模型(model):

from transformers import WhisperProcessor

processor = WhisperProcessor.from_pretrained("openai/whisper-small", language="Hindi", task="transcribe")

准备数据

让我们打印Common Voice数据集的第一个示例,以查看数据的形式:

print(common_voice["train"][0])

打印输出:

{'audio': {'path': '/home/sanchit_huggingface_co/.cache/huggingface/datasets/downloads/extracted/607848c7e74a89a3b5225c0fa5ffb9470e39b7f11112db614962076a847f3abf/cv-corpus-11.0-2022-09-21/hi/clips/common_voice_hi_25998259.mp3', 
           'array': array([0.0000000e+00, 0.0000000e+00, 0.0000000e+00, ..., 9.6724887e-07,
       1.5334779e-06, 1.0415988e-06], dtype=float32), 
           'sampling_rate': 48000},
 'sentence': 'खीर की मिठास पर गरमाई बिहार की सियासत, कुशवाहा ने दी सफाई'}

我们可以看到,我们得到了一个一维输入音频数组和相应的目标转录。我们已经深入讨论了采样率的重要性,以及我们需要使音频的采样率与Whisper模型的采样率(16kHz)匹配。由于我们的输入音频采样率为48kHz,我们需要将其下采样为16kHz,然后将其传递给Whisper特征提取器。

我们将使用数据集的cast_column方法将音频输入设置为正确的采样率。这个操作不会在原地改变音频,而是向datasets发出信号,在首次加载音频样本时动态重新采样:

from datasets import Audio

common_voice = common_voice.cast_column("audio", Audio(sampling_rate=16000))

重新加载Common Voice数据集中的第一个音频样本将对其进行所需的采样率转换:

print(common_voice["train"][0])

输出结果:

{'audio': {'path': '/home/sanchit_huggingface_co/.cache/huggingface/datasets/downloads/extracted/607848c7e74a89a3b5225c0fa5ffb9470e39b7f11112db614962076a847f3abf/cv-corpus-11.0-2022-09-21/hi/clips/common_voice_hi_25998259.mp3', 
           'array': array([ 0.0000000e+00,  0.0000000e+00,  0.0000000e+00, ...,
       -3.4206650e-07,  3.2979898e-07,  1.0042874e-06], dtype=float32),
           'sampling_rate': 16000},
 'sentence': 'खीर की मिठास पर गरमाई बिहार की सियासत, कुशवाहा ने दी सफाई'}

太棒了!我们可以看到采样率已经降低到16kHz。数组的值也不同了,因为我们现在只有大约每三个振幅值才有一个。

现在我们可以编写一个函数来准备我们的数据,以便为模型做好准备:

  1. 通过调用batch["audio"]加载和重新采样音频数据。如上所述,🤗数据集将根据需要执行任何必要的重新采样操作。
  2. 我们使用特征提取器从一维音频数组计算出log-Mel谱图输入特征。
  3. 我们通过使用分词器将转录编码为标签ID。
def prepare_dataset(batch):
    # load and resample audio data from 48 to 16kHz
    audio = batch["audio"]

    # compute log-Mel input features from input audio array 
    batch["input_features"] = feature_extractor(audio["array"], sampling_rate=audio["sampling_rate"]).input_features[0]

    # encode target text to label ids 
    batch["labels"] = tokenizer(batch["sentence"]).input_ids
    return batch

我们可以使用数据集的.map方法将数据准备函数应用于所有的训练样本:

common_voice = common_voice.map(prepare_dataset, remove_columns=common_voice.column_names["train"], num_proc=4)

好了!现在我们已经完全准备好训练数据了!接下来,让我们看看如何使用这些数据来微调Whisper模型。

注意:目前datasets同时使用torchaudiolibrosa进行音频加载和重新采样。如果您希望实现自己定制的数据加载/采样,可以使用"path"列获取音频文件路径,并忽略"audio"列。

训练和评估

现在我们已经准备好了数据,可以开始进行训练流程了。🤗 Trainer将为我们完成大部分繁重的工作。我们只需要:

  • 定义一个数据整理器:数据整理器接受我们预处理的数据,并准备好用于模型的PyTorch张量。

  • 评估指标:在评估过程中,我们希望使用词错误率(WER)指标评估模型。我们需要定义一个compute_metrics函数来处理这个计算。

  • 加载预训练的检查点:我们需要加载一个预训练的检查点,并正确配置它进行训练。

  • 定义训练参数:这些参数将由🤗 Trainer在构建训练计划时使用。

一旦我们对模型进行了微调,我们将在测试数据上对其进行评估,以验证我们是否正确地训练了它以转录印地语的语音。

定义数据整理器

序列到序列语音模型的数据整理器在处理input_featureslabels时是独立的:特征提取器处理input_features,而分词器处理labels

input_features 已经被填充到30秒,并转换为固定维度的对数梅尔频谱图,所以我们只需要将它们转换为批处理的PyTorch张量。我们使用特征提取器的.pad方法和return_tensors=pt来实现这一点。请注意,这里不需要进行额外的填充,因为输入数据的维度是固定的,input_features只需转换为PyTorch张量即可。

另一方面,labels没有进行填充。我们首先使用分词器的.pad方法将序列填充到批处理中的最大长度。然后,将填充标记替换为-100,以便在计算损失时不将这些标记考虑在内。然后,我们从标签序列的开头删除转录标记,因为我们在训练期间稍后追加它。

我们可以利用之前定义的WhisperProcessor来执行特征提取器和分词器的操作:

import torch

from dataclasses import dataclass
from typing import Any, Dict, List, Union

@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
    processor: Any

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        # 拆分输入和标签,因为它们的长度必须不同,并且需要不同的填充方法
        # 首先处理音频输入,只需返回torch张量
        input_features = [{"input_features": feature["input_features"]} for feature in features]
        batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")

        # 获取分词后的标签序列
        label_features = [{"input_ids": feature["labels"]} for feature in features]
        # 将标签填充到最大长度
        labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")

        # 用-100替换填充以正确忽略损失
        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)

        # 如果在前一个分词步骤中附加了bos标记,
        # 则在此处删除bos标记,因为它稍后会追加
        if (labels[:, 0] == self.processor.tokenizer.bos_token_id).all().cpu().item():
            labels = labels[:, 1:]

        batch["labels"] = labels

        return batch

让我们初始化刚刚定义的数据整理器:

data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor)

评估指标

接下来,我们定义在评估集上使用的评估指标。我们将使用字错误率(WER)指标,这是评估ASR系统的“事实上”的指标。有关更多信息,请参阅WER文档。我们将从🤗 Evaluate加载WER指标:

import evaluate

metric = evaluate.load("wer")

然后,我们只需要定义一个函数,该函数接受我们的模型预测并返回WER指标。这个名为compute_metrics的函数首先将label_ids中的-100替换为pad_token_id(撤消数据整理器中应用的步骤,以正确忽略填充的标记)。然后它将预测的id和标签id解码为字符串。最后,它计算预测和参考标签之间的WER:

def compute_metrics(pred):
    pred_ids = pred.predictions
    label_ids = pred.label_ids

    # 用pad_token_id替换-100
    label_ids[label_ids == -100] = tokenizer.pad_token_id

    # 在计算指标时,我们不想对令牌进行分组
    pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
    label_str = tokenizer.batch_decode(label_ids, skip_special_tokens=True)

    wer = 100 * metric.compute(predictions=pred_str, references=label_str)

    return {"wer": wer}

加载预训练的检查点

现在让我们加载预训练的 Whisper small 检查点。同样,通过使用 🤗 Transformers 是非常简单的!

from transformers import WhisperForConditionalGeneration

model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-small")

Whisper 模型具有在自回归生成开始之前被强制作为模型输出的令牌标识符( forced_decoder_ids )。这些令牌标识符控制着零样本自动语音识别的转录语言和任务。对于微调,我们将把这些标识符设置为 None ,因为我们将训练模型来预测正确的语言(印地语)和任务(转录)。在生成过程中,还有一些完全被抑制的令牌( suppress_tokens )。这些令牌的对数概率被设置为 -inf ,因此它们永远不会被采样。我们将这些令牌覆盖为一个空列表,表示不抑制任何令牌:

model.config.forced_decoder_ids = None
model.config.suppress_tokens = []

定义训练参数

在最后一步,我们定义与训练相关的所有参数。以下是部分参数的解释:

  • output_dir :用于保存模型权重的本地目录。这也将是 Hugging Face Hub 上的存储库名称。
  • generation_max_length :在评估过程中自动回归生成的最大令牌数。
  • save_steps :在训练过程中,中间检查点将被保存并异步上传到 Hub,每 save_steps 个训练步骤保存一次。
  • eval_steps :在训练过程中,每 eval_steps 个训练步骤执行一次中间检查点的评估。
  • report_to :日志保存位置。支持的平台有 "azure_ml""comet_ml""mlflow""neptune""tensorboard""wandb" 。选择您喜欢的平台,或者保留为 "tensorboard" 以将日志记录到 Hub。

有关其他训练参数的详细信息,请参阅 Seq2SeqTrainingArguments 文档。

from transformers import Seq2SeqTrainingArguments

training_args = Seq2SeqTrainingArguments(
    output_dir="./whisper-small-hi",  # 更改为您选择的存储库名称
    per_device_train_batch_size=16,
    gradient_accumulation_steps=1,  # 对于每减少一半的批次大小,增加2倍
    learning_rate=1e-5,
    warmup_steps=500,
    max_steps=4000,
    gradient_checkpointing=True,
    fp16=True,
    evaluation_strategy="steps",
    per_device_eval_batch_size=8,
    predict_with_generate=True,
    generation_max_length=225,
    save_steps=1000,
    eval_steps=1000,
    logging_steps=25,
    report_to=["tensorboard"],
    load_best_model_at_end=True,
    metric_for_best_model="wer",
    greater_is_better=False,
    push_to_hub=True,
)

注意 :如果不想将模型检查点上传到 Hub,请将 push_to_hub=False

我们可以将训练参数与我们的模型、数据集、数据整理器和 compute_metrics 函数一起传递给 🤗 Trainer:

from transformers import Seq2SeqTrainer

trainer = Seq2SeqTrainer(
    args=training_args,
    model=model,
    train_dataset=common_voice["train"],
    eval_dataset=common_voice["test"],
    data_collator=data_collator,
    compute_metrics=compute_metrics,
    tokenizer=processor.feature_extractor,
)

准备好开始训练了!

训练

要启动训练,只需执行:

trainer.train()

训练将花费大约 5-10 小时,具体取决于您的 GPU 或分配给 Google Colab 的 GPU。根据您的 GPU,当您开始训练时,可能会遇到 CUDA "out-of-memory" 错误。在这种情况下,您可以逐步减小 per_device_train_batch_size 的大小,每次减小一半,并使用 gradient_accumulation_steps 进行补偿。

打印输出:

我们最佳的WER是32.0% – 对于8小时的训练数据来说还不错!最大的问题是与其他ASR系统相比如何。为此,我们可以查看hf-speech-bench,一个按语言和数据集对模型进行分类,并根据其WER对其进行排名的排行榜。

使用🤗 Transformers对多语言ASR进行微调的Fine-Tune Whisper 四海 第5张

我们的微调模型显著改善了Whisper small检查点的零-shot性能,突显了Whisper强大的迁移学习能力。

当我们将训练结果推送到Hub时,我们可以自动将我们的检查点提交到排行榜上 – 我们只需设置适当的关键字参数(kwargs)。您可以更改这些值以匹配您的数据集、语言和模型名称:

kwargs = {
    "dataset_tags": "mozilla-foundation/common_voice_11_0",
    "dataset": "Common Voice 11.0",  # 训练数据集的'漂亮'名称
    "dataset_args": "config: hi, split: test",
    "language": "hi",
    "model_name": "Whisper Small Hi - Sanchit Gandhi",  # 您模型的'漂亮'名称
    "finetuned_from": "openai/whisper-small",
    "tasks": "automatic-speech-recognition",
    "tags": "hf-asr-leaderboard",
}

现在可以将训练结果上传到Hub。要执行此操作,请执行push_to_hub命令:

trainer.push_to_hub(**kwargs)

现在您可以使用Hub上的链接与任何人分享此模型。他们也可以使用标识符"your-username/the-name-you-picked"加载它,例如:

from transformers import WhisperForConditionalGeneration, WhisperProcessor

model = WhisperForConditionalGeneration.from_pretrained("sanchit-gandhi/whisper-small-hi")
processor = WhisperProcessor.from_pretrained("sanchit-gandhi/whisper-small-hi")

虽然在Common Voice印地语测试数据上,微调模型的结果令人满意,但并不是最佳结果。本笔记本的目的是演示如何在任何多语言ASR数据集上微调预训练的Whisper检查点。通过优化训练超参数(如学习率和dropout)并使用更大的预训练检查点(VoAGIlarge),结果可能会得到改善。

构建演示

现在我们已经微调了我们的模型,我们可以构建一个演示来展示它的ASR能力!我们将使用🤗 Transformers的pipeline,它将负责整个ASR流程,从预处理音频输入到解码模型预测。我们将使用Gradio构建我们的交互式演示。Gradio可能是构建机器学习演示的最简单方法;使用Gradio,我们只需要几分钟就可以构建一个演示!

运行下面的示例将生成一个Gradio演示,我们可以通过计算机的麦克风录制语音,并将其输入到我们微调的Whisper模型中以转录相应的文本:

from transformers import pipeline
import gradio as gr

pipe = pipeline(model="sanchit-gandhi/whisper-small-hi")  # 更改为"your-username/the-name-you-picked"

def transcribe(audio):
    text = pipe(audio)["text"]
    return text

iface = gr.Interface(
    fn=transcribe, 
    inputs=gr.Audio(source="microphone", type="filepath"), 
    outputs="text",
    title="Whisper Small Hindi",
    description="使用微调的Whisper小模型进行印地语实时语音识别的演示。",
)

iface.launch()

结束语

在本博客中,我们介绍了使用🤗 Datasets、Transformers和Hugging Face Hub对Whisper进行多语言ASR的微调的逐步指南。如果您希望尝试自己进行微调,请参考Google Colab。如果您对微调其他Transformers模型(包括英语和多语言ASR)感兴趣,请务必查看examples/pytorch/speech-recognition中的示例脚本。

Leave a Reply

Your email address will not be published. Required fields are marked *