Diffusers 文件
JAX/Flax
並獲得增強的文件體驗
開始使用
JAX/Flax
🤗 Diffusers 支援 Flax,可在 Google TPU 上實現超快速推理,例如 Colab、Kaggle 或 Google Cloud Platform 上提供的 TPU。本指南將向您展示如何使用 JAX/Flax 執行 Stable Diffusion 推理。
在開始之前,請確保您已安裝必要的庫:
# uncomment to install the necessary libraries in Colab
#!pip install -q jax==0.3.25 jaxlib==0.3.25 flax transformers ftfy
#!pip install -q diffusers
您還應確保使用 TPU 後端。雖然 JAX 不僅限於在 TPU 上執行,但您將在 TPU 上獲得最佳效能,因為每個伺服器有 8 個 TPU 加速器並行工作。
如果您在 Colab 中執行本指南,請選擇上方選單中的 *執行時*,選擇 *更改執行時型別* 選項,然後在 *硬體加速器* 設定下選擇 *TPU*。匯入 JAX 並快速檢查您是否正在使用 TPU:
import jax
import jax.tools.colab_tpu
jax.tools.colab_tpu.setup_tpu()
num_devices = jax.device_count()
device_type = jax.devices()[0].device_kind
print(f"Found {num_devices} JAX devices of type {device_type}.")
assert (
"TPU" in device_type,
"Available device is not a TPU, please select TPU from Runtime > Change runtime type > Hardware accelerator"
)
# Found 8 JAX devices of type Cloud TPU.
太好了,現在您可以匯入所需的其餘依賴項了:
import jax.numpy as jnp
from jax import pmap
from flax.jax_utils import replicate
from flax.training.common_utils import shard
from diffusers import FlaxStableDiffusionPipeline
載入模型
Flax 是一個功能性框架,因此模型是無狀態的,引數儲存在模型之外。載入預訓練的 Flax 流水線會同時返回流水線和模型權重(或引數)。在本指南中,您將使用 `bfloat16`,這是一種更高效的半精度浮點型別,受 TPU 支援(如果您願意,也可以使用 `float32` 獲取全精度)。
dtype = jnp.bfloat16
pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4",
variant="bf16",
dtype=dtype,
)
推理
TPU 通常有 8 個裝置並行工作,因此讓我們為每個裝置使用相同的提示。這意味著您可以同時在 8 個裝置上執行推理,每個裝置生成一張影像。因此,您將在單個晶片生成一張影像所需的時間內獲得 8 張影像!
在 並行化如何工作? 部分了解更多詳情。
複製提示後,透過呼叫流水線上的 `prepare_inputs` 函式獲取分詞文字 ID。分詞文字的長度設定為 77 個 token,這是底層 CLIP 文字模型配置所要求的。
prompt = "A cinematic film still of Morgan Freeman starring as Jimi Hendrix, portrait, 40mm lens, shallow depth of field, close up, split lighting, cinematic"
prompt = [prompt] * jax.device_count()
prompt_ids = pipeline.prepare_inputs(prompt)
prompt_ids.shape
# (8, 77)
模型引數和輸入必須在 8 個並行裝置上覆制。引數字典透過 flax.jax_utils.replicate
複製,該函式遍歷字典並更改權重的形狀,使其重複 8 次。陣列使用 `shard` 進行復制。
# parameters
p_params = replicate(params)
# arrays
prompt_ids = shard(prompt_ids)
prompt_ids.shape
# (8, 1, 77)
這種形狀意味著 8 個裝置中的每一個都接收一個形狀為 `(1, 77)` 的 `jnp` 陣列作為輸入,其中 `1` 是每個裝置的批次大小。在具有足夠記憶體的 TPU 上,如果您想一次生成多個影像(每個晶片),則可以擁有大於 `1` 的批次大小。
接下來,建立一個隨機數生成器傳遞給生成函式。這是 Flax 中的標準過程,它對隨機數非常重視並有自己的觀點。所有處理隨機數的函式都應接收一個生成器,以確保可重現性,即使您在多個分散式裝置上進行訓練時也是如此。
下面的輔助函式使用種子初始化隨機數生成器。只要您使用相同的種子,您將獲得完全相同的結果。在本指南後續探索結果時,請隨意使用不同的種子。
def create_key(seed=0):
return jax.random.PRNGKey(seed)
輔助函式,即 `rng`,被分成 8 份,以便每個裝置接收不同的生成器並生成不同的影像。
rng = create_key(0)
rng = jax.random.split(rng, jax.device_count())
為了利用 JAX 在 TPU 上的最佳化速度,將 `jit=True` 傳遞給流水線,以將 JAX 程式碼編譯成高效的表示,並確保模型在 8 個裝置上並行執行。
您需要確保所有後續呼叫中的輸入具有相同的形狀,否則 JAX 將需要重新編譯程式碼,這會更慢。
第一次推理執行需要更多時間,因為它需要編譯程式碼,但後續呼叫(即使輸入不同)會快得多。例如,在 TPU v2-8 上編譯耗時超過一分鐘,但在未來的推理執行中,它只需大約 **7 秒**!
%%time
images = pipeline(prompt_ids, p_params, rng, jit=True)[0]
# CPU times: user 56.2 s, sys: 42.5 s, total: 1min 38s
# Wall time: 1min 29s
返回的陣列形狀為 `(8, 1, 512, 512, 3)`,應將其重塑以去除第二個維度並獲得 8 張 `512 × 512 × 3` 的影像。然後您可以使用 numpy_to_pil() 函式將陣列轉換為影像。
from diffusers.utils import make_image_grid
images = images.reshape((images.shape[0] * images.shape[1],) + images.shape[-3:])
images = pipeline.numpy_to_pil(images)
make_image_grid(images, rows=2, cols=4)
使用不同的提示
您不一定必須在所有裝置上使用相同的提示。例如,要生成 8 個不同的提示:
prompts = [
"Labrador in the style of Hokusai",
"Painting of a squirrel skating in New York",
"HAL-9000 in the style of Van Gogh",
"Times Square under water, with fish and a dolphin swimming around",
"Ancient Roman fresco showing a man working on his laptop",
"Close-up photograph of young black woman against urban background, high quality, bokeh",
"Armchair in the shape of an avocado",
"Clown astronaut in space, with Earth in the background",
]
prompt_ids = pipeline.prepare_inputs(prompts)
prompt_ids = shard(prompt_ids)
images = pipeline(prompt_ids, p_params, rng, jit=True).images
images = images.reshape((images.shape[0] * images.shape[1],) + images.shape[-3:])
images = pipeline.numpy_to_pil(images)
make_image_grid(images, 2, 4)
並行化如何工作?
🤗 Diffusers 中的 Flax 流水線會自動編譯模型並使其在所有可用裝置上並行執行。讓我們仔細看看這個過程是如何工作的。
JAX 並行化可以透過多種方式完成。最簡單的方式是圍繞使用 jax.pmap
函式來實現單程式多資料 (SPMD) 並行化。這意味著運行同一程式碼的多個副本,每個副本處理不同的資料輸入。更復雜的方法是可能的,如果您有興趣,可以查閱 JAX 文件 以瞭解更多詳細資訊!
jax.pmap
有兩個作用:
- 編譯(或“`jit`”)程式碼,這類似於 `jax.jit()`。這在您呼叫 `pmap` 時不會發生,而是在 `pmap` 化的函式第一次被呼叫時發生。
- 確保編譯後的程式碼在所有可用裝置上並行執行。
為了演示,在流水線的 `_generate` 方法上呼叫 `pmap`(這是一個生成影像的私有方法,在 🤗 Diffusers 的未來版本中可能會被重新命名或移除):
p_generate = pmap(pipeline._generate)
呼叫 `pmap` 後,準備好的函式 `p_generate` 將會:
- 在每個裝置上覆制底層函式 `pipeline._generate`。
- 向每個裝置傳送輸入引數的不同部分(這就是為什麼需要呼叫 `shard` 函式的原因)。在本例中,`prompt_ids` 的形狀為 `(8, 1, 77, 768)`,因此陣列被分成 8 份,`_generate` 的每個副本都接收一個形狀為 `(1, 77, 768)` 的輸入。
這裡最重要的是要關注批次大小(本例中為 1)以及對您的程式碼有意義的輸入維度。您無需更改任何其他內容即可使程式碼並行工作。
第一次呼叫流水線需要更多時間,但之後的呼叫會快得多。`block_until_ready` 函式用於正確測量推理時間,因為 JAX 使用非同步排程,並在可以時立即將控制權返回給 Python 迴圈。您無需在程式碼中使用它;當您想要使用尚未實現的計算結果時,阻塞會自動發生。
%%time
images = p_generate(prompt_ids, p_params, rng)
images = images.block_until_ready()
# CPU times: user 1min 15s, sys: 18.2 s, total: 1min 34s
# Wall time: 1min 15s
檢查您的影像尺寸,看它們是否正確:
images.shape
# (8, 1, 512, 512, 3)
資源
要了解更多關於 JAX 如何與 Stable Diffusion 協同工作的資訊,您可能對閱讀以下內容感興趣:
< > 在 GitHub 上更新