Bamba:推理高效的混合 Mamba2 模型 🐍

釋出於 2024 年 12 月 18 日
在 GitHub 上更新

Bamba

摘要

我們介紹 Bamba-9B,這是一個由 IBM、普林斯頓大學、卡內基梅隆大學和伊利諾伊大學厄巴納-香檳分校在完全開放的資料上訓練的推理高效混合 Mamba2 模型。在推理時,與 vLLM 中的標準 Transformer 相比,該模型的吞吐量提升了 2.5 倍,延遲降低了 2 倍。為促進社群實驗,該模型可立即在 transformersvLLMTRLllama.cpp 中使用。我們還發布了帶有狀態化資料載入器的微調、訓練和擴充套件預訓練方案,並邀請社群進一步改進該模型。讓我們一起克服 KV 快取瓶頸!

產出物 📦

  1. Hugging Face Bamba 合集
  2. 包含推理、訓練和微調指令碼的 GitHub 倉庫
  3. 資料載入器
  4. 量化
  5. 用於叢集監控的 Auto-pilot

動機 🌟

Transformer 模型在實際應用中越來越廣泛,但在推理過程中面臨記憶體頻寬瓶頸,尤其是在長上下文長度模型中進行逐個 Token 解碼時。低精度、層剪枝和壓縮等技術可以緩解此問題,但並未解決根本原因,即隨著上下文長度的增加,KV 快取所需的記憶體量不斷增長。新興架構如 MambaGriffinDeltaNet 透過使 KV 快取大小恆定來消除這一瓶頸。Mamba 架構最近在社群中獲得了極大的關注。例如,JambaSamba 將 Mamba 層與 Transformer 層交錯,探索由此產生的混合 Mamba 模型。Codestral Mamba,一個純 Mamba2 模型,在編碼任務上展示了最先進(SOTA)的結果,而 NVIDIA 的混合 Mamba2 模型在長上下文和傳統 LLM 基準測試中取得了有競爭力的效能。近期的創新,如 Falcon MambaFalcon 3 Mamba 在釋出時在 Hugging Face 排行榜上取得了 SOTA 排名。

我們介紹了 Bamba-9B,這是一個在 2.2T Token 上訓練的混合 Mamba2 模型,進一步驗證了這些新興架構。這項由 IBM、普林斯頓大學、卡內基梅隆大學和伊利諾伊大學厄巴納-香檳分校合作的專案提供了完整的訓練沿襲、模型檢查點和預訓練程式碼,以支援可復現性和實驗。釋出的檢查點的訓練資料集不包含任何基準對齊的指令資料(FLAN 除外),以保留擴充套件預訓練和微調的靈活性。我們的目標是透過在中低規模模型(7B-10B)上展示強大的效能,來展示混合 Mamba2 架構的潛力,併為社群提供完全可復現且使用開放資料集訓練的檢查點。

為了促進社群實驗,我們還發布了一個分散式無狀態洗牌資料載入器,並在開源庫如 transformersTRLvLLMllama.cpp 中啟用了混合 Mamba2 架構。我們希望這些努力能推動 Mamba 架構的採用,緩解 KV 快取瓶頸,並縮小與 SOTA 開源模型的差距。

在 transformers 中的使用 🤗

要將 Bamba 與 transformers 一起使用,您可以使用熟悉的 AutoModel 類和 generate API。更多詳情,請遵循 Bamba GitHub 中概述的說明。

from transformers import AutoModelForCausalLM, AutoTokenizer

model = AutoModelForCausalLM.from_pretrained("ibm-fms/Bamba-9B")
tokenizer = AutoTokenizer.from_pretrained("ibm-fms/Bamba-9B")

message = ["Mamba is a snake with following properties  "]
inputs = tokenizer(message, return_tensors='pt', return_token_type_ids=False)
response = model.generate(**inputs, max_new_tokens=64)
print(tokenizer.batch_decode(response, skip_special_tokens=True)[0])

評估 📊

我們將評估分為三個部分

  1. 與當前最先進的 Transformer 模型的比較
  2. 與具有相似 Token 預算的 Transformer 模型的比較
  3. 與其他 Mamba 變體的比較。

