Press "Enter" to skip to content

用JAX和Haiku从零开始实现Transformer编码器 🤖

了解Transformer的基本构建块。

Transformer,以Edward Hopper的风格呈现(由Dall.E 3生成)

Transformer架构于2017年的重要论文“注意力就是一切”[0]中首次亮相。可以说,它是近年来深度学习领域最具影响力的突破之一,使得大型语言模型得以兴起,并且在计算机视觉等领域发挥了作用。

相较于先前依赖于重复(recurrence)的最先进架构,如长短期记忆(LSTM)网络或门控循环单元(GRU),Transformer引入了自注意(self-attention)的概念,并结合了编码器/解码器结构。

在本文中,我们将逐步从头开始实现Transformer的前半部分,即编码器。我们将使用JAX作为我们的主要框架,以及DeepMind的深度学习库之一——Haiku。

如果您不熟悉JAX,或需要提醒您其惊人功能的信息,我已经在我的前一篇文章中介绍过该主题,重点是在强化学习的背景下:

使用JAX对RL环境进行向量化和并行化:以光速进行Q学习⚡

学会对GridWorld环境进行向量化,并使用CPU并行训练30个Q学习代理,步速为180万步每次…

towardsdatascience.com

我们将逐个介绍构成编码器的每个模块,并学会高效地实现它们。具体而言,本文大纲包括:

  • 嵌入层(Embedding Layer)和位置编码(Positional Encodings)
  • 多头注意力(Multi-Head Attention)
  • 残差连接(Residual Connections)和层归一化(Layer Normalization)
  • 位置前馈网络(Position-wise Feed-Forward Networks)

免责声明:本文旨在实现为主,而非对这些概念的完整介绍。如果需要,您可以参考本文末尾的资源。

同往常一样,本文的完整代码及说明性笔记本都可在 GitHub上找到,如果您喜欢本文,请给该存储库加星标!

GitHub – RPegoud/jab:在JAX中实现的基础深度学习模型集合

在JAX中实现的基础深度学习模型集合 – GitHub – RPegoud/jab:一个…

github.com

主要参数

在开始之前,我们需要定义一些在编码器模块中起关键作用的参数:

  • 序列长度(seq_len):序列中的令牌或单词数。
  • 嵌入维度(embed_dim):嵌入的维度,换句话说,用于描述单个令牌或单词的数值数量。
  • 批量大小(batch_size):一批输入的大小,即同时处理的序列数目。

我们的编码器模型的输入序列通常是形状为(batch_size, seq_len)的。在本文中,我们将使用 batch_size=32seq_len=10,这意味着我们的编码器将同时处理32个长度为10的序列。

在处理的每个步骤中注意数据的形状将使我们能够更好地可视化和理解数据在编码器块中的流动。以下是我们的编码器的高级概述,我们将从底部开始,从嵌入层位置编码开始:

Transformer 编码器块的表示(作者创建)

嵌入层和位置编码

如前所述,我们的模型接受批量化的令牌序列作为输入。生成这些令牌可以简单地收集数据集中唯一单词的集合,并为每个单词分配一个索引。然后,我们将抽样32个长度为10的序列,并用词汇表中的索引替换每个单词。这个过程将为我们提供一个形状为(batch_size, seq_len)的数组,正如我们期望的那样。

现在我们准备好开始编码器的工作了。第一步是为我们的序列创建“位置嵌入”。位置嵌入是词嵌入位置编码总和

词嵌入

词嵌入允许我们在词汇表中编码单词之间的含义语义关系。在本文中,嵌入维度被固定为64。这意味着每个单词都由一个64维向量表示,使得具有相似含义的单词具有相似的坐标。此外,我们可以操作这些向量以提取单词之间的关系,如下图所示。

从词嵌入导出的类比示例(源自developers.google.com)

使用 Haiku,生成可学习的嵌入只需调用:

hk.Embed(vocab_size, embed_dim)

这些嵌入将在模型训练过程中与其他可学习参数一起更新(稍后会详细介绍)。

位置编码

与循环神经网络不同,Transformer 无法通过共享隐藏状态来推断令牌的位置,因为它们缺乏循环卷积结构。因此引入了位置编码,它们是传达令牌在输入序列中的位置向量

实质上,每个令牌都被分配一个由交替的正弦和余弦值组成的位置向量。这些向量与词嵌入的维度匹配,以便两者可以相加。

具体来说,原始 Transformer 论文使用以下函数:

用JAX和Haiku从零开始实现Transformer编码器 🤖 四海 第4张

位置编码函数(摘自“Attention is all you need”,Vaswani et al. 2017)

下面的图形使我们进一步了解位置编码的功能。让我们看看最上面图的第一行,我们可以看到0和1的交替序列。确实,行代表序列中令牌的位置(pos变量),而列代表嵌入维度(i变量)。

