透過 TRL 使用 DDPO 微調 Stable Diffusion 模型
引言
擴散模型 (例如 DALL-E 2, Stable Diffusion) 是一類生成模型,在生成影像方面取得了巨大成功,尤其是在照片級真實感影像方面。然而,這些模型生成的影像可能並不總是符合人類的偏好或意圖。因此出現了對齊問題,即如何確保模型的輸出與人類偏好(如“質量”)或難以透過提示表達的意圖保持一致?這就是強化學習發揮作用的地方。
在大型語言模型 (LLMs) 的世界裡,強化學習 (RL) 已被證明是一種非常有效的工具,用於將所述模型與人類偏好對齊。它是像 ChatGPT 這樣系統表現優異的主要秘訣之一。更準確地說,RL 是從人類反饋中進行強化學習 (RLHF) 的關鍵組成部分,它讓 ChatGPT 能夠像人類一樣聊天。
在 《使用強化學習訓練擴散模型》 一文中,Black 等人展示瞭如何透過一種名為去噪擴散策略最佳化 (Denoising Diffusion Policy Optimization, DDPO) 的方法來增強擴散模型,利用 RL 根據目標函式對其進行微調。
在這篇博文中,我們討論了 DDPO 的由來、其工作原理的簡要描述,以及如何將 DDPO 融入 RLHF 工作流程,以實現更符合人類審美的模型輸出。然後,我們迅速轉向討論如何使用 `trl` 庫中新整合的 `DDPOTrainer` 將 DDPO 應用到你的模型上,並分享我們在 Stable Diffusion 上執行 DDPO 的發現。
DDPO 的優勢
對於如何嘗試用 RL 微調擴散模型這個問題,DDPO 並非唯一可行的答案。
在深入探討之前,理解不同 RL 方案優劣時有兩個關鍵點需要記住:
- 計算效率是關鍵。資料分佈越複雜,計算成本就越高。
- 近似是好的,但由於近似並非真實情況,相關誤差會累積。
在 DDPO 之前,獎勵加權迴歸 (RWR) 是一種已確立的、使用強化學習微調擴散模型的方法。RWR 複用擴散模型的去噪損失函式,並使用從模型本身取樣的訓練資料,以及每個樣本的損失權重,該權重取決於最終樣本的關聯獎勵。該演算法忽略了中間的去噪步驟/樣本。雖然這能行得通,但有兩點需要注意:
- 透過加權關聯損失(這是一個最大似然目標)進行最佳化是一種近似最佳化。
- 關聯損失並非精確的最大似然目標,而是從一個重新加權的變分界推匯出的近似值。
這兩層近似對效能和處理複雜目標的能力都有顯著影響。
DDPO 以此方法為起點。DDPO 不像 RWR 那樣只關注最終樣本,將去噪步驟視為單一步驟,而是將整個去噪過程構建為一個多步馬爾可夫決策過程 (MDP),其中獎勵在最後才收到。這種形式化,加上使用固定的取樣器,為智慧體策略鋪平了道路,使其成為一個各向同性的高斯分佈,而不是任意複雜的分佈。因此,DDPO 不使用最終樣本的近似似然(這是 RWR 的路徑),而是使用每個去噪步驟的精確似然,這非常容易計算 ().
如果你有興趣瞭解更多關於 DDPO 的細節,我們鼓勵你檢視原始論文和相關的部落格文章。
DDPO 演算法簡介
鑑於用於模擬去噪過程順序性的 MDP 框架以及隨之而來的其他考慮,解決最佳化問題的首選工具是策略梯度方法。具體來說,是近端策略最佳化 (PPO)。整個 DDPO 演算法與近端策略最佳化 (PPO) 基本相同,但其中一個高度定製化的部分是 PPO 的軌跡收集部分。
這裡有一個圖表來總結流程
DDPO 和 RLHF:融合以增強美學
RLHF 的一般訓練流程大致可以分為以下幾個步驟:
- 對一個“基礎”模型進行有監督的微調,使其學習新資料的分佈。
- 收集偏好資料並用其訓練一個獎勵模型。
- 使用獎勵模型作為訊號,透過強化學習對模型進行微調。
需要注意的是,在 RLHF 的背景下,偏好資料是捕獲人類反饋的主要來源。
當我們將 DDPO 加入進來時,工作流程會變成如下形式:
- 從一個預訓練的擴散模型開始。
- 收集偏好資料並用其訓練一個獎勵模型。
- 使用獎勵模型作為訊號,透過 DDPO 對模型進行微調。
請注意,一般 RLHF 工作流程中的第 3 步在後者的步驟列表中缺失了,這是因為經驗證明(正如你將親眼看到的那樣)這一步並非必要。
為了讓擴散模型輸出更符合人類審美觀念的影像,我們遵循以下步驟:
- 從一個預訓練的 Stable Diffusion (SD) 模型開始。
- 在美學視覺分析 (AVA) 資料集上訓練一個凍結的 CLIP 模型,該模型帶有一個可訓練的迴歸頭,用於預測人們對輸入影像的平均喜好程度。
- 使用美學預測模型作為獎勵訊號,透過 DDPO 對 SD 模型進行微調。
我們在接下來的章節中將牢記這些步驟,實際執行它們,具體描述如下。
使用 DDPO 訓練 Stable Diffusion
設定
首先,在硬體方面,對於這個 DDPO 實現,至少需要一個 A100 NVIDIA GPU 才能成功訓練。任何低於此 GPU 型別的裝置很快就會遇到記憶體不足的問題。
使用 pip 安裝 `trl` 庫
pip install trl[diffusers]
這應該會安裝主庫。以下依賴項用於跟蹤和影像日誌記錄。安裝 `wandb` 後,請務必登入以將結果儲存到個人賬戶。
pip install wandb torchvision
注意:你也可以選擇使用 `tensorboard` 而不是 `wandb`,為此你需要透過 `pip` 安裝 `tensorboard` 包。
詳細步驟
`trl` 庫中負責 DDPO 訓練的主要類是 `DDPOTrainer` 和 `DDPOConfig`。有關 `DDPOTrainer` 和 `DDPOConfig` 的更多通用資訊,請參閱文件。在 `trl` 倉庫中有一個示例訓練指令碼。它將這兩個類與所需輸入的預設實現和預設引數結合使用,以微調來自 `RunwayML` 的預設預訓練 Stable Diffusion 模型。
此示例指令碼使用 `wandb` 進行日誌記錄,並使用一個美學獎勵模型,其權重從一個公開的 HuggingFace 倉庫中讀取(因此,收集資料和訓練美學獎勵模型的工作已經為你完成)。預設使用的提示資料集是一個動物名稱列表。
使用者只需提供一個命令列標誌引數即可開始執行。此外,使用者需要有一個 Hugging Face 使用者訪問令牌,該令牌將在微調後用於將模型上傳到 Hugging Face Hub。
以下 bash 命令可以啟動執行
python ddpo.py --hf_user_access_token <token>
下表包含了與積極結果直接相關的關鍵超引數。
引數 | 描述 | 單 GPU 訓練的推薦值(截至目前) |
---|---|---|
num_epochs |
訓練的輪數 | 200 |
train_batch_size |
用於訓練的批次大小 | 3 |
sample_batch_size |
用於取樣的批次大小 | 6 |
gradient_accumulation_steps |
要使用的基於加速器的梯度累積步數 | 1 |
sample_num_steps |
取樣的步數 | 50 |
sample_num_batches_per_epoch |
每輪取樣的批次數 | 4 |
per_prompt_stat_tracking |
是否按提示跟蹤統計資料。如果為 false,將使用整個批次的均值和標準差計算優勢,而不是按提示跟蹤 | True |
per_prompt_stat_tracking_buffer_size |
用於按提示跟蹤統計資料的緩衝區大小 | 32 |
mixed_precision |
混合精度訓練 | True |
train_learning_rate |
學習率 | 3e-4 |
提供的指令碼僅僅是一個起點。請隨意調整超引數,甚至徹底修改指令碼以適應不同的目標函式。例如,可以整合一個衡量 JPEG 可壓縮性的函式,或一個使用多模態模型評估視覺-文字對齊的函式,以及其他可能性。
經驗教訓
- 儘管訓練提示詞數量極少,但結果似乎在各種提示詞上都能很好地泛化。對於獎勵美學的目標函式,這一點已得到充分驗證。
- 嘗試透過增加訓練提示詞數量和改變提示詞來明確泛化,至少對於美學目標函式而言,似乎會減慢收斂速度,而學到的泛化行為幾乎察覺不到(如果存在的話)。
- 雖然 LoRA 是推薦的,並且經過多次測試,但非 LoRA 方案也值得考慮,原因之一是根據經驗證據,非 LoRA 似乎能生成比 LoRA 更復雜的影像。然而,為穩定的非 LoRA 執行找到合適的超引數要更具挑戰性。
- 對於非 LoRA 配置引數的建議是:將學習率設定得相對較低,大約 `1e-5` 應該可以,並將 `mixed_precision` 設定為 `None`。
結果
以下是對於提示詞 `bear`(熊)、`heaven`(天堂)和 `dune`(沙丘),微調前(左)和微調後(右)的輸出(每行對應一個提示詞的輸出)。
侷限性
- 目前 `trl` 的 DDPOTrainer 僅限於微調原版 SD 模型;
- 在我們的實驗中,我們主要關注 LoRA,它效果很好。我們也進行了一些全量訓練的實驗,這可以帶來更好的質量,但找到合適的超引數更具挑戰性。
結論
像 Stable Diffusion 這樣的擴散模型,在使用 DDPO 進行微調後,可以在人類感知或任何其他可被恰當概念化為目標函式的指標上,顯著提升生成影像的質量。
DDPO 的計算效率及其在不依賴近似的情況下進行最佳化的能力,尤其是在與早期實現相同微調擴散模型目標的方法相比時,使其成為微調像 Stable Diffusion 這樣的擴散模型的合適選擇。
`trl` 庫的 `DDPOTrainer` 實現了用於微調 SD 模型的 DDPO。
我們的實驗結果強調了 DDPO 在廣泛提示詞上泛化的能力,儘管透過變化提示詞進行顯式泛化的嘗試結果好壞參半。為非 LoRA 設定找到合適的超引數的困難也成為一個重要的經驗教訓。
DDPO 是一種很有前途的技術,可以將擴散模型與任何獎勵函式對齊,我們希望透過在 TRL 中的釋出,能讓它更容易地被社群所用!
致謝
感謝 Chunte Lee 為這篇博文製作縮圖。