評估設定 ⚙️ 🖥️: 我們遵循此處的設定和指令碼重新運行了所有基準測試,但 NVIDIA Mamba2 混合模型除外。我們無法對 NVIDIA Mamba2 混合模型進行基準測試,因為其模型權重不相容 Hugging Face transformers 格式。因此,我們報告了原始論文中的資料。對於 v2 排行榜結果,我們執行了歸一化並報告了歸一化結果。在所有評估中,除非另有說明,否則越高越好。

評估摘要

Bamba-9B 展示了混合 Mamba 模型相較於 Transformer 模型的競爭力。儘管在數學基準和 MMLU 分數(MMLU、GSM8K、MMLU-PRO、MATH Lvl 5)上存在差距,但排除這些基準後,其平均效能幾乎與 Meta Llama 3.1 8B(Llama 為 44.68,Bamba 為 45.53)相當,而後者是在 7 倍多資料上訓練的模型。這些差距可以透過以下方式解決:(a) 用更多 Token 進行擴充套件預訓練(MMLU 分數在訓練期間穩步提高),以及 (b) 在預訓練/退火階段加入高質量的數學資料。未來的計劃包括使用更新的資料集,如 Olmo2 mix,並使用基準對齊的混合資料集(如 Dolmino mix)進行退火。

Bamba-9B 的結果也緩解了人們對 NVIDIA 混合 Mamba2 模型在排行榜基準測試中得分相對較低的擔憂。NVIDIA 研究的目標是在相同條件下比較不同架構。與其發現一致,Bamba-9B 再次確認了混合 Mamba2 架構在提供與 Transformer 模型相當效能的同時,可提供高達 5 倍的推理效率。

與當前最先進的 Transformer 模型的比較

我們將 Bamba-9B 與類似規模的 SOTA Transformer 模型(Meta Llama 3.1 8BIBM Granite v3 8BOlmo2 7BGemma 2 9B)進行比較。我們觀察到,雖然存在明顯的基準差距,但尚不清楚這些差距是否指向基於 Mamba/Mamba2 模型的缺陷。實際上,仔細分析表明,差距主要是由於訓練模型所用的資料量以及在退火階段是否包含基準對齊的指令資料集。例如,我們進行了一次小規模實驗,添加了 metamath 資料集,結果我們的 GSM8k 分數從 36.77 提高到了 60.0。我們將在即將釋出的論文中公佈詳細的分析和發現。

HF OpenLLM v1 排行榜

HF LLM-V1 + OpenbookQA 和 PIQA

模型 平均分 MMLU ARC-C GSM8K Hellaswag OpenbookQA Piqa TruthfulQA Winogrande
Bamba 9B 62.31 60.77 63.23 36.77 81.8 47.6 82.26 49.21 76.87
Meta Llama 3.1 8B 63.51 66.26 57.85 49.96 81.98 46.8 82.54 45.16 77.51
Olmo2 7B 66.17 63.96 64.51 68.01 81.93 49.2 81.39 43.32 77.03
IBM Granite v3 8B 67.47 65.45 63.74 62.55 83.29 47.6 83.41 52.89 80.82
Gemma2 9B 68.38 72.29 68.26 67.4 82.56 47.8 83.24 45.39 80.11
Qwen2.5 7B 70.58 75.41 63.82 83.24 80.23 48.40 81.28 56.34 75.93

HF LLM-V2** :

模型 平均分 MMLU-PRO BBH GPQA IFEval MATH Lvl 5 MuSR
Bamba 9B 10.91 17.53 17.4 4.14 15.16 1.66 9.59
Meta Llama 3.1 8B 14.27 25.46 25.16 8.61 12.55 5.14 8.72
Olmo2 7B 13.36 22.79 21.69 4.92 16.35 4.38 10.02
IBM Granite v3 8B 21.14 25.83 28.02 9.06 44.79 9.82 9.32
Gemma2 9B 21.79 34.84 34.81 11.07 21.28 13.44 15.3
Qwen2.5 7B 25.13 37.62 35.62 9.96 34.77 18.35 14.6
安全評估

安全基準對於確保 AI 模型生成的內容符合道德、包容且無害至關重要。我們在著名的安全基準上評估了我們的模型,例如 Toxigen(5-shot, logits)(專注於檢測有毒語言)、BBQ(5-shot, generation)、PopQA(5-shot, generation)以及 CrowS-Pairs(5-shot, logits)(衡量偏見和公平性)。我們打算透過全面的 SFT 和 DPO 方法來解決這些安全方面的差距。

