Press "Enter" to skip to content

使用Amazon SageMaker智能筛选,将深度学习模型训练加速高达35%

在当今快速发展的人工智能领域,深度学习模型已经成为创新的前沿,其应用涵盖计算机视觉(CV)、自然语言处理(NLP)和推荐系统。然而,训练和微调这些模型所需的成本不断增加,给企业带来了挑战。这种成本主要是由于训练深度学习模型所使用的大量数据的规模。如今,大型模型常常需要使用数千兆字节的数据进行训练,即使使用强大的GPU或基于AWS Trainium硬件进行训练,也可能需要数周的时间。通常情况下,客户依赖于改进模型训练循环效率的技术和优化方法,例如优化的内核或层、混合精度训练,或者利用Amazon SageMaker分布式训练库等功能。然而,现在对训练数据本身的效率关注较少。在模型训练过程中,并非所有数据对学习过程的贡献相同:计算资源的相当大一部分可能花费在处理并不对模型整体准确性有很大贡献的简单示例上。

以往,客户通常依赖于预处理技术,如上采样或下采样和去重,以改进数据的信息质量。这些技术可以有所帮助,但通常非常耗时,需要专门的数据科学经验,有时可能更多的是一门艺术而非科学。客户通常还依赖于经过精心筛选的数据集,例如RefinedWeb,以提高模型的性能;然而,这些数据集并不总是完全开源的,而且通常更多是面向通用用途,而非与您的特定用例相关。

除此之外,您还可以如何克服与模型训练期间低信息数据样本相关的低效率问题呢?

我们非常高兴地宣布SageMaker的智能筛选功能正在公开预览中,这项功能可以将深度学习模型的训练成本降低多达35%。智能筛选是一种全新的数据效率技术,它在训练过程中主动分析您的数据样本,并筛选出对模型影响较小的样本。通过仅在少量数据子集上进行训练,只选择对模型收敛最具贡献的样本,总体训练成本降低,准确性几乎不受影响。此外,由于该功能在模型训练期间在线操作,智能筛选不需要对上游数据或下游训练流程进行更改。

在本文中,我们将讨论以下主题:

  • SageMaker中的智能筛选功能及其工作原理
  • 如何在PyTorch训练工作负载中使用智能筛选

您还可以查看我们的文档和示例笔记本,获取更多关于如何开始使用智能筛选的资源。

SageMaker智能筛选的工作原理

我们首先概述一下智能筛选功能如何加速SageMaker上的模型训练。

智能筛选的任务是在训练过程中筛选您的训练数据,并仅将更具信息量的样本提供给模型。在使用PyTorch进行典型训练时,数据会以批次的形式迭代地发送到训练循环和加速设备(例如GPU或Trainium芯片)中,通过PyTorch DataLoader进行加载。智能筛选是在这个数据加载阶段实现的,因此与训练流程中的任何上游数据预处理是独立的。

智能筛选使用您的模型和用户指定的损失函数,在每个数据样本加载时进行评估性前向传递。高损失的样本将对模型训练产生实质影响,因此会被用于训练;相对较低损失的数据样本则会被设置到一边,不参与训练。

智能筛选的一个关键输入是要排除的数据比例:例如,通过将比例设置为33%(beta_value=0.5),每个批次中大约损失处于底部三分之一的样本将被排除在训练之外。当已经找到足够多的高损失样本来完成一个批次时,数据将通过完整的训练循环,模型进行正常学习和训练。在启用智能筛选时,您无需对训练循环进行任何更改。

下图说明了这个工作流程。

使用Amazon SageMaker智能筛选,将深度学习模型训练加速高达35% 四海 第1张

通过仅包含部分训练数据,智能筛选减少了训练模型所需的时间和计算量。在我们的测试中,总训练时间和成本减少了近40%。通过智能筛选数据,模型的准确性基本不受影响,因为被排除的样本对模型的损失较小。在下表中,我们列出了一组实验结果,展示了使用SageMaker智能筛选可能实现的性能提升。

使用Amazon SageMaker智能筛选,将深度学习模型训练加速高达35% 四海 第2张

在表中,“接受的百分比”一列表示所包含和用于训练循环的数据比例。增加这个可调参数可以降低成本(如“IMR节省百分比”一列所示),但也可能影响准确性。接受的百分比的适当设置取决于你的数据集和模型;你应该尝试并调整这个参数,以在减少成本和对准确性的影响之间取得最佳平衡。

解决方案概述

在接下来的几节中,我们将通过一个实际示例,介绍如何在SageMaker上使用PyTorch训练作业启用智能筛选。如果你想快速入门,可以跳到PyTorch或PyTorch Lightning示例

前提条件

我们假设你已经知道如何使用PyTorch或PyTorch Lightning、SageMaker Python SDK和Estimator类以及SageMaker Deep Learning Containers进行训练。如果不是,请在继续之前参考使用SageMaker Python SDK

开始使用SageMaker智能筛选

在一个典型的PyTorch训练作业中,你使用你的数据集和其他必需的参数初始化PyTorch训练DataLoader,该DataLoader在训练过程中提供输入批次。为了启用训练数据的智能筛选,你将使用一个新的DataLoader类:smart_sifting.dataloader.sift_dataloader.SiftingDataloader。这个类被用作现有PyTorchDataLoader的包装器,训练过程将使用SiftingDataloader获取输入批次。SiftingDataLoader从原始的PyTorchDataLoader中获取输入批次,评估批次中样本的重要性,并构建一个包含高损失样本的筛选批次,然后将其传递给训练步骤。包装器的代码如下:

