為 AMD MI300 建立自定義核函式

釋出於 2025 年 7 月 9 日
在 GitHub 上更新

AMD 核函式

Title card

導言

每天超過十億次:這是對 ChatGPT 每日處理請求數量的保守估計,而且這個數字短期內不太可能下降。對於每個請求和每個生成的詞元 (token),我們都會對一個擁有數十億引數的模型進行一次推理。這就是為什麼模型最佳化在每個層面上都至關重要:當處理如此巨大的規模時,即使是 1% 的延遲或功耗提升也能帶來巨大的成本節約。

但是,這些提升能從何而來?模型架構已經相當成熟,流行的模型也早已實現了權重化。然而,還有一個關鍵層面可以最佳化模型推理:核函式 (kernel) 層面。核函式是你在網路中執行任何操作時執行的演算法:有矩陣乘法核函式、卷積核函式、批次歸一化核函式等。核函式是低階的、高度最佳化的演算法,通常是為它們將要執行的裝置量身定製的。它們編寫起來 notoriously 長且困難,並且需要對 GPU 的內部工作原理有很好的理解。

核函式對於在神經網路中執行操作至關重要——沒有核函式,一個操作實際上就無法使用。因此,新的創新產品通常會推出一個“day 0”核函式,該核函式通常只為最新的 Nvidia 硬體最佳化。這種方法排除了許多其他裝置,特別是 AMD GPU,儘管它們提供相當甚至更優的規格,卻常常被核函式開發者忽視。Hugging Face 與 AMD 合作,在 AMD 平臺上提供最先進的效能,並讓開源社群受益。作為這次合作的一部分,我們與 AMD 決定專注於提供開源的最佳化核函式,以提升在 8 個 MI300X 節點上使用 VLLM 以 FP8 格式服務 Llama 3.1 405B 的效能。

在這篇部落格文章中,我們將探討我們如何為 MI300X 最佳化效能,以及每個核函式是如何被單獨微調的。但首先,讓我們看看使用我們的自定義核函式所實現的效能提升。透過結合以下三個最佳化核函式:

  • 融合殘差連線、RMS 範數和 FP8 轉換的核函式
  • 融合 SwiGLU 啟用和 FP8 轉換的核函式
  • Skinny GEMM 核函式

我們在由 MI300X GPU 驅動的節點上執行 VLLM 時,實現了顯著的加速。

Latency gains

測量是在輸入大小為 1、輸出大小為 128 的情況下進行的,以模擬解碼模式。我們使用 30 次迭代的中位數來測量解碼延遲。

這些效能提升是在 VLLM 中測量的,但你也可以單獨使用這些核函式,具體方法見下文的“如何使用”部分。

如何使用這些核函式

hf-rocm-kernels 倉庫

前面描述的所有核函式都可以在 hf-rocm-kernels 倉庫中找到,地址在這裡。在該倉庫中,你會找到如何安裝該包的說明、每個核函式的原始碼、它們各自的 Python 繫結、各種基準測試指令碼和一個測試套件。使用基準測試指令碼和 MI300X,你甚至可以復現這篇部落格文章中的結果。為了確保 Torch 或 VLLM 的結果一致,你可以使用與我們相同的容器。你也可以將該倉庫作為基礎來構建自己的核函式:它包含了如何將一個 CUDA 風格的核函式繫結到 Python 的說明和一個簡單的示例核函式。你甚至可以檢視正在開發中的分支,以瞭解新的核函式,比如這裡描述的計算與通訊核函式。

在 VLLM 中的整合

所描述的核函式很快將被整合到 VLLM 專案的 AMD 分支中,但如果你想自己看看如何實現類似的功能,可以檢視這個分支和這份文件

最佳化過程

我們首先將快速回顧一下我們正在使用的裝置架構:MI300X。然後,我們將看看在最佳化之前模型推理的狀態。這將幫助我們識別瓶頸,並確定需要編寫哪些自定義核函式。接著,我們將逐一審視我們編寫的每個核函式,這將為我們提供一個從多個角度探討核函式最佳化如何進行的機會。

MI300X 簡介

在我們深入最佳化 GPU 程式碼之前,我們需要了解 GPU 是如何工作的。已經有很多資源對 GPU 的內部工作原理做了很好的解釋,我將連結放在這裡這裡這裡。我們仍然會快速過一遍 GPU 的不同層次,作為一個簡短的回顧。如果你想跳過回顧,直接進入我們自定義核函式的細節,請點選這裡

執行緒 (Threads)

GPU 中最小的工作單元是執行緒 (thread)。GPU 上完成的任何工作都是因為一個執行緒執行了一條指令。指令是基本操作,如加法、乘法、從一種資料型別到另一種的轉換,或載入和儲存。每個執行緒都有自己的記憶體,稱為暫存器 (registers,或 VGPRs),只有它自己可以訪問。一個執行緒最多可以有 256 個暫存器,每個暫存器 32 位寬。下面是一個執行緒及其可訪問的 256 個 VGPRs 的示意圖。

