Diffusers 文件

JAX/Flax

Hugging Face's logo
加入 Hugging Face 社群

並獲得增強的文件體驗

開始使用

JAX/Flax

🤗 Diffusers 支援 Flax,可在 Google TPU 上實現超快速推理,例如 Colab、Kaggle 或 Google Cloud Platform 上提供的 TPU。本指南將向您展示如何使用 JAX/Flax 執行 Stable Diffusion 推理。

在開始之前,請確保您已安裝必要的庫:

# uncomment to install the necessary libraries in Colab
#!pip install -q jax==0.3.25 jaxlib==0.3.25 flax transformers ftfy
#!pip install -q diffusers

您還應確保使用 TPU 後端。雖然 JAX 不僅限於在 TPU 上執行,但您將在 TPU 上獲得最佳效能,因為每個伺服器有 8 個 TPU 加速器並行工作。

如果您在 Colab 中執行本指南,請選擇上方選單中的 *執行時*,選擇 *更改執行時型別* 選項,然後在 *硬體加速器* 設定下選擇 *TPU*。匯入 JAX 並快速檢查您是否正在使用 TPU:

import jax
import jax.tools.colab_tpu
jax.tools.colab_tpu.setup_tpu()

num_devices = jax.device_count()
device_type = jax.devices()[0].device_kind

print(f"Found {num_devices} JAX devices of type {device_type}.")
assert (
    "TPU" in device_type,
    "Available device is not a TPU, please select TPU from Runtime > Change runtime type > Hardware accelerator"
)
# Found 8 JAX devices of type Cloud TPU.

太好了,現在您可以匯入所需的其餘依賴項了:

import jax.numpy as jnp
from jax import pmap
from flax.jax_utils import replicate
from flax.training.common_utils import shard

from diffusers import FlaxStableDiffusionPipeline

載入模型

Flax 是一個功能性框架,因此模型是無狀態的,引數儲存在模型之外。載入預訓練的 Flax 流水線會同時返回流水線和模型權重(或引數)。在本指南中,您將使用 `bfloat16`,這是一種更高效的半精度浮點型別,受 TPU 支援(如果您願意,也可以使用 `float32` 獲取全精度)。

dtype = jnp.bfloat16
pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
    "CompVis/stable-diffusion-v1-4",
    variant="bf16",
    dtype=dtype,
)

推理

TPU 通常有 8 個裝置並行工作,因此讓我們為每個裝置使用相同的提示。這意味著您可以同時在 8 個裝置上執行推理,每個裝置生成一張影像。因此,您將在單個晶片生成一張影像所需的時間內獲得 8 張影像!

並行化如何工作? 部分了解更多詳情。

複製提示後,透過呼叫流水線上的 `prepare_inputs` 函式獲取分詞文字 ID。分詞文字的長度設定為 77 個 token,這是底層 CLIP 文字模型配置所要求的。

prompt = "A cinematic film still of Morgan Freeman starring as Jimi Hendrix, portrait, 40mm lens, shallow depth of field, close up, split lighting, cinematic"
prompt = [prompt] * jax.device_count()
prompt_ids = pipeline.prepare_inputs(prompt)
prompt_ids.shape
# (8, 77)

模型引數和輸入必須在 8 個並行裝置上覆制。引數字典透過 flax.jax_utils.replicate 複製,該函式遍歷字典並更改權重的形狀,使其重複 8 次。陣列使用 `shard` 進行復制。

# parameters
p_params = replicate(params)

# arrays
prompt_ids = shard(prompt_ids)
prompt_ids.shape
# (8, 1, 77)

這種形狀意味著 8 個裝置中的每一個都接收一個形狀為 `(1, 77)` 的 `jnp` 陣列作為輸入,其中 `1` 是每個裝置的批次大小。在具有足夠記憶體的 TPU 上,如果您想一次生成多個影像(每個晶片),則可以擁有大於 `1` 的批次大小。

接下來,建立一個隨機數生成器傳遞給生成函式。這是 Flax 中的標準過程,它對隨機數非常重視並有自己的觀點。所有處理隨機數的函式都應接收一個生成器,以確保可重現性,即使您在多個分散式裝置上進行訓練時也是如此。

下面的輔助函式使用種子初始化隨機數生成器。只要您使用相同的種子,您將獲得完全相同的結果。在本指南後續探索結果時,請隨意使用不同的種子。

def create_key(seed=0):
    return jax.random.PRNGKey(seed)

輔助函式,即 `rng`,被分成 8 份,以便每個裝置接收不同的生成器並生成不同的影像。

rng = create_key(0)
rng = jax.random.split(rng, jax.device_count())

為了利用 JAX 在 TPU 上的最佳化速度,將 `jit=True` 傳遞給流水線,以將 JAX 程式碼編譯成高效的表示,並確保模型在 8 個裝置上並行執行。

