TRL 文件

去噪擴散策略最佳化

Hugging Face's logo
加入 Hugging Face 社群

並獲得增強的文件體驗

開始使用

去噪擴散策略最佳化

目的

之前 DDPO 微調後

透過強化學習開始微調 Stable Diffusion

使用強化學習微調 Stable Diffusion 模型的機制大量使用了 HuggingFace 的 `diffusers` 庫。之所以這樣說,是因為入門需要對 `diffusers` 庫的概念,主要是其中的兩個——管道(pipelines)和排程器(schedulers),有一定的熟悉。開箱即用(`diffusers` 庫)的情況下,沒有適合強化學習微調的 `Pipeline` 或 `Scheduler` 例項。需要進行一些調整。

本庫提供了一個管道介面,該介面必須實現才能與 `DDPOTrainer` 結合使用,`DDPOTrainer` 是使用強化學習微調 Stable Diffusion 的主要機制。**注意:目前僅支援 StableDiffusion 架構。** 本庫提供了一個預設實現,您可以直接使用。假設預設實現足夠或者為了快速啟動,請參考本指南中的訓練示例。

該介面的目的是將管道和排程器融合到一個物件中,從而最大限度地將所有約束集中在一個地方。設計該介面的目的是希望在本文撰寫之時,能支援此倉庫及其他地方示例之外的管道和排程器。此外,排程器步驟是此管道介面的一個方法,這可能看起來有些多餘,因為原始排程器可以透過介面訪問,但這是將排程器步驟輸出限制為符合當前演算法(DDPO)的輸出型別的唯一方法。

要更詳細地瞭解該介面及其關聯的預設實現,請點選此處

請注意,預設實現包含 LoRA 實現路徑和非 LoRA 實現路徑。LoRA 標誌預設啟用,可以透過傳遞標誌來關閉。基於 LORA 的訓練速度更快,並且 LORA 相關的模型超引數對模型收斂的影響不像非 LORA 訓練那樣挑剔。

此外,還期望提供一個獎勵函式和一個提示函式。獎勵函式用於評估生成的影像,提示函式用於生成用於生成影像的提示。

ddpo.py 示例指令碼入門

`ddpo.py` 指令碼是使用 `DDPO` 訓練器微調 Stable Diffusion 模型的工作示例。此示例明確配置了與配置物件 (`DDPOConfig`) 相關聯的整體引數的一小部分。

**注意:** 建議使用一塊 A100 GPU 來執行此示例。低於 A100 的顯示卡將無法執行此示例指令碼,即使能透過相對較小的引數執行,結果也可能不盡如人意。

幾乎每個配置引數都有一個預設值。使用者只需要一個命令列標誌引數即可啟動和執行。使用者需要擁有一個 huggingface 使用者訪問令牌,該令牌將用於在微調後將模型上傳到 HuggingFace Hub。以下是要輸入的 bash 命令以啟動和執行:

python ddpo.py --hf_user_access_token <token>

要獲取 `stable_diffusion_tuning.py` 的文件,請執行 `python stable_diffusion_tuning.py --help`

在配置訓練器時(除了使用示例指令碼的情況),請記住以下幾點(程式碼也會為您檢查這些):

  • 可配置的取樣批大小 (`--ddpo_config.sample_batch_size=6`) 應大於或等於可配置的訓練批大小 (`--ddpo_config.train_batch_size=3`)
  • 可配置的取樣批次大小(`--ddpo_config.sample_batch_size=6`)必須能夠被可配置的訓練批次大小(`--ddpo_config.train_batch_size=3`)整除。
  • 可配置的取樣批次大小 (`--ddpo_config.sample_batch_size=6`) 必須能被可配置的梯度累積步數 (`--ddpo_config.train_gradient_accumulation_steps=1`) 和可配置的加速器程序數同時整除。

設定影像日誌鉤子函式

期望函式以列表的形式接收一個列表列表:

[[image, prompt, prompt_metadata, rewards, reward_metadata], ...]

並且 `image`、`prompt`、`prompt_metadata`、`rewards`、`reward_metadata` 都是批處理的。列表列表中的最後一個列表表示最後一個樣本批次。您可能希望記錄這一個。雖然您可以隨意記錄,但建議使用 `wandb` 或 `tensorboard`。