Representation of a thread

除了使用載入或儲存指令外,執行緒只能在自己的暫存器上執行指令。例如,要將兩個向量 A 和 B 相加,每個執行緒將:1) 將 A 中的一個元素載入到其暫存器中,2) 將 B 中的另一個元素載入到暫存器中,然後 3) 執行加法並將結果儲存在另一個暫存器中,最後 4) 將該暫存器中的值儲存到記憶體中。總共是 4 條指令。

執行緒束 (Warps)

下一個工作單元是執行緒束 (warp):每個執行緒束由 64 個執行緒組成。執行緒束沒有自己的記憶體,但它們對我們很重要,因為一個執行緒束中的所有執行緒必須在同一時間執行相同的指令。這既是一種保證,也是一種約束。

Representation of a warp

執行緒束還允許不同的執行緒與同一執行緒束中的其他執行緒交換來自其暫存器的資訊。儘管一個執行緒束中的不同執行緒可以訪問不同的資料,但它們都必須執行相同的指令,這意味著在編寫核函式時,你需要考慮的是執行緒束級別的行為。

計算單元 (Compute units)

執行緒束被捆綁成執行緒塊 (thread blocks):執行緒塊是軟體抽象,但執行在稱為計算單元 (CU) 的硬體元件上。一個計算單元可以同時執行多個執行緒塊,但它最多隻能容納 16 個執行緒束。每個計算單元都有一個專用的 L1 快取和共享記憶體。L1 快取無法控制或分配,它有助於位於該 CU 上的所有執行緒束的資料重用。相反,共享記憶體可以被分配和用作所有執行緒束共享的儲存空間。例如,當我們希望一個計算單元中的所有執行緒束(因此也是所有執行緒)訪問同一個緩衝區時,我們會在共享記憶體中分配它。共享記憶體和 L1 快取的訪問速度都很快,因為它們“靠近”執行緒。

Representation of a compute unit

執行緒塊還提供了同步其內部所有正在執行的執行緒的能力:這在處理影響共享記憶體的操作時非常有用,比如在共享記憶體中將一個數組初始化為零或進行歸約操作。總的來說,在編寫核函式時,執行緒塊是需要考慮的最高級別:很難同步不同的執行緒塊或讓它們以任何方式進行互動。核函式的吞吐量與 GPU 上存在的計算單元數量緊密相關:CU 越多,可以同時執行的執行緒塊就越多,如果你能充分利用所有 CU,吞吐量就會增加。

XCD

計算單元隨後被分組為加速器複合晶片 (XCD),每個 XCD 包含 38 個計算單元。儘管 CU 之間可能無法直接互動,但它們都共享一個 L2 快取,你無法控制這個快取,但在重用資料時它可能非常有用。例如,在訪問記憶體時,讓位於同一個 XCD 上的兩個計算單元訪問相同的資料將大大減少載入延遲。L2 快取相當大:大小為 4MB,而共享記憶體大小為 64kB,L1 快取包含 32kB。

Representation of a XCD

整個 GPU (MI300X)

透過組裝 8 個 XCD(這給了我們 8 * 38 = 304 個 CU),並增加最後一級快取(稱為 infinity cache,大小為 256MB)和大量的視訊記憶體(192GB),我們就得到了 MI300X。

Representation of a MI300

所有的 XCD,因此所有的執行緒,都可以訪問視訊記憶體 (VRAM),但到達那裡的速度相當慢。隨著你離執行緒級別越遠,記憶體訪問速度變得越慢,但其大小和作用範圍也越大,意味著它服務於更多的執行緒。在最佳化核函式時,總需要在執行大量操作和載入大量資料之間取得平衡,但總的來說,你應該儘可能少地訪問視訊記憶體(通常稱為全域性記憶體)。

當看這張圖時,我們可以理解為什麼 GPU 被稱為“大規模並行”:在這裡,我們有 304 個計算單元,每個計算單元可以執行 16 個執行緒束,每個執行緒束有 64 個執行緒。這意味著我們最多可以同時執行 311,296 個執行緒,每個執行緒執行自己的指令。請記住,一條指令是像加法這樣的基本操作,所以像牛頓法這樣的簡單例程對於單個執行緒來說可能執行時間很長。GPU 並非為指令快速執行而最佳化,即不是為了降低每條指令的延遲:那是延遲導向的裝置。它們被最佳化為讓許多執行緒一起執行,消耗和輸出大量資料:它是一個吞吐量導向的裝置。在為 GPU 最佳化核函式時,我們相應地進行調整:最好是讓一個演算法在許多執行緒上同時執行幾條指令,而不是讓它在少數執行緒上執行許多指令。因此,將在 GPU 上執行的演算法稱為“並行”的。

