🧨 JAX / Flax 中的 Stable Diffusion!

釋出於 2022 年 10 月 13 日
在 GitHub 上更新
Open In Colab

🤗 Hugging Face Diffusers0.5.1 版本開始支援 Flax!這使得在 Google TPU 上進行超快速推理成為可能,例如 Colab、Kaggle 或 Google Cloud Platform 中可用的 TPU。

這篇帖子展示瞭如何使用 JAX / Flax 執行推理。如果您想了解更多關於 Stable Diffusion 如何工作的詳細資訊,或者想在 GPU 上執行它,請參閱 此 Colab 筆記本

如果您想跟著操作,請點選上面的按鈕,將此帖子作為 Colab 筆記本開啟。

首先,請確保您正在使用 TPU 後端。如果您在 Colab 中執行此筆記本,請在上方選單中選擇 Runtime,然後選擇“更改執行時型別”選項,然後在 Hardware accelerator 設定下選擇 TPU

請注意,JAX 並非 TPU 獨有,但它在這種硬體上表現出色,因為每個 TPU 伺服器都有 8 個 TPU 加速器並行工作。

設定

import jax
num_devices = jax.device_count()
device_type = jax.devices()[0].device_kind

print(f"Found {num_devices} JAX devices of type {device_type}.")
assert "TPU" in device_type, "Available device is not a TPU, please select TPU from Edit > Notebook settings > Hardware accelerator"

輸出:

    Found 8 JAX devices of type TPU v2.

請確保已安裝 diffusers

!pip install diffusers==0.5.1

然後我們匯入所有依賴項。

import numpy as np
import jax
import jax.numpy as jnp

from pathlib import Path
from jax import pmap
from flax.jax_utils import replicate
from flax.training.common_utils import shard
from PIL import Image

from huggingface_hub import notebook_login
from diffusers import FlaxStableDiffusionPipeline

模型載入

在使用模型之前,您需要接受模型 許可證 才能下載和使用權重。

該許可證旨在減輕這種強大機器學習系統可能造成的有害影響。我們要求使用者**完整並仔細閱讀許可證**。以下是摘要:

  1. 您不得使用該模型故意生成或共享非法或有害的輸出或內容,
  2. 我們不主張您生成的輸出的任何權利,您可以自由使用它們,並對其使用負責,其使用不應違反許可證中規定的條款,並且
  3. 您可以重新分發權重並將其商業化和/或作為服務使用。如果您這樣做,請注意您必須包含與許可證中相同的限制,並將 CreativeML OpenRAIL-M 的副本分享給所有使用者。

Flax 權重作為 Stable Diffusion 倉庫的一部分,在 Hugging Face Hub 中可用。Stable Diffusion 模型根據 CreateML OpenRail-M 許可證分發。這是一個開放許可證,不對您生成的輸出主張任何權利,並禁止您故意生成非法或有害內容。模型卡提供了更多詳細資訊,請花點時間閱讀並仔細考慮您是否接受該許可證。如果您接受,您需要成為 Hub 中的註冊使用者並使用訪問令牌才能使程式碼正常工作。您有兩種選擇來提供您的訪問令牌:

  • 在您的終端中使用 huggingface-cli login 命令列工具,並在提示時貼上您的令牌。它將儲存在您計算機上的檔案中。
  • 或者在筆記本中使用 notebook_login(),它做的是同樣的事情。

除非您之前已在此計算機上進行過身份驗證,否則以下單元格將顯示登入介面。您需要貼上您的訪問令牌。

if not (Path.home()/'.huggingface'/'token').exists(): notebook_login()

TPU 裝置支援 bfloat16,一種高效的半精度浮點型別。我們將在測試中使用它,但您也可以使用 float32 來代替使用全精度。

dtype = jnp.bfloat16

