Diffusers 文件

Wuerstchen

Hugging Face's logo
加入 Hugging Face 社群

並獲得增強的文件體驗

開始使用

Wuerstchen

Wuerstchen 模型透過將潛在空間壓縮 42 倍,顯著降低了計算成本,同時不影響影像質量並加速了推理。在訓練過程中,Wuerstchen 使用兩個模型(VQGAN + 自動編碼器)來壓縮潛在空間,然後第三個模型(文字條件潛在擴散模型)在此高度壓縮的空間上進行條件訓練以生成影像。

為了將先驗模型適應到 GPU 記憶體並加速訓練,可以嘗試分別啟用 gradient_accumulation_stepsgradient_checkpointingmixed_precision

本指南探討了 train_text_to_image_prior.py 指令碼,以幫助您更熟悉它,以及如何根據您自己的用例對其進行調整。

在執行指令碼之前,請確保從原始碼安裝庫

git clone https://github.com/huggingface/diffusers
cd diffusers
pip install .

然後導航到包含訓練指令碼的示例資料夾並安裝您正在使用的指令碼所需的依賴項

cd examples/wuerstchen/text_to_image
pip install -r requirements.txt

🤗 Accelerate 是一個幫助您在多個 GPU/TPU 上或使用混合精度進行訓練的庫。它會根據您的硬體和環境自動配置您的訓練設定。請檢視 🤗 Accelerate 快速入門 以瞭解更多資訊。

初始化 🤗 Accelerate 環境

accelerate config

要設定預設的 🤗 Accelerate 環境而不選擇任何配置

accelerate config default

或者如果您的環境不支援互動式 shell(例如筆記本),您可以使用

from accelerate.utils import write_basic_config

write_basic_config()

最後,如果您想在自己的資料集上訓練模型,請檢視 建立訓練資料集 指南,瞭解如何建立與訓練指令碼相容的資料集。

以下部分重點介紹了訓練指令碼中對於理解如何修改它很重要的部分,但它並未詳細介紹 指令碼 的所有方面。如果您有興趣瞭解更多資訊,請隨時閱讀指令碼並告訴我們您是否有任何問題或疑慮。

指令碼引數

訓練指令碼提供了許多引數來幫助您自定義訓練執行。所有引數及其描述都可以在 parse_args() 函式中找到。它為每個引數提供了預設值,例如訓練批處理大小和學習率,但您也可以在訓練命令中設定自己的值。

例如,要使用 fp16 格式的混合精度加速訓練,請將 --mixed_precision 引數新增到訓練命令中

accelerate launch train_text_to_image_prior.py \
  --mixed_precision="fp16"

大多數引數與 文字到影像 訓練指南中的引數相同,所以讓我們直接深入瞭解 Wuerstchen 訓練指令碼!

訓練指令碼

訓練指令碼也與 文字到影像 訓練指南類似,但它已經過修改以支援 Wuerstchen。本指南側重於 Wuerstchen 訓練指令碼特有的程式碼。

main() 函式首先初始化影像編碼器——一個 EfficientNet——以及常用的排程器和分詞器。

with ContextManagers(deepspeed_zero_init_disabled_context_manager()):
    pretrained_checkpoint_file = hf_hub_download("dome272/wuerstchen", filename="model_v2_stage_b.pt")
    state_dict = torch.load(pretrained_checkpoint_file, map_location="cpu")
    image_encoder = EfficientNetEncoder()
    image_encoder.load_state_dict(state_dict["effnet_state_dict"])
    image_encoder.eval()

您還將載入 WuerstchenPrior 模型進行最佳化。

prior = WuerstchenPrior.from_pretrained(args.pretrained_prior_model_name_or_path, subfolder="prior")

optimizer = optimizer_cls(
    prior.parameters(),
    lr=args.learning_rate,
    betas=(args.adam_beta1, args.adam_beta2),
    weight_decay=args.adam_weight_decay,
    eps=args.adam_epsilon,
)

接下來,您將對影像應用一些 轉換,並對標題進行 分詞

def preprocess_train(examples):
    images = [image.convert("RGB") for image in examples[image_column]]
    examples["effnet_pixel_values"] = [effnet_transforms(image) for image in images]
    examples["text_input_ids"], examples["text_mask"] = tokenize_captions(examples)
    return examples

最後,訓練迴圈 負責使用 EfficientNetEncoder 將影像壓縮到潛在空間,向潛在空間新增噪聲,並使用 WuerstchenPrior 模型預測噪聲殘差。

pred_noise = prior(noisy_latents, timesteps, prompt_embeds)

如果您想了解更多關於訓練迴圈如何工作的資訊,請檢視 理解管道、模型和排程器 教程,它分解了去噪過程的基本模式。

啟動指令碼

完成所有更改或對預設配置滿意後,您就可以啟動訓練指令碼了!🚀

DATASET_NAME 環境變數設定為 Hub 中資料集的名稱。本指南使用 Naruto BLIP captions 資料集,但您也可以建立並訓練自己的資料集(請參閱 建立訓練資料集 指南)。

要使用 Weights & Biases 監控訓練進度,請在訓練命令中新增 --report_to=wandb 引數。您還需要在訓練命令中新增 --validation_prompt 以跟蹤結果。這對於除錯模型和檢視中間結果非常有用。

export DATASET_NAME="lambdalabs/naruto-blip-captions"

accelerate launch  train_text_to_image_prior.py \
  --mixed_precision="fp16" \
  --dataset_name=$DATASET_NAME \
  --resolution=768 \
  --train_batch_size=4 \
  --gradient_accumulation_steps=4 \
  --gradient_checkpointing \
  --dataloader_num_workers=4 \
  --max_train_steps=15000 \
  --learning_rate=1e-05 \
  --max_grad_norm=1 \
  --checkpoints_total_limit=3 \
  --lr_scheduler="constant" \
  --lr_warmup_steps=0 \
  --validation_prompts="A robot naruto, 4k photo" \
  --report_to="wandb" \
  --push_to_hub \
  --output_dir="wuerstchen-prior-naruto-model"

訓練完成後,您可以使用新訓練的模型進行推理!

import torch
from diffusers import AutoPipelineForText2Image
from diffusers.pipelines.wuerstchen import DEFAULT_STAGE_C_TIMESTEPS

pipeline = AutoPipelineForText2Image.from_pretrained("path/to/saved/model", torch_dtype=torch.float16).to("cuda")

caption = "A cute bird naruto holding a shield"
images = pipeline(
    caption,
    width=1024,
    height=1536,
    prior_timesteps=DEFAULT_STAGE_C_TIMESTEPS,
    prior_guidance_scale=4.0,
    num_images_per_prompt=2,
).images

下一步

恭喜您訓練了一個 Wuerstchen 模型!要了解如何使用您的新模型,以下內容可能會有所幫助

  • 檢視 Wuerstchen API 文件,瞭解如何使用流水線進行文字到影像生成及其限制。
< > 在 GitHub 上更新

© . This site is unofficial and not affiliated with Hugging Face, Inc.