有三件事可能會阻礙此類演算法以最佳化方式執行:當需要載入大量資料時(記憶體受限)、當需要執行許多操作時(計算受限)或當執行緒必須協同工作時(同步開銷)。

Day 0 效能分析

在最佳化工作負載時,寫下第一行程式碼之前要做的第一件事就是對當前工作負載的狀態進行效能分析。在我們的案例中,我們將對 VLLM 中的模型推理進行效能分析,以瞭解每個操作佔用的時間。這有助於識別主要瓶頸以及我們可以首先處理哪些核函式以獲得最大加速。例如,以下是批次大小為 32 時的分解圖:

Disk plot ok kernels latency

我們可以透過每個切片看到網路的不同部分:

  • “Attention*”切片,我們將 RoPE、注意力 (attention) 和 KV 快取核函式分組在一起;
  • “Attention GEMMs”,包括兩個投影,QKV 和 Output;
  • “Communications”,由兩個 all-reduce 操作組成,一個在 Attention 塊之後,一個在 MLP 塊之後,它們的存在是因為我們正在進行張量並行(TP8)工作;
  • “MLP GEMMs”,包括在 MLP 中進行的兩個投影,Gate / Up 和 Down;
  • “RMS norm”和“SwiGLU”切片,每個核函式一個——請注意,RMS norm 核函式每個塊被呼叫兩次,一次在 Attention 之前,一次在 MLP 之前;
  • “Other”切片,重新組合了我們沒有標記為更大類別的核函式,因為它們的影響較小。

我們已經可以看到,大部分延遲來自 GEMM 和通訊,但注意力及其周圍的操作對延遲的貢獻並不大。這可能有點令人驚訝,因為許多論文都關注於注意力並降低其成本,但似乎透過 KV 快取和 FlashAttention 的結合(VLLM 中已經進行了最佳化),這部分可能不再是首要任務。令人驚訝的是,對“RMS norm”核函式的兩次呼叫成本相當高,因此最佳化該核函式可能會帶來很大的好處。連同 SwiGLU 核函式,它們佔總延遲的 15%,這是不可忽視的。總而言之,我們最好的行動方案可能是致力於這兩個核函式,並嘗試在 GEMM 上獲得 небольшое 加速。為了確認這種效能分解不是偶然現象,我們可以看看其他批次大小:

Latency distribution over batch sizes

我們可以看到,在批次大小為 32 時出現的模式在其他批次大小下也成立,儘管隨著批次大小的增加,GEMM 和通訊的延遲貢獻變得更大。此外,批次大小為 32 在 GEMM 的延遲方面似乎是一個異常值:這可能是因為當批次大小為 32 時選擇的 GEMM 經過了手動調整,或者因為批次大小為 32 呈現出良好的記憶體對齊模式,所以批次大小為 32 的 GEMM 比批次大小為 24 或 28 的更快。

現在我們已經確定了一些需要最佳化的熱點,讓我們來看看我們編寫的第一個核函式:RMS norm 核函式。


RMS norm 核函式

在每個解碼器塊中,我們有兩個主要部分:一個注意力塊和一個 MLP 塊。兩者都以兩個輸入之間的殘差連線開始:當前隱藏狀態 x x 和殘差 r r 。它們具有相同的形狀,即 n n 行(與詞元數量相同)和 d d 列。將它們相加後,我們對 x x 應用逐行的均方根 (RMS) 範數,並且由於模型採用 FP8,我們使用一個縮放因子 s s x x 量化為 FP8。僅僅將這三個操作融合到一個核函式中就可以帶來不錯的效能提升。在數學上,我們需要執行的操作如下:

i+j+kxx+rrxV=i=1dxi2xxV+ϵxQ=Qfp8(sxw) \begin{align} \phantom{i + j + k} &\begin{aligned} x &\leftarrow x + r\\ r &\leftarrow x \end{aligned}\\ &\begin{aligned} V &= \sum_{i=1}^{d} x_i^2 \end{aligned}\\ &\begin{aligned} x &\leftarrow \frac{x}{\sqrt{V + \epsilon}} \\ x_Q &= Q_{\text{fp8}} \left( s * x * w\right) \end{aligned} \end{align}

其中 w w 是一個大小為 d d 的權重向量。步驟 (1) 和 (3) 都非常基礎。對於步驟 (1),我們只需將每個執行緒定位到張量中的不同位置,載入 x x r r 的一些元素,將它們相加並存回 r r 。對於步驟 (3),每個執行緒執行一些標量操作(加法、平方根、除法)和一次到 FP8 的轉換。所有這些,每個執行緒都可以獨立完成:這完全符合 GPU 的並行特性。需要注意的步驟是 (2):我們需要對 d d 進行求和,這意味著要麼每個執行緒將訪問 d d 列中的每一列,要麼我們需要線上程之間交換資料。d d 越大,第一種方案需要載入的資料就越多,因此可行性越低。我們將選擇第二種方案:在塊級別同步執行緒,它們將使用共享記憶體交換資料。每個執行緒將獨立累加 V V 的一部分,然後我們將在整個執行緒塊中對所有這些部分求和,這就是我們所說的歸約 (reduction)。由於 V V 是跨整行計算的,我們將為每一行分配一個執行緒塊。