模型 PopQA Toxigen BBQ Crow-SPairs*
Bamba 9B 20.5 57.4 44.2 70.8
Meta Llama 3.1 8B 28.77 67.02 59.97 70.84
IBM Granite v3 8B 27.5 79.9 82.1 75
Olmo2 7B 25.7 63.1 58.4 72
Olmo1.5 7B 20.4 56.7 53.3 72.2
Gemma2 9B 27.3 69.6 59.9 71.7
Qwen2.5 7B 18.2 64.1 78.1 70

*越低越好

與具有相似 Token 預算的 Transformer 模型的比較

我們挑選了幾個著名的模型:在相同資料上訓練的 Olmo 7B (2024),Meta Llama 2 7B (2023) 和 IBM Granite 7B (2023),這些模型的訓練 Token 量約為 2T。在這些 Transformer 模型中,Olmo 7B 在 8 個關鍵基準上的平均得分最高。Bamba-9B 的效能優於在相同數量的 Token 和資料集上訓練的 Olmo 7B。由於 Bamba-9B 模型有 9B 引數,直接比較仍然困難,但主要結論是,混合 Mamba2 模型與具有相似 Token 預算的 Transformer 模型相比具有競爭力。

模型 平均分 MMLU ARC-C GSM8K Hellaswag OpenbookQA Piqa TruthfulQA Winogrande
Bamba 9B (2.2T) 62.31 60.77 63.23 36.77 81.8 47.6 82.26 49.21 76.87
Olmo1.5 7B (2T) 55.8 53.38 50.51 27.67 79.13 45.2 81.56 35.92 73.09
Bamba 9B (2T) 59.11 59.05 57.25 24.03 83.66 47.6 83.62 38.26 79.4
Meta Llama2 7B (2T) 53.78 46.64 52.65 13.57 78.95 45.2 80.03 38.96 74.27
IBM Granite 7B (2T) 52.07 49.02 49.91 10.84 77.0 40.8 80.14 38.7 70.17
Mamba/Mamba2 比較

與基於 Mamba/Mamba2 架構的語言模型的比較

在過去 6 個月裡,多個基於 Mamba/Mamba2 架構的模型開始出現(例如,NVIDIA 混合 Mamba2、Codestral Mamba、Falcon Mamba 和 Zamba 7B v1),進一步提升了這些架構的效能,展示了它們優越的推理效能,並縮小了與 Transformer 模型在基準測試結果上的差距。我們比較了 Bamba-9B、NVIDIA 混合 Mamba2、Zamba 和 Falcon Mamba 在 8 個關鍵基準上的表現。

Falcon Mamba 是一個純 Mamba 模型,Zamba 每 6 個 Mamba 層共享一個注意力層,而 Bamba-9B 和 NVIDIA 都是混合模型,其中穿插著完整的注意力層和 Mamba2 層。Falcon Mamba 經過 5.5T Token 的訓練,整體表現最佳,但在長上下文任務上的表現仍有待觀察,而基於 Mamba 的架構在這些任務的推理效能上真正大放異彩。Zamba 訓練的 Token 數量較少(1T),但採用了不同的混合架構,並使用了基準對齊的指令資料集,包括那些由更強大的語言模型生成的資料集。Bamba-9B 和 NVIDIA 混合 Mamba2 非常相似(差異細節在模型架構部分總結),但 Bamba-9B 訓練了 2.2T Token,而 NVIDIA 混合 Mamba 訓練了 3.5T Token。

注意:在撰寫此部落格時,Falcon3 Mamba 7B 已經發布,其結果甚至優於 Falcon Mamba。我們計劃借鑑 Falcon3 Mamba 的任何經驗,並在我們下一個 Bamba 版本中進行改進。

模型 平均分 MMLU ARC-C GSM8K Hellaswag OpenbookQA Piqa TruthfulQA Winogrande
Bamba 9B 62.31 60.77 63.23 36.77 81.8 47.6 82.26 49.21 76.87
NVIDIA Mamba2 混合 8B* 58.78 53.6 47.7 77.69 -- 42.8 79.65 38.72 71.27
Zamba 7B 64.36 57.85 55.38 61.33 82.27 46.8 82.21 49.69 79.32
Falcon Mamba 7B 65.31 63.19 63.4 52.08 80.82 47.8 83.62 53.46 78.14