關鍵術語

  • `rewards`:獎勵/分數是與生成的影像相關的數值,是指導強化學習過程的關鍵。
  • `reward_metadata`:獎勵元資料是與獎勵相關的元資料。可以將其理解為隨獎勵一起提供的額外資訊負載。
  • `prompt`:提示是用於生成影像的文字。
  • `prompt_metadata`:提示元資料是與提示相關的元資料。當獎勵模型包含 `FLAVA` 設定時,即生成的影像(請參見此處:https://github.com/kvablack/ddpo-pytorch/blob/main/ddpo_pytorch/rewards.py#L45)預期附帶問題和地面答案(連結到生成的影像),這種情況下的元資料將不為空。
  • `image`:由 Stable Diffusion 模型生成的影像

以下是使用 `wandb` 記錄取樣影像的示例程式碼。

# for logging these images to wandb

def image_outputs_hook(image_data, global_step, accelerate_logger):
    # For the sake of this example, we only care about the last batch
    # hence we extract the last element of the list
    result = {}
    images, prompts, _, rewards, _ = image_data[-1]
    for i, image in enumerate(images):
        pil = Image.fromarray(
            (image.cpu().numpy().transpose(1, 2, 0) * 255).astype(np.uint8)
        )
        pil = pil.resize((256, 256))
        result[f"{prompts[i]:.25} | {rewards[i]:.2f}"] = [pil]
    accelerate_logger.log_images(
        result,
        step=global_step,
    )

使用微調模型

假設您已完成所有 epoch 並將模型推送到 hub,您可以按如下方式使用微調模型:


import torch
from trl import DefaultDDPOStableDiffusionPipeline

pipeline = DefaultDDPOStableDiffusionPipeline("metric-space/ddpo-finetuned-sd-model")

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

# memory optimization
pipeline.vae.to(device, torch.float16)
pipeline.text_encoder.to(device, torch.float16)
pipeline.unet.to(device, torch.float16)

prompts = ["squirrel", "crab", "starfish", "whale","sponge", "plankton"]
results = pipeline(prompts)

for prompt, image in zip(prompts,results.images):
    image.save(f"{prompt}.png")

鳴謝

這項工作深受此處的倉庫以及相關論文《透過強化學習訓練擴散模型》(作者:Kevin Black、Michael Janner、Yilan Du、Ilya Kostrikov、Sergey Levine)此處的影響。

DDPOTrainer

trl.DDPOTrainer

< >

( config: DDPOConfig reward_function: typing.Callable[[torch.Tensor, tuple[str], tuple[typing.Any]], torch.Tensor] prompt_function: typing.Callable[[], tuple[str, typing.Any]] sd_pipeline: DDPOStableDiffusionPipeline image_samples_hook: typing.Optional[typing.Callable[[typing.Any, typing.Any, typing.Any], typing.Any]] = None )

引數

  • **config** (`DDPOConfig`) — `DDPOTrainer` 的配置物件。請檢視 `PPOConfig` 文件以獲取更多詳細資訊。
  • **reward_function** (Callable[[torch.Tensor, tuple[str], tuple[Any]], torch.Tensor]) — 要使用的獎勵函式 —
  • **prompt_function** (Callable[[], tuple[str, Any]]) — 用於生成指導模型的提示的函式 —
  • **sd_pipeline** (`DDPOStableDiffusionPipeline`) — 用於訓練的 Stable Diffusion 管道。 —
  • **image_samples_hook** (Optional[Callable[[Any, Any, Any], Any]]) — 用於記錄影像的鉤子 —

DDPOTrainer 使用深度擴散策略最佳化來最佳化擴散模型。請注意,此訓練器深受此處工作的影響:https://github.com/kvablack/ddpo-pytorch。目前僅支援基於 Stable Diffusion 的管道

計算損失

< >

( 潛空間向量 時間步 下一潛空間向量 對數機率 優勢 嵌入 )

引數

  • **latents** (torch.Tensor) — 從擴散模型取樣的潛空間向量,形狀:[batch_size, num_channels_latents, height, width]
  • **timesteps** (torch.Tensor) — 從擴散模型取樣的時間步,形狀:[batch_size]
  • **next_latents** (torch.Tensor) — 從擴散模型取樣的下一個潛在變數,形狀:[batch_size, num_channels_latents, height, width]
  • **log_probs** (torch.Tensor) — 潛在變數的對數機率,形狀:[batch_size]
  • **advantages** (torch.Tensor) — 潛在變數的優勢,形狀:[batch_size]
  • **embeds** (torch.Tensor) — 提示的嵌入,形狀:[2*batch_size 或 batch_size, ...] 注意:“或”是因為如果 train_cfg 為 True,則預期負面提示會與嵌入連線。

計算一批解包樣本的損失

建立模型卡片

< >

( model_name: typing.Optional[str] = None dataset_name: typing.Optional[str] = None tags: typing.Union[str, list[str], NoneType] = None )

引數

  • **model_name** (`str` 或 `None`,*可選*,預設為 `None`) — 模型名稱。
  • **dataset_name** (`str` 或 `None`,*可選*,預設為 `None`) — 用於訓練的資料集名稱。
  • **tags** (`str`、`list[str]` 或 `None`,*可選*,預設為 `None`) — 與模型卡關聯的標籤。

使用 Trainer 可用的資訊建立模型卡片的草稿。

步驟

< >

( epoch: int global_step: int ) global_step (int)

引數

  • **epoch** (int) — 當前的 epoch。
  • **global_step** (int) — 當前的全域性步。

返回

全域性步數 (int)

更新後的全域性步。

執行單步訓練。

副作用

  • 模型權重已更新
  • 將統計資料記錄到加速器跟蹤器。
  • 如果 `self.image_samples_callback` 不為空,它將與 `prompt_image_pairs`、`global_step` 和加速器跟蹤器一起被呼叫。

訓練

< >

( epochs: typing.Optional[int] = None )

訓練模型指定數量的 epoch

DDPOConfig

trl.DDPOConfig

< >

( exp_name: str = 'doc-buil' run_name: str = '' seed: int = 0 log_with: typing.Optional[str] = None tracker_kwargs: dict = <factory> accelerator_kwargs: dict = <factory> project_kwargs: dict = <factory> tracker_project_name: str = 'trl' logdir: str = 'logs' num_epochs: int = 100 save_freq: int = 1 num_checkpoint_limit: int = 5 mixed_precision: str = 'fp16' allow_tf32: bool = True resume_from: str = '' sample_num_steps: int = 50 sample_eta: float = 1.0 sample_guidance_scale: float = 5.0 sample_batch_size: int = 1 sample_num_batches_per_epoch: int = 2 train_batch_size: int = 1 train_use_8bit_adam: bool = False train_learning_rate: float = 0.0003 train_adam_beta1: float = 0.9 train_adam_beta2: float = 0.999 train_adam_weight_decay: float = 0.0001 train_adam_epsilon: float = 1e-08 train_gradient_accumulation_steps: int = 1 train_max_grad_norm: float = 1.0 train_num_inner_epochs: int = 1 train_cfg: bool = True train_adv_clip_max: float = 5.0 train_clip_range: float = 0.0001 train_timestep_fraction: float = 1.0 per_prompt_stat_tracking: bool = False per_prompt_stat_tracking_buffer_size: int = 16 per_prompt_stat_tracking_min_count: int = 16 async_reward_computation: bool = False max_workers: int = 2 negative_prompts: str = '' push_to_hub: bool = False )

引數

  • **exp_name** (`str`,*可選*,預設為 `os.path.basename(sys.argv[0])[ -- -len(".py")]`):此實驗的名稱(預設情況下是檔名,不帶副檔名)。
  • **run_name** (`str`,*可選*,預設為 `""`) — 此執行的名稱。
  • **seed** (`int`,*可選*,預設為 `0`) — 隨機種子。
  • **log_with** (`Literal["wandb", "tensorboard"]]` 或 `None`,*可選*,預設為 `None`) — 使用 'wandb' 或 'tensorboard' 記錄,請檢視 https://huggingface.co/docs/accelerate/usage_guides/tracking 獲取更多詳細資訊。
  • **tracker_kwargs** (`Dict`,*可選*,預設為 `{}`) — 跟蹤器的關鍵字引數(例如 wandb_project)。
  • **accelerator_kwargs** (`Dict`,*可選*,預設為 `{}`) — 加速器的關鍵字引數。
  • **project_kwargs** (`Dict`,*可選*,預設為 `{}`) — 加速器專案配置的關鍵字引數(例如 `logging_dir`)。
  • **tracker_project_name** (`str`,*可選*,預設為 `"trl"`) — 用於跟蹤的專案名稱。
  • **logdir** (`str`,*可選*,預設為 `"logs"`) — 用於儲存檢查點的頂級日誌目錄。
  • num_epochs (int, optional, defaults to 100) — 訓練的 epoch 數量。
  • save_freq (int, optional, defaults to 1) — 儲存模型檢查點之間的 epoch 數量。
  • num_checkpoint_limit (int, optional, defaults to 5) — 在覆蓋舊檢查點之前保留的檢查點數量。
  • mixed_precision (str, optional, defaults to "fp16") — 混合精度訓練。
  • allow_tf32 (bool, optional, defaults to True) — 允許在 Ampere GPU 上使用 tf32
  • resume_from (str, optional, defaults to "") — 從檢查點恢復訓練。
  • sample_num_steps (int, optional, defaults to 50) — 取樣器推理步數。
  • sample_eta (float, optional, defaults to 1.0) — DDIM 取樣器的 Eta 引數。
  • sample_guidance_scale (float, optional, defaults to 5.0) — 無分類器指導權重。
  • sample_batch_size (int, optional, defaults to 1) — 用於取樣的批大小(每 GPU)。
  • sample_num_batches_per_epoch (int, optional, defaults to 2) — 每個 epoch 取樣的批次數量。
  • train_batch_size (int, optional, defaults to 1) — 用於訓練的批大小(每 GPU)。
  • train_use_8bit_adam (bool, optional, defaults to False) — 使用 bitsandbytes 中的 8 位 Adam 最佳化器。
  • train_learning_rate (float, optional, defaults to 3e-4) — 學習率。
  • train_adam_beta1 (float, optional, defaults to 0.9) — Adam beta1。
  • train_adam_beta2 (float, optional, defaults to 0.999) — Adam beta2。
  • train_adam_weight_decay (float, optional, defaults to 1e-4) — Adam 權重衰減。
  • train_adam_epsilon (float, optional, defaults to 1e-8) — Adam epsilon。
  • train_gradient_accumulation_steps (int, optional, defaults to 1) — 梯度累積步數。
  • train_max_grad_norm (float, optional, defaults to 1.0) — 梯度裁剪的最大梯度範數。
  • train_num_inner_epochs (int, optional, defaults to 1) — 每個外部 epoch 的內部 epoch 數量。
  • train_cfg (bool, optional, defaults to True) — 訓練期間是否使用無分類器指導。
  • train_adv_clip_max (float, optional, defaults to 5.0) — 將優勢剪輯到範圍。
  • train_clip_range (float, optional, defaults to 1e-4) — PPO 裁剪範圍。
  • train_timestep_fraction (float, optional, defaults to 1.0) — 訓練時間步長的比例。
  • per_prompt_stat_tracking (bool, optional, defaults to False) — 是否為每個提示單獨跟蹤統計資訊。
  • per_prompt_stat_tracking_buffer_size (int, optional, defaults to 16) — 為每個提示在緩衝區中儲存的獎勵值數量。
  • per_prompt_stat_tracking_min_count (int, optional, defaults to 16) — 在緩衝區中儲存的最小獎勵值數量。
  • async_reward_computation (bool, optional, defaults to False) — 是否非同步計算獎勵。
  • max_workers (int, optional, defaults to 2) — 用於非同步獎勵計算的最大工作器數量。
  • negative_prompts (str, optional, defaults to "") — 用作負面示例的提示的逗號分隔列表。
  • push_to_hub (bool, optional, defaults to False) — 是否將最終模型檢查點推送到 Hub。

用於 DDPOTrainer 的配置類。

使用 HfArgumentParser,我們可以將此類別轉換為可在命令列上指定的 argparse 引數。

< > 在 GitHub 上更新

© . This site is unofficial and not affiliated with Hugging Face, Inc.