因此,当pos=0时,前面的方程返回sin(0)=0对于偶数嵌入维度和cos(0)=1对于奇数维度。

此外,我们看到相邻的行共享相似的值,而第一行和最后一行相差很大。这个属性有助于模型评估序列中单词之间的距离以及它们的顺序

最后,第三个图表示位置编码和嵌入的总和,这是嵌入块的输出。

单词嵌入和位置编码的表示,其中seq_len=16,embed_dim=64(由作者制作)

使用Haiku,我们将嵌入层定义如下。与其他深度学习框架类似,Haiku允许我们定义自定义模块(这里是hk.Module)来存储可学习参数并定义我们模型组件的行为。

每个Haiku模块都需要一个__init____call__函数。在这里,调用函数只需使用hk.Embed函数和位置编码计算嵌入,然后将它们相加。

位置编码函数使用了JAX的功能,如vmaplax.cond以实现高性能。如果你对这些函数不熟悉,请随意查看我之前的文章,其中对它们进行了更深入的介绍。

简而言之,vmap允许我们为一个单个样本定义一个函数,并对其进行向量化,以便可以应用于数据批次。而lax.cond是Python的if/else语句的XLA兼容版本。

自注意力和多头注意力

注意力旨在计算相对于输入词每个词的重要性。例如,在以下句子中:

“黑猫跳上沙发,躺下并入睡,因为它累了。”

词“”对于模型来说可能相当模棱两可,因为从技术上讲,它既可以指代“”又可以指“沙发”。一个经过良好训练的注意力模型将能理解“”指的是“”,因此会相应地分配给其余句子注意力值。

本质上,注意力值可以被视为描述给定上下文的情况下某个词的重要性权重。例如,词“”的注意力向量会对“”(是什么跳?)、“”和“沙发”(跳到哪里?)等与其上下文相关的词具有较高的值,因为这些词对于其上下文是相关的

注意向量的可视化表示(由作者制作)

在 Transformer 论文中,注意力是使用缩放点积注意力进行计算的。它的总结公式如下:

缩放点积注意力(来源:“Attention is all you need”,Vaswani et al. 2017)

在这里,Q、K 和 V 代表 Queries(查询)、Keys(关键字)Values(值)。这些矩阵是通过将学习到的权重向量 WQ、WK 和 WV 与位置编码相乘得到的。

这些名称主要是用来帮助理解信息在注意力模块中的处理和加权过程的抽象概念。它们是对检索系统词汇[2](例如在 YouTube 上搜索视频)的暗示。

这里有一个直观的解释:

  • Queries(查询):它们可以被解释为对序列中所有位置提出的一组问题。例如,询问一个词的上下文,并尝试识别序列中最相关的部分。
  • Keys(关键字):它们可以被看作是与查询交互的信息,查询与关键字之间的兼容性确定查询应该对相应的值给予多少注意。
  • Values(值):匹配关键字和查询使我们能够决定哪些关键字是相关的,值是与关键字配对的实际内容。

在下图中,查询是一个 YouTube 搜索,关键字是视频描述和元数据,而值则是相关的视频。

Queries、Keys、Values 概念的直观表示法(作者制作)

在我们的案例中,查询、关键字和值来自同一来源(因为它们源自输入序列),因此被称为自注意(self-attention)。

注意力分数的计算通常会并行执行多次,每次使用一部分嵌入(embeddings)。这个机制称为“多头注意力”,它使每个头能够并行学习数据的几个不同表示,从而得到更稳健的模型。

单个注意力头通常会处理形状为 (batch_size,seq_len,d_k) 的数组,其中 d_k 可以设置为头的数量与嵌入的维度之间的比率(d_k = n_heads/embed_dim)。通过方便地连接每个头的输出,就可以得到形状为(batch_size,seq_len,embed_dim)的数组,作为输入。

注意力矩阵的计算可以分解为以下几个步骤:

  • 首先,我们定义可学习的权重向量 WQ、WK 和 WV。这些向量的形状为(n_heads,embed_dim,d_k)
  • 同时,我们将位置编码权重向量相乘。我们得到形状为(batch_size,seq_len,d_k)的 Q、K 和 V 矩阵。
  • 然后,我们对 Q 和 K 的点积进行缩放操作(对每个矩阵的行进行缩放)。这个缩放操作包括将点积的结果除以 d_k 的平方根,并在矩阵的行上应用 softmax 函数。因此,输入令牌(即行)的注意力分数总和为1,这有助于防止值变得过大并减慢计算速度。输出的形状为 (batch_size,seq_len,seq_len)。
  • 最后,我们将上述操作的结果与 V 进行点积运算,使输出的形状为 (batch_size,seq_len,d_k)。
