TRL 文件

多介面卡強化學習 (MARL) - 一個基礎模型搞定一切

Hugging Face's logo
加入 Hugging Face 社群

並獲得增強的文件體驗

開始使用

多介面卡強化學習 (MARL) - 一個基礎模型搞定一切

這裡我們提出一種方法,它使用單個基礎模型來完成整個 PPO 演算法——包括檢索參考 logits、計算活動 logits 以及計算獎勵。此功能尚處於實驗階段,因為我們尚未測試該方法的收斂性。我們鼓勵社群成員在遇到潛在問題時告知我們。

環境要求

您只需安裝 peft,如果想使用 8 位基礎模型以實現更高效的記憶體微調,還可以選擇安裝 bitsandbytes

概要

您需要分三個階段來實施此方法,我們總結如下:

1- 在目標領域(例如 IMDB 資料集)上訓練一個基礎模型——這是監督式微調(SFT)階段——可以利用 TRL 中的 SFTTrainer。 2- 使用 peft 訓練一個獎勵模型。這是為了在強化學習最佳化過程(下面的步驟 3)中複用介面卡所必需的。我們在這個例子中展示瞭如何利用 TRL 中的 RewardTrainer。 3- 使用 PPO 和獎勵介面卡在基礎模型上微調新的介面卡。(“零抽象強化學習”)

請確保在第 2 階段和第 3 階段使用相同的模型(即相同的架構和相同的權重)。

快速入門

假設您已經使用 RewardTrainerllama-7b 模型上訓練了獎勵介面卡,並將其權重推送到了 Hub 上的 trl-lib/llama-7b-hh-rm-adapter。在進行 PPO 訓練時,在將模型傳遞給 PPOTrainer 之前,請按如下方式建立您的模型:

model_name = "huggyllama/llama-7b"
rm_adapter_id = "trl-lib/llama-7b-hh-rm-adapter"

# PPO adapter
lora_config = LoraConfig(
    r=16,
    lora_alpha=32,
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
)

model = AutoModelForCausalLMWithValueHead.from_pretrained(
    model_name,
    peft_config=lora_config,
    reward_adapter=rm_adapter_id,
)

...
trainer = PPOTrainer(
    model=model,
    ...
)

...

然後在您的 PPO 訓練迴圈中,透過訪問 PPOTrainermodel 屬性來呼叫 compute_reward_score 方法。

rewards = trainer.model.compute_reward_score(**inputs)

高階用法

控制介面卡名稱

如果您熟悉 peft 庫,您會知道可以在同一個模型中使用多個介面卡。您可以做的是,在同一個基礎模型上訓練多個介面卡,以針對不同的策略進行微調。在這種情況下,您希望能夠在檢索到獎勵後,控制要重新啟用的介面卡名稱。為此,在呼叫 compute_reward_score 時,只需將相應的 adapter_name 傳遞給 ppo_adapter_name 引數即可。

adapter_name_policy_1 = "policy_1"
rewards = trainer.model.compute_reward_score(**inputs, ppo_adapter_name=adapter_name_policy_1)
...

使用 4 位和 8 位基礎模型

為了實現更高效的記憶體微調,您可以將基礎模型載入為 8 位或 4 位,同時保持介面卡為預設精度(float32)。只需將適當的引數(即 load_in_8bit=Trueload_in_4bit=True)傳遞給 AutoModelForCausalLMWithValueHead.from_pretrained 即可,如下所示(假設您已安裝 bitsandbytes):

model_name = "llama-7b"
rm_adapter_id = "trl-lib/llama-7b-hh-rm-adapter"

# PPO adapter
lora_config = LoraConfig(
    r=16,
    lora_alpha=32,
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
)

model = AutoModelForCausalLMWithValueHead.from_pretrained(
    model_name,
    peft_config=lora_config,
    reward_adapter=rm_adapter_id,
    load_in_8bit=True,
)

...
trainer = PPOTrainer(
    model=model,
    ...
)
...
< > 在 GitHub 上更新

© . This site is unofficial and not affiliated with Hugging Face, Inc.