從“RL for LLM”視角重新理解KL近似:關於“近似KL散度”的筆記
PPO和GRPO中使用的KL散度估計方法有什麼區別?
John Schulman的部落格文章“近似KL散度”討論瞭如何透過取樣(蒙特卡羅)近似KL散度,並介紹了三種估計器(\(k_1\)、、)及其偏差-方差行為。但原始文章是在一般機率分佈的背景下提出的,並未涉及大型語言模型(LLM)的強化學習訓練設定。本文記錄了我在閱讀時遇到的問題、將內容對映到RL for LLM後形成的思考,以及一些我認為原始解釋可以進一步闡述的地方。
“近似KL散度”說了什麼(用我自己的話)
在本節中,我假定讀者尚未閱讀原始文章,因此我們快速瀏覽最重要的部分。簡單來說,這篇文章是關於當我們無法直接計算KL散度時,如何構建合理的蒙特卡羅式估計器。
如公式所示:當估計兩個(複雜)分佈之間的KL散度時,人們常用一種編碼技巧:即透過從中抽樣,用的樣本均值來近似KL(而不是試圖精確評估完整的期望)。文章接著指出另一種方法:使用的樣本均值來替代更“標準”的形式,其中。本文解釋了為什麼這個表示式可以成為KL的良好(儘管有偏)估計器,以及如何在保持低方差的同時使其無偏。
我們計算KL的方式取決於我們如何訪問和。這裡我們假設我們可以評估任何的和(機率或密度),但我們**無法**對進行解析求和/積分。我們為什麼不能進行解析求和/積分呢?可能是因為精確計算在計算或記憶體方面過於昂貴,可能沒有閉合形式,或者我們為了簡化程式碼,只儲存對數機率而不是完整的分佈,尤其是在KL僅用於診斷時(強化學習中常出現這種情況)。近似求和或積分最常見的策略是**蒙特卡羅**。給定從中抽取的樣本,我們如何構建一個好的估計器?
一個好的估計器應該**無偏**(平均值正確)且**方差低**。我們知道一個無偏估計器
但它的方差很高:根據定義,KL是一個非負量,然而對於上述估計器,大約“一半”的樣本值可能是負的(如果我們不對和做任何先驗假設),這使得平均值波動很大,因此方差很高。為了符號方便,設。那麼原始的KL可以寫成
為了減少方差,我們可以設計一個替代估計器:它的方差較低,但有偏。直觀上,感覺更好,因為每個樣本都給出了和之間的非負“距離”,因此它保持正值。經驗上,的方差確實比低得多,而且偏差可以很小。至於為什麼相比能夠大幅降低方差,原始文章使用了f-散度檢視給出了分析解釋,這裡我不再贅述。
現在,我們能否得到一個既**無偏**又**低方差**的估計器呢?一個通用的技巧是使用**控制變數**:從無偏的開始,並新增一個期望值為零且與它負相關的量以降低方差。這裡一個非常方便的零均值量是。因此,對於任意,仍然是一個無偏的KL估計器。理論上,我們可以在上最小化方差,但其閉合形式取決於和,不容易得到。然而請注意,由於是凹函式,所以如果我們選擇,該表示式保證非負。在這裡,是在處的切線。因此,當時,我們實際上測量的是與其切線之間的垂直距離。這導致了估計器它總是非負的。而正是實際中GRPO與PPO在KL估計方式上有所不同的地方(PPO使用)。
從“RL for LLM”的角度討論KL估計
在強化學習(例如PPO、GRPO等)中,我們通常會在損失函式中加入一個KL散度項,以防止新策略偏離舊策略太遠。這裡,是舊策略分佈(),是新策略分佈(),而是一個完整的動作樣本(在LLM中,這表示一個token或一個token序列)。我們通常用表示狀態(在LLM中,這是提示或上下文),是在該上下文中生成的特定token。當我們計算KL時,我們實際上是在**給定狀態下的動作分佈**上計算KL,然後對狀態進行平均:
在取樣時,我們通常會固定一個提示(狀態),然後為該提示估計此KL散度。
那麼**為什麼我們不能直接精確計算KL散度,而非要估計它呢?**原因與原始部落格文章中列出的完全相同;在LLM的強化學習中,主要癥結在於**原因1**:*動作空間(token空間)太大,無法對所有可能的進行求和/積分*。例如,如果一個分詞器有50,000個詞彙條目,即使計算單個token的KL散度也意味著對50,000個動作求和;而在強化學習中,我們通常進行多步(序列)生成,因此空間呈指數級增長,這完全不切實際。還有一個實用原因:在訓練過程中,我們通常不儲存完整的分佈(所有token的機率);我們只保留沿軌跡實際生成的token的對數機率,以節省GPU記憶體和I/O。因此,我們必須使用**蒙特卡羅取樣**:從某個分佈(通常是,即舊策略)中抽取,並使用這些樣本來近似KL散度。這就把我們直接帶入了部落格文章所討論的領域。
在該文章中,我們一直談論的**估計器**實際上只是樣本的一個函式:它接收某個取樣的和(或它們的比率),並輸出一個數字。然後,我們對這些數字在樣本上求平均,以近似KL散度。例如:
這些只是不同的KL估計器公式。它們都透過**對樣本求平均**來近似KL散度,但在偏差和方差上有所不同。一旦我們選擇了一個估計器,我們實際上就承諾使用一個特定的公式來近似KL散度。這個過程看起來像這樣:
- 取樣
從舊策略中取樣一批token(或序列)。 - 計算對數機率
對於每個樣本,計算新舊策略下的對數機率
並得到或。3. **代入估計器公式**
例如,如果我們選擇
- 平均分
這是近似的 KL 值,代表了真實的 KL。
如果我們將這與離散機率分佈(LLM 單令牌步長)的真實 KL 計算(無估計)進行比較:我們需要遍歷每個可能的令牌 : 您可以立即看到,使用估算器,計算量比進行完整求和小得多,尤其是在高維動作空間中。
談論不同 KL 估計器的方差
重要提示:我們這裡討論的“方差”是估計器在樣本上輸出值的方差: 也就是說, 在樣本空間中的波動程度。一個**無偏**估計器意味著在無限多的樣本下,其均值等於真實 KL。但高方差估計器意味著即使均值正確(無偏),在少量樣本下,平均值也可能偏差很大。在 LLM 的強化學習中,KL 項通常是損失中的正則化因子(例如, )。如果 KL 估計器的方差很大,會使損失變得嘈雜,進而使梯度嘈雜並導致訓練不穩定。
在原帖中,為了讓讀者直觀理解為什麼 不是低方差的,作者寫道:
然而,它()具有高方差,因為它對一半的樣本是負的,而 KL 始終是正的。
作者指出,儘管 是無偏的,但如果沒有對 和 的先驗約束,一半的樣本會一個比另一個大,所以一半的 值是正的,一半是負的。到目前為止,我都同意。但隨後作者說:因為 KL 總是大於 0(一個基本不等式),所以 因此必須具有高方差。而在這裡,我認為因果關係並不成立:你不能用期望的符號來決定單個樣本的符號。一個簡單的反例:在計算期望時, 也時而為正,時而為負;這個事實本身並不能說明方差。實際上,單樣本的**對數比率**(無論是 或 )都可以是正的或負的,就像 一樣,所以**單獨的符號翻轉並不是高方差的唯一原因**。
根據 KL 定義: 期望值**保證非負**,但被積函式 可以對單個樣本是正的或負的。而 正是這個被積函式: 所以每個樣本值確實可以是正的或負的,與 KL 定義中的被積函式相同。
那麼為什麼 會有高方差?
這不僅僅是“符號翻轉”。真正的原因是 的值分佈通常很寬(重尾)。例如,如果 對於某些樣本來說很小,那麼 可能會非常大(正或負)。這些極端值主導有限樣本平均值,推高了方差。換句話說,它是**極端值 + 正負抵消**的組合:抵消意味著你需要更多的樣本才能收斂到真實平均值,而極端值會使樣本方差本身更大。因此,部落格中“一半為負”的評論更多的是一種直覺提示,而不是完整的解釋。
從這個角度來看,如果我們看其他估計器 和 ,我們發現: 總是正的,所以沒有抵消,但這引入了偏差;平方也平滑了幅度,降低了方差。 使用控制變數來消除部分波動源,在保持無偏性的同時降低方差(詳細資訊見下文)。
在 PPO/GRPO 中,如果您使用 並且批次很小或分佈相距很遠,KL 估計值將跳來跳去(因為少數極端樣本會使平均值劇烈波動)。這使得 KL 懲罰係數不穩定:它可能突然變得過強或過弱。切換到低方差估計器( 或 )使每個樣本的 KL 貢獻更穩定,更不容易被少數極端樣本主導。
為什麼 既能無偏又能低方差?
乍一看, 總是正的,所以你可能會認為它的平均值必須大於 的平均值。
但請記住: 是透過**控制變數**從 匯出的。部落格的推理如下: 其中 ,並且在 下,其期望值為: 因此,新增任何 的倍數都不會改變期望值。當 時: 這解釋了為什麼 的期望值等於 的期望值,並等於 KL,使其成為一個無偏估計器。
比 具有更低方差的原因是: 只有 ,其值可能劇烈波動(既有正有負,偶爾出現巨大值)。但是 和 在數值上高度相關(一個增長,另一個也增長/收縮),並且這種相關性是**負的**。新增 就像引入一個**負相關項來抵消波動**。抵消後, 中剩下的值範圍更緊密,始終為正,因此樣本方差更低。