介紹 Würstchen:用於影像生成的快速擴散模型

釋出於 2023 年 9 月 13 日
在 GitHub 上更新

Collage of images created with Würstchen

什麼是 Würstchen?

Würstchen 是一種擴散模型,其文字條件元件在高度壓縮的影像潛在空間中工作。為什麼這很重要?壓縮資料可以將訓練和推理的計算成本降低幾個數量級。在 1024×1024 影像上訓練比在 32×32 影像上訓練昂貴得多。通常,其他工作使用相對較小的壓縮,空間壓縮範圍為 4 倍 - 8 倍。Würstchen 將其推向極致。透過其新穎的設計,它實現了 42 倍的空間壓縮!這以前從未見過,因為常見方法在 16 倍空間壓縮後無法忠實地重建詳細影像。Würstchen 採用兩階段壓縮,我們稱之為階段 A 和階段 B。階段 A 是 VQGAN,階段 B 是擴散自編碼器(更多詳細資訊可以在論文中找到)。階段 A 和 B 合稱為*解碼器*,因為它們將壓縮影像解碼回畫素空間。第三個模型,階段 C,在該高度壓縮的潛在空間中學習。這種訓練所需的計算量只是當前頂級模型所需計算量的一小部分,同時還允許更便宜、更快的推理。我們將階段 C 稱為*先驗*。

Würstchen images with Prompts

為什麼需要另一個文字到影像模型?

嗯,這個模型非常快且高效。Würstchen 的最大優勢在於它可以比 Stable Diffusion XL 等模型更快地生成影像,同時使用更少的記憶體!因此,對於我們這些沒有 A100 的人來說,這將非常有用。以下是與 SDXL 在不同批次大小下的比較:

Inference Speed Plots

此外,Würstchen 的另一個重要優勢是降低了訓練成本。在 512x512 解析度下工作的 Würstchen v1 僅需要 9,000 GPU 小時的訓練。與 Stable Diffusion 1.4 所花費的 150,000 GPU 小時相比,這表明成本降低了 16 倍,這不僅有利於研究人員進行新實驗,還為更多組織訓練此類模型打開了大門。Würstchen v2 使用了 24,602 GPU 小時。在解析度達到 1536 的情況下,這仍然比僅在 512x512 解析度下訓練的 SD1.4 便宜 6 倍。

Inference Speed Plots

你也可以在這裡找到詳細的解釋影片

如何使用 Würstchen?

你可以在這裡嘗試使用演示:

此外,該模型透過 Diffusers 庫提供,因此你可以使用你已經熟悉的介面。例如,以下是如何使用 `AutoPipeline` 執行推理:

import torch
from diffusers import AutoPipelineForText2Image
from diffusers.pipelines.wuerstchen import DEFAULT_STAGE_C_TIMESTEPS

pipeline = AutoPipelineForText2Image.from_pretrained("warp-ai/wuerstchen", torch_dtype=torch.float16).to("cuda")

caption = "Anthropomorphic cat dressed as a firefighter"
images = pipeline(
    caption,
    height=1024,
    width=1536,
    prior_timesteps=DEFAULT_STAGE_C_TIMESTEPS,
    prior_guidance_scale=4.0,
    num_images_per_prompt=4,
).images

Anthropomorphic cat dressed as a fire-fighter

Würstchen 適用於哪些影像尺寸?

Würstchen 在 1024x1024 到 1536x1536 之間的影像解析度上進行訓練。我們有時也會在 1024x2048 等解析度下觀察到不錯的輸出。歡迎隨意嘗試。我們還觀察到先驗(Stage C)能極快地適應新解析度。因此,在 2048x2048 解析度下進行微調的計算成本應該很低。

Hub 上的模型

所有檢查點都可以在Huggingface Hub上找到。那裡可以找到多個檢查點,以及未來的演示和模型權重。目前,先驗有 3 個檢查點可用,解碼器有 1 個檢查點。請參閱文件,其中解釋了檢查點以及不同的先驗模型的用途。

