Press "Enter" to skip to content

AdaTape 具有自适应计算和动态读写的基础模型

作者:Google研究实习生Fuzhao Xue和研究科学家Mostafa Dehghani

自适应计算是指机器学习系统根据环境变化调整其行为的能力。传统神经网络具有固定的功能和计算能力,即它们对不同输入的处理都花费相同数量的FLOPs,而具有自适应和动态计算的模型会根据输入的复杂性调节其分配给处理每个输入的计算预算。

神经网络中的自适应计算具有两个关键原因的吸引力。首先,引入自适应性的机制提供了归纳偏差,在解决一些具有挑战性的任务中起到关键作用。例如,为不同输入启用不同数量的计算步骤对于解决需要建模不同深度层次的算术问题至关重要。其次,通过动态计算提供的更大灵活性,它使从业者能够调整推理的成本,因为这些模型可以根据需要调整花费更多的FLOPs来处理新的输入。

可以通过使用不同的函数或计算预算来使神经网络具有自适应性。深度神经网络可以被看作是一个根据输入和其参数输出结果的函数。为了实现自适应函数类型,根据输入有选择地激活参数的子集,这个过程被称为条件计算。基于函数类型的自适应性已经在混合专家研究中得到了探索,其中每个输入样本的稀疏激活参数是通过路由确定的。

自适应计算的另一个研究领域涉及动态计算预算。与标准神经网络(如T5、GPT-3、PaLM和ViT)不同,它们的计算预算对于不同样本是固定的,最近的研究表明,自适应计算预算可以提高在转换器无法胜任的任务上的性能。其中许多作品通过使用动态深度来分配计算预算来实现自适应性。例如,提出了自适应计算时间(ACT)算法,为递归神经网络提供自适应的计算预算。通用变压器将ACT算法扩展到变压器中,通过使计算预算依赖于用于每个输入示例或令牌的变压器层数的数量。最近的研究,如PonderNet,在改进动态停止机制的同时采用了类似的方法。

在论文“自适应计算与弹性输入序列”中,我们介绍了一种利用自适应计算的新模型,称为AdaTape。这个模型是基于变压器的架构,它使用一组动态的令牌来创建弹性输入序列,与之前的作品相比,提供了一种独特的适应性视角。AdaTape使用自适应的读带机制来确定根据输入复杂性添加到每个输入的令牌的数量。AdaTape的实现非常简单,提供了一个有效的旋钮,可以在需要时增加准确性,但与其他自适应基线相比,它也更加高效,因为它直接将适应性注入输入序列而不是模型深度。最后,AdaTape在标准任务(如图像分类)和算法任务上提供了更好的性能,同时保持了有利的质量和成本平衡。

可变输入序列的自适应计算变压器

AdaTape同时使用自适应函数类型和动态计算预算。具体而言,在分词后的一批输入序列(例如,从视觉变压器中的图像的非重叠块的线性投影)中,AdaTape使用表示每个输入的向量来动态选择一个可变大小的读带令牌序列。

AdaTape使用一个令牌库,称为“读带库”,用于存储通过自适应的读带机制与模型交互的所有候选读带令牌。我们探索了两种不同的方法来创建读带库:基于输入驱动的令牌库和可学习的令牌库。

基于输入驱动的令牌库的一般思想是从输入中提取一组令牌,并使用与原始模型分词器不同的方法将原始输入映射到一系列输入令牌。这使得可以动态、按需地访问从输入中获取的信息,该信息是使用不同的视角获得的,例如不同的图像分辨率或不同的抽象级别。

在某些情况下,以不同抽象级别进行分词是不可能的,因此无法使用基于输入驱动的读带库,例如在图形变压器中难以进一步分割每个节点的情况。为了解决这个问题,AdaTape通过使用一组可训练向量作为读带令牌提供了一种更通用的生成读带库的方法。这种方法被称为可学习的令牌库,可以看作是一个嵌入层,模型可以根据输入示例的复杂性动态检索令牌。可学习的令牌库使AdaTape能够生成更灵活的读带库,使其能够根据每个输入示例的复杂性动态调整计算预算,例如,更复杂的示例从库中检索更多的令牌,这不仅让模型使用存储在库中的知识,还可以花费更多的FLOPs来处理输入,因为输入现在更大了。