Flax 是一個函式式框架,因此模型是無狀態的,引數儲存在模型之外。載入預訓練的 Flax pipeline 將同時返回 pipeline 本身和模型權重(或引數)。我們正在使用 bf16 版本的權重,這會導致型別警告,您可以安全地忽略它們。

pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
    "CompVis/stable-diffusion-v1-4",
    revision="bf16",
    dtype=dtype,
)

推理

由於 TPU 通常有 8 個裝置並行工作,我們將把提示覆制多次,以匹配裝置的數量。然後我們將同時在 8 個裝置上執行推理,每個裝置負責生成一張影像。因此,我們將在單個晶片生成一張影像的相同時間內獲得 8 張影像。

複製提示後,我們透過呼叫 pipeline 的 prepare_inputs 函式獲得分詞後的文字 ID。分詞後的文字長度設定為 77 個 token,這是底層 CLIP Text 模型的配置所要求的。

prompt = "A cinematic film still of Morgan Freeman starring as Jimi Hendrix, portrait, 40mm lens, shallow depth of field, close up, split lighting, cinematic"
prompt = [prompt] * jax.device_count()
prompt_ids = pipeline.prepare_inputs(prompt)
prompt_ids.shape

輸出:

    (8, 77)

複製與並行化

模型引數和輸入必須在我們的 8 個並行裝置之間複製。引數字典使用 flax.jax_utils.replicate 複製,該函式遍歷字典並改變權重的形狀,使其重複 8 次。陣列使用 shard 複製。

p_params = replicate(params)
prompt_ids = shard(prompt_ids)
prompt_ids.shape

輸出:

    (8, 1, 77)

這種形狀意味著每個 8 個裝置將接收一個形狀為 (1, 77)jnp 陣列作為輸入。因此,1 是每個裝置的批處理大小。在記憶體充足的 TPU 中,如果我們想一次生成多張影像(每個晶片),它可能會大於 1

我們幾乎準備好生成影像了!我們只需要建立一個隨機數生成器來傳遞給生成函式。這是 Flax 中的標準過程,它對隨機數非常認真和有主見——所有處理隨機數的函式都應該接收一個生成器。這確保了可重現性,即使我們在多個分散式裝置上進行訓練。

下面的輔助函式使用一個種子來初始化一個隨機數生成器。只要我們使用相同的種子,我們將得到完全相同的結果。稍後在筆記本中探索結果時,隨意使用不同的種子。

def create_key(seed=0):
    return jax.random.PRNGKey(seed)

我們獲得一個隨機數生成器,然後將其“分割”成 8 份,以便每個裝置接收一個不同的生成器。因此,每個裝置將建立一個不同的影像,並且整個過程是可重現的。

rng = create_key(0)
rng = jax.random.split(rng, jax.device_count())

JAX 程式碼可以編譯成高效的表示,執行速度非常快。然而,我們需要確保所有輸入在後續呼叫中都具有相同的形狀;否則,JAX 將不得不重新編譯程式碼,我們將無法利用最佳化的速度。

如果我們將 jit = True 作為引數傳遞,Flax pipeline 可以為我們編譯程式碼。它還將確保模型在 8 個可用裝置上並行執行。

我們第一次執行以下單元格時,編譯將花費很長時間,但隨後的呼叫(即使輸入不同)也會快得多。例如,我在 TPU v2-8 上測試時,編譯花費了超過一分鐘,但隨後的推理執行僅需約 7秒

images = pipeline(prompt_ids, p_params, rng, jit=True)[0]

輸出:

    CPU times: user 464 ms, sys: 105 ms, total: 569 ms
    Wall time: 7.07 s

返回的陣列形狀為 (8, 1, 512, 512, 3)。我們將其重塑以去除第二個維度,得到 8 張 512 × 512 × 3 的影像,然後將它們轉換為 PIL 影像。

images = images.reshape((images.shape[0],) + images.shape[-3:])
images = pipeline.numpy_to_pil(images)

視覺化

讓我們建立一個輔助函式來以網格形式顯示影像。

