在 Cloud TPU v5e 上使用 JAX 加速 Stable Diffusion XL 推理

釋出於 2023 年 10 月 3 日
在 GitHub 上更新

像 Stable Diffusion XL (SDXL) 這樣的生成式 AI 模型能夠建立高質量、逼真的內容,並具有廣泛的應用。然而,要發揮這些模型的力量也帶來了巨大的挑戰和計算成本。SDXL 是一個大型的影像生成模型,其 UNet 元件的大小約為該模型先前版本的三倍。由於記憶體需求和推理時間的增加,在生產環境中部署這樣的模型具有挑戰性。今天,我們激動地宣佈,Hugging Face Diffusers 現在支援在 Cloud TPU 上使用 JAX 來服務 SDXL,從而實現高效能、高性價比的推理。

Google Cloud TPU 是定製化設計的 AI 加速器,專為大型 AI 模型(包括最先進的 LLM 和生成式 AI 模型如 SDXL)的訓練和推理進行了最佳化。新的 Cloud TPU v5e 專為大規模 AI 訓練推理提供所需的成本效益和效能。TPU v5e 的成本不到 TPU v4 的一半,使得更多的組織能夠訓練和部署 AI 模型。

🧨 Diffusers JAX 整合提供了一種透過 XLA 在 TPU 上執行 SDXL 的便捷方式,我們構建了一個演示來展示它。您可以在此 Space 或下面嵌入的 playground 中進行嘗試

在底層,該演示執行在多個 TPU v5e-4 例項上(每個例項有 4 個 TPU 晶片),並利用並行化技術在大約 4 秒內提供四張 1024×1024 的大圖。這個時間包括格式轉換、通訊時間和前端處理;實際的生成時間約為 2.3 秒,我們將在下面看到!

在這篇博文中,

  1. 我們描述了為什麼 JAX + TPU + Diffusers 是執行 SDXL 的強大框架
  2. 解釋如何使用 Diffusers 和 JAX 編寫一個簡單的影像生成流水線
  3. 展示比較不同 TPU 設定的基準測試

為什麼選擇 JAX + TPU v5e 來執行 SDXL?

透過專用 TPU 硬體和為效能最佳化的軟體棧相結合,可以在 Cloud TPU v5e 上使用 JAX 以高效能和高成本效益地服務 SDXL。下面我們強調兩個關鍵因素:JAX 的即時(jit)編譯和使用 JAX pmap 實現的 XLA 編譯器驅動的並行化。

即時編譯

JAX 的一個顯著特點是其即時(jit)編譯。JIT 編譯器在第一次執行時跟蹤程式碼,並生成高度最佳化的 TPU 二進位制檔案,這些檔案在後續呼叫中被重用。這個過程的要點在於,它要求所有輸入、中間和輸出的形狀都是**靜態**的,這意味著它們必須是預先知道的。每當我們改變形狀,就會再次觸發一個新的、代價高昂的編譯過程。JIT 編譯非常適合那些可以圍繞靜態形狀設計的服務:編譯只執行一次,然後我們就可以享受超快的推理速度。

影像生成非常適合 JIT 編譯。如果我們總是生成相同數量且大小相同的影像,那麼輸出形狀就是恆定且預先知道的。文字輸入也是恆定的:按照設計,Stable Diffusion 和 SDXL 使用固定形狀的嵌入向量(帶有填充)來表示使用者輸入的提示。因此,我們可以編寫依賴於固定形狀的 JAX 程式碼,從而可以被極大地最佳化!

針對高批次大小的高效能吞吐量

使用 JAX 的 pmap,可以將工作負載擴充套件到多個裝置上,它表達了單程式多資料(SPMD)程式。將 pmap 應用於一個函式會使用 XLA 編譯該函式,然後在各種 XLA 裝置上並行執行它。對於文字到影像的生成工作負載,這意味著同時增加渲染的影像數量很容易實現,並且不會影響效能。例如,在有 8 個晶片的 TPU 上執行 SDXL 將在與 1 個晶片建立單個影像相同的時間內生成 8 張影像。

TPU v5e 例項有多種規格,包括 1、4 和 8 晶片的配置,一直到 256 個晶片(一個完整的 TPU v5e pod),晶片之間有超快的 ICI 連結。這允許您選擇最適合您用例的 TPU 規格,並輕鬆利用 JAX 和 TPU 提供的並行性。

如何用 JAX 編寫一個影像生成流水線

我們將一步步介紹使用 JAX 實現超快推理所需的程式碼!首先,讓我們匯入依賴項。

# Show best practices for SDXL JAX
import jax
import jax.numpy as jnp
import numpy as np
from flax.jax_utils import replicate
from diffusers import FlaxStableDiffusionXLPipeline
import time

現在,我們將載入基礎 SDXL 模型以及推理所需的其他元件。diffusers 流水線會為我們處理下載和快取所有內容。遵循 JAX 的函式式方法,模型的引數會單獨返回,並且在推理時必須傳遞給流水線。

pipeline, params = FlaxStableDiffusionXLPipeline.from_pretrained(
    "stabilityai/stable-diffusion-xl-base-1.0", split_head_dim=True
)

預設情況下,模型引數以 32 位精度下載。為了節省記憶體並加快計算速度,我們會將它們轉換為 bfloat16,這是一種高效的 16 位表示。然而,這裡有一個注意事項:為了獲得最佳效果,我們必須將_排程器狀態_保持在 float32,否則精度誤差會累積,導致低質量甚至全黑的影像。

scheduler_state = params.pop("scheduler")
params = jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16), params)
params["scheduler"] = scheduler_state

我們現在準備好設定我們的提示和流水線的其餘輸入了。