與開箱即用的 PyTorch 相比,這個核函式的最基本版本帶來了大約 10 倍的加速。但這還不夠:在此基礎上我們還可以新增許多最佳化。

最佳化:記憶體相關

就延遲而言,成本最高的操作之一是訪問視訊記憶體,也稱為全域性記憶體。幸運的是,有一些易於遵循的原則可以顯著降低載入資料的成本。

首先,我們可以看看單個執行緒在單個指令中能載入多少資料:使用 MI300X 指令指南,我們看到從全域性記憶體進行的最大載入是 128 位寬。由於我們載入的是 FP16 資料,我們將每次載入 128b / 16b = 8 個元素。對於 FP32 元素,這將對應於每次載入 4 個元素。

其次,我們確保記憶體訪問是合併的。由於每個執行緒都是執行緒束的一部分,當一個執行緒到達“載入”指令時,執行緒束中的所有其他執行緒也同時到達。為了提高效率,這些“載入”指令會在整個執行緒束中被捆綁在一起。然後,執行緒束集體獲取所需的資料,每個執行緒得到它需要的資料。當執行緒束獲取一個沒有任何間隙的單個數據塊時,就達到了最高效率:這就是我們所說的連續資料。當我們需要的載入資料量超過一次“載入”指令所能載入時,就會出現問題,如下圖所示。

Two loading scenarios

在這個假設的場景中,我們在同一個執行緒束中有兩個執行緒。它們需要共同載入 16 個 fp32 元素,對於哪個執行緒載入哪個元素沒有限制。這是一個典型的“歸約”情況。由於一個執行緒每個指令只能載入 4 個 fp32 元素,我們至少有兩種讀取資料的方式,如場景 (a) 和 (b) 所示。要決定哪個場景最好,我們需要從執行緒束的角度來看,而不是執行緒的角度。在場景 (a) 中,第一次載入獲取元素 0,1,2,3,8,9,10,11:我們看到資料不是連續的,因為元素 3 和 8 之間有間隙。而在場景 (b) 中,第一次載入獲取元素 0,1,2,3,4,5,6,7:我們載入了連續的資料。第二次載入也是如此。因此場景 (b) 更好。儘管在場景 (a) 中,每個執行緒最終得到 8 個連續的元素,但這並不重要:重要的是執行緒束是否載入了連續的資料。這很重要,因為如果執行緒束在一個週期內只能載入 8 個連續元素,那麼場景 (a) 的每次載入都需要兩個週期來處理,而在場景 (b) 中,每次載入只需要一個週期。

第三,我們減少儲存次數:當我們看步驟 (1) 和 (3) 時,可以看到只需要兩次儲存:一次是 r r ,一次是 xQ x_Q 。在步驟 (1) 之後,我們已經可以儲存 r r 並完成該操作。但我們仍然需要在步驟 (2) 完成後訪問 x x 的修改版本。為此,我們可以將 x x 的修改版本儲存在全域性記憶體中,並在步驟 (2) 完成後重新載入它,並依賴於重新載入時的快取命中。或者,如果 x x 足夠小,我們可以將其修改版本儲存在共享記憶體中:如果 x x 是 FP16 格式,並且每個 CU 只有一個執行緒塊,那麼我們每個執行緒塊可以在共享記憶體中儲存 64KB / 2B = 32 * 1024 個元素。在 Llama 405B 的情況下,d d 等於 16384,所以這能放得下。使用共享記憶體比依賴快取命中提供了更好的加速,特別是當許多執行緒塊同時活動時:如果 L1 快取不夠大,無法容納整個 x x ,那麼我們必須依賴 L2 快取,而 L2 快取是由 38 個 CU 共享的。

除了記憶體訪問,我們還可以最佳化計算效率,但我們將把這部分留到下一個核函式,因為兩種情況下的最佳化是相似的。

結果

當我們應用上述最佳化後,我們得到以下結果:

Latency of RMS norm kernels

行數 Torch (μs) VLLM (μs) 我們的 (μs)
1 38.8998 5.5145 4.18138
2 43.2469 5.65645 4.36976
4 41.1304 5.6893 4.37628
8 43.8883 5.72275 4.39081
16 46.8876 5.85667 4.48165
32 55.2276 6.08502 4.72017
64 75.6086 6.4629 5.54214
128 98.1122 7.49166 6.27341
256 119.727 11.8812 10.739
512 195.782 23.1595 18.5549
1024 355.42 44.8143 34.7204
2048 671.513 81.2089 73.35

