Press "Enter" to skip to content

🧨 使用云TPU v5e和JAX加速稳定的XL推理扩散

生成AI模型,例如Stable Diffusion XL(SDXL),可以创建具有广泛应用的高质量、逼真的内容。然而,利用这种模型的威力面临着重大的挑战和计算成本。SDXL是一个大型图像生成模型,其UNet组件比模型的先前版本的大约三倍。将这样的模型部署到生产环境中具有挑战性,因为它增加了内存需求,并增加了推理时间。今天,我们非常高兴地宣布,Hugging Face Diffusers现在支持使用JAX在Cloud TPUs上提供SDXL,实现高性能和高效的推理。

Google Cloud TPUs是定制的AI加速器,经过优化,用于训练和推理大型AI模型,包括最先进的语言模型和生成AI模型,例如SDXL。新的Cloud TPU v5e专为大规模AI训练和推理提供所需的成本效益和性能。TPU v5e的成本不到TPU v4的一半,使更多组织能够训练和部署AI模型成为可能。

🧨 Diffusers JAX集成提供了一种方便的方式,通过XLA在TPU上运行SDXL,我们构建了一个演示来展示它。您可以在这个空间或下面的嵌入式平台上尝试它。

在底层,这个演示在几个TPU v5e-4实例上运行(每个实例有4个TPU芯片),利用并行化在大约4秒内提供四个1024×1024大小的大图像。这个时间包括格式转换、通讯时间和前端处理;实际生成时间约为2.3秒,我们后面会看到的!

在这篇博文中,

  1. 我们描述了为什么JAX + TPU + Diffusers是运行SDXL的强大框架
  2. 解释了如何使用Diffusers和JAX编写一个简单的图像生成流水线
  3. 展示了比较不同TPU设置的基准测试

为什么选择JAX + TPU v5e运行SDXL?

利用JAX在Cloud TPU v5e上高效而经济地提供SDXL的服务,得益于专为TPU硬件构建的目标和为性能优化的软件堆栈的组合。以下我们强调两个关键因素:JAX即时编译和JAX pmap的XLA编译器驱动的并行性。

即时编译

JAX的一个显著特点是其即时编译。JIT编译器在首次运行期间跟踪代码,并生成高度优化的TPU二进制文件,后续调用中可以重复使用。这个过程的关键是需要所有输入、中间和输出的形状都是静态的,即必须事先知道。每次改变形状时,都会触发一次新的、代价高昂的编译过程。JIT编译适用于可以围绕静态形状进行设计的服务:编译只运行一次,然后我们可以利用超快的推理时间。

图像生成非常适合JIT编译。如果我们总是生成相同数量和相同尺寸的图像,那么输出形状是恒定的,事先是已知的。文本输入也是恒定的:根据设计,Stable Diffusion和SDXL使用了固定形状的嵌入向量(带有填充)来表示用户输入的提示。因此,我们可以编写依赖于固定形状的JAX代码,并进行大幅度的优化!

高性能吞吐量,适用于高批量大小

使用JAX的pmap可以在多个设备上扩展工作负载,它表达了单程序多数据(SPMD)程序。将pmap应用于函数将使用XLA编译函数,然后在各种XLA设备上并行执行。对于文本到图像生成工作负载来说,这意味着增加同时渲染的图像数量很容易实现,而且不会影响性能。例如,在具有8个芯片的TPU上运行SDXL将在1个芯片创建单个图像所需的时间内生成8个图像。

TPU v5e实例有多种形状,包括1、4和8芯片形状,最多可达256个芯片(完整的TPU v5e pod),芯片之间具有超快的ICI连接。这使您可以选择最适合您的用例的TPU形状,并轻松利用JAX和TPU提供的并行性。

如何在JAX中编写图像生成流程

我们将逐步介绍使用JAX快速运行推理所需的代码!首先,让我们导入所需的依赖项。

# Show best practices for SDXL JAXimport jaximport jax.numpy as jnpimport numpy as npfrom flax.jax_utils import replicatefrom diffusers import FlaxStableDiffusionXLPipelineimport time

现在,我们将加载基本的SDXL模型和推理所需的其他组件。Diffusers流程会为我们负责下载和缓存一切。遵循JAX的函数式方法,模型的参数将分别返回,并在推理过程中需要将它们传递给流程:

pipeline, params = FlaxStableDiffusionXLPipeline.from_pretrained(    "stabilityai/stable-diffusion-xl-base-1.0", split_head_dim=True)

模型参数默认以32位精度下载。为了节省内存和更快地运行计算,我们将它们转换为高效的16位表示bfloat16。然而,有一个注意事项:为了获得最佳结果,我们必须将调度器状态保持在float32中,否则精度误差会累积,并导致质量低或甚至黑色图像。

scheduler_state = params.pop("scheduler")params = jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16), params)params["scheduler"] = scheduler_state

现在,我们准备好设置我们的提示和管道的其余输入。

