🤗 Hugging Face Diffusers 自版本 0.5.1
起支持 Flax!这使得可以在 Google TPUs 上进行超快速推理,例如在 Colab、Kaggle 或 Google Cloud Platform 上可用的 TPUs。
本文展示了如何使用 JAX / Flax 进行推理。如果您想了解 Stable Diffusion 的更多细节,或者想在 GPU 上运行它,请参考此 Colab 笔记本。
如果您想跟随操作,请点击上方的按钮将此文章作为 Colab 笔记本打开。
首先,确保您正在使用 TPU 后端。如果您在 Colab 上运行此笔记本,请选择上方的 Runtime
菜单,然后选择 “Change runtime type” 选项,然后在 Hardware accelerator
设置下选择 TPU
。
请注意,JAX 不仅限于 TPUs,但在该硬件上表现出色,因为每个 TPU 服务器都有 8 个并行工作的 TPU 加速器。
设置
import jax
num_devices = jax.device_count()
device_type = jax.devices()[0].device_kind
print(f"发现 {num_devices} 个类型为 {device_type} 的 JAX 设备。")
assert "TPU" in device_type, "可用设备不是 TPU,请从 Edit > Notebook settings > Hardware accelerator 中选择 TPU"
输出:
发现 8 个类型为 TPU v2 的 JAX 设备。
确保已安装 diffusers
。
!pip install diffusers==0.5.1
然后导入所有依赖项。
import numpy as np
import jax
import jax.numpy as jnp
from pathlib import Path
from jax import pmap
from flax.jax_utils import replicate
from flax.training.common_utils import shard
from PIL import Image
from huggingface_hub import notebook_login
from diffusers import FlaxStableDiffusionPipeline
模型加载
在使用模型之前,您需要接受模型许可证以便下载和使用权重。
该许可证旨在减轻这种强大机器学习系统可能带来的潜在有害影响。我们要求用户完整仔细阅读许可证。以下是许可证的摘要:
- 您不能使用模型故意生成或共享非法或有害的输出或内容,
- 我们对您生成的输出不享有任何权利,您可以自由使用它们,并对其使用负责,不得违反许可证中的规定,
- 您可以重新分发权重并将模型用于商业用途和/或作为服务。如果您这样做,请注意您必须包含与许可证中相同的使用限制,并向所有用户共享 CreativeML OpenRAIL-M 的副本。
Flax 权重可在 Hugging Face Hub 上作为 Stable Diffusion 存储库的一部分获得。Stable Diffusion 模型根据 CreateML OpenRail-M 许可证进行分发。这是一种开放许可证,对您生成的输出不享有任何权利,并禁止您故意生成非法或有害的内容。模型卡提供了更多细节,请花些时间阅读并仔细考虑是否接受许可证。如果接受,请在 Hub 中注册用户并使用访问令牌使代码工作。有两种方法可以提供访问令牌:
- 在终端中使用
huggingface-cli login
命令行工具,并在提示时粘贴您的令牌。它将保存在您的计算机上的文件中。 - 或在笔记本中使用
notebook_login()
,其功能相同。
以下单元格将呈现一个登录界面,除非您之前在此计算机上进行过身份验证。您需要粘贴您的访问令牌。
if not (Path.home()/'.huggingface'/'token').exists(): notebook_login()
TPU 设备支持 bfloat16
,一种高效的半浮点类型。我们将在测试中使用它,但您也可以使用 float32
来使用完整的精度。
dtype = jnp.bfloat16
Flax是一个功能性框架,所以模型是无状态的,参数存储在模型外部。加载预训练的Flax管道将返回管道本身和模型权重(或参数)。我们正在使用权重的bf16
版本,这会导致类型警告,您可以安全地忽略。
pipeline,params = FlaxStableDiffusionPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4",
revision="bf16",
dtype=dtype,
)
推理
由于TPU通常有8个并行工作的设备,我们将多次复制我们的提示,每个设备一次。然后,我们将同时在8个设备上执行推理,每个设备负责生成一张图片。因此,我们将在相同的时间内获得8张图片,这与一块芯片生成一张图片所需的时间相同。
在复制提示后,我们通过调用管道的prepare_inputs
函数来获取令牌化的文本ID。令牌化文本的长度设置为77个令牌,这是基础CLIP文本模型的配置要求。
prompt = "Morgan Freeman扮演Jimi Hendrix的电影剧照,画像,40mm镜头,浅景深,特写,分割光线,电影"
prompt = [prompt] * jax.device_count()
prompt_ids = pipeline.prepare_inputs(prompt)
prompt_ids.shape
输出:
(8, 77)
复制和并行化
模型参数和输入必须在我们拥有的8个并行设备上复制。使用flax.jax_utils.replicate
复制参数字典,它遍历字典并更改权重的形状,使其重复8次。使用shard
复制数组。
p_params = replicate(params)
prompt_ids = shard(prompt_ids)
prompt_ids.shape
输出:
(8, 1, 77)
这个形状意味着每个8
个设备将接收一个形状为(1, 77)
的jnp
数组作为输入。因此,1
是每个设备的批量大小。在具有足够内存的TPU中,如果我们想要一次生成多张图像(每个芯片),批量大小可以大于1
。
我们几乎准备好生成图像了!我们只需要创建一个随机数生成器传递给生成函数。这是Flax的标准过程,它对随机数非常认真和有意见 – 所有处理随机数的函数都应该接收一个生成器。这样可以确保即使我们在多个分布式设备上进行训练,也能保证可重复性。
下面的辅助函数使用种子初始化随机数生成器。只要我们使用相同的种子,就会得到完全相同的结果。在笔记本中稍后探索结果时,可以使用不同的种子。
def create_key(seed=0):
return jax.random.PRNGKey(seed)
我们获得一个随机数生成器,然后将其“划分”为8个部分,以便每个设备接收一个不同的生成器。因此,每个设备将创建一个不同的图像,整个过程是可重复的。
rng = create_key(0)
rng = jax.random.split(rng, jax.device_count())
JAX代码可以编译为高效的表示形式,运行非常快。然而,我们需要确保后续调用中所有的输入具有相同的形状;否则,JAX将不得不重新编译代码,我们将无法利用优化的速度。
如果我们传递jit = True
作为参数,Flax管道可以为我们编译代码。它还将确保模型在8个可用设备上并行运行。
第一次运行下面的单元格时,编译需要很长时间,但是后续调用(即使输入不同)将快得多。例如,在我测试的TPU v2-8中,编译需要超过一分钟的时间,但是以后的推理运行大约需要7s
。
images = pipeline(prompt_ids, p_params, rng, jit=True)[0]
输出:
CPU 时间:用户 464 毫秒,系统:105 毫秒,总共:569 毫秒
墙上时间:7.07 秒
返回的数组的形状为(8, 1, 512, 512, 3)
。我们将其重塑以去除第二维,并获得512×512×3
的8个图像,然后将它们转换为PIL格式。
images = images.reshape((images.shape[0],) + images.shape[-3:])
images = pipeline.numpy_to_pil(images)
可视化
让我们创建一个辅助函数以在网格中显示图像。
def image_grid(imgs, rows, cols):
w,h = imgs[0].size
grid = Image.new('RGB', size=(cols*w, rows*h))
for i, img in enumerate(imgs): grid.paste(img, box=(i%cols*w, i//cols*h))
return grid
image_grid(images, 2, 4)
使用不同的提示
我们不必在所有设备上复制相同的提示。我们可以做任何我们想做的事情:生成2个提示,每个重复4次,甚至一次生成8个不同的提示。让我们这样做!
首先,我们将将输入准备代码重构为一个方便的函数:
prompts = [
"以北斗之神风格绘制的拉布拉多犬",
"一只松鼠在纽约滑冰的绘画作品",
"以梵高风格绘制的HAL-9000",
"时代广场被水淹没,周围有鱼和海豚在游泳",
"描绘一名男子在笔记本电脑上工作的古罗马壁画",
"高质量的近景摄影,展示了年轻的黑人女性与城市背景,有迷离的效果",
"鳄梨形状的扶手椅",
"太空中的小丑宇航员,背景是地球",
]
prompt_ids = pipeline.prepare_inputs(prompts)
prompt_ids = shard(prompt_ids)
images = pipeline(prompt_ids, p_params, rng, jit=True).images
images = images.reshape((images.shape[0], ) + images.shape[-3:])
images = pipeline.numpy_to_pil(images)
image_grid(images, 2, 4)
并行化是如何工作的?
我们之前说过,diffusers
Flax流水线会自动编译模型并在所有可用设备上并行运行。现在,我们将简要介绍一下该过程的内部,以展示其工作原理。
JAX并行化可以以多种方式进行。最简单的方法是使用jax.pmap
函数实现单程序、多数据(SPMD)并行化。这意味着我们将在不同的数据输入上运行同一份代码的多个副本。更复杂的方法也是可行的,如果您感兴趣,我们邀请您阅读JAX文档和pjit
页面来探索这个主题!
jax.pmap
为我们完成了两件事:
- 编译(或
jit
)代码,就好像我们调用了jax.jit()
一样。这不会在调用pmap
时发生,而是在第一次调用pmapped函数时发生。 - 确保编译后的代码在所有可用设备上并行运行。
为了展示它是如何工作的,我们将对流水线的_generate
方法使用pmap
,这是运行图像生成的私有方法。请注意,该方法在将来的diffusers
版本中可能会更名或删除。
p_generate = pmap(pipeline._generate)
在使用了pmap
之后,准备好的函数p_generate
将在概念上执行以下操作:
- 在每个设备上调用底层函数
pipeline._generate
的副本。 - 将每个设备发送不同部分的输入参数。这就是分片用于的目的。在我们的案例中,
prompt_ids
的形状为(8, 1, 77, 768)
。该数组将被分割成8
个部分,每个_generate
的副本将接收到一个形状为(1, 77, 768)
的输入。
我们可以完全忽略将会以并行方式调用的_generate
函数的事实来编写代码。我们只关心我们的批次大小(在此示例中为1
)以及对我们的代码有意义的维度,不需要改变任何东西就可以使其在并行中工作。
就像我们使用管道调用时一样,第一次运行以下单元格时需要一些时间,但之后会更快。
images = p_generate(prompt_ids, p_params, rng)
images = images.block_until_ready()
images.shape
输出:
CPU 时间:用户 118 毫秒,系统:83.9 毫秒,总计:202 毫秒
真实时间:6.82 秒
(8, 1, 512, 512, 3)
我们使用block_until_ready()
来正确测量推理时间,因为JAX使用异步调度,并尽快将控制权返回给Python循环。您不需要在代码中使用它;当您想使用尚未实现的计算结果时,阻塞将自动发生。