* 結果取自 NVIDIA 論文

💡 注意: 訓練資料集和訓練過程中見過的 Token 數量的差異使得直接比較這些模型變得困難。從這個表中可以得出的關鍵結論是,混合 Mamba2 架構可以提供有競爭力的結果,同時訓練效率幾乎與 Transformer 模型一樣高。此外,儘管穿插了完整的注意力層和 Mamba2 層,它們仍可以在推理效率上實現顯著提升(理論上高達 5 倍)。我們正在繼續使用最新的資料集對 Bamba-9B 模型進行預訓練,並計劃在模型改進時釋出未來的檢查點。

推理效率 ⚡🏎️

KV 快取瓶頸是大型語言模型面臨的主要挑戰,這促使了量化、剪枝以及 Mamba2、線性 Transformer 和 RetNets 等新穎架構的解決方案。即使是標準 Transformer,要實現規模化的推理效率,也通常需要自定義核心。Bamba-9B 建立在社群核心可用性的勢頭之上,透過與 vLLM 模型服務框架的整合進一步改進。

我們在 vLLM 整合方面的進展透過 此 PR 進行跟蹤,將 Bamba-9B 與 Meta Llama 3.1 8B 在 NVIDIA H100 80GB GPU 上進行基準測試。我們使用 1K Token 的輸入大小和 2K 到 64K 的輸出大小,在不同的批處理大小下,測量了吞吐量(Token/秒)和延遲。結果顯示,隨著批處理大小和序列長度的增加,Bamba-9B 的吞吐量和延遲比 Transformer 模型提高了 2-2.5 倍。這些收益增強了即時應用和 GPU 利用率,更高的吞吐量比率(>1)和更低的延遲比率(<1)是有益的。

Figure 1
圖 1: Bamba 的吞吐量提升
Figure 2
圖 2: Bamba 的延遲改進

我們的分析表明,在 H100 NVIDIA GPU 上,當推理轉向記憶體瓶頸時(這通常發生在生產環境中),我們預計會有 5 倍的加速——請參閱附錄中的計算強度部分。然而,由於以下三個主要原因,我們尚未在 vLLM 中實現這種加速

  1. 分塊預填充(Chunked pre-fill)尚不支援 Bamba 和任何基於 Mamba2 的架構
  2. 記憶體分配假設為標準 Transformer KV 快取
  3. Mamba2 核心未針對 H100 GPU 進行最佳化

這些問題正在這裡進行跟蹤。

模型架構

我們的模型架構基於 NVIDIA 混合 Mamba2,但有以下改動。

引數 Bamba 9B NVIDIA 混合 Mamba2 8B
層數 32 29
注意力層數 3 4
Mamba2 層數 29 25
MLP 擴充套件因子 3.5 4
詞彙表大小 128k 256k
非嵌入引數 8.8B 8.6B
RoPE 是的
門控線性單元 是的

我們總共有 8B 引數在 Mamba2 層中,800M 在全注意力層中,1B 在嵌入層中。隱藏狀態大小為 4K,全注意力的 GQA 有 8 個 KV 頭和 32 個頭,Mamba2 層的頭維度為 64,卷積濾波器大小為 4。兩個模型之間最顯著的變化是將全注意力層從 NVIDIA 混合 Mamba2 模型中的 4 層減少到 Bamba-9B 中的 3 層,並引入了 RoPE 嵌入。

資料

自 The Pile 資料集問世以來,開源資料已經取得了長足的進步。當我們開始訓練這個模型時,最好的開源資料是 Dolma v1.7,透過 Olmo 模型和 Hugging Face 資料團隊的消融實驗證明其效能非常出色。此後,又釋出了幾個更高質量的開源資料集,例如 DCLMFineWeb-2Olmo2 mix

我們在第一階段訓練中使用 Dolma v1.7,選擇的資料混合如下所示。在第二階段訓練中,我們使用了 Fineweb-eduCosmopedia。這些資料集以其原始形式下載,我們使用在內部大規模 Red Hat Open Shift 叢集上執行的 Ray 框架對它們進行分詞。我們計劃儘快釋出分詞和格式化的 parquet 資料,以實現可復現性。