輸入張量為形狀為 [X, 16384] 的 FP16。我們核函式的最基本版本,稱為“Pointwise”,沒有任何與記憶體相關的最佳化,但已經比 Torch 快了至少 4 倍。它不如 VLLM 的核函式實現,但我們的“Vectorized”實現超過了“Pointwise”和 VLLM。這是實現了合併 128 位載入的核函式版本,僅次於“Vectorized + SMEM”(SMEM 代表共享記憶體)實現,後者在低和高批次大小下都提供了比 VLLM 明顯更好的加速比。


SwiGLU 核函式

在 MLP 塊中,在我們剛才討論的核函式之後,是一個我們之前稱之為“Gate / Up”投影的投影。我們之所以這樣稱呼它,是因為“Gate / Up”投影實際上是兩個具有相同輸入的投影的拼接:“Gate”和“Up”。因此,我們將“Gate / Up”投影的結果 x x 寫為 x=xGxU x = x_G | x_U ,其中 | 是沿列軸應用的拼接運算子。xG x_G xU x_U 具有相同的維度。我們需要這兩個投影的原因是緊隨其後的 SwiGLU 啟用函式,其結果 y y 由方程 (4) 定義。SwiGLU 啟用函式之後是“Down”投影,在我們的案例中是 FP8 格式,所以我們還需要如方程 (5) 所示對 y y 進行量化。

i+j+ky=σ(xG)xUyQ=QFP8(sy) \begin{align} \phantom{i + j + k}& \begin{aligned} y = \sigma \left( x_G \right) \cdot x_U \\\end{aligned}\\ &\begin{aligned} y_Q = Q_\text{FP8} \left( s * y \right) \end{aligned} \end{align}

其中 σ \sigma 是 sigmoid 函式:σ(x)=ex/(1+x) \sigma (x) = e^{-x} / (1 + x) 。我們將編寫一個融合核函式 (fused kernel) 來處理所有這些操作。對於這個核函式,除了共享記憶體緩衝區外,為 RMS 核函式描述的最佳化仍然適用。這裡我們將重點關注與計算相關的最佳化。

最佳化:與計算相關

我們將透過兩種方式來提高核函式的速度:增加每條執行指令完成的工作量,以及使用更快的指令。

為了增加每條指令完成的工作量,我們可以使用 打包 (packed) 指令。當我們要對多個元素應用相同操作時,打包指令非常有用:我們不是對每個元素執行一條指令,而是在一個元素向量上執行一條指令。在 CPU 中,打包(或向量化)指令是單執行緒最佳化的基礎,AVX 指令集家族就是明證。GPU 上也有一些打包指令,在適當的地方它們可以非常有用。在 MI300X 上,除其他外,還有用於 FP16 加法和乘法的打包指令,我們將在兩個步驟中都使用它們。還存在從 FP32 到 FP8 的打包轉換,與非打包轉換相比,這可以顯著提升效能。事實上,除了從 FP32,沒有任何其他資料型別可以轉換為 FP8,因此對於 RMS norm 核函式和這個核函式,我們必須先轉到 FP32 精度才能轉換為 FP8。

然而,在這個核函式中這不成問題:sigmoid 函式 σ \sigma 需要我們計算一個指數,這是一個能從 FP32 精度中獲益匪淺的操作。這是一個我們可以透過使用更快的指令來最佳化計算的例子:我們不使用 exp 指令,而是將輸入乘以 log(2) \text{log}(2) 並使用 exp2 指令,這要快得多。我們只遭受幾乎可以忽略不計的精度損失,但卻降低了延遲。

結果

對於形狀為 [X, 16384] 的 FP16 輸入張量,我們得到下表

行數 1 2 4 8 16 32 64 128 256 512 1024 2048
Torch (μs) 40.2731 29.923 35.305 23.5763 22.4738 25.3445 31.5829 40.3194 53.5369 79.8037 124.873 243.202
VLLM (μs) 3.84116 3.86192 3.92937 3.94151 4.01047 4.02421 4.08943 4.20317 4.48755 7.48465 13.7389 25.4306
我們的 (μs) 1.92981 1.93904 1.93524 1.99316 2.00415 1.91563 2.04498 2.61763 3.57726 5.47608 10.0482 19.8957
加速比 (VLLM / 我們的) 1.990434291 1.991665979 2.030430334 1.977518112 2.001082753 2.100724044 1.999740829 1.605715857 1.254465708 1.366789747 1.367299616 1.278195791

透過針對 MI300X 定製的記憶體和計算最佳化,我們得到的核函式平均比 Torch 快 14 倍以上,比 VLLM 的核函式快 27% 到 100%。


瘦 GEMM 核函式