展示了注意力块内矩阵运算的视觉表示(由作者制作)
  • 然后可以将每个注意力头的输出串联起来,形成一个形状为(batch_size, seq_len, embed_dim)的矩阵。Transformer论文在多头注意力模块的末尾还添加了一个线性层,以聚合组合来自所有注意力头的学习表示。
多头注意力矩阵和线性层的串联(由作者制作)

在Haiku中,可以按以下方式实现多头注意力模块。 __call__函数遵循与上述图形相同的逻辑,而类方法利用JAX工具,如vmap(将操作向量化到不同的注意力头和矩阵上)和tree_map(将矩阵点积映射到权重向量上)。

残差连接和层归一化

您可能已经注意到Transformer图中,多头注意力块和前馈网络后面跟着残差连接层归一化

残差连接还是跳过连接

残差连接是解决梯度消失问题的标准解决方案,当梯度变得太小以至于无法有效地更新模型参数时,就会出现梯度消失问题。

由于这个问题在特别深的架构中自然而然地出现,残差连接在各种复杂模型中被使用,例如计算机视觉中的ResNet(Kaiming et al,2015),强化学习中的AlphaZero(Silver et al,2017),当然还有Transformers

在实践中,残差连接简单地将特定层的输出转发到后续层,途中跳过一个或多个层。例如,多头注意力周围的残差连接等效于将多头注意力的输出与位置嵌入相加。

这样可以使梯度在反向传播过程中更有效地流向架构,并且通常可以导致更快的收敛和更稳定的训练

表现了Transformer中残差连接的表示(由作者制作)

层归一化

层归一化有助于确保在模型中传播的值不会“爆炸”(趋向于无限),在注意力块中很容易发生这种情况,因为每个前向传递中都会对多个矩阵进行乘法。

与批归一化不同,批归一化在假设均匀分布的批次维度上进行标准化,而层归一化在特征上进行操作。这种方法适用于句子批次,其中每个句子可能由于其独特的含义和词汇而具有不同的分布

通过在特征(如嵌入或注意力值)上进行归一化,层归一化可以将数据标准化到一致的尺度,而不会混淆不同的句子特征,同时保持每个句子的独特分布。

在Transformer的上下文中表示层标准化(由作者制作)

层标准化的实现非常简单,我们初始化可学习参数α和β,并沿着期望的特征轴进行归一化处理。

位置级前馈网络

编码器的最后一个组件是位置级前馈网络。这个全连接网络以注意力块的归一化输出作为输入,用于引入非线性并增加模型的容量来学习复杂函数。

它由两个由gelu激活分隔的稠密层组成:

在这个块之后,我们还有一个残差连接和层标准化来完成编码器。

总结

到此为止!现在您应该熟悉Transformer编码器的主要概念了。这是完整的编码器类,值得注意的是,在Haiku中,我们为每个层指定一个名称,以便将可学习参数分离并易于访问。 __call__函数提供了我们编码器不同步骤的良好摘要:

要在实际数据上使用此模块,我们必须对封装编码器类的函数应用hk.transform。确实,您可能还记得JAX采用了函数式编程范式,因此Haiku遵循相同的原则。

我们定义一个包含编码器类实例的函数,并返回正向传递的输出。应用hk.transform将返回一个转换后的对象,具有两个函数:initapply

前者允许我们使用随机键以及一些虚拟数据(请注意,在这里,我们传递一个形状为batch_size,seq_len的零数组)来初始化模块,而后者允许我们处理真实数据。

# 注意:以下两种语法是等效的# 1:将transform作为类装饰器使用@hk.transformdef encoder(x):  ...  return model(x)  encoder.init(...)encoder.apply(...)# 2:分别应用transformdef encoder(x):  ...  return model(x)encoder_fn = hk.transform(encoder)encoder_fn.init(...)encoder_fn.apply(...)

在下一篇文章中,我们将通过添加一个解码器完成Transformer架构,该解码器重用我们迄今为止介绍的大部分块,并学习如何使用Optax在特定任务上训练模型

结论

非常感谢您阅读至此,如果您有兴趣尝试代码,您可以在GitHub上完整注释的方式找到它,以及使用玩具数据集进行演练的额外细节和指南。

GitHub – RPegoud/jab: 用JAX实现的基础深度学习模型的集合

用JAX实现的基础深度学习模型的集合 – GitHub – RPegoud/jab: 一个收藏…

github.com

如果您想深入了解Transformer,下面的部分包含了一些帮助我撰写本文的文章。

下次见👋

参考和资源:

[1] Attention is all you need(2017),Vaswani等人,Google

[2] 在注意机制中,键、查询和值到底是什么? (2019) Stack Exchange

[3] 插图化转换器 (2018), Jay Alammar

[4] Transformer模型中位置编码的简要介绍 (2023), Mehreen Saeed, Machine Learning Mastery

图像来源

Leave a Reply

Your email address will not be published. Required fields are marked *