default_prompt = "high-quality photo of a baby dolphin ​​playing in a pool and wearing a party hat"
default_neg_prompt = "illustration, low-quality"
default_seed = 33
default_guidance_scale = 5.0
default_num_steps = 25

提示必須作為張量提供給流水線,並且它們在每次呼叫中必須具有相同的維度。這使得推理呼叫可以被編譯。流水線的 prepare_inputs 方法為我們執行了所有必要的步驟,所以我們將建立一個輔助函式來準備我們的提示和負面提示作為張量。我們稍後會在 generate 函式中使用它。

def tokenize_prompt(prompt, neg_prompt):
    prompt_ids = pipeline.prepare_inputs(prompt)
    neg_prompt_ids = pipeline.prepare_inputs(neg_prompt)
    return prompt_ids, neg_prompt_ids

為了利用並行化,我們將在裝置之間複製輸入。一個 Cloud TPU v5e-4 有 4 個晶片,所以透過複製輸入,我們可以讓每個晶片並行生成一個不同的影像。我們需要小心為每個晶片提供一個不同的隨機種子,這樣 4 張影像才會不同。

NUM_DEVICES = jax.device_count()

# Model parameters don't change during inference,
# so we only need to replicate them once.
p_params = replicate(params)

def replicate_all(prompt_ids, neg_prompt_ids, seed):
    p_prompt_ids = replicate(prompt_ids)
    p_neg_prompt_ids = replicate(neg_prompt_ids)
    rng = jax.random.PRNGKey(seed)
    rng = jax.random.split(rng, NUM_DEVICES)
    return p_prompt_ids, p_neg_prompt_ids, rng

我們現在準備將所有東西整合到一個生成函式中。

def generate(
    prompt,
    negative_prompt,
    seed=default_seed,
    guidance_scale=default_guidance_scale,
    num_inference_steps=default_num_steps,
):
    prompt_ids, neg_prompt_ids = tokenize_prompt(prompt, negative_prompt)
    prompt_ids, neg_prompt_ids, rng = replicate_all(prompt_ids, neg_prompt_ids, seed)
    images = pipeline(
        prompt_ids,
        p_params,
        rng,
        num_inference_steps=num_inference_steps,
        neg_prompt_ids=neg_prompt_ids,
        guidance_scale=guidance_scale,
        jit=True,
    ).images

    # convert the images to PIL
    images = images.reshape((images.shape[0] * images.shape[1], ) + images.shape[-3:])
    return pipeline.numpy_to_pil(np.array(images))

jit=True 表示我們希望編譯流水線呼叫。這將在我們第一次呼叫 generate 時發生,並且會非常慢——JAX 需要跟蹤操作,最佳化它們,並將其轉換為低階原語。我們將執行第一次生成來完成這個過程並進行預熱。

start = time.time()
print(f"Compiling ...")
generate(default_prompt, default_neg_prompt)
print(f"Compiled in {time.time() - start}")

我們第一次執行時,這大約花了三分鐘。但是一旦程式碼被編譯,推理就會變得超快。讓我們再試一次!

start = time.time()
prompt = "llama in ancient Greece, oil on canvas"
neg_prompt = "cartoon, illustration, animation"
images = generate(prompt, neg_prompt)
print(f"Inference in {time.time() - start}")

現在生成這 4 張圖片只花了大約 2 秒!

基準測試

以下測量結果是在執行 SDXL 1.0 base 模型 20 個步驟,並使用預設的 Euler Discrete 排程器獲得的。我們比較了相同批次大小下 Cloud TPU v5e 與 TPUv4 的效能。請注意,由於並行性,像我們在演示中使用的 TPU v5e-4,在使用批次大小為 1 時將生成 **4 張影像**(或在使用批次大小為 2 時生成 8 張影像)。同樣,TPU v5e-8 在使用批次大小為 1 時將生成 8 張影像。

Cloud TPU 測試使用 Python 3.10 和 jax 0.4.16 版本進行。這些規格與我們的演示 Space 中使用的相同。

批次大小 延遲 價效比 (Perf/$)
TPU v5e-4 (JAX) 4 2.33 秒 21.46
8 4.99 秒 20.04
TPU v4-8 (JAX) 4 2.16 秒 9.05
8 4.17 8.98

TPU v5e 在 SDXL 上的價效比高達 TPU v4 的 2.4 倍,展示了最新一代 TPU 的成本效益。

為了衡量推理效能,我們使用行業標準的吞吐量指標。首先,我們測量模型編譯和載入後每張影像的延遲。然後,我們透過將批次大小除以每個晶片的延遲來計算吞吐量。因此,吞吐量衡量的是模型在生產環境中的效能,無論使用多少晶片。然後,我們將吞吐量除以標價,得到單位成本的效能。

這個演示是如何工作的?

我們之前展示的演示是使用一個指令碼構建的,該指令碼基本上遵循了我們在這篇博文中釋出的代​​碼。它執行在幾個各帶 4 個晶片的 Cloud TPU v5e 裝置上,還有一個簡單的負載均衡伺服器,隨機將使用者請求路由到後端伺服器。當您在演示中輸入提示時,您的請求將被分配到其中一個後端伺服器,然後您將收到它生成的 4 張影像。

這是一個基於多個預分配 TPU 例項的簡單解決方案。在未來的文章中,我們將介紹如何使用 GKE 建立適應負載的動態解決方案。

演示的所有程式碼都是開源的,並且現在可以在 Hugging Face Diffusers 中找到。我們很期待看到您使用 Diffusers + JAX + Cloud TPU 構建的作品!

社群

註冊登入 發表評論

© . This site is unofficial and not affiliated with Hugging Face, Inc.