透過 TRL 使用 DDPO 微調 Stable Diffusion 模型

釋出於 2023 年 9 月 29 日
在 GitHub 上更新

引言

擴散模型 (例如 DALL-E 2, Stable Diffusion) 是一類生成模型,在生成影像方面取得了巨大成功,尤其是在照片級真實感影像方面。然而,這些模型生成的影像可能並不總是符合人類的偏好或意圖。因此出現了對齊問題,即如何確保模型的輸出與人類偏好(如“質量”)或難以透過提示表達的意圖保持一致?這就是強化學習發揮作用的地方。

在大型語言模型 (LLMs) 的世界裡,強化學習 (RL) 已被證明是一種非常有效的工具,用於將所述模型與人類偏好對齊。它是像 ChatGPT 這樣系統表現優異的主要秘訣之一。更準確地說,RL 是從人類反饋中進行強化學習 (RLHF) 的關鍵組成部分,它讓 ChatGPT 能夠像人類一樣聊天。

《使用強化學習訓練擴散模型》 一文中,Black 等人展示瞭如何透過一種名為去噪擴散策略最佳化 (Denoising Diffusion Policy Optimization, DDPO) 的方法來增強擴散模型,利用 RL 根據目標函式對其進行微調。

在這篇博文中,我們討論了 DDPO 的由來、其工作原理的簡要描述,以及如何將 DDPO 融入 RLHF 工作流程,以實現更符合人類審美的模型輸出。然後,我們迅速轉向討論如何使用 `trl` 庫中新整合的 `DDPOTrainer` 將 DDPO 應用到你的模型上,並分享我們在 Stable Diffusion 上執行 DDPO 的發現。

DDPO 的優勢

對於如何嘗試用 RL 微調擴散模型這個問題,DDPO 並非唯一可行的答案。

在深入探討之前,理解不同 RL 方案優劣時有兩個關鍵點需要記住:

  1. 計算效率是關鍵。資料分佈越複雜,計算成本就越高。
  2. 近似是好的,但由於近似並非真實情況,相關誤差會累積。

在 DDPO 之前,獎勵加權迴歸 (RWR) 是一種已確立的、使用強化學習微調擴散模型的方法。RWR 複用擴散模型的去噪損失函式,並使用從模型本身取樣的訓練資料,以及每個樣本的損失權重,該權重取決於最終樣本的關聯獎勵。該演算法忽略了中間的去噪步驟/樣本。雖然這能行得通,但有兩點需要注意:

  1. 透過加權關聯損失(這是一個最大似然目標)進行最佳化是一種近似最佳化。
  2. 關聯損失並非精確的最大似然目標,而是從一個重新加權的變分界推匯出的近似值。

這兩層近似對效能和處理複雜目標的能力都有顯著影響。

DDPO 以此方法為起點。DDPO 不像 RWR 那樣只關注最終樣本,將去噪步驟視為單一步驟,而是將整個去噪過程構建為一個多步馬爾可夫決策過程 (MDP),其中獎勵在最後才收到。這種形式化,加上使用固定的取樣器,為智慧體策略鋪平了道路,使其成為一個各向同性的高斯分佈,而不是任意複雜的分佈。因此,DDPO 不使用最終樣本的近似似然(這是 RWR 的路徑),而是使用每個去噪步驟的精確似然,這非常容易計算 ((μ,σ2;x)=n2log(2π)n2log(σ2)12σ2i=1n(xiμ)2 \ell(\mu, \sigma^2; x) = -\frac{n}{2} \log(2\pi) - \frac{n}{2} \log(\sigma^2) - \frac{1}{2\sigma^2} \sum_{i=1}^n (x_i - \mu)^2 ).

如果你有興趣瞭解更多關於 DDPO 的細節,我們鼓勵你檢視原始論文相關的部落格文章

DDPO 演算法簡介

鑑於用於模擬去噪過程順序性的 MDP 框架以及隨之而來的其他考慮,解決最佳化問題的首選工具是策略梯度方法。具體來說,是近端策略最佳化 (PPO)。整個 DDPO 演算法與近端策略最佳化 (PPO) 基本相同,但其中一個高度定製化的部分是 PPO 的軌跡收集部分。

這裡有一個圖表來總結流程

dppo rl schematic

DDPO 和 RLHF:融合以增強美學

RLHF 的一般訓練流程大致可以分為以下幾個步驟:

  1. 對一個“基礎”模型進行有監督的微調,使其學習新資料的分佈。
  2. 收集偏好資料並用其訓練一個獎勵模型。
  3. 使用獎勵模型作為訊號,透過強化學習對模型進行微調。

需要注意的是,在 RLHF 的背景下,偏好資料是捕獲人類反饋的主要來源。

當我們將 DDPO 加入進來時,工作流程會變成如下形式:

  1. 從一個預訓練的擴散模型開始。
  2. 收集偏好資料並用其訓練一個獎勵模型。
  3. 使用獎勵模型作為訊號,透過 DDPO 對模型進行微調。

請注意,一般 RLHF 工作流程中的第 3 步在後者的步驟列表中缺失了,這是因為經驗證明(正如你將親眼看到的那樣)這一步並非必要。

為了讓擴散模型輸出更符合人類審美觀念的影像,我們遵循以下步驟:

  1. 從一個預訓練的 Stable Diffusion (SD) 模型開始。
  2. 美學視覺分析 (AVA) 資料集上訓練一個凍結的 CLIP 模型,該模型帶有一個可訓練的迴歸頭,用於預測人們對輸入影像的平均喜好程度。
  3. 使用美學預測模型作為獎勵訊號,透過 DDPO 對 SD 模型進行微調。

