從“RL for LLM”視角重新理解KL近似:關於“近似KL散度”的筆記

社群文章 釋出於2025年8月11日

PPO和GRPO中使用的KL散度估計方法有什麼區別?

John Schulman的部落格文章“近似KL散度”討論瞭如何透過取樣(蒙特卡羅)近似KL散度,並介紹了三種估計器(\(k_1\)、k2k_2k3k_3)及其偏差-方差行為。但原始文章是在一般機率分佈的背景下提出的,並未涉及大型語言模型(LLM)的強化學習訓練設定。本文記錄了我在閱讀時遇到的問題、將內容對映到RL for LLM後形成的思考,以及一些我認為原始解釋可以進一步闡述的地方。

“近似KL散度”說了什麼(用我自己的話)

在本節中,我假定讀者尚未閱讀原始文章,因此我們快速瀏覽最重要的部分。簡單來說,這篇文章是關於當我們無法直接計算KL散度時,如何構建合理的蒙特卡羅式估計器。

KL(q,p)=xq(x)logq(x)p(x)=Exq ⁣[logq(x)p(x)]. \mathrm{KL}(q, p) = \sum_x q(x)\,\log\frac{q(x)}{p(x)} = \mathbb{E}_{x\sim q}\!\left[\log\frac{q(x)}{p(x)}\right].

如公式所示:當估計兩個(複雜)分佈之間的KL散度時,人們常用一種編碼技巧:即透過從qq中抽樣,用log ⁣(q(x)p(x))\log\!\big(\frac{q(x)}{p(x)}\big)的樣本均值來近似KL(而不是試圖精確評估完整的期望)。文章接著指出另一種方法:使用12(logr)2\tfrac{1}{2}(\log r)^2的樣本均值來替代更“標準”的logr\log r形式,其中r=q(x)p(x)r=\frac{q(x)}{p(x)}。本文解釋了為什麼這個表示式可以成為KL的良好(儘管有偏)估計器,以及如何在保持低方差的同時使其無偏。

我們計算KL的方式取決於我們如何訪問ppqq。這裡我們假設我們可以評估任何xxp(x)p(x)q(x)q(x)(機率或密度),但我們**無法**對xx進行解析求和/積分。我們為什麼不能進行解析求和/積分呢?可能是因為精確計算在計算或記憶體方面過於昂貴,可能沒有閉合形式,或者我們為了簡化程式碼,只儲存對數機率而不是完整的分佈,尤其是在KL僅用於診斷時(強化學習中常出現這種情況)。近似求和或積分最常見的策略是**蒙特卡羅**。給定從qq中抽取的樣本x1,x2,,xnqx_1, x_2, \dots, x_n \sim q,我們如何構建一個好的估計器?

一個好的估計器應該**無偏**(平均值正確)且**方差低**。我們知道一個無偏估計器

k1=logq(x)p(x). k_1 = \log\frac{q(x)}{p(x)}.

但它的方差很高:根據定義,KL是一個非負量,然而對於上述估計器,大約“一半”的樣本值可能是負的(如果我們不對ppqq做任何先驗假設),這使得平均值波動很大,因此方差很高。為了符號方便,設r=q(x)p(x)r = \frac{q(x)}{p(x)}。那麼原始的KL可以寫成

KL[q,p]  =  Exq[logr]. \mathrm{KL}[q, p] \;=\; \mathbb{E}_{x\sim q}\,[\log r].

為了減少方差,我們可以設計一個替代估計器:k2=12(logr)2. k_2 = \frac{1}{2}(\log r)^2. 它的方差較低,但有偏。直觀上,k2k_2感覺更好,因為每個樣本都給出了ppqq之間的非負“距離”,因此它保持正值。經驗上,k2k_2的方差確實比k1k_1低得多,而且偏差可以很小。至於為什麼k2k_2相比k1k_1能夠大幅降低方差,原始文章使用了f-散度檢視給出了分析解釋,這裡我不再贅述。

現在,我們能否得到一個既**無偏**又**低方差**的估計器呢?一個通用的技巧是使用**控制變數**:從無偏的k1k_1開始,並新增一個期望值為零且與它負相關的量以降低方差。這裡一個非常方便的零均值量是r1r-1。因此,對於任意λ\lambdak  =  logr+λ(r1) k \;=\; -\log r + \lambda\,(r-1) 仍然是一個無偏的KL估計器。理論上,我們可以在λ\lambda上最小化方差,但其閉合形式取決於ppqq,不容易得到。然而請注意,由於log(x)\log(x)是凹函式,log(x)    x1, \log(x) \;\le\; x-1, 所以如果我們選擇λ=1\lambda=1,該表示式保證非負。在這裡,r1r-1logr\log rr=1r=1處的切線。因此,當λ=1\lambda=1時,我們實際上測量的是log(x)\log(x)與其切線之間的垂直距離。這導致了估計器k3  =  (r1)    logr, k_3 \;=\; (r - 1) \;-\; \log r, 它總是非負的。而k3k_3正是實際中GRPO與PPO在KL估計方式上有所不同的地方(PPO使用k1k_1)。

