使用Diffusers和PEFT為Flux實現快速LoRA推理
LoRA介面卡為各種大小的模型提供了極大的定製化能力。在影像生成方面,它們可以賦予模型不同的風格、不同的角色等更多功能。有時,它們還可以用於減少推理延遲。因此,它們的重要性是至關重要的,尤其是在定製和微調模型時。
在這篇文章中,我們選擇了Flux.1-Dev模型進行文字到影像生成,因為它廣受歡迎且應用廣泛。我們探討了如何在使用LoRA時最佳化其推理速度(約2.3倍)。根據Hugging Face Hub平臺上的報告,該模型已訓練了超過3萬個介面卡。因此,它對社群的重要性是巨大的。
請注意,儘管我們演示了Flux的加速效果,但我們相信我們的方法足夠通用,可以應用於其他模型。
如果您迫不及待想開始編碼,請檢視隨附的程式碼庫。
目錄
最佳化LoRA推理的障礙
在提供LoRA服務時,通常會進行熱插拔(即插拔不同的LoRA)。LoRA會改變基礎模型的架構。此外,LoRA之間也可能不同——每個LoRA可能具有不同的秩,並針對不同的層進行適配。為了應對LoRA的這些動態特性,我們必須採取必要的措施來確保我們應用的最佳化是穩健的。
例如,我們可以在載入了特定LoRA的模型上應用torch.compile
,以提高推理延遲。但是,一旦我們將LoRA替換為另一個(可能具有不同配置的)LoRA,就會遇到重新編譯的問題,導致推理速度下降。
還可以將LoRA引數融合到基礎模型引數中,執行編譯,然後在載入新引數時解除LoRA引數的融合。然而,這種方法在每次執行推理時,由於潛在的架構級更改,仍然會遇到重新編譯的問題。
我們的最佳化方法考慮了上述情況,以儘可能地切合實際。以下是我們最佳化方法的核心組成部分:
- Flash Attention 3 (FA3)
torch.compile
- TorchAO的FP8量化
- 支援熱插拔
請注意,在上述元件中,FP8量化是無損的,但通常能提供最強大的速度-記憶體權衡。儘管我們主要使用NVIDIA GPU測試了該方法,但它也應該適用於AMD GPU。
最佳化方法
在我們之前的部落格文章(文章1和文章2)中,我們已經討論了使用我們最佳化方法前三個元件的好處。逐一應用它們只需幾行程式碼。
from diffusers import DiffusionPipeline, TorchAoConfig
from diffusers.quantizers import PipelineQuantizationConfig
from utils.fa3_processor import FlashFluxAttnProcessor3_0
import torch
# quantize the Flux transformer with FP8
pipe = DiffusionPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev",
torch_dtype=torch.bfloat16,
quantization_config=PipelineQuantizationConfig(
quant_mapping={"transformer": TorchAoConfig("float8dq_e4m3_row")}
)
).to("cuda")
# use Flash-attention 3
pipe.transformer.set_attn_processor(FlashFluxAttnProcessor3_0())
# use torch.compile()
pipe.transformer.compile(fullgraph=True, mode="max-autotune")
# perform inference
pipe_kwargs = {
"prompt": "A cat holding a sign that says hello world",
"height": 1024,
"width": 1024,
"guidance_scale": 3.5,
"num_inference_steps": 28,
"max_sequence_length": 512,
}
# first time will be slower, subsequent runs will be faster
image = pipe(**pipe_kwargs).images[0]
FA3處理器來自此處。
當我們將LoRA熱插拔到已編譯的擴散Transformer(`pipe.transformer`)中而不觸發重新編譯時,問題開始浮現。
通常,載入和解除安裝LoRA會需要重新編譯,這會抵消編譯帶來的任何速度優勢。幸運的是,有一種方法可以避免重新編譯。透過傳遞`hotswap=True`,Diffusers將保持模型架構不變,只交換LoRA介面卡本身的權重,這不需要重新編譯。
pipe.enable_lora_hotswap(target_rank=max_rank)
pipe.load_lora_weights(<lora-adapter-name1>)
# compile *after* loading the first LoRA
pipe.transformer.compile(mode="max-autotune", fullgraph=True)
image = pipe(**pipe_kwargs).images[0]
# from this point on, load new LoRAs with `hotswap=True`
pipe.load_lora_weights(<lora-adapter-name2>, hotswap=True)
image = pipe(**pipe_kwargs).images[0]
(提醒一下,第一次呼叫`pipe`會很慢,因為`torch.compile`是即時編譯器。然而,隨後的呼叫應該會顯著加快。)
這通常允許在不重新編譯的情況下交換 LoRA,但存在一些限制:
- 我們需要提前提供所有 LoRA 介面卡中的最大秩。因此,如果我們有一個秩為 16 的介面卡,另一個秩為 32 的介面卡,我們需要傳遞 `max_rank=32`。
- 熱插拔的LoRA介面卡只能針對第一個LoRA所針對的相同層或其子集。
- 目前尚不支援文字編碼器目標化。
有關Diffusers中熱插拔及其限制的更多資訊,請訪問文件中的熱插拔部分。
當我們檢視不使用編譯進行熱插拔時的推理延遲時,這種工作流程的好處變得顯而易見。
選項 | 時間 (秒) ⬇️ | 加速 (對比基線) ⬆️ | 備註 |
---|---|---|---|
基準 | 7.8910 | – | 基線 |
已最佳化 | 3.5464 | 2.23倍 | 熱插拔 + 編譯,無重新編譯卡頓(預設開啟FP8) |
無FP8 | 4.3520 | 1.81倍 | 與“已最佳化”相同,但停用FP8量化 |
無FA3 | 4.3020 | 1.84倍 | 停用 FA3 (flash‑attention v3) |
基線 + 編譯 | 5.0920 | 1.55倍 | 編譯開啟,但受間歇性重新編譯停頓影響 |
無FA3_FP8 | 5.0850 | 1.55倍 | 停用 FA3 和 FP8 |
無編譯_FP8 | 7.5190 | 1.05倍 | 停用 FP8 量化和編譯 |
無編譯 | 10.4340 | 0.76倍 | 停用編譯:最慢的設定 |
主要收穫:
- “常規+編譯”選項比常規選項提供了不錯的加速,但它會引發重新編譯問題,從而增加總執行時間。在我們的基準測試中,我們沒有給出編譯時間。
- 透過熱插拔消除重新編譯問題(也稱為“最佳化”選項)時,我們實現了最高的加速。
- 在“最佳化”選項中,FP8量化已啟用,這可能導致質量損失。即使不使用FP8,我們也能獲得不錯的加速(“無FP8”選項)。
- 為了演示目的,我們使用一個包含兩個LoRA的池進行編譯熱插拔。有關完整程式碼,請參閱隨附的程式碼庫。
我們迄今討論的最佳化方法假定能夠訪問像H100這樣的強大GPU。然而,當我們受限於使用RTX 4090等消費級GPU時,我們能做些什麼呢?讓我們一探究竟。
在消費級GPU上最佳化LoRA推理
Flux.1-Dev(不帶任何LoRA)使用Bfloat16資料型別執行,佔用約33GB記憶體。根據LoRA模組的大小,如果不進行任何最佳化,記憶體佔用還會進一步增加。許多消費級GPU,如RTX 4090,只有24GB記憶體。在本節的其餘部分,我們將RTX 4090機器作為我們的測試平臺。
首先,為了實現Flux.1-Dev的端到端執行,我們可以應用CPU解除安裝,將不需要執行當前計算的元件解除安裝到CPU,以釋放更多加速器記憶體。這樣做可以在RTX 4090上以約22GB的記憶體執行整個管道,耗時**35.403秒**。啟用編譯可以將延遲降低到**31.205秒**(1.12倍加速)。在程式碼方面,只需幾行:
pipe = DiffusionPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16,
)
pipe.enable_model_cpu_offload()
# Instead of full compilation, we apply regional compilation
# here to take advantage of `fullgraph=True` and also to reduce
# compilation time. More details can be found here:
# https://huggingface.co/docs/diffusers/main/en/optimization/fp16#regional-compilation
pipe.transformer.compile_repeated_blocks(fullgraph=True)
image = pipe(**pipe_kwargs).images[0]
請注意,我們在此處沒有應用FP8量化,因為它不支援CPU解除安裝和編譯(支援問題執行緒)。因此,僅將FP8量化應用於Flux Transformer不足以緩解記憶體耗盡問題。在這種情況下,我們決定將其移除。
因此,為了利用FP8量化方案,我們需要找到一種無需CPU解除安裝的方法。對於Flux.1-Dev,如果再對T5文字編碼器進行量化,我們應該能夠在24GB記憶體中載入和執行完整的管道。下面是T5文字編碼器量化(來自bitsandbytes
的NF4量化)和未量化時的結果比較。
如上圖所示,量化T5文字編碼器並不會造成太大的質量損失。將量化後的T5文字編碼器和FP8量化後的Flux Transformer與`torch.compile`結合使用,我們得到了相當不錯的結果——從32.27秒降至**9.668秒**(大幅加速約3.3倍),且沒有明顯的質量下降。
即使不量化T5文字編碼器,也可以用24GB的VRAM生成影像,但這會使我們的生成流程稍微複雜一些。
我們現在有了一種在RTX 4090上使用FP8量化執行整個Flux.1-Dev管道的方法。我們可以在相同的硬體上應用先前建立的最佳化LoRA推理方法。由於RTX 4090不支援FA3,我們將堅持以下最佳化方法,並新增T5量化:
- FP8量化
torch.compile
- 支援熱插拔
- T5量化 (使用NF4)
在下表中,我們展示了應用上述元件不同組合的推理延遲資料。
選項 | 關鍵引數標誌 | 時間 (秒) ⬇️ | 加速 (對比基線) ⬆️ |
---|---|---|---|
基準 | disable_fp8=False disable_compile=True quantize_t5=True offload=False |
23.6060 | – |
已最佳化 | disable_fp8=False disable_compile=False quantize_t5=True offload=False |
11.5715 | 2.04倍 |
簡要說明:
- 編譯比基線提供了巨大的2倍加速。
- 即使啟用了解除安裝,其他選項也導致了OOM錯誤。
熱插拔的技術細節
為了實現熱插拔而不觸發重新編譯,必須克服兩個障礙。首先,LoRA的縮放因子必須從浮點數轉換為torch張量,這相對容易實現。其次,LoRA權重的形狀需要填充到所需的最大形狀。這樣,可以替換權重中的資料而無需重新分配整個屬性。這就是為什麼上面討論的`max_rank`引數至關重要。由於我們將值用零填充,結果保持不變,儘管計算速度會根據填充的大小而稍有減慢。
由於沒有新增新的LoRA屬性,這也要求第一個LoRA之後的每個LoRA只能針對第一個LoRA所針對的相同層或其子集。因此,請明智地選擇載入順序。如果LoRA針對不相交的層,則可以建立一個針對所有目標層並集的虛擬LoRA。
要檢視此實現的詳細資訊,請訪問PEFT中的`hotswap.py`檔案。
結論
本文概述了一種用於Flux快速LoRA推理的最佳化方法,並展示了顯著的加速效果。我們的方法結合了Flash Attention 3、`torch.compile`和FP8量化,同時確保了熱插拔功能,避免了重新編譯問題。在H100等高階GPU上,這種最佳化設定比基線提供了2.23倍的加速。
對於消費級GPU,特別是RTX 4090,我們透過引入T5文字編碼器量化(NF4)和利用區域編譯解決了記憶體限制。這種全面的方法實現了顯著的2.04倍加速,即使在有限的VRAM下,也能使Flux上的LoRA推理變得可行且高效。關鍵在於,透過仔細管理編譯和量化,LoRA的優勢可以在不同的硬體配置上充分實現。
希望本文提供的秘訣能激發您最佳化基於LoRA的用例,從而受益於快速推理。
資源
以下是本文中引用的重要資源列表