介紹 Würstchen:用於影像生成的快速擴散模型

什麼是 Würstchen?
Würstchen 是一種擴散模型,其文字條件元件在高度壓縮的影像潛在空間中工作。為什麼這很重要?壓縮資料可以將訓練和推理的計算成本降低幾個數量級。在 1024×1024 影像上訓練比在 32×32 影像上訓練昂貴得多。通常,其他工作使用相對較小的壓縮,空間壓縮範圍為 4 倍 - 8 倍。Würstchen 將其推向極致。透過其新穎的設計,它實現了 42 倍的空間壓縮!這以前從未見過,因為常見方法在 16 倍空間壓縮後無法忠實地重建詳細影像。Würstchen 採用兩階段壓縮,我們稱之為階段 A 和階段 B。階段 A 是 VQGAN,階段 B 是擴散自編碼器(更多詳細資訊可以在論文中找到)。階段 A 和 B 合稱為*解碼器*,因為它們將壓縮影像解碼回畫素空間。第三個模型,階段 C,在該高度壓縮的潛在空間中學習。這種訓練所需的計算量只是當前頂級模型所需計算量的一小部分,同時還允許更便宜、更快的推理。我們將階段 C 稱為*先驗*。
為什麼需要另一個文字到影像模型?
嗯,這個模型非常快且高效。Würstchen 的最大優勢在於它可以比 Stable Diffusion XL 等模型更快地生成影像,同時使用更少的記憶體!因此,對於我們這些沒有 A100 的人來說,這將非常有用。以下是與 SDXL 在不同批次大小下的比較:
此外,Würstchen 的另一個重要優勢是降低了訓練成本。在 512x512 解析度下工作的 Würstchen v1 僅需要 9,000 GPU 小時的訓練。與 Stable Diffusion 1.4 所花費的 150,000 GPU 小時相比,這表明成本降低了 16 倍,這不僅有利於研究人員進行新實驗,還為更多組織訓練此類模型打開了大門。Würstchen v2 使用了 24,602 GPU 小時。在解析度達到 1536 的情況下,這仍然比僅在 512x512 解析度下訓練的 SD1.4 便宜 6 倍。
你也可以在這裡找到詳細的解釋影片
如何使用 Würstchen?
你可以在這裡嘗試使用演示:
此外,該模型透過 Diffusers 庫提供,因此你可以使用你已經熟悉的介面。例如,以下是如何使用 `AutoPipeline` 執行推理:
import torch
from diffusers import AutoPipelineForText2Image
from diffusers.pipelines.wuerstchen import DEFAULT_STAGE_C_TIMESTEPS
pipeline = AutoPipelineForText2Image.from_pretrained("warp-ai/wuerstchen", torch_dtype=torch.float16).to("cuda")
caption = "Anthropomorphic cat dressed as a firefighter"
images = pipeline(
caption,
height=1024,
width=1536,
prior_timesteps=DEFAULT_STAGE_C_TIMESTEPS,
prior_guidance_scale=4.0,
num_images_per_prompt=4,
).images
Würstchen 適用於哪些影像尺寸?
Würstchen 在 1024x1024 到 1536x1536 之間的影像解析度上進行訓練。我們有時也會在 1024x2048 等解析度下觀察到不錯的輸出。歡迎隨意嘗試。我們還觀察到先驗(Stage C)能極快地適應新解析度。因此,在 2048x2048 解析度下進行微調的計算成本應該很低。
Hub 上的模型
所有檢查點都可以在Huggingface Hub上找到。那裡可以找到多個檢查點,以及未來的演示和模型權重。目前,先驗有 3 個檢查點可用,解碼器有 1 個檢查點。請參閱文件,其中解釋了檢查點以及不同的先驗模型的用途。
Diffusers 整合
因為 Würstchen 完全整合在 `diffusers` 中,所以它自動附帶了各種開箱即用的便利功能和最佳化。其中包括:
- 自動使用PyTorch 2 `SDPA`加速注意力,如下所述。
- 如果需要使用 PyTorch 1.x 而非 2,則支援xFormers flash attention實現。
- 模型解除安裝,以便在不使用時將未使用的元件移至 CPU。這可以節省記憶體,同時效能影響可忽略不計。
- 順序 CPU 解除安裝,適用於記憶體非常寶貴的情況。記憶體使用將最小化,但推理速度會變慢。
- 使用 Compel 庫進行提示權重。
- 支援 Apple Silicon Mac 上的`mps` 裝置。
- 使用生成器實現可復現性。
- 針對推理的合理預設值,可在大多數情況下生成高質量結果。當然,您可以根據需要調整所有引數!
最佳化技術 1:Flash Attention
從 2.0 版本開始,PyTorch 集成了高度最佳化且資源友好的注意力機制版本,稱為`torch.nn.functional.scaled_dot_product_attention` 或 SDPA。根據輸入的性質,此函式利用多種底層最佳化。其效能和記憶體效率超越了傳統的注意力模型。值得注意的是,SDPA 函式反映了 Dao 及其團隊撰寫的《Fast and Memory-Efficient Exact Attention with IO-Awareness》研究論文中強調的 *flash attention* 技術的特點。
如果您使用的是 PyTorch 2.0 或更高版本的 Diffusers,並且 SDPA 函式可訪問,則會自動應用這些增強功能。請根據官方指南設定 torch 2.0 或更高版本以開始使用!
images = pipeline(caption, height=1024, width=1536, prior_timesteps=DEFAULT_STAGE_C_TIMESTEPS, prior_guidance_scale=4.0, num_images_per_prompt=4).images
有關 `diffusers` 如何利用 SDPA 的深入瞭解,請查閱文件。
如果您使用的是 Pytorch 2.0 之前的版本,仍然可以使用 xFormers 庫來實現記憶體高效的注意力機制。
pipeline.enable_xformers_memory_efficient_attention()
最佳化技術 2:Torch Compile
如果你正在尋求額外的效能提升,可以使用 `torch.compile`。最好將其應用於先驗模型和解碼器的主模型,以實現最大的效能提升。
pipeline.prior_prior = torch.compile(pipeline.prior_prior , mode="reduce-overhead", fullgraph=True)
pipeline.decoder = torch.compile(pipeline.decoder, mode="reduce-overhead", fullgraph=True)
請記住,首次推理步驟將花費很長時間(長達 2 分鐘),因為模型正在編譯。之後,您可以正常執行推理
images = pipeline(caption, height=1024, width=1536, prior_timesteps=DEFAULT_STAGE_C_TIMESTEPS, prior_guidance_scale=4.0, num_images_per_prompt=4).images
好訊息是這種編譯是一次性執行的。之後,您將始終體驗到相同影像解析度下的更快推理。編譯的初始時間投入很快就會被隨後的速度優勢抵消。有關 `torch.compile` 及其細微差別的深入探討,請檢視官方文件。
模型是如何訓練的?
能夠訓練這個模型,完全得益於 Stability AI 提供的計算資源。我們要特別感謝 Stability 讓我們能夠進行這類研究,並有機會讓更多人接觸到它!
資源
- 有關此模型的更多資訊,請參閱官方 Diffusers 文件。
- 所有檢查點都可以在hub上找到
- 你可以在這裡試用演示。
- 如果你想討論未來的專案或貢獻自己的想法,請加入我們的 Discord!
- 訓練程式碼等更多內容可以在官方 GitHub 儲存庫中找到。