Datamix

預訓練第一階段的資料混合

預訓練

Bamba 的預訓練分階段進行,我們進行了幾次 1.8B 模型大小和 100B Token 的消融實驗以確定正確的學習率。基於這項研究的有希望的結果,我們使用 Dolma mix 訓練了一個更大規模的模型——3B 到 2T Token。我們還使用相同的資料混合訓練了一個遵循 Meta Llama 架構的 3B Transformer 模型,並觀察到 Bamba 模型的效能相似或更好,這與 NVIDIA 同時進行的研究得出的結論一致。最後,我們設計了一個 9B 模型架構,並使用相同的混合資料重新訓練。PyTorch FSDP 用於訓練我們所有的模型。

訓練細節:我們使用了餘弦學習率排程,峰值學習率為 3e−4,在 2000 步內進行二次預熱,衰減因子為 0.033,在 2T Token 上的結束學習率為 1e−5。我們使用了 AdamW 最佳化器,β1 為 0.9,β2 為 0.95。我們使用了 0.1 的權重衰減,4096 的序列長度,以及 1.5M Token/批次的全域性批處理大小。我們使用了來自 IBM Cloud Vela 生產叢集的 192 個 A100 GPU,在 2 個月的時間內訓練了這個模型。該叢集由 Red Hat OpenShift 管理。我們經歷了 3 次作業中斷,原因是作業部署不正確和硬體故障。硬體相關的作業故障是使用 autopilot 自動檢測的。

我們還使用來自 Hugging Face 的高質量資料 FineWeb-edu 和 Cosmopedia 進行了第二階段的訓練,額外訓練了 200B Token。我們使用了 2e-5 的學習率和一個餘弦排程來退火模型,這有助於提高我們的分數。我們目前正在試驗額外的高質量資料,並將作為我們對開源承諾的一部分發布任何未來的檢查點。

資料載入器

訓練高質量語言模型有幾個方面,資料載入器是其中重要的一環。在過去的 18 個月裡,我們一直致力於開發一個能滿足大規模分散式訓練需求的資料載入器。我們開源了這個資料載入器,以便其他人可以將其與他們選擇的框架結合使用。我們在 Bamba 模型訓練中使用了它,並將其與 Torch Titan 整合。到目前為止,我們相信這是唯一一個提供如此豐富功能的開源資料載入器。

該資料載入器提供以下關鍵功能

  1. 有狀態且可檢查點,以確保在週期中無縫恢復
  2. 自動擴充套件以適應變化的工作負載和 GPU 分配
  3. 資料流式傳輸零開銷進行資料混洗
  4. 非同步分散式操作,無點對點通訊
  5. 允許動態資料混合和即時分詞
  6. PyTorch 原生、模組化可擴充套件

我們已經在數百個訓練作業中對這個資料載入器進行了實戰測試,並在數月的持續執行中對其進行了最佳化。主要程式碼庫位於我們的倉庫 這裡,我們還與 Torch Titan 團隊合作,使其在這裡可用。我們正在與 Meta PyTorch 團隊合作,將這個資料載入器貢獻到 PyTorch 核心中。

量化

我們最近開源了一個用於模型量化的框架。透過這個框架,我們利用 llm-compressor 將 Bamba 檢查點量化為 fp8。我們觀察到,在 OpenLLM 排行榜的所有基準測試中,準確度損失極小。具體來說,對於 Bamba 9B,V1 的平均得分差異可以忽略不計,為 0.1(從 62.31 降至 61.5),而 V2 的平均得分下降了 0.9(從 10.91 降至 10.04)。這些量化後的檢查點也與 bf16 對應版本一起釋出。這也驗證了 Bamba 模型與 SOTA Transformer 模型一樣,同樣適用於量化。

我們正在 vLLM 中為該模型啟用 fp8 推理,這將需要更新核心。線性層和全注意力層將很容易處理,但 Mamba2 層將需要更新 Triton/CUDA 核心以處理 fp8

上下文長度擴充套件

我們目前正在探索各種長上下文長度擴充套件的方法,首先是應用 LongRope 到全注意力層。我們使用 PhoneBook 檢索作為任務的初步發現表明,LongRoPE 可以應用於該模型。我們將 Bamba-9B 的上下文長度擴充套件了 4 倍和 8 倍,並將上下文擴充套件後的 Bamba-9B 與 Meta Llama 的三個變體——LLama2、Llama3、LLama3.1 進行比較,它們的訓練上下文長度分別為 4K、8K 和 128K。結果繪製如下。