我們在接下來的章節中將牢記這些步驟,實際執行它們,具體描述如下。

使用 DDPO 訓練 Stable Diffusion

設定

首先,在硬體方面,對於這個 DDPO 實現,至少需要一個 A100 NVIDIA GPU 才能成功訓練。任何低於此 GPU 型別的裝置很快就會遇到記憶體不足的問題。

使用 pip 安裝 `trl` 庫

pip install trl[diffusers]

這應該會安裝主庫。以下依賴項用於跟蹤和影像日誌記錄。安裝 `wandb` 後,請務必登入以將結果儲存到個人賬戶。

pip install wandb torchvision

注意:你也可以選擇使用 `tensorboard` 而不是 `wandb`,為此你需要透過 `pip` 安裝 `tensorboard` 包。

詳細步驟

`trl` 庫中負責 DDPO 訓練的主要類是 `DDPOTrainer` 和 `DDPOConfig`。有關 `DDPOTrainer` 和 `DDPOConfig` 的更多通用資訊,請參閱文件。在 `trl` 倉庫中有一個示例訓練指令碼。它將這兩個類與所需輸入的預設實現和預設引數結合使用,以微調來自 `RunwayML` 的預設預訓練 Stable Diffusion 模型。

此示例指令碼使用 `wandb` 進行日誌記錄,並使用一個美學獎勵模型,其權重從一個公開的 HuggingFace 倉庫中讀取(因此,收集資料和訓練美學獎勵模型的工作已經為你完成)。預設使用的提示資料集是一個動物名稱列表。

使用者只需提供一個命令列標誌引數即可開始執行。此外,使用者需要有一個 Hugging Face 使用者訪問令牌,該令牌將在微調後用於將模型上傳到 Hugging Face Hub。

以下 bash 命令可以啟動執行

python ddpo.py --hf_user_access_token <token>

下表包含了與積極結果直接相關的關鍵超引數。

引數 描述 單 GPU 訓練的推薦值(截至目前)
num_epochs 訓練的輪數 200
train_batch_size 用於訓練的批次大小 3
sample_batch_size 用於取樣的批次大小 6
gradient_accumulation_steps 要使用的基於加速器的梯度累積步數 1
sample_num_steps 取樣的步數 50
sample_num_batches_per_epoch 每輪取樣的批次數 4
per_prompt_stat_tracking 是否按提示跟蹤統計資料。如果為 false,將使用整個批次的均值和標準差計算優勢,而不是按提示跟蹤 True
per_prompt_stat_tracking_buffer_size 用於按提示跟蹤統計資料的緩衝區大小 32
mixed_precision 混合精度訓練 True
train_learning_rate 學習率 3e-4

提供的指令碼僅僅是一個起點。請隨意調整超引數,甚至徹底修改指令碼以適應不同的目標函式。例如,可以整合一個衡量 JPEG 可壓縮性的函式,或一個使用多模態模型評估視覺-文字對齊的函式,以及其他可能性。

經驗教訓

  1. 儘管訓練提示詞數量極少,但結果似乎在各種提示詞上都能很好地泛化。對於獎勵美學的目標函式,這一點已得到充分驗證。
  2. 嘗試透過增加訓練提示詞數量和改變提示詞來明確泛化,至少對於美學目標函式而言,似乎會減慢收斂速度,而學到的泛化行為幾乎察覺不到(如果存在的話)。
  3. 雖然 LoRA 是推薦的,並且經過多次測試,但非 LoRA 方案也值得考慮,原因之一是根據經驗證據,非 LoRA 似乎能生成比 LoRA 更復雜的影像。然而,為穩定的非 LoRA 執行找到合適的超引數要更具挑戰性。
  4. 對於非 LoRA 配置引數的建議是:將學習率設定得相對較低,大約 `1e-5` 應該可以,並將 `mixed_precision` 設定為 `None`。

結果

以下是對於提示詞 `bear`(熊)、`heaven`(天堂)和 `dune`(沙丘),微調前(左)和微調後(右)的輸出(每行對應一個提示詞的輸出)。

微調前 微調後
nonfinetuned_bear.png finetuned_bear.png
nonfinetuned_heaven.png finetuned_heaven.png
nonfinetuned_dune.png finetuned_dune.png

侷限性

  1. 目前 `trl` 的 DDPOTrainer 僅限於微調原版 SD 模型;
  2. 在我們的實驗中,我們主要關注 LoRA,它效果很好。我們也進行了一些全量訓練的實驗,這可以帶來更好的質量,但找到合適的超引數更具挑戰性。

結論

像 Stable Diffusion 這樣的擴散模型,在使用 DDPO 進行微調後,可以在人類感知或任何其他可被恰當概念化為目標函式的指標上,顯著提升生成影像的質量。

DDPO 的計算效率及其在不依賴近似的情況下進行最佳化的能力,尤其是在與早期實現相同微調擴散模型目標的方法相比時,使其成為微調像 Stable Diffusion 這樣的擴散模型的合適選擇。

`trl` 庫的 `DDPOTrainer` 實現了用於微調 SD 模型的 DDPO。

我們的實驗結果強調了 DDPO 在廣泛提示詞上泛化的能力,儘管透過變化提示詞進行顯式泛化的嘗試結果好壞參半。為非 LoRA 設定找到合適的超引數的困難也成為一個重要的經驗教訓。

DDPO 是一種很有前途的技術,可以將擴散模型與任何獎勵函式對齊,我們希望透過在 TRL 中的釋出,能讓它更容易地被社群所用!

致謝

感謝 Chunte Lee 為這篇博文製作縮圖。

社群

註冊登入以發表評論

© . This site is unofficial and not affiliated with Hugging Face, Inc.