我们正处于机器学习(ML)广泛应用的激动人心的拐点上,我们相信大多数客户体验和应用将通过生成式人工智能(AI)重新创造。生成式AI可以创建新的内容和想法,包括对话、故事、图像、视频和音乐。与大多数AI一样,生成式AI由机器学习模型驱动,这些模型是基于大量数据进行训练的非常庞大的模型,通常被称为基础模型(FMs)。FMs基于transformers。由于模型的庞大规模,transformers 在生成长文本序列时速度较慢且占用大量内存。用于生成文本序列的大型语言模型(LLMs)需要巨大的计算能力,并且难以访问可用的高带宽内存(HBM)和计算能力。这是因为大部分可用内存带宽被模型参数加载和自回归解码过程消耗。因此,即使拥有大量计算能力,LLMs 仍受到内存I/O和计算限制的限制,无法充分利用可用的硬件资源。
总体而言,LLMs的生成式推理有三个主要挑战(根据Pope等人2022年的研究):
- 由于庞大的模型参数和解码过程中的临时状态,内存占用量较大。这些参数往往超过单个加速器芯片的内存。注意力键值缓存也需要大量内存。
- 低并行性增加了延迟,尤其是在内存占用量较大的情况下,需要将参数和缓存数据传输到计算核心中的每个步骤。这导致总内存带宽需求很高,以满足延迟目标。
- 相对于序列长度,注意力机制计算的计算量呈二次扩展,加剧了延迟和计算挑战。
批处理是解决这些挑战的技术之一。批处理是指将多个输入序列一起发送给LLM,从而优化LLM推理的性能。这种方法有助于提高吞吐量,因为不需要为每个输入序列加载模型参数。参数可以加载一次并用于处理多个输入序列。批处理有效地利用了加速器的HBM带宽,提高了计算利用率,改善了吞吐量,并实现了成本效益的推理。
本文讨论了使用批处理技术来最大化吞吐量的技术,用于LLMs中并行生成式推理。我们讨论了不同的批处理方法,以减少内存占用量,增加并行性,并减轻注意力二次扩展对吞吐量的影响。我们的目标是充分利用像HBM和加速器这样的硬件资源,克服内存、I/O和计算中的瓶颈。然后,我们重点介绍了Amazon SageMaker大型模型推理(LMI)深度学习容器(DLCs)如何帮助实现这些技术。最后,我们在SageMaker上使用LMI DLCs对每个批处理策略的吞吐量改进进行了比较分析,以提高像Llama v2这样的模型的吞吐量。您可以在SageMaker示例GitHub存储库中找到相应的示例笔记本。
大型语言模型(LLMs)的推理
自回归解码是像GPT这样的语言模型生成文本输出的过程,它通过将生成的标记递归地反馈到模型中作为输入序列的一部分来预测后续标记。具体步骤如下:
- 模型接收先前的标记序列作为输入。对于第一步,这是用户提供的起始提示。
- 模型预测下一个标记的词汇分布。
- 选择具有最高预测概率的标记,并将其附加到输出序列中。步骤2和3是解码的一部分。截至撰写本文时,最常见的解码方法有贪婪搜索、波束搜索、对比搜索和抽样。
- 将这个新的标记添加到下一个解码步骤的输入序列中。
- 模型重复执行这些步骤,每一步生成一个新的标记,直到产生一个序列结束标记或达到所需的输出长度。
LLMs的模型服务
LLMs的模型服务是指接收用于文本生成的输入请求、进行推理并将结果返回给请求的应用程序的过程。模型服务涉及以下关键概念:
- 客户端生成多个推理请求,每个请求包含一系列标记或输入提示
- 推理服务器(例如DJLServing、TorchServe、Triton或Hugging Face TGI)接收请求
- 推理服务器将推理请求进行批处理,并将批次安排给执行引擎,该引擎包括模型分区库(例如Transformers-NeuronX、DeepSpeed、Accelerate或FasterTransformer),用于在生成式语言模型上运行前向传播(预测输出标记序列)
- 执行引擎生成响应标记,并将响应发送回推理服务器
- 推理服务器用生成的结果回复客户端
当推理服务器与执行引擎在请求级别进行交互时,存在请求级别调度的挑战,例如每个请求使用一个Python进程,这需要一个单独的模型副本,这对内存有限制。例如,如下图所示,您只能在具有96 GB总加速器设备内存的机器学习(ML)实例上加载一个大小为80 GB的模型的副本。如果要同时处理其他请求,则需要加载整个模型的附加副本。这不是内存和成本高效的。
现在我们了解了请求级别调度带来的挑战,让我们来看一下可以帮助优化吞吐量的不同批处理技术。
批处理技术
在本节中,我们解释不同的批处理技术,并展示如何使用SageMaker LMI容器实现它们。
推理请求的批处理有两种主要类型:
- 客户端端(静态) – 通常,当客户端向服务器发送请求时,默认情况下服务器会顺序处理每个请求,这对于吞吐量来说不是最佳选择。为了优化吞吐量,客户端会将推理请求批处理为单个有效载荷,服务器实现预处理逻辑,将批处理拆分为多个请求,并分别运行推理。在此选项中,客户端需要更改批处理的代码,该解决方案与批处理大小紧密耦合。
- 服务器端(动态) – 另一种批处理技术是使用推理来帮助在服务器端进行批处理。当独立的推理请求到达服务器时,推理服务器可以在服务器端动态地将它们分组到较大的批次中。推理服务器可以管理批处理以满足指定的延迟目标,最大限度地提高吞吐量,同时保持在所需的延迟范围内。推理服务器会自动处理这个过程,因此不需要更改客户端代码。服务器端批处理包括不同的技术,以进一步优化生成式语言模型的吞吐量,这些批处理技术包括动态批处理、连续批处理和PagedAttention(vLLM)批处理。
动态批处理
动态批处理是指将输入请求组合在一起,并作为一个批次发送进行推理。动态批处理是一种通用的服务器端批处理技术,适用于所有任务,包括计算机视觉(CV)、自然语言处理(NLP)等。
在LMI容器中,可以基于serving.properties中的以下设置配置请求的批处理:
- batch_size – 批处理大小
- max_batch_delay – 批聚合的最大延迟
如果满足这些阈值之一(达到最大批处理大小或等待期限完成),则准备一个新的批次并将其推送到模型进行推理。以下图示显示了具有不同输入序列长度的请求的动态批处理由模型一起处理。
您可以通过配置LMI容器的serving.properties在SageMaker上实现动态批处理,如下所示:
#动态批处理
engine=Python
option.entryPoint=djl_python.huggingface
batch_size=64 #示例
max_batch_delay=1000 #示例
option.tensor_parallel_degree=2 #示例
尽管与无批处理相比,动态批处理可以提供高达四倍的吞吐量增加,但我们观察到在这种情况下GPU利用率不是最佳的,因为系统在所有请求完成处理之前无法接受另一个批次。
连续批处理
连续批处理是文本生成的一种优化方法。它提高了吞吐量,而不会牺牲首字节延迟时间。连续批处理(也称为迭代批处理或滚动批处理)解决了空闲GPU时间的挑战,并在动态批处理方法的基础上不断地将新请求推入批处理中。下图显示了请求的连续批处理。当请求2和3完成处理时,将安排另一组请求。
下面的交互式图解深入介绍了连续批处理的工作原理。
(来源:https://github.com/InternLM/lmdeploy)
您可以使用一种强大的技术来提高LLM和文本生成的效率:缓存一些注意力矩阵。这意味着提示的第一次通行与随后的前向通行不同。对于第一次通行,您必须计算整个注意力矩阵,而后续通行只需要计算新的令牌注意力。本代码库将第一次通行称为预填充(prefill),而后续通行称为解码(decode)。由于预填充比解码更加昂贵,我们不希望一直进行预填充,但当前正在运行的查询可能正在进行解码。如果我们想要使用前面解释的连续批处理方法,我们需要在某个时刻运行预填充,以创建所需的注意力矩阵,以便能够加入解码组。
与不进行批处理相比,这种技术可以使吞吐量增加多达20倍,有效利用空闲的GPU。
您可以在LMI容器的serving.properties
文件中调整以下参数以使用连续批处理:
- engine – 代码的运行时引擎。可选值包括
Python
、DeepSpeed
、FasterTransformer
和MPI
。使用MPI
来启用连续批处理。 - rolling_batch – 使用支持的策略启用迭代级别批处理。可选值包括
auto
、scheduler
和lmi-dist
。我们在Llama 2中使用lmi-dist
来启用连续批处理。 - max_rolling_batch_size – 限制连续批处理中的并发请求数量。默认为32。
- max_rolling_batch_prefill_tokens – 限制缓存的令牌数量。这需要根据批处理大小和输入序列长度进行调整,以避免GPU内存不足。仅当
rolling_batch=lmi-dist
时支持此选项。我们建议根据每个请求所需的输入令牌和输出令牌的数量乘以并发请求的数量来设置该值。
以下是用于配置连续批处理的serving.properties
示例代码:
#连续批处理
engine=MPI
option.entryPoint=djl_python.huggingface
option.rolling_batch=auto
option.max_rolling_batch_size=64 #示例
option.paged_attention=false
option.max_rolling_batch_prefill_tokens=16080 #示例
option.tensor_parallel_degree=2 #示例
PagedAttention批处理
在自回归解码过程中,LLM的所有输入令牌都会生成它们的关注键(attention key)和值(value)张量,并将这些张量保留在GPU内存中以生成下一个令牌。这些缓存的关注键和值张量通常称为KV缓存或注意力缓存。根据论文vLLM: Easy, Fast, and Cheap LLM Serving with PagedAttention,KV缓存在Llama 13B中占用高达1.7GB的内存。它也是动态的,其大小取决于高度可变和不可预测的序列长度。因此,有效管理KV缓存是一个重大挑战。该论文发现,现有系统由于碎片化和过度预留而浪费了60-80%的内存。
PagedAttention 是由加州大学伯克利分校开发的一种新的优化算法,通过允许注意力缓存(KV缓存)在内存中以固定大小的页面或块的方式进行非连续分配,改进了连续批处理过程。这受到操作系统使用的虚拟内存和分页概念的启发。
根据 vLLM 论文,每个令牌序列的注意力缓存被分成块,并通过块表映射到物理块。在计算注意力时,PagedAttention 内核可以使用块表从物理内存中高效地获取块。这大大减少了内存浪费,并允许更大的批量大小、增加的 GPU 利用率和更高的吞吐量。下图说明了将注意力缓存分割为非连续页面的过程。
下图显示了使用 PagedAttention 的推理示例。关键步骤如下:
- 接收带有输入提示的推理请求。
- 在预填充阶段,计算注意力并将键值存储在非连续的物理内存中,并映射到逻辑键值块。这个映射存储在一个块表中。
- 通过模型运行输入提示(前向传递)生成第一个响应令牌。在生成响应令牌时,使用预填充阶段的注意力缓存。
- 在后续的令牌生成过程中,如果当前的物理块已满,则以非连续的方式分配额外的内存,实现即时分配。
PagedAttention 有助于实现近乎最佳的内存使用和减少内存浪费。这允许更多请求一起进行批处理,从而显著提高推理的吞吐量。
下面的代码是在 SageMaker 上配置 LMI 容器中 PagedAttention 批处理的示例 serving.properties
:
#Paged Attention Batching
engine=MPI
option.entryPoint=djl_python.huggingface
option.rolling_batch=auto
option.max_rolling_batch_size=64 #example
option.paged_attention=true
option.max_rolling_batch_prefill_tokens=16080 #example
option.tensor_parallel_degree=2 #example
何时使用哪种批处理技术
下图总结了 LMI 在 SageMaker 上的服务器端批处理技术以及示例 serving.properties
。
下表总结了不同的批处理技术及其使用场景。
PagedAttention 批处理 | 连续批处理 | 动态批处理 | 客户端批处理 | 无批处理 | |
工作原理 | 始终在令牌级别合并新请求,同时使用分页块进行批量推理。 | 始终在令牌级别合并新请求,并进行批量推理。 | 将新请求合并到请求级别;可以延迟几毫秒以形成批处理。 | 客户端负责在发送到推理服务器之前对多个推理请求进行批处理。 | 当请求到达时,立即运行推理。 |
最佳适用场景 | 这是支持的仅解码模型的推荐方法。适用于吞吐量优化的工作负载。仅适用于文本生成模型。 | 以不同时间到达的并发请求,使用相同的解码策略。适用于吞吐量优化的工作负载。仅适用于文本生成模型。 | 以不同时间到达的并发请求,使用相同的解码策略。适用于需要更高吞吐量的响应时间敏感的工作负载。适用于计算机视觉(CV)、自然语言处理(NLP)和其他类型的模型。 | 适用于没有延迟约束以最大化吞吐量的离线推理用例。 | 不频繁的推理请求或具有不同解码策略的推理请求。适用于具有严格响应时间延迟需求的工作负载。 |
使用SageMaker比较不同批处理技术的吞吐量
我们在SageMaker上使用LMI容器和本文讨论的不同批处理技术对Llama v2 7B模型进行了性能基准测试,同时并发进入请求为50,总请求数为5,000。
我们在性能测试中使用了三个不同长度的输入提示。在连续和PagedAttention批处理中,三个输入提示的输出令牌长度分别设置为64、128和256。对于动态批处理,我们使用了一个固定的输出令牌长度(128个令牌)。我们使用ml.g5.24xlarge实例类型部署了SageMaker端点进行测试。以下表格包含了性能基准测试的结果。
模型 | 批处理策略 | 在ml.g5.24xlarge上的每秒请求次数 |
LLaMA2-7b | 动态批处理 | 3.24 |
LLaMA2-7b | 连续批处理 | 6.92 |
LLaMA2-7b | PagedAttention批处理 | 7.41 |
通过使用PagedAttention批处理,相比于动态批处理,我们发现Llama2-7B模型在SageMaker上使用LMI容器的吞吐量增加了约2.3倍。
结论
在本文中,我们解释了LLMs推理的不同批处理技术以及它如何帮助提高吞吐量。我们展示了如何通过使用连续和PagedAttention批处理来提高硬件效率,并提供比动态批处理更高的吞吐量值。我们发现,通过使用PagedAttention批处理,相比于动态批处理,Llama2-7B模型在SageMaker上使用LMI容器的吞吐量增加了约2.3倍。您可以在GitHub上找到用于测试不同批处理技术的笔记本。