正如我們之前所見,模型推理延遲的大約 60% 來自於投影,而投影依賴於 GEMM 核函式。GEMM 核函式在 AMD 的 hipBLASLT rocBLAS 等專用庫中被高度最佳化,因此編寫一個在所有情況下都表現更好的自定義核函式相當困難。但如果我們專注於一些與我們相關的邊緣情況,併為這些特定情況編寫一個 GEMM 核函式,那麼我們的自定義核函式就有可能比專用庫中的更快。

在預填充和解碼階段,網路中任何投影的輸入行數都與正在處理的 token 數量相同。而在解碼期間,正在處理的 token 數量等於批處理大小。因此,在解碼期間,所有 GEMM 核函式的輸入行數都等於批處理大小,為了我們的目的,這個範圍在 1 到 256 之間。我們將關注非常小的批處理大小。當我們有一個 GEMM AB=C A * B = C A A 的行數很少而列數很多時,我們稱之為 瘦 (skinny) GEMM。我們為這種 GEMM 使用一個特定術語的原因是,它們不適合我們在 GPU 上執行的經典 GEMM 演算法。通常,GEMM 核函式的效率來自於 分塊 (tiling):我們將結果矩陣分成許多子矩陣,稱為塊 (tile),並將每個塊分配給一個不同的計算單元 (CU)。如果我們有很多塊,就可以使用很多 CU,GPU 使用率就會很高。下圖對此進行了說明。

Classic GEMM dimensions

但是如果輸入 A A 的行數非常少,那麼只能形成少數幾個塊,這導致只有少數計算單元處於活動狀態,因此 GPU 利用率很低。

Skinny GEMM dimensions

瘦 GEMM 對 GPU 來說是天生不便的。在下一部分,我們將看到如何透過一個假設我們處於瘦 GEMM 上下文中的自定義核函式,使它們變得更方便。

最佳化:split-K

由於瘦 GEMM 的主要問題是我們使用的計算單元太少,所以我們首先要做的就是找出一種方法來使用更多的計算單元。為此,我們可以利用以下這個令人拍案叫絕的公式:

cij=k=1Kaikbkj=(k=1K/2aikbkj)+(k=1+K/2Kaikbkj) c_{ij} = \sum_{k=1}^K a_{ik} b_{kj} = \left( \sum_{k=1}^{K/2} a_{ik} b_{kj} \right) + \left( \sum_{k=1+K/2}^{K} a_{ik} b_{kj} \right)

藉助和的結合律,我們可以沿著共享軸(通常稱為 K 軸)拆分主 GEMM,並用幾個併發執行的子 GEMM 替換一個 GEMM。每個子 GEMM 將使用與主 GEMM 一樣多的 CU,因此使用的 CU 數量將乘以我們拆分 K 軸的次數。下圖對此進行了說明。

Split-K algorithm

在這裡,我們將 split-K 設定為 2,從而使一次性使用的 CU 數量增加了一倍。由於我們得到的是部分結果,我們需要在兩個子 GEMM 都完成後將它們相加。可能看起來違反直覺的是,我們增加了一個操作——對部分結果求和,但我們聲稱這減少了整個過程的延遲。但由於每個 CU 都需要遍歷整個 K 軸來計算結果,因為我們將其一分為二,所以每個 CU 完成的工作量也減少了一半。如果以這種方式節省的工作量能夠抵消對最終結果求和所增加的工作量,那麼我們就能實現整體最佳化。只要 K 很大且原始 GEMM 使用的 GPU 不到 50%,這通常是成立的。

最佳化:移除填充

如果我們假設透過 split-K,大多數計算單元都在忙於處理自己的塊,我們就可以將最佳化範圍集中在計算單元級別。我們將看一下實際的矩陣乘法是如何完成的,以及我們如何加速它。

在像 MI300X 這樣的頂級 GPU 中,矩陣乘法由一個稱為張量核心 (tensor core) 的專用硬體單元處理。張量核心只執行矩陣乘法,但速度非常快。張量核心指令的格式是 mfma_MxNxK...,其中 mfma 代表矩陣融合乘加 (matrix fused multiply-add),M 是左側矩陣的行數,N 是右側矩陣的列數,K 是兩者的共享維度。我們在下面展示一個假設的指令 mfma_2x2x4

MFMA dense version

張量核心指令只有少數幾種,但對於任何三元組 MxNxK,使用專用的張量核心指令都比任何其他替代方案快得多。張量核心指令還有兩種型別:“密集 (dense)” 和 “稀疏 (sparse)”。密集指令對應於標準矩陣乘法。稀疏指令假設左側矩陣 A A 具有 4:2 結構化稀疏模式,這意味著沿矩陣 K 軸每 4 個元素中就有兩個是零。在數學上,對於任何 i,j i, j 使得 ai,4j+3 a_{i, 4j+3} A A 的一個元素,我們在 (ai,4j,ai,4j+1,ai,4j+2,ai,4j+3) \left( a_{i,4j}, a_{i,4j+1}, a_{i,4j+2}, a_{i,4j+3} \right) 中至少有兩個零。下面是一個稀疏矩陣的例子。