最后,选择的磁带标记被附加到原始输入上,并传递给后续的Transformer层。对于每个Transformer层,相同的多头注意力应用于所有输入和磁带标记。然而,使用了两个不同的前馈网络(FFN):一个用于所有来自原始输入的标记,另一个用于所有磁带标记。我们观察到,对于输入和磁带标记使用独立的前馈网络可以稍微提高质量。

AdaTape 具有自适应计算和动态读写的基础模型 四海 第1张
AdaTape概述。对于不同的样本,我们从磁带库中选择一个可变数量的不同标记。磁带库可以由输入驱动,例如通过提取一些额外的细粒度信息,或者可以是一组可训练的向量。自适应磁带读取用于递归地选择不同长度的磁带标记序列,以适应不同的输入。然后,这些标记简单地附加到输入中,并传递给Transformer编码器。

AdaTape提供了有益的归纳偏差

我们在奇偶性任务上评估AdaTape,这对于标准Transformer来说是一个非常具有挑战性的任务,以研究AdaTape中的归纳偏差的影响。在奇偶性任务中,给定一个由1、0和-1组成的序列,模型必须预测序列中1的数量是偶数还是奇数。奇偶性是最简单的非计数自由或周期正则语言,但令人惊讶的是,标准Transformer无法解决这个任务。

AdaTape 具有自适应计算和动态读写的基础模型 四海 第2张
奇偶性任务的评估。标准Transformer和通用Transformer都无法执行此任务,两者的性能都与随机猜测基线相同。

尽管在短而简单的序列上评估,但标准Transformer和通用Transformer都无法执行奇偶性任务,因为它们无法在模型内部保持计数器。然而,AdaTape优于所有基线,因为它在其输入选择机制中结合了轻量级的循环,提供了一个归纳偏差,使得隐式地维护计数器成为可能,而这在标准Transformer中是不可能的。

图像分类评估

我们还对图像分类任务评估了AdaTape。为此,我们从头开始在ImageNet-1K上训练了AdaTape。下图显示了AdaTape和基线方法(包括A-ViT以及通用Transformer ViT(UViT和U2T))的准确性与速度(每秒处理的图像数量)之间的关系。在质量和成本的权衡方面,AdaTape比其他自适应Transformer基线表现得更好。在效率方面,参数数量较大的AdaTape模型比较小的基线更快。这样的结果与之前的研究结果一致,显示自适应模型深度架构不适用于许多加速器,如TPU。

AdaTape 具有自适应计算和动态读写的基础模型 四海 第3张
我们通过在ImageNet上从头开始训练来评估AdaTape。对于A-ViT,我们不仅报告了论文中的结果,还重新实现了A-ViT的从头开始训练,即A-ViT(我们自己的版本)。

对AdaTape行为的研究

除了在奇偶任务和ImageNet-1K上的性能之外,我们还使用基于输入的bank在JFT-300M验证集上评估了AdaTape的令牌选择行为。为了更好地理解模型的行为,我们将基于输入的bank上的令牌选择结果可视化为热图,其中较浅的颜色表示该位置被更频繁选择。热图显示AdaTape更频繁地选择中心补丁。这与我们的先前知识一致,因为中心补丁通常更具信息量,尤其是在具有自然图像的数据集上,其中主要对象位于图像中央。这个结果突出了AdaTape的智能,它可以有效地识别和优先选择更具信息量的补丁,以提高性能。

AdaTape 具有自适应计算和动态读写的基础模型 四海 第4张
我们可视化了AdaTape-B/32(左)和AdaTape-B/16(右)的磁带令牌选择热图。热度/较浅的颜色表示该位置的补丁被更频繁选择。

结论

AdaTape的特点是通过自适应磁带阅读机制生成的弹性序列长度。这也引入了一种新的归纳偏差,使得AdaTape有潜力解决对标准Transformer和现有自适应Transformer都具有挑战性的任务。通过在图像识别基准测试上进行全面的实验,我们证明了当计算保持恒定时,AdaTape优于标准Transformer和自适应架构Transformer。

致谢

本篇文章的一位作者Mostafa Dehghani现在在Google DeepMind工作。

Leave a Reply

Your email address will not be published. Required fields are marked *