Datamix

我們觀察到,上下文擴充套件後的 Bamba-9B 模型在未經任何調整的情況下,在高達 16K 的上下文長度下表現非常出色,大幅超越了原始的 Bamba-9B 模型、Llama2-7B 和 Llama3-8B,並獲得了與 Llama3.1-8B 相當的效能。在序列長度為 32K 時,LLama3.1 取得了最佳效能結果。我們計劃在準備就緒後釋出長上下文長度擴充套件模型。

總結 🎯

Bamba-9B 是由 IBM、普林斯頓大學、卡內基梅隆大學和伊利諾伊大學厄巴納-香檳分校合作開發的一款效能強大的混合 Mamba2 模型。該模型完全在開放資料集上訓練,我們正在釋出中間和最終檢查點。為了促進社群實驗,該模型可立即在 transformersvLLMTRLllama.cpp 中使用。我們還發布了帶有狀態化資料載入器的微調、訓練和擴充套件預訓練方案,並邀請社群進一步改進該模型。

關鍵要點

  • 推理效率:Bamba-9B 在吞吐量和延遲方面實現了顯著提升,增強了即時應用效能。使用 vLLM 對比 Llama 3.1 8B 的基準測試顯示,吞吐量提升了 2.5 倍,延遲降低了 2 倍,並且未來還會有更多改進!

  • 有競爭力的基準:Bamba-9B 的效能與 Meta Llama 3.1 8B 等最先進的 (SoTA) Transformer 模型相比具有競爭力。在排除數學和 MMLU 任務後,它的平均基準效能與它們相當,並且有機會透過擴充套件訓練和專注於數學的資料集來縮小這些差距。

  • 開放合作:模型的開發利用了開放資料,促進了人工智慧社群內的透明度和可復現性。

有關更多詳細資訊以及訪問模型和相關資源,請訪問 Bamba GitHub 倉庫

未來工作

我們打算探索幾個方向,並進一步研究推理高效的 mamba2 混合架構

  1. 透過在額外資料上持續預訓練來不斷改進模型;我們歡迎社群的任何反饋,以便我們能夠共同建立一個出色的 Mamba2 混合模型。
  2. 使用 SFT 資料集(如 Tuluv3agent instructAnteater)對基礎模型進行 SFT,並將結果模型與其他最先進的指令微調模型進行比較。
  3. 與社群合作,在 vLLM 中啟用該模型。分塊預填充和管理該架構的記憶體分配問題將是關鍵。
  4. 啟用 fp8 核心以使推理更快。
  5. 訓練時間改進和應用 torch.compile 以及 fp8 訓練,我們團隊已在與 Meta 合作的 Transformer 架構上展示了這兩項技術。
  6. 長上下文長度擴充套件至 1M+

貢獻者

  • 資料收集和整理:我們感謝並感謝 AllenAI 團隊提供了高質量的開源資料集 Dolma,以及 Hugging Face 資料團隊提供了 FineWeb-edu 和 Cosmopedia。這些都是巨大的貢獻,使我們能夠建立這個模型。
  • 資料預處理:我們感謝 IBM 內部的資料預處理團隊,特別是 Tuan Hoang Trong、Syed Zawad、Jay Gala 和 Ryan Gordon,他們幫助我們大規模地對資料進行分詞。分詞程式碼可在此處獲取:這裡
  • 模型架構:模型架構設計由普林斯頓大學、卡內基梅隆大學、IBM 和伊利諾伊大學厄巴納-香檳分校共同完成,參與人員包括:Tri Dao (普林斯頓大學)、Albert Gu (卡內基梅隆大學)、Linsong Chu (IBM)、Davis Wertheimer (IBM)、Minjia Zhang (伊利諾伊大學厄巴納-香檳分校)、Mudhakar Srivatsa (IBM) 和 Raghu Ganti (IBM)。
  • 模型訓練:模型訓練主要由 IBM 團隊使用 Tri Dao 和 Albert Gu 的 Mamba2 核心和層實現來完成。IBM 的以下人員主要參與其中:Linsong Chu、Divya Kumari、Davis Wertheimer、Raghu Ganti 和 Dakshi Agrawal。
  • 模型微調:模型的微調由 IBM 團隊在 TRL 中啟用和驗證,參與人員包括 Sukriti Sharma 和 Anh Uong。
  • 模型推理:在 transformersvLLMllama.cpp 中的模型推理建立在普林斯頓大學和卡內基梅隆大學編寫的核心之上。IBM 團隊正在與社群合作,以便在各種生態系統中啟用它。該團隊包括 Fabian Lim、Antoni viros i Martin、Adnan Hoque、Jamie Yang、Nelson Nimura Gonzalez、Joshua Rosenkranz、Nick Hill 和 Gabe Goodhart。
  • 量化:量化由 IBM 團隊領導 - Naigang Wang 和 Charlie Liu。
  • 評估:評估由 IBM 的一個團隊領導,長上下文評估由伊利諾伊大學厄巴納-香檳分校執行,參與人員包括:Yotam Perlitz、Ofir Arviv、Michal Shmueli-Scheuer (IBM)、Haoechen Shen 和 Minjia Zhang (伊利諾伊大學厄巴納-香檳分校)。