A 4:2 sparse matrix

讓我們回到我們的模型,FP8 精度的 Llama 405B。對於 FP8,我們只有兩個密集張量核心指令:16x16x3232x32x16。我們還有一個大小為 16x16x64 的稀疏指令。對於一個有 8 行的輸入,即使使用最小的密集指令 16x16x32 也意味著我們必須為輸入新增 8 行填充,這是對計算資源的浪費。人們可能會想,我們是否可以改用稀疏指令:畢竟,如果一個 16 行矩陣的一半是 4:2 稀疏的,我們可以用一個密集的 8 行矩陣完全描述其非零係數。反之,如果我們有一個 8 行的密集矩陣,我們可以將其所有資料放入一個具有 4:2 稀疏性的 16 行矩陣中。而使用稀疏指令的好處是顯而易見的:密集指令的 K=32,而稀疏指令的 K=64。在相同的週期數內,稀疏指令的深度是原來的兩倍。我們在下圖中用一個 1 行輸入和 2x2x4 密集指令及其稀疏的 2x2x8 對應指令來說明這個稀疏技巧。

Using sparsity for skinny inputs

利用這個技巧,我們可以顯著加快任何行數小於等於 8 的輸入的 GEMM 速度,這導致任何批處理請求數少於 8 的解碼的每 token 延遲降低。

最佳化:Warp 專用化和非同步執行

我們已經看到,在瘦 GEMM 中,行數少的事實限制了輸出塊的數量,這反過來又限制了 GPU 的利用率。但行數少也限制了每個輸出塊的行數,這反過來又減少了我們所說的 算術強度 (arithmetic intensity)。簡單地說,算術強度是完成的工作量除以為完成該工作而載入的資料量。讓我們比較兩個例子

sn=i=1nxitn=i=1nyi=y (1+tn1) s_n = \sum_{i=1}^{n} x_i \\ t_n = \sum_{i=1}^n y^i = y ~( 1 + t_{n-1})

其中 x x 是一個大小為 n n 的向量,而 y y 是一個標量。要計算 sn s_n ,我們載入 n n 個元素並執行 n1 n-1 次加法。要計算 tn t_n ,我們載入 1 個元素並執行 2n1 2n-1 次加法和乘法。因此,計算 sn s_n 的“算術強度”是 n1n \frac{n-1}{n} tn t_n 的是 2n1 2n - 1 :計算 tn t_n 比計算 sn s_n “算術強度”更高。我們在這裡看到的是,當 算術強度越低,我們需要載入更多資料來執行工作

這對我們來說為什麼重要?嗯,我們已經看到從 VRAM 載入資料有很高的延遲成本,這對 GPU 來說不是好事。換句話說,算術強度低的工作負載不適合 GPU,而事實證明,瘦 GEMM 的算術強度比它們的非瘦對應物要低。當看下面的圖時,這一點變得直觀:我們可以看到,當我們將載入的資料量減半時,由於 GEMM 維度的二次性質,輸出係數的數量減少了四倍。

The arithmetic intensity of two GEMMs

在瘦 GEMM 中,輸出塊的行數是有限的,因此算術強度也是有限的。這已經意味著我們需要載入大量資料來計算一個輸出塊。此外,由於我們使用的是 FP8 算術,計算速度相當快,所以我們不能依靠計算時間來隱藏資料載入的延遲。總而言之,理想情況是讓負責載入資料的執行緒多於負責計算結果的執行緒。

為了實現這一點,我們將使用一種稱為 warp 專用化 (warp specialization) 的技術。我們不再讓執行緒塊中的所有 warp 執行相同的指令,而是將一些 warp 專門用於僅載入資料,另一些專門用於僅計算結果。負責載入資料的 warp 稱為 生產者 (producers),計算結果的 warp 稱為 消費者 (consumers)。生產者和消費者非同步工作:生產者首先從 VRAM 載入資料(這很慢),然後透過將其儲存在共享記憶體緩衝區中使其對消費者可用。在資料在共享記憶體中可用之前,消費者是空閒的。資料可用後,消費者從共享記憶體載入資料(這很快)並計算結果。生產者和消費者的協調是透過儲存在共享記憶體中的佇列來實現的。當生產者完成在共享記憶體緩衝區 i i 中儲存資料時,它會改變佇列的第 i i 個變數的狀態,以表示資料在那裡可用。消費者正在監視這一點,然後開始載入資料。當它完成後,它會改變佇列的第 i i 個變數的狀態,以表示資料可以被寫入緩衝區 i i 。在下圖中,我們展示了一個簡單的非同步 GEMM 中涉及的步驟,其中有一個生產者、一個消費者和一個大小為 2 的佇列。

Async GEMM mechanism

