Press "Enter" to skip to content

🧨 在JAX / Flax中稳定扩散!

🧨 在JAX / Flax中稳定扩散! 四海 第1张

🤗 Hugging Face Diffusers 自版本 0.5.1 起支持 Flax!这使得可以在 Google TPUs 上进行超快速推理,例如在 Colab、Kaggle 或 Google Cloud Platform 上可用的 TPUs。

本文展示了如何使用 JAX / Flax 进行推理。如果您想了解 Stable Diffusion 的更多细节,或者想在 GPU 上运行它,请参考此 Colab 笔记本。

如果您想跟随操作,请点击上方的按钮将此文章作为 Colab 笔记本打开。

首先,确保您正在使用 TPU 后端。如果您在 Colab 上运行此笔记本,请选择上方的 Runtime 菜单,然后选择 “Change runtime type” 选项,然后在 Hardware accelerator 设置下选择 TPU

请注意,JAX 不仅限于 TPUs,但在该硬件上表现出色,因为每个 TPU 服务器都有 8 个并行工作的 TPU 加速器。

设置

import jax
num_devices = jax.device_count()
device_type = jax.devices()[0].device_kind

print(f"发现 {num_devices} 个类型为 {device_type} 的 JAX 设备。")
assert "TPU" in device_type, "可用设备不是 TPU,请从 Edit > Notebook settings > Hardware accelerator 中选择 TPU"

输出:

    发现 8 个类型为 TPU v2 的 JAX 设备。

确保已安装 diffusers

!pip install diffusers==0.5.1

然后导入所有依赖项。

import numpy as np
import jax
import jax.numpy as jnp

from pathlib import Path
from jax import pmap
from flax.jax_utils import replicate
from flax.training.common_utils import shard
from PIL import Image

from huggingface_hub import notebook_login
from diffusers import FlaxStableDiffusionPipeline

模型加载

在使用模型之前,您需要接受模型许可证以便下载和使用权重。

该许可证旨在减轻这种强大机器学习系统可能带来的潜在有害影响。我们要求用户完整仔细阅读许可证。以下是许可证的摘要:

  1. 您不能使用模型故意生成或共享非法或有害的输出或内容,
  2. 我们对您生成的输出不享有任何权利,您可以自由使用它们,并对其使用负责,不得违反许可证中的规定,
  3. 您可以重新分发权重并将模型用于商业用途和/或作为服务。如果您这样做,请注意您必须包含与许可证中相同的使用限制,并向所有用户共享 CreativeML OpenRAIL-M 的副本。

Flax 权重可在 Hugging Face Hub 上作为 Stable Diffusion 存储库的一部分获得。Stable Diffusion 模型根据 CreateML OpenRail-M 许可证进行分发。这是一种开放许可证,对您生成的输出不享有任何权利,并禁止您故意生成非法或有害的内容。模型卡提供了更多细节,请花些时间阅读并仔细考虑是否接受许可证。如果接受,请在 Hub 中注册用户并使用访问令牌使代码工作。有两种方法可以提供访问令牌:

  • 在终端中使用 huggingface-cli login 命令行工具,并在提示时粘贴您的令牌。它将保存在您的计算机上的文件中。
  • 或在笔记本中使用 notebook_login(),其功能相同。

以下单元格将呈现一个登录界面,除非您之前在此计算机上进行过身份验证。您需要粘贴您的访问令牌。

if not (Path.home()/'.huggingface'/'token').exists(): notebook_login()

TPU 设备支持 bfloat16,一种高效的半浮点类型。我们将在测试中使用它,但您也可以使用 float32 来使用完整的精度。

dtype = jnp.bfloat16

Flax是一个功能性框架,所以模型是无状态的,参数存储在模型外部。加载预训练的Flax管道将返回管道本身和模型权重(或参数)。我们正在使用权重的bf16版本,这会导致类型警告,您可以安全地忽略。

pipeline,params = FlaxStableDiffusionPipeline.from_pretrained(
    "CompVis/stable-diffusion-v1-4",
    revision="bf16",
    dtype=dtype,
)

推理

由于TPU通常有8个并行工作的设备,我们将多次复制我们的提示,每个设备一次。然后,我们将同时在8个设备上执行推理,每个设备负责生成一张图片。因此,我们将在相同的时间内获得8张图片,这与一块芯片生成一张图片所需的时间相同。

在复制提示后,我们通过调用管道的prepare_inputs函数来获取令牌化的文本ID。令牌化文本的长度设置为77个令牌,这是基础CLIP文本模型的配置要求。

prompt = "Morgan Freeman扮演Jimi Hendrix的电影剧照,画像,40mm镜头,浅景深,特写,分割光线,电影"
prompt = [prompt] * jax.device_count()
prompt_ids = pipeline.prepare_inputs(prompt)
prompt_ids.shape

输出:

    (8, 77)

复制和并行化

模型参数和输入必须在我们拥有的8个并行设备上复制。使用flax.jax_utils.replicate复制参数字典,它遍历字典并更改权重的形状,使其重复8次。使用shard复制数组。

p_params = replicate(params)

prompt_ids = shard(prompt_ids)
prompt_ids.shape

输出:

    (8, 1, 77)

这个形状意味着每个8个设备将接收一个形状为(1, 77)jnp数组作为输入。因此,1是每个设备的批量大小。在具有足够内存的TPU中,如果我们想要一次生成多张图像(每个芯片),批量大小可以大于1

我们几乎准备好生成图像了!我们只需要创建一个随机数生成器传递给生成函数。这是Flax的标准过程,它对随机数非常认真和有意见 – 所有处理随机数的函数都应该接收一个生成器。这样可以确保即使我们在多个分布式设备上进行训练,也能保证可重复性。