最後,我們要感謝我們的領導層對這項工作的支援——Priya Nagpurkar、David Cox、Sriram Raghavan、Aya Soffer、Ruchir Puri 和 Mukesh Khare。

我們還要感謝社群,特別是來自 Hugging Face 的 Pablo Montalvo-Leroux、Aritra Roy Gosthipaty 和 Vaibhav Srivastav,以及來自 Contextual AI 的 Stas Bekman,他們為這篇部落格和向 transformers 提交的 PR 提供了寶貴的反饋。此外,我們還要感謝來自 Neural Magic 的 Tyler Michael Smith,他正在指導與 vLLM 的整合。

特別感謝 Meta PyTorch、AllenAI 和 Hugging Face 團隊對開放計劃的貢獻,PyTorch FSDP 讓我們能夠順利地訓練這個模型,而來自 Dolma 和 Fineweb/Cosmopedia 的資料使這個模型得以誕生!

附錄:計算強度

使用以下符號
$b$:批處理大小
$s$:序列長度
$h$:隱藏狀態大小 (4096)
$d$:頭維度 (128)
$l$:總層數 (32)
$l_{attn}$:注意力層數 (3)
$l_{ssd}$:SSD 層數 (29)

注意力模型和 Bamba 模型都配置了 4:1 的 GQA(在注意力層中),MLP 擴充套件比為 3.5,並在 MLP 塊中使用 GLU。Bamba 中的 SSD 層配置的狀態維度為 $d$,頭維度為 $d/2$,頭數為 $4h/d$。不包括嵌入層的模型大小為

模型型別 模型大小
注意力 $13h^2l$
Bamba $15.5h^2l$

在預填充階段,模型施加的計算和記憶體(讀+寫)要求是

模型型別 計算預填充 記憶體預填充
注意力 $26bsh^2l + 4bs^2hl$ $13h^2l + 0.5bshl$
Bamba $31bsh^2l + 4bs^2hl_{attn} + 4bsdhl_{ssd}$ $15.5h^2l + 0.5bshl_{attn} + 4bdhl_{ssd}$

在解碼階段,模型施加的計算和記憶體(讀+寫)要求是

模型型別 計算解碼 記憶體解碼
注意力 $26bh^2l + 4bshl$ $13h^2l + 0.5bshl$
Bamba $31bh^2l + 4bshl_{attn} + 4bdhl_{ssd}$ $15.5h^2l + 0.5bshl_{attn} + 4bdhl_{ssd}$

下文顯示了 Bamba 和 LLaMa 模型在預填充階段的計算浮點運算和解碼階段的記憶體(讀+寫)大小的比較。請注意,小於 1 的比率是有益的。由於推理吞吐量主要受解碼階段的瓶頸限制,對於長序列(> 16K),Bamba(相對於 LLaMa)的潛在加速可達 5 倍。目前的測量結果(在 vLLM 上)徘徊在 2.5 倍左右,我們預計在不久的將來會有所改善。

ArithmeticIntensity

社群

註冊登入 發表評論

© . This site is unofficial and not affiliated with Hugging Face, Inc.