從“RL for LLM”的角度討論KL估計

在強化學習(例如PPO、GRPO等)中,我們通常會在損失函式中加入一個KL散度項,以防止新策略偏離舊策略太遠。這裡,qq是舊策略分佈(πold\pi_{\text{old}}),pp是新策略分佈(πnew\pi_{\text{new}}),而xx是一個完整的動作樣本(在LLM中,這表示一個token或一個token序列)。我們通常用ss表示狀態(在LLM中,這是提示或上下文),xx是在該上下文中生成的特定token。當我們計算KL時,我們實際上是在**給定狀態下的動作分佈**上計算KL,然後對狀態進行平均:

KL[p,q]=Es[xp(xs)logp(xs)]. \mathrm{KL}[p, q] = \mathbb{E}_{s} \left[ \sum_x p(x|s) \log \frac{p(x|s)}{q(x|s)} \right].

在取樣時,我們通常會固定一個提示(狀態),然後為該提示估計此KL散度。

那麼**為什麼我們不能直接精確計算KL散度,而非要估計它呢?**原因與原始部落格文章中列出的完全相同;在LLM的強化學習中,主要癥結在於**原因1**:*動作空間(token空間)太大,無法對所有可能的xx進行求和/積分*。例如,如果一個分詞器有50,000個詞彙條目,即使計算單個token的KL散度也意味著對50,000個動作求和;而在強化學習中,我們通常進行多步(序列)生成,因此空間呈指數級增長,這完全不切實際。還有一個實用原因:在訓練過程中,我們通常不儲存完整的分佈(所有token的機率);我們只保留沿軌跡實際生成的token的對數機率,以節省GPU記憶體和I/O。因此,我們必須使用**蒙特卡羅取樣**:從某個分佈(通常是qq,即舊策略)中抽取xx,並使用這些樣本來近似KL散度。這就把我們直接帶入了部落格文章所討論的領域。

在該文章中,我們一直談論的**估計器**實際上只是樣本的一個函式:它接收某個取樣xxp(x)p(x)q(x)q(x)(或它們的比率r=q(x)p(x)r = \frac{q(x)}{p(x)}),並輸出一個數字。然後,我們對這些數字在樣本上求平均,以近似KL散度。例如:

  • k1(x)=logrk_1(x) = -\log r
  • k2(x)=(logr)2k_2(x) = \frac12 (\log r)^2
  • k3(x)=(r1)logrk_3(x) = (r - 1) - \log r

這些kik_i只是不同的KL估計器公式。它們都透過**對樣本求平均**來近似KL散度,但在偏差和方差上有所不同。一旦我們選擇了一個估計器,我們實際上就承諾使用一個特定的公式來近似KL散度。這個過程看起來像這樣:

  1. 取樣
    從舊策略qq中取樣一批token(或序列)x1,x2,,xNx_1, x_2, \dots, x_N
  2. 計算對數機率
    對於每個樣本,計算新舊策略下的對數機率

logp(xi), logq(xi) \log p(x_i),\ \log q(x_i)

並得到ri=q(xi)p(xi)r_i = \frac{q(x_i)}{p(x_i)}logri\log r_i。3. **代入估計器公式**
例如,如果我們選擇k3k_3

k3(xi)=(ri1)logri k_3(x_i) = (r_i - 1) - \log r_i

  1. 平均分

KL^1Ni=1Nk3(xi) \widehat{\mathrm{KL}} \approx \frac1N \sum_{i=1}^N k_3(x_i)

這是近似的 KL 值,代表了真實的 KL。

如果我們將這與離散機率分佈(LLM 單令牌步長)的真實 KL 計算(無估計)進行比較:我們需要遍歷每個可能的令牌 xxKL(pq)=xp(x)logp(x)q(x) \mathrm{KL}(p\|q) = \sum_x p(x) \log \frac{p(x)}{q(x)} 您可以立即看到,使用估算器,計算量比進行完整求和小得多,尤其是在高維動作空間中。

談論不同 KL 估計器的方差

重要提示:我們這裡討論的“方差”是估計器在樣本上輸出值的方差: Varxq[k(x)] \mathrm{Var}_{x \sim q}[k(x)] 也就是說, k(x)k(x) 在樣本空間中的波動程度。一個**無偏**估計器意味著在無限多的樣本下,其均值等於真實 KL。但高方差估計器意味著即使均值正確(無偏),在少量樣本下,平均值也可能偏差很大。在 LLM 的強化學習中,KL 項通常是損失中的正則化因子(例如, βKL\beta \cdot \mathrm{KL})。如果 KL 估計器的方差很大,會使損失變得嘈雜,進而使梯度嘈雜並導致訓練不穩定。

在原帖中,為了讓讀者直觀理解為什麼 k1k_1 不是低方差的,作者寫道:

然而,它(k1k_1)具有高方差,因為它對一半的樣本是負的,而 KL 始終是正的。

作者指出,儘管 k1k_1 是無偏的,但如果沒有對 ppqq 的先驗約束,一半的樣本會一個比另一個大,所以一半的 k1k_1 值是正的,一半是負的。到目前為止,我都同意。但隨後作者說:因為 KL 總是大於 0(一個基本不等式),所以 k1k_1 因此必須具有高方差。而在這裡,我認為因果關係並不成立:你不能用期望的符號來決定單個樣本的符號。一個簡單的反例:在計算期望時, p(x)logp(x)q(x)p(x) \log \frac{p(x)}{q(x)} 也時而為正,時而為負;這個事實本身並不能說明方差。實際上,單樣本的**對數比率**(無論是 logq(x)p(x)\log \frac{q(x)}{p(x)}logp(x)q(x)\log \frac{p(x)}{q(x)})都可以是正的或負的,就像 k1k_1 一樣,所以**單獨的符號翻轉並不是高方差的唯一原因**。

根據 KL 定義: KL(qp)=Exq[logq(x)p(x)] \mathrm{KL}(q \| p) = \mathbb{E}_{x\sim q}\left[ \log \frac{q(x)}{p(x)} \right] 期望值**保證非負**,但被積函式 logq(x)p(x)\log\frac{q(x)}{p(x)} 可以對單個樣本是正的或負的。而 k1k_1 正是這個被積函式: k1(x)=logq(x)p(x) k_1(x) = \log \frac{q(x)}{p(x)} 所以每個樣本值確實可以是正的或負的,與 KL 定義中的被積函式相同。

那麼為什麼 k1k_1 會有高方差?

這不僅僅是“符號翻轉”。真正的原因是 k1k_1 的值分佈通常很寬(重尾)。例如,如果 p(x)p(x) 對於某些樣本來說很小,那麼 logqp\log\frac{q}{p} 可能會非常大(正或負)。這些極端值主導有限樣本平均值,推高了方差。換句話說,它是**極端值 + 正負抵消**的組合:抵消意味著你需要更多的樣本才能收斂到真實平均值,而極端值會使樣本方差本身更大。因此,部落格中“一半為負”的評論更多的是一種直覺提示,而不是完整的解釋。

從這個角度來看,如果我們看其他估計器 k2k_2k3k_3,我們發現: k2=12(logr)2k_2 = \frac12 (\log r)^2 總是正的,所以沒有抵消,但這引入了偏差;平方也平滑了幅度,降低了方差。k3k_3 使用控制變數來消除部分波動源,在保持無偏性的同時降低方差(詳細資訊見下文)。

在 PPO/GRPO 中,如果您使用 k1k_1 並且批次很小或分佈相距很遠,KL 估計值將跳來跳去(因為少數極端樣本會使平均值劇烈波動)。這使得 KL 懲罰係數不穩定:它可能突然變得過強或過弱。切換到低方差估計器( k2k_2k3k_3 )使每個樣本的 KL 貢獻更穩定,更不容易被少數極端樣本主導。

為什麼 k3k_3 既能無偏又能低方差?

乍一看, k3k_3 總是正的,所以你可能會認為它的平均值必須大於 k1k_1 的平均值。
但請記住: k3k_3 是透過**控制變數**從 k1k_1 匯出的。部落格的推理如下: k~(x)=k1(x)+λh(x) \tilde{k}(x) = k_1(x) + \lambda \cdot h(x) 其中 h(x)=r1h(x) = r - 1,並且在 xqx\sim q 下,其期望值為: Exq[h(x)]=Eq[p(x)q(x)1]=xp(x)1=11=0. \mathbb{E}_{x\sim q}[h(x)] = \mathbb{E}_q\left[\frac{p(x)}{q(x)} - 1\right] = \sum_x p(x) - 1 = 1 - 1 = 0. 因此,新增任何 h(x)h(x) 的倍數都不會改變期望值。當 λ=1\lambda = 1 時: k~(x)=logr+(r1)=(r1)logr=k3(x). \tilde{k}(x) = -\log r + (r - 1) = (r - 1) - \log r = k_3(x). 這解釋了為什麼 k3k_3 的期望值等於 k1k_1 的期望值,並等於 KL,使其成為一個無偏估計器。

k3k_3k1k_1 具有更低方差的原因是: k1k_1 只有 logr-\log r,其值可能劇烈波動(既有正有負,偶爾出現巨大值)。但是 r1r - 1logr-\log r 在數值上高度相關(一個增長,另一個也增長/收縮),並且這種相關性是**負的**。新增 (r1)(r - 1) 就像引入一個**負相關項來抵消波動**。抵消後, k3k_3 中剩下的值範圍更緊密,始終為正,因此樣本方差更低。

社群

註冊登入 發表評論

© . This site is unofficial and not affiliated with Hugging Face, Inc.