def image_grid(imgs, rows, cols):
    w,h = imgs[0].size
    grid = Image.new('RGB', size=(cols*w, rows*h))
    for i, img in enumerate(imgs): grid.paste(img, box=(i%cols*w, i//cols*h))
    return grid
image_grid(images, 2, 4)

png

使用不同的提示

我們不必在所有裝置上覆制相同的提示。我們可以做任何我們想做的事情:生成 2 個提示,每個提示 4 次,甚至一次生成 8 個不同的提示。讓我們開始吧!

首先,我們將輸入準備程式碼重構為一個方便的函式

prompts = [
    "Labrador in the style of Hokusai",
    "Painting of a squirrel skating in New York",
    "HAL-9000 in the style of Van Gogh",
    "Times Square under water, with fish and a dolphin swimming around",
    "Ancient Roman fresco showing a man working on his laptop",
    "Close-up photograph of young black woman against urban background, high quality, bokeh",
    "Armchair in the shape of an avocado",
    "Clown astronaut in space, with Earth in the background",
]
prompt_ids = pipeline.prepare_inputs(prompts)
prompt_ids = shard(prompt_ids)
images = pipeline(prompt_ids, p_params, rng, jit=True).images
images = images.reshape((images.shape[0], ) + images.shape[-3:])
images = pipeline.numpy_to_pil(images)
image_grid(images, 2, 4)

png


並行化如何工作?

我們之前提到,diffusers Flax pipeline 會自動編譯模型並在所有可用裝置上並行執行。現在,我們將簡要了解該過程的內部工作原理。

JAX 並行化可以透過多種方式完成。最簡單的方法是使用 jax.pmap 函式來實現單程式多資料(SPMD)並行化。這意味著我們將在不同的資料輸入上運行同一程式碼的多個副本。更復雜的方法也是可能的,如果您感興趣,我們邀請您查閱 JAX 文件pjit 頁面,以探索此主題!

jax.pmap 為我們做了兩件事

  • 編譯(或 jit)程式碼,就像我們呼叫了 jax.jit() 一樣。這在呼叫 pmap 時不會發生,而是在第一次呼叫 pmapped 函式時發生。
  • 確保編譯後的程式碼在所有可用裝置上並行執行。

為了展示它的工作原理,我們使用 pmap 處理 pipeline 的 _generate 方法,這是執行生成影像的私有方法。請注意,此方法在未來的 diffusers 版本中可能會被重新命名或刪除。

p_generate = pmap(pipeline._generate)

使用 pmap 後,準備好的函式 p_generate 將在概念上執行以下操作

  • 在每個裝置中呼叫底層函式 pipeline._generate 的副本。
  • 向每個裝置傳送輸入引數的不同部分。這就是分片的目的。在我們的例子中,prompt_ids 的形狀是 (8, 1, 77, 768)。此陣列將被分成 8 份,每個 _generate 副本將接收一個形狀為 (1, 77, 768) 的輸入。

我們可以完全忽略它將並行呼叫的事實來編寫 _generate。我們只關心我們的批處理大小(本例中為 1)和對我們的程式碼有意義的維度,無需更改任何內容即可使其並行工作。

與我們使用 pipeline 呼叫時一樣,第一次執行以下單元格需要一段時間,但之後會快得多。

images = p_generate(prompt_ids, p_params, rng)
images = images.block_until_ready()
images.shape

輸出:

    CPU times: user 118 ms, sys: 83.9 ms, total: 202 ms
    Wall time: 6.82 s

    (8, 1, 512, 512, 3)

我們使用 block_until_ready() 來正確測量推理時間,因為 JAX 使用非同步排程並在它能夠返回 Python 迴圈時立即返回控制。您不需要在程式碼中使用它;當您想要使用尚未實現的計算結果時,阻塞會自動發生。

社群

註冊登入 發表評論

© . This site is unofficial and not affiliated with Hugging Face, Inc.