Press "Enter" to skip to content

多查询注意力解析

多查询注意力(MQA) 是一种注意力机制,可以加速解码器生成标记的速度,同时确保模型性能。

它被广泛应用于大语言模型时代,许多大语言模型采用了MQA,例如 FalconPaLMStarCoder等。

多头注意力(MHA)

在介绍MQA之前,让我们首先回顾一下变压器的默认注意力机制。

多头注意力是变压器模型的默认注意力机制,如图1所示:

图1

然而,当涉及文本生成时,基于变压器解码器的自回归语言模型存在一个问题。

在训练过程中,我们可以访问真实的目标序列,并且可以高效地实现并行处理。

然而,在推理过程中,每个位置的查询都会关注到该位置之前生成的所有键值对。换句话说,特定位置的自注意层的输出会影响下一个标记的生成。由于无法进行并行计算,解码过程变慢。

下面是基于变压器解码器的自回归语言模型中自注意层的解码过程:

def MHAForDecoder(x, prev_K, prev_V, P_q, P_k, P_v, P_o):    q = tf.einsum("bd, hdk−>bhk", x, P_q)    new_K = tf.concat([prev_K, tf.expand_dims(tf.einsum ("bd, hdk−>bhk", x, P_k), axis = 2)], axis = 2)    new_V = tf.concat([prev_V, tf.expand_dims(tf.einsum("bd, hdv−>bhv", x, P_v), axis = 2)], axis = 2)    logits = tf.einsum("bhk, bhmk−>bhm", q, new_K)    weights = tf.softmax(logits)    O = tf.einsum("bhm, bhmv−>bhv", weights, new_V)    Y = tf.einsum("bhv, hdv−>bd", O, P_o)    return Y, new_K, new_V

变量:

  • x:当前步骤的输入张量,即第m+1步骤,形状为 [b, d]
  • P_q, P_k:查询和键投影张量,形状为 [h, d, k]
  • P_v:值投影张量,形状为 [h, d, v]
  • P_o:学习的线性投影,形状为 [h, d, v]
  • Prev_K:上一步的键张量,形状为 [b, h, m, k]
  • Prev_V:上一步的值张量,形状为 [b, h, m, v]
  • new_K:当前步骤的键张量的加法,形状为 [b, h, m+1, k]
  • new_V:当前步骤的值张量的加法,形状为 [b, h, m+1, v]

维度:

  • m:已执行的先前步骤的数量
  • b:批量大小
  • d:输入和输出的维度
  • h:头数
  • k:Q、K张量的另一个维度
  • v:V张量的另一个维度

多查询注意力(MQA)

多查询注意力是多头注意力的变种。

MQA的方法是保持Q的原始头数,但对于K和V只有一个头。这意味着所有的Q头共享相同的K和V头,因此称为多查询,如图2所示:

图2

MQA的解码过程的代码与MHA的代码本质上是相同的,只是tf.einsum方程的K、V、P_k和P_v中表示头维度的字母”h”被移除:

def MQAForDecoder(x, prev_K, prev_V, P_q, P_k, P_v, P_o):    q = tf.einsum("bd, hdk−>bhk", x, P_q)    new_K = tf.concat([prev_K, tf.expand_dims(tf.einsum ("bd, dk−>bk", x, P_k), axis = 2)], axis = 2)    new_V = tf.concat([prev_V, tf.expand_dims(tf.einsum("bd, dv−>bv", x, P_v), axis = 2)], axis = 2)    logits = tf.einsum("bhk, bmk−>bhm", q, new_K)    weights = tf.softmax(logits)    O = tf.einsum("bhm, bmv−>bhv", weights, new_V)    Y = tf.einsum("bhv, hdv−>bd", O, P_o)    return Y, new_K, new_V

性能

MQA实际上能提高多少速度?让我们来看一下原文提供的结果图表:

多查询注意力解析 四海 第3张

从上表中可以看出,MQA在编码器上的速度改善并不是非常显著,但在解码器上则相当显著。

论文中还有关于质量的实验,结果显示MQA与基准相比仅略有降低。更多详情请参考论文,链接在本文底部。

为什么MQA能实现推理加速?

内存更高效

在MQA中,键和值张量的大小为b * k和b * v,而在MHA中,键和值的大小为b * h * k和b * h * v,其中h表示头数。

计算复杂度较低

通过使用KV缓存,在MQA的每个步骤中计算张量键和值的计算成本是MHA的1 / h,其中h表示头数。

总结

总的来说,MQA通过以下几种方法实现推理加速:

  • KV缓存大小减少了h(头数)倍,这意味着需要存储在GPU内存中的张量也减少了。节省的空间可以用来增加批量大小,从而提高效率。
  • 减少了从内存中读取的数据量,降低了计算单元的等待时间,提高了计算利用率。
  • MQA拥有相对较小的KV缓存,可以适应缓存(SRAM)中,而MHA则拥有较大的KV缓存,无法完全存储在缓存中,需要从GPU内存(DRAM)中读取,这需要耗费时间。

结论

值得一提的是,MQA提出于2019年,当时它的应用还不如今天广泛。这是因为先前的模型不需要考虑这些方面,例如,LSTM只需要维护一个状态,而不需要保留任何缓存。

变压器最初被提出时,主要用于Seq2Seq任务,特别是在编码器-解码器模型中。然而,这些模型的规模并不是很大,也没有太多的实际需求,因此MQA并没有引起太多关注。

后来,代表性的模型BERT,也是基于变压器编码器结构,进行了直接的前向传递。

直到最近基于变压器解码器的大型语言模型(如GPT)获得了广泛的应用,才发现了推理的瓶颈。因此,人们重新审视了几年前的技巧,并发现它们非常有用。换句话说,这主要是由于对大规模GPT风格生成模型的实际需求。

最后,如果这段文字中有任何错误或遗漏,请随时指出。

参考资料

MQA论文:快速变压器解码:一个写头就够了

关注力就是一切

https://paperswithcode.com/method/multi-query-attention

Leave a Reply

Your email address will not be published. Required fields are marked *