您需要確保所有後續呼叫中的輸入具有相同的形狀,否則 JAX 將需要重新編譯程式碼,這會更慢。

第一次推理執行需要更多時間,因為它需要編譯程式碼,但後續呼叫(即使輸入不同)會快得多。例如,在 TPU v2-8 上編譯耗時超過一分鐘,但在未來的推理執行中,它只需大約 **7 秒**!

%%time
images = pipeline(prompt_ids, p_params, rng, jit=True)[0]

# CPU times: user 56.2 s, sys: 42.5 s, total: 1min 38s
# Wall time: 1min 29s

返回的陣列形狀為 `(8, 1, 512, 512, 3)`,應將其重塑以去除第二個維度並獲得 8 張 `512 × 512 × 3` 的影像。然後您可以使用 numpy_to_pil() 函式將陣列轉換為影像。

from diffusers.utils import make_image_grid

images = images.reshape((images.shape[0] * images.shape[1],) + images.shape[-3:])
images = pipeline.numpy_to_pil(images)
make_image_grid(images, rows=2, cols=4)

img

使用不同的提示

您不一定必須在所有裝置上使用相同的提示。例如,要生成 8 個不同的提示:

prompts = [
    "Labrador in the style of Hokusai",
    "Painting of a squirrel skating in New York",
    "HAL-9000 in the style of Van Gogh",
    "Times Square under water, with fish and a dolphin swimming around",
    "Ancient Roman fresco showing a man working on his laptop",
    "Close-up photograph of young black woman against urban background, high quality, bokeh",
    "Armchair in the shape of an avocado",
    "Clown astronaut in space, with Earth in the background",
]

prompt_ids = pipeline.prepare_inputs(prompts)
prompt_ids = shard(prompt_ids)

images = pipeline(prompt_ids, p_params, rng, jit=True).images
images = images.reshape((images.shape[0] * images.shape[1],) + images.shape[-3:])
images = pipeline.numpy_to_pil(images)

make_image_grid(images, 2, 4)

img

並行化如何工作?

🤗 Diffusers 中的 Flax 流水線會自動編譯模型並使其在所有可用裝置上並行執行。讓我們仔細看看這個過程是如何工作的。

JAX 並行化可以透過多種方式完成。最簡單的方式是圍繞使用 jax.pmap 函式來實現單程式多資料 (SPMD) 並行化。這意味著運行同一程式碼的多個副本,每個副本處理不同的資料輸入。更復雜的方法是可能的,如果您有興趣,可以查閱 JAX 文件 以瞭解更多詳細資訊!

jax.pmap 有兩個作用:

  1. 編譯(或“`jit`”)程式碼,這類似於 `jax.jit()`。這在您呼叫 `pmap` 時不會發生,而是在 `pmap` 化的函式第一次被呼叫時發生。
  2. 確保編譯後的程式碼在所有可用裝置上並行執行。

為了演示,在流水線的 `_generate` 方法上呼叫 `pmap`(這是一個生成影像的私有方法,在 🤗 Diffusers 的未來版本中可能會被重新命名或移除):

p_generate = pmap(pipeline._generate)

呼叫 `pmap` 後,準備好的函式 `p_generate` 將會:

  1. 在每個裝置上覆制底層函式 `pipeline._generate`。
  2. 向每個裝置傳送輸入引數的不同部分(這就是為什麼需要呼叫 `shard` 函式的原因)。在本例中,`prompt_ids` 的形狀為 `(8, 1, 77, 768)`,因此陣列被分成 8 份,`_generate` 的每個副本都接收一個形狀為 `(1, 77, 768)` 的輸入。

這裡最重要的是要關注批次大小(本例中為 1)以及對您的程式碼有意義的輸入維度。您無需更改任何其他內容即可使程式碼並行工作。

第一次呼叫流水線需要更多時間,但之後的呼叫會快得多。`block_until_ready` 函式用於正確測量推理時間,因為 JAX 使用非同步排程,並在可以時立即將控制權返回給 Python 迴圈。您無需在程式碼中使用它;當您想要使用尚未實現的計算結果時,阻塞會自動發生。

%%time
images = p_generate(prompt_ids, p_params, rng)
images = images.block_until_ready()

# CPU times: user 1min 15s, sys: 18.2 s, total: 1min 34s
# Wall time: 1min 15s

檢查您的影像尺寸,看它們是否正確:

images.shape
# (8, 1, 512, 512, 3)

資源

要了解更多關於 JAX 如何與 Stable Diffusion 協同工作的資訊,您可能對閱讀以下內容感興趣:

< > 在 GitHub 上更新

© . This site is unofficial and not affiliated with Hugging Face, Inc.