from smart_sifting.dataloader.sift_dataloader import SiftingDataloadertrain_dataloader =  SiftingDataloader(    sift_config = sift_config,    orig_dataloader=DataLoader(self.train, self.batch_size, shuffle=True),    loss_impl=BertLoss(),    model=self.model)

SiftingDataloader需要一些额外的参数来分析你的训练数据,你可以通过参数指定。首先,创建一个smart_sifting.sift_config.sift_configs.RelativeProbabilisticSiftConfig对象。这个对象保存可配置的和必需的beta_valueloss_history_length,分别定义了要保留的样本比例和在评估相对损失时要包含的样本窗口。注意,由于智能筛选使用你的模型来定义样本的重要性,如果我们使用完全随机权重的模型,可能会产生负面影响。相反,你可以使用loss_based_sift_configsift_delay来延迟筛选过程,直到模型中的参数权重超过随机值。 (有关更多详细信息,请参见将智能筛选应用于您的训练脚本。)在下面的代码中,我们定义了sift_config并指定了beta_valueloss_history_length,同时也延迟了筛选的开始使用loss_based_sift_config

from smart_sifting.sift_config.sift_configs import RelativeProbabilisticSiftConfig, LossConfig, SiftingBaseConfigsift_config = RelativeProbabilisticSiftConfig(    beta_value=3,    loss_history_length=500,    loss_based_sift_config=LossConfig(         sift_config=SiftingBaseConfig(sift_delay=10)    ))

接下来,您还必须在SiftingDataloader对象中包含一个loss_impl参数。智能筛选在个体样本级别上工作,所以能够访问损失计算方法来确定样本的重要性至关重要。您必须实现一个筛选损失方法,该方法返回一个nx1的张量,其中存储了n个样本的损失值。通常情况下,您要指定与训练中使用的model相同的损失方法。最后,在SiftingDataloader对象中包含对您的模型的引用,该引用用于在样本被包含在训练中之前对其进行评估。请参考以下代码:

from smart_sifting.sift_config.sift_configs import RelativeProbabilisticSiftConfig, LossConfig, SiftingBaseConfig## 定义筛选损失class SiftBertLoss(Loss):    # 您应该添加以下初始化函数    # 以计算每个样本的损失,而不是每个批次    def __init__(self):        self.celoss = torch.nn.CrossEntropyLoss(reduction='none')    def loss(            self,            model: torch.nn.Module,            transformed_batch: SiftingBatch,            original_batch: Any = None,    ) -> torch.Tensor:            device = next(model.parameters()).device        batch = [t.to(device) for t in original_batch]        # 计算损失        outputs = model(batch)        return self.celoss(outputs.logits, batch[2])........train_dataloader =  SiftingDataloader(    sift_config = sift_config,    orig_dataloader=DataLoader(self.train, self.batch_size, shuffle=True),    loss_impl=SiftBertLoss(),    model=self.model)

以下代码展示了如何在现有的BERT训练任务中启用智能筛选功能:

from smart_sifting.dataloader.sift_dataloader import SiftingDataloaderfrom smart_sifting.loss.abstract_sift_loss_module import Lossfrom smart_sifting.sift_config.sift_configs import RelativeProbabilisticSiftConfig, LossConfig, SiftingBaseConfig.........## 定义筛选损失class SiftBertLoss(Loss):    # 您应该添加以下初始化函数    # 以计算每个样本的损失,而不是每个批次    def __init__(self):        self.celoss = torch.nn.CrossEntropyLoss(reduction='none')    def loss(            self,            model: torch.nn.Module,            transformed_batch: SiftingBatch,            original_batch: Any = None,    ) -> torch.Tensor:            device = next(model.parameters()).device        batch = [t.to(device) for t in original_batch]        # 计算损失        outputs = model(batch)        return self.celoss(outputs.logits, batch[2])              .... .... ....  sift_config = RelativeProbabilisticSiftConfig(    beta_value=3,    loss_history_length=500,    loss_based_sift_config=LossConfig(        sift_config=SiftingBaseConfig(sift_delay=10)    ))train_dataloader =  SiftingDataloader(    sift_config = sift_config,    orig_dataloader=DataLoader(self.train, self.batch_size, shuffle=True),    loss_impl=SiftBertLoss(),    model=self.model)......# 在剩余的训练逻辑中使用train_dataloader.

结论

在本文中,我们探讨了SageMaker智能筛选的公开预览版,这是一种可以将深度学习模型训练成本降低高达35%的新功能。此功能在训练期间提高了数据效率,过滤掉信息量较少的数据样本。通过仅包含对模型收敛最有影响力的数据,您可以显著减少训练时间和成本,同时保持准确性。而且,它可以与您现有的流程无缝集成,无需对数据或训练管道进行修改。

要深入了解SageMaker智能筛选的工作原理,并在PyTorch训练工作负载中实施它,请查看我们的文档示例笔记本,立即开始使用这项新功能。

Leave a Reply

Your email address will not be published. Required fields are marked *