default_prompt = "高质量照片,描绘了一只在游泳池玩耍并戴着派对帽的小海豚"default_neg_prompt = "插图,低质量"default_seed = 33default_guidance_scale = 5.0default_num_steps = 25

提示必须作为张量提供给管道,并且在每次调用时始终具有相同的维度。这样可以编译推理调用。管道的prepare_inputs方法为我们执行所有必要的步骤,因此我们将创建一个帮助函数来准备提示和负面提示的张量。我们稍后将在generate函数中使用它:

def tokenize_prompt(prompt, neg_prompt):    prompt_ids = pipeline.prepare_inputs(prompt)    neg_prompt_ids = pipeline.prepare_inputs(neg_prompt)    return prompt_ids, neg_prompt_ids

为了利用并行化,我们将在设备之间复制输入。Cloud TPU v5e-4有4个芯片,因此通过复制输入,我们可以使每个芯片并行生成不同的图像。我们需要小心向每个芯片提供不同的随机种子,以使这4个图像不同:

NUM_DEVICES = jax.device_count()# 模型参数在推理过程中不会改变,所以我们只需要复制一次.p_params = replicate(params)def replicate_all(prompt_ids, neg_prompt_ids, seed):    p_prompt_ids = replicate(prompt_ids)    p_neg_prompt_ids = replicate(neg_prompt_ids)    rng = jax.random.PRNGKey(seed)    rng = jax.random.split(rng, NUM_DEVICES)    return p_prompt_ids, p_neg_prompt_ids, rng

现在,我们准备好在生成函数中将所有内容组合在一起:

def generate(    prompt,    negative_prompt,    seed=default_seed,    guidance_scale=default_guidance_scale,    num_inference_steps=default_num_steps,):    prompt_ids, neg_prompt_ids = tokenize_prompt(prompt, negative_prompt)    prompt_ids, neg_prompt_ids, rng = replicate_all(prompt_ids, neg_prompt_ids, seed)    images = pipeline(        prompt_ids,        p_params,        rng,        num_inference_steps=num_inference_steps,        neg_prompt_ids=neg_prompt_ids,        guidance_scale=guidance_scale,        jit=True,    ).images    # 将图像转换为PIL    images = images.reshape((images.shape[0] * images.shape[1], ) + images.shape[-3:])    return pipeline.numpy_to_pil(np.array(images))

jit=True表示我们希望对管道调用进行编译。这将在我们第一次调用generate时发生,并且非常慢 – JAX需要跟踪操作,优化它们,并将它们转换为低级原语。我们将进行第一次生成以完成此过程并使其热身:

start = time.time()print(f"编译中...")generate(default_prompt, default_neg_prompt)print(f"编译完成,耗时 {time.time() - start} 秒")

第一次运行大约花费了三分钟。但是一旦代码被编译,推断将会非常快。让我们再试一次!

start = time.time()prompt = "古希腊的美洲驼,画布上的油画"neg_prompt = "卡通,插图,动画"images = generate(prompt, neg_prompt)print(f"推断耗时 {time.time() - start} 秒")

现在生成4张图片只需要大约2秒钟!

基准测试

以下数据是在默认的Euler离散调度器下使用SDXL 1.0 base进行20个步骤的运行结果。我们比较了Cloud TPU v5e和TPUv4在相同批量大小下的性能。请注意,由于并行处理的原因,我们在演示中使用的TPU v5e-4类似的设备在批量大小为1时会生成4张图片(批量大小为2时为8张图片)。同样地,使用批量大小为1的TPU v5e-8会生成8张图片。

云TPU测试使用的是Python 3.10和jax版本0.4.16。这些规格与我们的演示空间中使用的规格相同。

与TPU v4相比,TPU v5e在SDXL上的性能/美元比达到了最高2.4倍,展示了最新TPU一代的成本效益。

为了衡量推断性能,我们使用行业标准的吞吐量指标。首先,我们测量模型编译和加载后每张图片的延迟。然后,通过将批量大小除以每个芯片的延迟来计算吞吐量。因此,吞吐量衡量了模型在生产环境中的性能,而不管使用了多少个芯片。然后,我们将吞吐量除以列表价格以获得性能每美元。

演示如何工作?

之前展示的演示是使用一个脚本构建的,该脚本基本上按照我们在博客文章中发布的代码进行操作。它在几个Cloud TPU v5e设备上运行,每个设备有4个芯片,并有一个简单的负载均衡服务器将用户请求随机路由到后端服务器。当您在演示中输入提示时,您的请求将分配给其中一个后端服务器,并将收到它生成的4张图片。

这是一个基于几个预分配的TPU实例的简单解决方案。在未来的文章中,我们将介绍如何使用GKE创建根据负载自适应的动态解决方案。

演示中的所有代码都是开源的,并且在Hugging Face Diffusers上可用。我们很期待看到您如何利用Diffusers + JAX + Cloud TPUs构建!

Leave a Reply

Your email address will not be published. Required fields are marked *