TRL 文件
多介面卡強化學習 (MARL) - 一個基礎模型搞定一切
並獲得增強的文件體驗
開始使用
多介面卡強化學習 (MARL) - 一個基礎模型搞定一切
這裡我們提出一種方法,它使用單個基礎模型來完成整個 PPO 演算法——包括檢索參考 logits、計算活動 logits 以及計算獎勵。此功能尚處於實驗階段,因為我們尚未測試該方法的收斂性。我們鼓勵社群成員在遇到潛在問題時告知我們。
環境要求
您只需安裝 peft
,如果想使用 8 位基礎模型以實現更高效的記憶體微調,還可以選擇安裝 bitsandbytes
。
概要
您需要分三個階段來實施此方法,我們總結如下:
1- 在目標領域(例如 IMDB 資料集)上訓練一個基礎模型——這是監督式微調(SFT)階段——可以利用 TRL 中的 SFTTrainer
。 2- 使用 peft
訓練一個獎勵模型。這是為了在強化學習最佳化過程(下面的步驟 3)中複用介面卡所必需的。我們在這個例子中展示瞭如何利用 TRL 中的 RewardTrainer
。 3- 使用 PPO 和獎勵介面卡在基礎模型上微調新的介面卡。(“零抽象強化學習”)
請確保在第 2 階段和第 3 階段使用相同的模型(即相同的架構和相同的權重)。
快速入門
假設您已經使用 RewardTrainer
在 llama-7b
模型上訓練了獎勵介面卡,並將其權重推送到了 Hub 上的 trl-lib/llama-7b-hh-rm-adapter
。在進行 PPO 訓練時,在將模型傳遞給 PPOTrainer
之前,請按如下方式建立您的模型:
model_name = "huggyllama/llama-7b"
rm_adapter_id = "trl-lib/llama-7b-hh-rm-adapter"
# PPO adapter
lora_config = LoraConfig(
r=16,
lora_alpha=32,
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",
)
model = AutoModelForCausalLMWithValueHead.from_pretrained(
model_name,
peft_config=lora_config,
reward_adapter=rm_adapter_id,
)
...
trainer = PPOTrainer(
model=model,
...
)
...
然後在您的 PPO 訓練迴圈中,透過訪問 PPOTrainer
的 model
屬性來呼叫 compute_reward_score
方法。
rewards = trainer.model.compute_reward_score(**inputs)
高階用法
控制介面卡名稱
如果您熟悉 peft
庫,您會知道可以在同一個模型中使用多個介面卡。您可以做的是,在同一個基礎模型上訓練多個介面卡,以針對不同的策略進行微調。在這種情況下,您希望能夠在檢索到獎勵後,控制要重新啟用的介面卡名稱。為此,在呼叫 compute_reward_score
時,只需將相應的 adapter_name
傳遞給 ppo_adapter_name
引數即可。
adapter_name_policy_1 = "policy_1"
rewards = trainer.model.compute_reward_score(**inputs, ppo_adapter_name=adapter_name_policy_1)
...
使用 4 位和 8 位基礎模型
為了實現更高效的記憶體微調,您可以將基礎模型載入為 8 位或 4 位,同時保持介面卡為預設精度(float32)。只需將適當的引數(即 load_in_8bit=True
或 load_in_4bit=True
)傳遞給 AutoModelForCausalLMWithValueHead.from_pretrained
即可,如下所示(假設您已安裝 bitsandbytes
):
model_name = "llama-7b"
rm_adapter_id = "trl-lib/llama-7b-hh-rm-adapter"
# PPO adapter
lora_config = LoraConfig(
r=16,
lora_alpha=32,
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",
)
model = AutoModelForCausalLMWithValueHead.from_pretrained(
model_name,
peft_config=lora_config,
reward_adapter=rm_adapter_id,
load_in_8bit=True,
)
...
trainer = PPOTrainer(
model=model,
...
)
...