修復梯度累積

釋出於 2024 年 10 月 16 日
在 GitHub 上更新

我們的朋友 Unsloth 昨天分享了一個關於梯度累積的問題,該問題影響了 transformers Trainer。最初的報告來自 @bnjmn_marie(向他致敬!)。

梯度累積應該在數學上等同於全批次訓練;然而,在啟用和停用該設定的訓練執行之間,損失不匹配。

問題根源何在?

在每個模型的建模程式碼中,transformers 提供了一個“預設”損失函式,這是模型任務最常用的損失函式。它由建模類應 SONY 的任務決定:問答、token 分類、因果語言模型、掩碼語言模型。

這是預設的損失函式,它不應該被定製:只有當 labelsinput_ids 作為輸入傳遞給模型時,才會計算它,這樣使用者就不必計算損失。預設損失函式很有用,但設計上是有限制的:對於任何不同的操作,我們期望標籤不直接傳遞,並且使用者從模型中獲取 logits 並使用它們在模型外部計算損失。

然而,transformers Trainer 以及許多 Trainer,都嚴重依賴這些方法,因為它提供了簡單性:這是一把雙刃劍。提供一個隨著用例不同而變得不同的簡單 API 並不是一個經過深思熟慮的 API,我們自己也感到驚訝。

準確地說,對於像因果語言模型訓練這樣的 token 級任務的梯度累積,正確的損失應該透過梯度累積步驟中所有批次的總損失除以這些批次中所有非填充 token 的總數來計算。這與每批次損失值的平均值不同。修復方法很簡單,請參閱以下內容

def ForCausalLMLoss(logits, labels, vocab_size, **kwargs):
    # Upcast to float if we need to compute the loss to avoid potential precision issues
    logits = logits.float()
    # Shift so that tokens < n predict n
    shift_logits = logits[..., :-1, :].contiguous()
    shift_labels = labels[..., 1:].contiguous()

    # Flatten the tokens
    shift_logits = shift_logits.view(-1, vocab_size)
    shift_labels = shift_labels.view(-1)
    # Enable model parallelism
    shift_labels = shift_labels.to(shift_logits.device)

    num_items = kwargs.pop("num_items", None)
+        loss = nn.functional.cross_entropy(shift_logits, shift_labels, ignore_index=-100, reduction="sum")
+        loss = loss / num_items
-        loss = nn.functional.cross_entropy(shift_logits, shift_labels, ignore_index=-100)
    return loss

我們如何修復它

為了解決這個問題,我們正在以兩種方式改變模型的訓練方式

  • 如果使用者使用“預設”損失函式,我們將在使用梯度累積時自動考慮所需的更改,以確保報告和使用了正確的損失,從而解決核心問題。
  • 為了確保將來計算損失的任何問題都不會阻礙使用者,我們將公開一個 API,讓使用者可以直接將自己的損失函式傳遞給 Trainer,這樣他們就可以輕鬆使用自己的修復,直到我們內部修復任何問題併發布新的 transformers 版本。

所有繼承自 PreTrainedModel 的模型現在都有一個 loss_function 屬性,它由以下任一方式確定:

  • config.loss_type:這是為了確保任何人都可以使用他們的自定義損失。您可以透過修改 LOSS_MAPPING 來實現這一點
def my_super_loss(logits, labels):
    return loss = nn.functional.cross_entropy(logits, labels, ignore_index=-100)

LOSS_MAPPING["my_loss_type"] = my_super_loss

我們正在努力在本次 PR 中為最流行的模型提供第一個更改:https://github.com/huggingface/transformers/pull/34191#pullrequestreview-2372725010。在此之後,將呼籲貢獻者幫助將此更改傳播到其餘模型,以便大多數模型在下一次釋出時得到支援。

我們還在積極努力在此 PR 中提供第二個更改:https://github.com/huggingface/transformers/pull/34198,這將允許使用者使用自己的損失函式,並利用每批次看到的樣本數量來幫助計算其損失(並且隨著更多模型得到先前更改的支援,將在梯度累積期間執行正確的損失計算)

到明天,您應該會看到 Trainer 在梯度累積方面表現正常。屆時請從 main 安裝以享受此修復

pip install git+https://github.com/huggingface/transformers

通常,我們對提交到我們的問題跟蹤器中的錯誤報告非常及時響應:https://github.com/huggingface/transformers/issues

這個問題已經在 Transformers 中存在了一段時間,因為它主要是一個應該由終端使用者更新的預設設定;然而,當預設設定變得不直觀時,它們就註定要被更改。在這種情況下,我們在不到 24 小時內更新了程式碼併發布了修復,這是我們針對 Transformers 中的此類問題所追求的目標。請您如果有任何問題,隨時提交;這是我們改進 Transformers 並使其更好地適應您的不同用例的唯一途徑。

Transformers 團隊 🤗

社群

註冊登入評論

© . This site is unofficial and not affiliated with Hugging Face, Inc.