Diffusers 整合

因為 Würstchen 完全整合在 `diffusers` 中,所以它自動附帶了各種開箱即用的便利功能和最佳化。其中包括:

  • 自動使用PyTorch 2 `SDPA`加速注意力,如下所述。
  • 如果需要使用 PyTorch 1.x 而非 2,則支援xFormers flash attention實現。
  • 模型解除安裝,以便在不使用時將未使用的元件移至 CPU。這可以節省記憶體,同時效能影響可忽略不計。
  • 順序 CPU 解除安裝,適用於記憶體非常寶貴的情況。記憶體使用將最小化,但推理速度會變慢。
  • 使用 Compel 庫進行提示權重
  • 支援 Apple Silicon Mac 上的`mps` 裝置
  • 使用生成器實現可復現性
  • 針對推理的合理預設值,可在大多數情況下生成高質量結果。當然,您可以根據需要調整所有引數!

最佳化技術 1:Flash Attention

從 2.0 版本開始,PyTorch 集成了高度最佳化且資源友好的注意力機制版本,稱為`torch.nn.functional.scaled_dot_product_attention` 或 SDPA。根據輸入的性質,此函式利用多種底層最佳化。其效能和記憶體效率超越了傳統的注意力模型。值得注意的是,SDPA 函式反映了 Dao 及其團隊撰寫的《Fast and Memory-Efficient Exact Attention with IO-Awareness》研究論文中強調的 *flash attention* 技術的特點。

如果您使用的是 PyTorch 2.0 或更高版本的 Diffusers,並且 SDPA 函式可訪問,則會自動應用這些增強功能。請根據官方指南設定 torch 2.0 或更高版本以開始使用!

images = pipeline(caption, height=1024, width=1536, prior_timesteps=DEFAULT_STAGE_C_TIMESTEPS, prior_guidance_scale=4.0, num_images_per_prompt=4).images

有關 `diffusers` 如何利用 SDPA 的深入瞭解,請查閱文件

如果您使用的是 Pytorch 2.0 之前的版本,仍然可以使用 xFormers 庫來實現記憶體高效的注意力機制。

pipeline.enable_xformers_memory_efficient_attention()

最佳化技術 2:Torch Compile

如果你正在尋求額外的效能提升,可以使用 `torch.compile`。最好將其應用於先驗模型和解碼器的主模型,以實現最大的效能提升。

pipeline.prior_prior = torch.compile(pipeline.prior_prior , mode="reduce-overhead", fullgraph=True)
pipeline.decoder = torch.compile(pipeline.decoder, mode="reduce-overhead", fullgraph=True)

請記住,首次推理步驟將花費很長時間(長達 2 分鐘),因為模型正在編譯。之後,您可以正常執行推理

images = pipeline(caption, height=1024, width=1536, prior_timesteps=DEFAULT_STAGE_C_TIMESTEPS, prior_guidance_scale=4.0, num_images_per_prompt=4).images

好訊息是這種編譯是一次性執行的。之後,您將始終體驗到相同影像解析度下的更快推理。編譯的初始時間投入很快就會被隨後的速度優勢抵消。有關 `torch.compile` 及其細微差別的深入探討,請檢視官方文件

模型是如何訓練的?

能夠訓練這個模型,完全得益於 Stability AI 提供的計算資源。我們要特別感謝 Stability 讓我們能夠進行這類研究,並有機會讓更多人接觸到它!

資源

  • 有關此模型的更多資訊,請參閱官方 Diffusers 文件
  • 所有檢查點都可以在hub上找到
  • 你可以在這裡試用演示。
  • 如果你想討論未來的專案或貢獻自己的想法,請加入我們的 Discord
  • 訓練程式碼等更多內容可以在官方 GitHub 儲存庫中找到。

社群

註冊登入 以發表評論

© . This site is unofficial and not affiliated with Hugging Face, Inc.