下面的辅助函数使用种子初始化随机数生成器。只要我们使用相同的种子,就会得到完全相同的结果。在笔记本中稍后探索结果时,可以使用不同的种子。

def create_key(seed=0):
    return jax.random.PRNGKey(seed)

我们获得一个随机数生成器,然后将其“划分”为8个部分,以便每个设备接收一个不同的生成器。因此,每个设备将创建一个不同的图像,整个过程是可重复的。

rng = create_key(0)
rng = jax.random.split(rng, jax.device_count())

JAX代码可以编译为高效的表示形式,运行非常快。然而,我们需要确保后续调用中所有的输入具有相同的形状;否则,JAX将不得不重新编译代码,我们将无法利用优化的速度。

如果我们传递jit = True作为参数,Flax管道可以为我们编译代码。它还将确保模型在8个可用设备上并行运行。

第一次运行下面的单元格时,编译需要很长时间,但是后续调用(即使输入不同)将快得多。例如,在我测试的TPU v2-8中,编译需要超过一分钟的时间,但是以后的推理运行大约需要7s

images = pipeline(prompt_ids, p_params, rng, jit=True)[0]

输出:

    CPU 时间:用户 464 毫秒,系统:105 毫秒,总共:569 毫秒
    墙上时间:7.07 秒

返回的数组的形状为(8, 1, 512, 512, 3)。我们将其重塑以去除第二维,并获得512×512×3的8个图像,然后将它们转换为PIL格式。

images = images.reshape((images.shape[0],) + images.shape[-3:])
images = pipeline.numpy_to_pil(images)

可视化

让我们创建一个辅助函数以在网格中显示图像。

def image_grid(imgs, rows, cols):
    w,h = imgs[0].size
    grid = Image.new('RGB', size=(cols*w, rows*h))
    for i, img in enumerate(imgs): grid.paste(img, box=(i%cols*w, i//cols*h))
    return grid

image_grid(images, 2, 4)

🧨 在JAX / Flax中稳定扩散! 四海 第2张

使用不同的提示

我们不必在所有设备上复制相同的提示。我们可以做任何我们想做的事情:生成2个提示,每个重复4次,甚至一次生成8个不同的提示。让我们这样做!

首先,我们将将输入准备代码重构为一个方便的函数:

prompts = [
    "以北斗之神风格绘制的拉布拉多犬",
    "一只松鼠在纽约滑冰的绘画作品",
    "以梵高风格绘制的HAL-9000",
    "时代广场被水淹没,周围有鱼和海豚在游泳",
    "描绘一名男子在笔记本电脑上工作的古罗马壁画",
    "高质量的近景摄影,展示了年轻的黑人女性与城市背景,有迷离的效果",
    "鳄梨形状的扶手椅",
    "太空中的小丑宇航员,背景是地球",
]

prompt_ids = pipeline.prepare_inputs(prompts)
prompt_ids = shard(prompt_ids)
images = pipeline(prompt_ids, p_params, rng, jit=True).images
images = images.reshape((images.shape[0], ) + images.shape[-3:])
images = pipeline.numpy_to_pil(images)
image_grid(images, 2, 4)

🧨 在JAX / Flax中稳定扩散! 四海 第3张


并行化是如何工作的?

我们之前说过,diffusers Flax流水线会自动编译模型并在所有可用设备上并行运行。现在,我们将简要介绍一下该过程的内部,以展示其工作原理。

JAX并行化可以以多种方式进行。最简单的方法是使用jax.pmap函数实现单程序、多数据(SPMD)并行化。这意味着我们将在不同的数据输入上运行同一份代码的多个副本。更复杂的方法也是可行的,如果您感兴趣,我们邀请您阅读JAX文档和pjit页面来探索这个主题!

jax.pmap为我们完成了两件事:

  • 编译(或jit)代码,就好像我们调用了jax.jit()一样。这不会在调用pmap时发生,而是在第一次调用pmapped函数时发生。
  • 确保编译后的代码在所有可用设备上并行运行。

为了展示它是如何工作的,我们将对流水线的_generate方法使用pmap,这是运行图像生成的私有方法。请注意,该方法在将来的diffusers版本中可能会更名或删除。

p_generate = pmap(pipeline._generate)

在使用了pmap之后,准备好的函数p_generate将在概念上执行以下操作:

  • 在每个设备上调用底层函数pipeline._generate的副本。
  • 将每个设备发送不同部分的输入参数。这就是分片用于的目的。在我们的案例中,prompt_ids的形状为(8, 1, 77, 768)。该数组将被分割成8个部分,每个_generate的副本将接收到一个形状为(1, 77, 768)的输入。

我们可以完全忽略将会以并行方式调用的_generate函数的事实来编写代码。我们只关心我们的批次大小(在此示例中为1)以及对我们的代码有意义的维度,不需要改变任何东西就可以使其在并行中工作。

就像我们使用管道调用时一样,第一次运行以下单元格时需要一些时间,但之后会更快。

images = p_generate(prompt_ids, p_params, rng)
images = images.block_until_ready()
images.shape

输出:

    CPU 时间:用户 118 毫秒,系统:83.9 毫秒,总计:202 毫秒
    真实时间:6.82 秒

    (8, 1, 512, 512, 3)

我们使用block_until_ready()来正确测量推理时间,因为JAX使用异步调度,并尽快将控制权返回给Python循环。您不需要在代码中使用它;当您想使用尚未实现的计算结果时,阻塞将自动发生。

Leave a Reply

Your email address will not be published. Required fields are marked *