使整個過程奏效的是,一旦緩衝區 0 0 被生產者填充,它就可以開始處理緩衝區 1 1 ,而無需等待消費者從緩衝區 0 0 載入資料。目標是擁有一個足夠大的佇列,以便生產者不斷填充緩衝區,而消費者不斷消費它們。佇列的大小受共享記憶體大小的限制。

我們還需要調整生產者與消費者的比例:我們說過我們的算術強度低,所以我們需要載入大量資料來做一個相對快速的計算。因此,我們將有大量的生產者 warp(通常是 8 或 10 個)對應少數消費者 warp(比如 2 或 3 個)。此外,我們可以利用 GEMM 是瘦的事實,為輸入(瘦矩陣)和權重(非瘦矩陣)設定不同的生產者。為了使輸出塊在不受約束的維度(即列維度)上更大,我們為權重分配更多的生產者。

關於非同步 GEMM 更深入的部落格文章,我鼓勵您檢視這篇部落格文章。不過,其中的許多內容在我們的情況下不適用:MI300X 沒有 warp 級別的屏障,只有一個執行緒塊級別的屏障。這導致了一些“有趣”的把戲,比如使用 ASM 來確保 warp 在其屏障處等待,共享記憶體載入和儲存在檢查屏障狀態之前得到解決,以及對佇列的模組化特性進行仔細處理。所有這些在這裡都會顯得不合時宜,但我鼓勵您檢視程式碼或在評論中提問。未來可能會有關於非同步處理細節的深入探討。

透過 warp 專用化和非同步工作,我們可以使我們的核函式適應低算術強度的負載,但這是否足以超越像 hipBLASLT 這樣的庫?答案是肯定的,在某些情況下。

結果

由於 Torch 已經綁定了取自 AMD 線性代數庫的高度最佳化的 GEMM,我們不會得到與最後兩個核函式相同範圍的加速。我們首先將看一下我們感興趣的三個 GEMM 維度:即與 QKV 投影、Gate / Up 投影和 Down 投影相關的 GEMM 維度。輸出投影被排除在外,因為它的維度不符合瘦 GEMM 的情況。

M (行) N (列) K (深度) Torch 時間 (μs) SkG 時間 (μs) 加速比
1 2304 16384 14.938 ± 0.292 11.685 ± 0.299 127.84 %
8 2304 16384 16.300 ± 0.282 12.342 ± 0.375 132.07 %
16 2304 16384 16.693 ± 0.233 13.909 ± 0.295 120.02 %
32 2304 16384 16.817 ± 0.124 17.021 ± 0.133 98.80 %
1 13312 16384 77.636 ± 0.364 54.717 ± 0.628 141.88 %
8 13312 16384 80.031 ± 0.449 58.355 ± 0.612 137.15 %
16 13312 16384 75.236 ± 0.378 59.973 ± 1.922 125.45 %
32 13312 16384 82.198 ± 0.590 69.483 ± 1.672 118.30 %
1 16384 6656 31.066 ± 0.193 27.613 ± 0.218 112.51 %
8 16384 6656 31.559 ± 0.200 28.134 ± 0.209 112.17 %
16 16384 6656 31.671 ± 0.250 30.233 ± 0.267 104.76 %
32 16384 6656 35.561 ± 0.335 35.052 ± 1.365 101.45 %

測量是在 500 次預熱迭代後,在 2000 次效能分析迭代中進行的,使用 CUDA graph 和多個權重以避免快取命中。上面顯示的 GEMM 維度按順序對應 QKV 投影 (N = 2304 和 K = 16384)、Gate / Up 投影 (N = 13312 和 K = 16384) 和 Down 投影 (N = 16384 和 K = 6656)。我們可以看到,對於那些經過調整的維度,在行數較少 (M = 1, 8, 16) 的情況下有顯著的加速,但在行數較多 (M = 32) 的情況下則不那麼明顯。特別是在我們可以使用稀疏技巧的維度 (M = 1, 8) 中,我們看到了比 Torch 顯著的加速,Torch 可能將所有內容都填充到 16 行以使用最小的 MFMA 指令。

結論

在這篇文章中,我們只探討了眾多可用核函式最佳化技術中的一小部分。如果您有興趣嘗試它們,請隨時深入 hf-rocm-kernels 倉庫並開始動手實驗!如果您開發了自己喜歡的核函式並希望分發它,請務必檢視 kernel-builderkernels — 這兩個 Hugging Face 軟體包旨在幫助核函式構建者將其工作廣泛提供併產生更大影響。

社群

幹得漂亮!對於 Skinny GEMM,您有分別來自稀疏 mfma 和 warp 專用化 (WS) 的改進百分比資料嗎?想了解在沒有稀疏 mfma 的情況下,WS 對不同形狀的影響。

·

看了程式碼,我的理解是稀疏性只用於 M = 8。這正確嗎?

註冊登入 以發表評論

© . This site is unofficial and not affiliated with Hugging Face, Inc.