將 LLM 微調至 1.58 位:輕鬆實現極限量化

釋出於 2024 年 9 月 18 日
在 GitHub 上更新

隨著大型語言模型 (LLM) 規模和複雜性的增長,尋找降低其計算和能源成本的方法已成為一項關鍵挑戰。一種流行的解決方案是量化,即將引數的精度從標準的 16 位浮點 (FP16) 或 32 位浮點 (FP32) 降低到 8 位或 4 位等低位格式。雖然這種方法顯著減少了記憶體使用並加快了計算速度,但通常會以犧牲精度為代價。過度降低精度可能導致模型丟失關鍵資訊,從而導致效能下降。

BitNet 是一種特殊的 Transformer 架構,它僅用三個值來表示每個引數:(-1, 0, 1),從而實現了每引數 1.58 (log2(3) log_2(3) ) 位的極限量化。然而,它需要從頭開始訓練模型。雖然結果令人印象深刻,但並非所有人都擁有預訓練 LLM 的預算。為了克服這一限制,我們探索了一些技巧,允許將現有模型微調到 1.58 位!繼續閱讀以瞭解如何實現!

目錄

TL;DR

BitNet 是微軟研究院推出的一種架構,它採用極限量化,每個引數僅用三個值表示:-1、0 和 1。這使得模型每個引數僅使用 1.58 位,顯著降低了計算和記憶體需求。

與 LLaMA LLM 的 FP16 加法和乘法運算相比,該架構在執行矩陣乘法時使用 INT8 加法計算。

The new computation paradigm of BitNet b1.58
BitNet b1.58 的新計算正規化(來源:BitNet 論文 https://arxiv.org/abs/2402.17764)

這導致理論上能耗降低,BitNet b1.58 在矩陣乘法方面比 Llama 基線節省了 71.4 倍的算術運算能耗。

Energy consumption of BitNet b1.58 compared to LLaMA
BitNet b1.58 與 Llama 的能耗比較(來源:BitNet 論文 https://arxiv.org/abs/2402.17764)

我們已成功使用 BitNet 架構微調了 Llama3 8B 模型,並在下游任務中取得了出色的表現。我們開發的 8B 模型在 HF1BitLLM 組織下發布。其中兩個模型在 100 億個 Token 上進行了不同訓練設定的微調,而第三個模型在 1000 億個 Token 上進行了微調。值得注意的是,我們的模型在 MMLU 基準測試中超越了 Llama 1 7B 模型。

如何與 Transformers 配合使用

為了將 BitNet 架構整合到 Transformers 中,我們引入了一種新的量化方法,名為“bitnet”(PR)。這種方法涉及將標準線性層替換為與 BitNet 架構相容的專用 BitLinear 層,並進行適當的動態啟用量化、權重解包和矩陣乘法。

在 Transformers 中載入和測試模型非常簡單,API 沒有任何變化

model = AutoModelForCausalLM.from_pretrained(
    "HF1BitLLM/Llama3-8B-1.58-100B-tokens",
    device_map="cuda",
    torch_dtype=torch.bfloat16
)    
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B-Instruct")

input_text = "Daniel went back to the the the garden. Mary travelled to the kitchen. Sandra journeyed to the kitchen. Sandra went to the hallway. John went to the bedroom. Mary went back to the garden. Where is Mary?\nAnswer:"

input_ids = tokenizer.encode(input_text, return_tensors="pt").cuda()
output = model.generate(input_ids, max_new_tokens=10)
generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
print(generated_text)

透過這段程式碼,一切都在後臺無縫管理,因此無需擔心額外的複雜性,您只需安裝最新版本的 transformers 即可。

要快速測試模型,請檢視此 Notebook

BitNet 深入解讀

BitNet 用名為 BitLinear 的特殊層取代了多頭注意力和前饋網路中的傳統線性層,這些特殊層使用三元精度(或二進位制,在初始版本中)。我們在這個專案中使用的 BitLinear 層使用三元精度(值分別為 -1、0 和 1)量化權重,並將啟用量化為 8 位精度。我們用於訓練的 BitLinear 實現與用於推理的不同,這將在下一節中看到。

三元精度訓練的主要障礙是權重值是離散的(透過 round() 函式),因此不可微分。BitLinear 透過一個巧妙的技巧解決了這個問題:STE (Straight Through Estimator)。STE 允許梯度透過不可微分的舍入操作,透過將其梯度近似為 1(將 round() 視為等同於恆等函式)。另一種看法是,STE 不會在舍入步驟停止梯度,而是讓梯度透過,就像舍入從未發生過一樣,從而可以使用標準基於梯度的最佳化技術更新權重。

The architecture of BitNet with BitLinear layers
帶 BitLinear 層的 BitNet 架構(來源:BitNet 論文 https://arxiv.org/pdf/2310.11453)

訓練

我們以全精度進行訓練,但在此過程中使用對稱的每張量量化將權重量化為三元值。首先,我們計算權重矩陣的絕對值平均值並將其用作比例。然後,我們將權重除以比例,對值進行舍入,將其限制在 -1 和 1 之間,最後將其重新縮放以繼續以全精度進行計算。

scalew=11nmijWij scale_w = \frac{1}{\frac{1}{nm} \sum_{ij} |W_{ij}|}

Wq=clamp[1,1](round(Wscale)) W_q = \text{clamp}_{[-1,1]}(\text{round}(W*scale))

Wdequantized=Wqscalew W_{dequantized} = W_q*scale_w

然後,啟用量化為指定的位寬(在本例中為 8 位),使用 absmax 每 token 量化(有關量化方法的全面介紹,請檢視此帖子)。這涉及將啟用縮放到 8 位位寬的範圍 [−128, 127]。量化公式為:

scalex=127Xmax,dim=1 scale_x = \frac{127}{|X|_{\text{max}, \, \text{dim}=-1}}

Xq=clamp[128,127](round(Xscale)) X_q = \text{clamp}_{[-128,127]}(\text{round}(X*scale))

Xdequantized=Xqscalex X_{dequantized} = X_q * scale_x

為了使公式更清晰,這裡有使用 3x3 矩陣進行權重和啟用量化的示例


示例 1:權重矩陣量化

設權重矩陣 ( W ) 為:

W=[0.80.51.21.50.40.91.30.70.2] W = \begin{bmatrix} 0.8 & -0.5 & 1.2 \\ -1.5 & 0.4 & -0.9 \\ 1.3 & -0.7 & 0.2 \end{bmatrix}

步驟 1:計算權重比例

使用公式

scalew=11nmijWij scale_w = \frac{1}{\frac{1}{nm} \sum_{ij} |W_{ij}|}

我們計算 ( W ) 的平均絕對值:

1nmijWij=19(0.8+0.5+1.2+1.5+0.4+0.9+1.3+0.7+0.2)=19(7.5)=0.8333 \frac{1}{nm} \sum_{ij} |W_{ij}| = \frac{1}{9}(0.8 + 0.5 + 1.2 + 1.5 + 0.4 + 0.9 + 1.3 + 0.7 + 0.2) = \frac{1}{9}(7.5) = 0.8333

現在,比例因子是

scalew=10.83331.2 scale_w = \frac{1}{0.8333} \approx 1.2

步驟 2:量化權重矩陣

使用公式

Wq=clamp[1,1](round(W×scalew)) W_q = \text{clamp}_{[-1, 1]}(\text{round}(W \times scale_w))

我們首先將權重按 scalew1.2 scale_w \approx 1.2 縮放。

W×scalew=[0.8×1.20.5×1.21.2×1.21.5×1.20.4×1.20.9×1.21.3×1.20.7×1.20.2×1.2]=[0.960.61.441.80.481.081.560.840.24] W \times scale_w = \begin{bmatrix} 0.8 \times 1.2 & -0.5 \times 1.2 & 1.2 \times 1.2 \\ -1.5 \times 1.2 & 0.4 \times 1.2 & -0.9 \times 1.2 \\ 1.3 \times 1.2 & -0.7 \times 1.2 & 0.2 \times 1.2 \end{bmatrix} = \begin{bmatrix} 0.96 & -0.6 & 1.44 \\ -1.8 & 0.48 & -1.08 \\ 1.56 & -0.84 & 0.24 \end{bmatrix}

接下來,我們對值進行舍入並將其限制在範圍 [1,1] [-1, 1] 內。

Wq=[111101110] W_q = \begin{bmatrix} 1 & -1 & 1 \\ -1 & 0 & -1 \\ 1 & -1 & 0 \end{bmatrix}

步驟 3:反量化權重

最後,我們使用以下公式反量化權重:

Wdequantized=Wq×scalew W_{dequantized} = W_q \times scale_w

代入 scale_w,我們得到:

Wdequantized=[1×1.21×1.21×1.21×1.20×1.21×1.21×1.21×1.20×1.2]=[1.21.21.21.201.21.21.20] W_{dequantized} = \begin{bmatrix} 1 \times 1.2 & -1 \times 1.2 & 1 \times 1.2 \\ -1 \times 1.2 & 0 \times 1.2 & -1 \times 1.2 \\ 1 \times 1.2 & -1 \times 1.2 & 0 \times 1.2 \end{bmatrix} = \begin{bmatrix} 1.2 & -1.2 & 1.2 \\ -1.2 & 0 & -1.2 \\ 1.2 & -1.2 & 0 \end{bmatrix}

示例 2:啟用矩陣量化

設啟用矩陣 ( X ) 為:

X=[1.00.60.70.90.41.20.80.50.3] X = \begin{bmatrix} 1.0 & -0.6 & 0.7 \\ -0.9 & 0.4 & -1.2 \\ 0.8 & -0.5 & 0.3 \end{bmatrix}

步驟1:計算啟用的尺度

對於每一行(或通道),計算最大絕對值

  • 第1行:最大絕對值 = 1.0
  • 第2行:最大絕對值 = 1.2
  • 第3行:最大絕對值 = 0.8

計算每一行的尺度因子

scale=[1271.01271.21270.8]=[127105.83158.75] \text{scale} = \begin{bmatrix} \frac{127}{1.0} \\ \frac{127}{1.2} \\ \frac{127}{0.8} \end{bmatrix} = \begin{bmatrix} 127 \\ 105.83 \\ 158.75 \end{bmatrix}

步驟2:量化啟用矩陣

使用公式

Xq=clamp[128,127](round(X×scale)) X_q = \text{clamp}_{[-128,127]}(\text{round}(X \times \text{scale}))

縮放啟用

X×scale=[1.0×1270.6×1270.7×1270.9×105.830.4×105.831.2×105.830.8×158.750.5×158.750.3×158.75]=[12776.288.995.242.312712779.447.6] X \times \text{scale} = \begin{bmatrix} 1.0 \times 127 & -0.6 \times 127 & 0.7 \times 127 \\ -0.9 \times 105.83 & 0.4 \times 105.83 & -1.2 \times 105.83 \\ 0.8 \times 158.75 & -0.5 \times 158.75 & 0.3 \times 158.75 \end{bmatrix} = \begin{bmatrix} 127 & -76.2 & 88.9 \\ -95.2 & 42.3 & -127 \\ 127 & -79.4 & 47.6 \end{bmatrix}

將值四捨五入並將其限制在 [128,127][-128, 127] 範圍內

Xq=[127768995421271277948] X_q = \begin{bmatrix} 127 & -76 & 89 \\ -95 & 42 & -127 \\ 127 & -79 & 48 \end{bmatrix}

步驟3:反量化啟用

最後,使用以下公式反量化啟用:

Xdequantized=Xq×1scale X_{dequantized} = X_q \times \frac{1}{\text{scale}}

代入尺度

Xdequantized=[127×112776×112789×112795×1105.8342×1105.83127×1105.83127×1158.7579×1158.7548×1158.75]=[1.00.60.70.90.41.20.80.50.3] X_{dequantized} = \begin{bmatrix} 127 \times \frac{1}{127} & -76 \times \frac{1}{127} & 89 \times \frac{1}{127} \\ -95 \times \frac{1}{105.83} & 42 \times \frac{1}{105.83} & -127 \times \frac{1}{105.83} \\ 127 \times \frac{1}{158.75} & -79 \times \frac{1}{158.75} & 48 \times \frac{1}{158.75} \end{bmatrix} = \begin{bmatrix} 1.0 & -0.6 & 0.7 \\ -0.9 & 0.4 & -1.2 \\ 0.8 & -0.5 & 0.3 \end{bmatrix}


我們在量化啟用之前應用層歸一化(LN)以保持輸出的方差。

LN(x)=xE(x)Var(x)+ϵ \text{LN}(x) = \frac{x - E(x)}{\sqrt{\text{Var}(x) + \epsilon}}

其中 $\epsilon$ 是一個很小的數字,以防止溢位。

如前所述,`round()` 函式是不可微分的。我們使用 `detach()` 作為一種技巧,在反向傳播中實現可微分的直通估計器。

# Adapted from https://github.com/microsoft/unilm/blob/master/bitnet/The-Era-of-1-bit-LLMs__Training_Tips_Code_FAQ.pdf
import torch
import torch.nn as nn 
import torch.nn.functional as F

def activation_quant(x):
    scale = 127.0 / x.abs().max(dim=-1, keepdim=True).values.clamp_(min=1e-5)
    y = (x * scale).round().clamp_(-128, 127) / scale
    return y
 
def weight_quant(w):
    scale = 1.0 / w.abs().mean().clamp_(min=1e-5)
    u = (w * scale).round().clamp_(-1, 1) / scale
    return u

class BitLinear(nn.Linear):
    """
    Only for training
    """
    def forward(self, x):
        w = self.weight
        x_norm = LN(x)
        
        # A trick for implementing Straight−Through−Estimator (STE) using detach()
        x_quant = x_norm + (activation_quant(x_norm) - x_norm).detach()
        w_quant = w + (weight_quant(w) - w).detach()
        
        # Perform quantized linear transformation
        y = F.linear(x_quant, w_quant)
        return y

推理

在推理過程中,我們只需將權重量化為三元值,無需重新縮放。我們對啟用應用相同的8位精度方法,然後使用高效核心執行矩陣乘法,再除以權重和啟用尺度。這應該能顯著提高推理速度,尤其是在最佳化硬體上。你可以看到,訓練期間的重新縮放過程有所不同,因為矩陣乘法保持在fp16/bf16/fp32以進行正確訓練。

# Adapted from https://github.com/microsoft/unilm/blob/master/bitnet/The-Era-of-1-bit-LLMs__Training_Tips_Code_FAQ.pdf
import torch
import torch.nn as nn 
import torch.nn.functional as F

def activation_quant_inference(x):
    x = LN(x)
    scale = 127.0 / x.abs().max(dim=-1, keepdim=True).values.clamp_(min=1e-5)
    y = (x * scale).round().clamp_(-128, 127)
    return y, scale
 
class BitLinear(nn.Linear):
    """
    Only for training
    """
    def forward(self, x):
        w = self.weight # weights here are already quantized to (-1, 0, 1)    
        w_scale = self.w_scale  
        x_quant, x_scale = activation_quant_inference(x)
        y = efficient_kernel(x_quant, w) / w_scale / x_scale
        return y

1.58b 預訓練結果

在嘗試微調之前,我們首先嚐試使用預訓練重現 BitNet 論文的結果。我們從一個小型資料集 tinystories 和一個 Llama3 8B 模型開始。我們確認,像論文那樣新增歸一化函式可以提高效能。例如,經過2000步訓練後,驗證集上的困惑度在沒有歸一化的情況下為6.3,在有歸一化的情況下為5.9。兩種情況下的訓練都是穩定的。

Pre-training plots without (blue) & with (green) layer normalisation
預訓練圖(不含層歸一化為藍色,含層歸一化為橙色)

雖然這種預訓練方法看起來非常有趣,但只有少數機構能夠以必要的規模進行。然而,目前已經有各種強大的預訓練模型,如果它們能在預訓練後轉換為1.58位,那將非常有益。其他小組報告稱,微調結果不如預訓練實現的結果強,因此我們著手研究如何使1.58微調起作用。

1.58位微調

當我們開始從預訓練的Llama3 8B權重進行微調時,模型表現略好,但不如我們預期。

注意:我們所有的實驗都是使用 Nanotron 進行的。如果您對嘗試 1.58 位預訓練或微調感興趣,可以檢視此 PR

Fine-tuning plot compared to pre-training plot
微調圖與預訓練圖對比

為了理解原因,我們嘗試檢查隨機初始化模型和預訓練模型的權重分佈,以找出潛在問題。

Random weights distribution (2 merged stds)
隨機權重分佈(2個合併的標準差)
Pre-trained Llama3 weights distribution
預訓練的 Llama3 權重分佈

兩個分佈的尺度值分別為

Random weights scales distribution
隨機權重尺度分佈
Pre-trained Llama3 weights distribution
預訓練的 Llama3 權重分佈

初始隨機權重分佈是兩個正態分佈的混合

  • 一個標準差 (std) 為 0.025 0.025
  • 另一個標準差為 0.0252num_hidden_layers=0.00325 \frac{0.025}{\sqrt{2 \cdot \text{num\_hidden\_layers}}} = 0.00325

這是由於在 `nanotron` 中,列線性權重和行線性權重使用不同的標準差。在量化版本中,所有矩陣只有兩個權重尺度(50.25 和 402),它們是每個矩陣權重絕對值的平均值的倒數:`scale = 1.0 / w.abs().mean().clamp_(min=1e-5)`。

  • 對於 scale=50.25\text{scale} = 50.25 w.abs().mean()=0.0199 w.abs().mean() = 0.0199 ,這與我們的第一個標準差 std=0.025\text{std} = 0.025 相匹配。用於推導標準差的公式基於 w |w| 的半正態分佈期望。
    E(w)=std(w)2π \mathbb{E}(|w|) = \text{std}(w) \cdot \sqrt{\frac{2}{\pi}}
  • 對於 scale=402 \text{scale} = 402 w.abs().mean()=0.0025 w.abs().mean() = 0.0025 ,導致 std=0.00325\text{std} = 0.00325

另一方面,預訓練權重的分佈看起來像一個標準差為 std=0.013 \text{std} = 0.013 的正態分佈。

顯然,預訓練模型以更多資訊(尺度)開始,而隨機初始化模型幾乎沒有資訊,並隨著時間的推移增加資訊。我們的結論是,使用隨機權重開始會給模型提供最小的初始資訊,從而實現漸進式學習過程,而在微調過程中,BitLinear層的引入會使模型不堪重負,從而丟失所有先前的的資訊。

為了改善微調結果,我們嘗試了不同的技術。例如,我們嘗試了逐行和逐列量化,而不是逐張量量化,以保留Llama 3權重中的更多資訊。我們還嘗試改變尺度計算方式:不再僅僅將權重的平均絕對值作為尺度,而是將離群值的平均絕對值作為尺度(離群值是超出k*平均絕對值的值,其中k是我們嘗試在實驗中改變的常數),但我們沒有發現顯著改進。

def scale_outliers(tensor, threshold_factor=1):
    mean_absolute_value = torch.mean(torch.abs(tensor))
    threshold = threshold_factor * mean_absolute_value
    outliers = tensor[torch.abs(tensor) > threshold]
    mean_outlier_value = torch.mean(torch.abs(outliers))
    return mean_outlier_value

def weight_quant_scaling(w):
    scale = 1.0 / scale_outliers(w).clamp_(min=1e-5)
    quantized_weights = (w * scale).round().clamp_(-1, 1) / scale
    return quantized_weights

我們觀察到,隨機權重和Llama 3權重都導致損失從大約相同的13值開始。這表明Llama 3模型在引入量化時丟失了其所有先驗資訊。為了進一步研究模型在此過程中丟失了多少資訊,我們嘗試了逐組量化。

作為一項健全性檢查,我們首先將組大小設定為1,這實際上意味著沒有量化。在這種情況下,損失從1.45開始,與我們在正常微調期間看到的情況相同。然而,當我們將組大小增加到2時,損失躍升至約11。這表明即使組大小最小為2,模型仍然幾乎丟失了所有資訊。

為了解決這個問題,我們考慮了逐步引入量化的可能性,而不是對每個張量的權重和啟用突然應用量化。為此,我們實現了一個lambda值來控制該過程。

lambda_ = ?
x_quant = x + lambda_ * (activation_quant(x) - x).detach()
w_quant = w + lambda_ * (weight_quant(w) - w).detach()

當 `lambda` 設定為 0 時,基本上沒有發生量化,而當 `lambda=1` 時,則應用完全量化。

我們最初測試了一些離散的 `lambda` 值,例如 0.25、0.5、0.75 和 1。然而,這種方法並沒有導致結果的顯著改善,主要是因為 `lambda=0.25` 已經足夠高,導致損失從一開始就非常高。

Fine-tuning plot with lambda = 0.25->0.5->0.75->1
lambda = 0.25->0.5->0.75->1的微調圖

因此,我們決定嘗試根據訓練步數動態調整的 `lambda` 值。

lambda_ = training_step / total_training_steps

使用這個動態的 `lambda` 值導致了更好的損失收斂,但當 `lambda` 設定為 1 時,推理期間的困惑度(ppl)結果仍然遠不盡如人意。我們意識到這可能是因為模型在 `lambda=1` 下訓練的時間不夠長。為了解決這個問題,我們調整了 `lambda` 值以改進訓練過程。

lambda_ = min(2 * training_step / total_training_steps, 1)

在此配置下,經過2000步訓練後,我們得到

Fine-tuning plot with lambda = min(2*training_step/total_training_steps, 1)
lambda = min(2*training_step/total_training_steps, 1) 的微調圖

我們的微調方法整體顯示出更好的收斂性。您會注意到在1000步左右,損失曲線略有增加,這對應於我們開始接近`lambda=1`或完全量化的時候。然而,在此之後,損失立即開始再次收斂,導致困惑度提高到大約4。

儘管取得了這一進展,但當我們在WikiText資料集(而不是我們用於微調的tinystories資料集)上測試量化模型時,它顯示出非常高的困惑度。這表明在特定資料集上以低位模式微調模型會導致它失去大部分通用知識。這個問題可能是因為我們用三元權重追求的最小表示在不同資料集之間可能存在顯著差異。為了解決這個問題,我們擴大了訓練過程,包括更大的FineWeb-edu資料集。我們保持了

lambda_ = min(training_step/1000, 1)

我們選擇這個 `lambda` 值是因為它似乎是預熱模型的一個良好起點。然後,我們使用 1e-4 的學習率在 FineWeb-edu 資料集上訓練了模型 5,000 步。訓練涉及 200 萬的批處理大小 (BS),總計 100 億個 token。

找到合適的學習率和衰減率極具挑戰性;這似乎是模型效能的關鍵因素。

Fine-tuning plot with warmup quantization on Fineweb-edu
使用Fineweb-edu進行預熱量化的微調圖

在 Fineweb-Edu 上進行微調後,WikiText 資料集上的困惑度達到了 12.2,考慮到我們只使用了 100 億個標記,這相當令人印象深刻。其他評估指標也顯示出強勁的效能,考慮到有限的資料量(參見結果)。

我們還嘗試在 lambda 接近 1 時平滑急劇增加。為此,我們考慮使用 lambda 排程器,它們首先呈指數增長,然後隨著接近 1 而趨於平穩。

def scheduler(step, total_steps, k):
    normalized_step = step / total_steps
    return 1 - (1 - normalized_step)**k

對於不同的 k 值,在總預熱步數為 1 的情況下,我們有如下所示的圖:

Exponential scheduler for different k values
不同k值的指數排程器

我們使用表現最佳的學習率1e-4進行了4次實驗,測試了k在[4, 6, 8, 10]中的值。

Fine-tuning plots with exponential scheduler
指數排程器微調圖

平滑處理效果良好,沒有像線性排程器那樣出現尖峰。然而,困惑度並不理想,保持在15左右,並且下游任務的效能也沒有更好。

我們還注意到一開始的尖峰,模型很難從中恢復過來。當lambda=0時,基本上沒有量化,所以損失從低點開始,大約在2左右。但緊接著第一步,就出現了一個尖峰,類似於線性排程器的情況(如上圖藍色曲線所示)。因此,我們嘗試了另一種排程器——S形排程器——它開始緩慢,急劇上升到1,然後隨著接近1而趨於平穩。

def sigmoid_scheduler(step, total_steps, k):
    # Sigmoid-like curve: slow start, fast middle, slow end
    normalized_step = step / total_steps
    return 1 / (1 + np.exp(-k * (normalized_step - 0.5)))

對於不同的 k 值,我們有以下曲線

Sigmoid scheduler for different k values
不同k值的S形排程器

我們這次運行了 5 次實驗,k 的取值範圍是 [15, 20, 25, 40, 100]。

Finetuning plots with sigmoid scheduler
S形排程器微調圖

lambda的急劇增加導致了第500步左右的不穩定性,並且沒有解決第一個發散問題。然而,對於 k=100 k = 100 ,我們觀察到下游任務有所改善(參見結果表),儘管困惑度仍保持在13.5左右。儘管如此,它並未顯示出比線性排程器更明顯的效能提升。

此外,我們還嘗試從頭開始,使用隨機權重和不同的學習率來訓練模型。這使我們能夠比較微調方法與傳統預訓練方法的有效性。

Different Pre-training plots with different learning rates
不同學習率下的預訓練曲線圖

所有從隨機權重訓練的模型表現都不如我們微調後的模型。這些模型的最佳困惑度為26,遠低於我們微調方法的結果。

擴充套件到 1000 億個 Token!

我們將實驗擴充套件到 1000 億個 Token,以觀察能否與 Llama 3 8B 的效能相匹配。我們進行了更長時間的訓練,從線性排程器下較短執行中表現最佳的檢查點開始,並繼續微調 45,000 步。我們嘗試了不同的學習率,雖然模型在某些指標上與 Llama 3 模型表現接近,但平均而言仍有所落後。

以下是我們在訓練過程中不同檢查點評估的指標示例:

Metrics evaluations during the training for different lrs
不同學習率下訓練過程中的指標評估

平均得分如下:

Average evaluation during the training for different lrs
不同學習率下訓練過程中的平均評估

小型模型實驗

在我們對 SmolLM 等小型模型的初步實驗中,我們觀察到熱身量化技術並未像在大型模型中那樣帶來顯著改進。這表明熱身量化的有效性可能與模型大小和複雜性更密切相關。

例如,這裡是 SmolLM 135M 模型的損失曲線,比較了熱身量化與從一開始就進行完全量化的效果。有趣的是,曲線非常吻合,並且最終的困惑度沒有顯著差異。

Smoll LLm fine-tuning experiment with & without warmup quantization
Smol LLm 微調實驗(帶/不帶熱身量化)

結果與比較

BitNet 在提供強大效能方面,尤其是低位元級別,與基線方法相比非常有效。根據論文所述,BitNet 取得了與 8 位元模型相當的分數,但推理成本顯著降低。對於 4 位元模型,僅量化權重的方法優於同時量化權重和啟用的方法,因為啟用更難量化。然而,使用 1.58 位元權重的 BitNet 超越了僅權重和權重與啟用量化這兩種方法。

下表展示了 Llama3 8B 經過 100 億次微調後的各項指標結果。這些結果與其他模型架構的結果進行了比較,以提供全面的效能概覽(所有評估均使用 LightevalNanotron 格式模型上進行)

Metrics comparison with Llama models
與 Llama 模型的指標比較:Linear 表示線性 lambda 排程器,Sigmoid 表示 Sigmoid lambda 排程器(在我們的例子中 k = 100)

僅使用三元權重對模型進行 100 億個 Token 的微調後,模型展現出令人印象深刻的效能,尤其是在與經過更廣泛訓練的其他模型進行比較時。例如,它優於 Bitnet 7B 模型,後者在規模顯著更大的 1000 億個 Token 資料集上進行了訓練。此外,它也優於 FBI LLM (Fully Binarized LLM),一個在更龐大的 1.26 萬億個 Token 上進行蒸餾的模型。這凸顯了該模型儘管微調規模相對較小,但其效率和有效性。

對於 1000 億個 Token 的實驗,我們表現最好的檢查點如下:

Metrics comparison with Llama models for the model trained on 100B tokens
在 1000 億個 Token 上訓練的模型與 Llama 模型的指標比較

為了復現這些結果,您可以檢視此 PR 以將模型轉換為 nanotron 格式,解包權重(檢視函式 unpack_weights),並使用 lighteval

請注意,儘管這些模型是從 Instruct-tuned 模型微調而來,但它們仍然需要使用 Instruct 資料集進行微調。這些可以被視為基礎模型。

自定義核心與基準測試

為了利用 BitNet 的低精度權重,我們將它們打包成 int8 張量(這將引數數量從 8B 減少到 2.8B!)。在推理過程中,這些權重必須在進行矩陣乘法之前解包。我們用 Cuda 和 Triton 實現了自定義核心,以處理矩陣乘法過程中的即時解包。對於矩陣乘法本身,我們採用了快取平鋪矩陣乘法技術。為了充分理解這種方法,我們首先回顧一些 Cuda 程式設計基礎知識。

基本 GPU 概念:執行緒、塊和共享記憶體

在深入研究快取平鋪矩陣乘法之前,理解一些基本的 GPU 概念非常重要:

  • 執行緒和塊:GPU 同時執行數千個執行緒。這些執行緒被分組到塊中,每個塊獨立執行。網格由這些塊組成,它代表整個問題空間。例如,在矩陣乘法中,每個執行緒可能負責計算輸出矩陣的一個元素。
  • 共享記憶體:每個塊都可以訪問有限的共享記憶體,其速度遠快於全域性記憶體(GPU 上的主記憶體)。然而,共享記憶體的大小有限,並且在塊內的所有執行緒之間共享。有效地使用共享記憶體是提高 GPU 程式效能的關鍵。

矩陣乘法中的挑戰

在 GPU 上實現矩陣乘法的簡單方法可能涉及每個執行緒透過直接從全域性記憶體讀取所需元素來計算結果矩陣的單個元素。然而,這種方法可能效率低下,原因如下:

  • 記憶體頻寬:與 GPU 核心執行計算的速度相比,訪問全域性記憶體相對較慢。如果每個執行緒直接從全域性記憶體讀取矩陣元素,記憶體訪問時間可能會成為瓶頸。
  • 冗餘資料訪問:在矩陣乘法中,輸入矩陣的許多元素被多次使用。如果每個執行緒獨立地從全域性記憶體獲取所需資料,則相同的資料可能會被多次載入到 GPU 中,從而導致效率低下。例如,如果每個執行緒用於計算輸出矩陣中的單個元素,則負責計算位置 (i, j) 處元素的執行緒將需要從全域性記憶體載入矩陣 A 的第 i 行和矩陣 B 的第 j 列。然而,其他執行緒,例如計算位置 (i+1, j) 處元素的執行緒,無法重用此資料,並且必須再次從全域性記憶體中重新載入相同的第 j 列。

平鋪的思想

平鋪是一種用於解決這些挑戰的技術,它主要用於 FlashAttention 以提高核心的效率。基本思想是將矩陣分成更小的子矩陣,稱為瓦片(tiles),它們可以放入 GPU 的共享記憶體中。計算不再一次性完成整個輸出矩陣,而是分解成更小的部分,逐瓦片處理。

在矩陣乘法中,這意味著將矩陣 A 和 B 分成塊(瓦片),將這些瓦片載入到共享記憶體中,然後對這些較小的塊進行乘法運算。這種方法允許執行緒重用儲存在快速共享記憶體中的資料,從而減少重複訪問全域性記憶體的需求。

工作原理如下:

  • 將瓦片載入到共享記憶體:每個執行緒塊協同地將矩陣 A 的一個瓦片和矩陣 B 的一個對應瓦片從全域性記憶體載入到共享記憶體中。此操作每個瓦片只執行一次,然後該瓦片由塊中的執行緒多次重用。
  • 計算部分積:一旦瓦片載入到共享記憶體中,每個執行緒計算一個部分積。由於塊中的所有執行緒都在共享記憶體中的相同瓦片上工作,它們可以有效地重用資料而無需額外的全域性記憶體訪問。
  • 累積結果:計算完一個瓦片的部分積後,執行緒將矩陣 A 和 B 的下一個瓦片載入到共享記憶體中,並重復該過程。結果累積在暫存器(或本地記憶體)中,一旦所有瓦片都處理完畢,輸出矩陣元素的最終值將寫回全域性記憶體。
Tiled Matrix multiplication illustration
平鋪矩陣乘法示意圖(來源:https://cnugteren.github.io/tutorial/pages/page4.html)

實際考量

在實現快取平鋪矩陣乘法時,需要考慮幾個因素:

  • 瓦片大小:瓦片的大小應在可放入共享記憶體的資料量和全域性記憶體訪問次數之間取得平衡。
  • 記憶體合併:全域性記憶體訪問被合併,這意味著相鄰執行緒訪問相鄰記憶體位置。
  • 佔用率:每個塊的執行緒數和網格中的塊數應選擇,以確保高佔用率,這意味著在 GPU 上儘可能多地擁有活動的 warp(warp 是 32 個執行緒的集合),以隱藏記憶體延遲。

Triton 核心

這是我們進行基準測試的 Triton 核心:

@triton.autotune(
    configs=get_cuda_autotune_config(),
    key=['M', 'N', 'K'],
)
@triton.jit
def matmul_kernel(
        a_ptr, b_ptr, c_ptr,
        M, N, K,
        stride_am, stride_ak,
        stride_bk, stride_bn, 
        stride_cm, stride_cn,
        BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,  
        GROUP_SIZE_M: tl.constexpr,
):

    pid = tl.program_id(axis=0)
    num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
    num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
    num_pid_in_group = GROUP_SIZE_M * num_pid_n
    group_id = pid // num_pid_in_group
    first_pid_m = group_id * GROUP_SIZE_M
    group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
    pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
    pid_n = (pid % num_pid_in_group) // group_size_m

    offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
    offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
    offs_k = tl.arange(0, BLOCK_SIZE_K)
    a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
    b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)

    accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.int32)

    for i in range(4) : 
        b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
        for j in range(0, tl.cdiv(K // 4, BLOCK_SIZE_K) ):
            k = i * tl.cdiv(K // 4, BLOCK_SIZE_K) + j 

            # BLOCK_SIZE_K must be a divisor of K / 4 
            a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0)
            b_uint8 = tl.load(b_ptrs, mask=offs_k[:, None] < K // 4 - j * BLOCK_SIZE_K, other=0)
            mask = 3<<(2*i)
            b = ((b_uint8 & mask) >> (2*i))

            # We accumulate the tiles along the K dimension.
            tensor_full = tl.full((1,), 1, dtype=tl.int8)

            accumulator += tl.dot(a, (b.to(tl.int8) - tensor_full), out_dtype=tl.int32)

            a_ptrs += BLOCK_SIZE_K * stride_ak
            b_ptrs += BLOCK_SIZE_K * stride_bk

    c = accumulator

    offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
    offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
    c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
    c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
    tl.store(c_ptrs, c, mask=c_mask)


def matmul(a, b):
    assert a.shape[1] == b.shape[0] * 4, "Incompatible dimensions, the weight matrix need to be packed"
    assert a.is_contiguous(), "Matrix A must be contiguous"
    M, K = a.shape
    _, N = b.shape
    c = torch.empty((M, N), device=a.device, dtype=torch.float16)
    grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), )
    matmul_kernel[grid](
        a, b, c,
        M, N, K,
        a.stride(0), a.stride(1),
        b.stride(0), b.stride(1),
        c.stride(0), c.stride(1),
    )
    return c

程式碼分解

  1. 確定瓦片位置

核心首先確定每個執行緒塊負責的輸出矩陣瓦片(塊)

  • pid 是每個執行緒塊的唯一識別符號,透過 tl.program_id(axis=0) 獲取。
  • 網格被劃分為執行緒塊組(GROUP_SIZE_M)。每個組處理輸出矩陣的一部分。
  • pid_mpid_n 分別是瓦片在 M 和 N 維度上的座標。
  • 計算偏移量(offs_am, offs_bn, offs_k)以確定塊中每個執行緒將處理矩陣 A 和 B 的哪些元素。
  1. 載入和計算瓦片

核心使用迴圈以 BLOCK_SIZE_K 的塊大小迭代 K 維度。對於每個塊:

  • 載入瓦片:從全域性記憶體載入矩陣 A 和 B 的瓦片。
  • 解包矩陣 B:核心假定矩陣 B 用 int8 值打包,這意味著每個元素實際上代表四個打包成一個位元組的較小值。解包在迴圈中進行:
    • b_uint8 作為打包的 int8 從全域性記憶體載入。
    • 每個打包的值都被解包以獲取用於計算的實際權重值。
  • 點積:核心計算從 A 和 B 載入的瓦片的點積,並將結果累積到 accumulator 中。accumulator 儲存輸出矩陣 C 的瓦片的部分結果。
  1. 儲存結果

在沿 K 維度處理完所有瓦片後,儲存在 accumulator 中的最終結果將轉換為 float16 並寫回全域性記憶體中矩陣 C 的相應瓦片。寫入過程透過掩碼遵循記憶體邊界,以確保只寫入有效元素。

有關程式碼的更詳細說明,請檢視此 PR

基準測試

我們對自定義核心與使用 @torch.compile 解包權重後執行 BF16 精度的矩陣乘法進行了基準測試,發現兩種方法均取得了大致相同的效能。為確保基準測試的準確性,我們對矩陣乘法操作進行了 2000 次迭代,並對最後 1000 次迭代所花費的時間取平均值,以消除任何與初始載入或編譯相關的低效率。下圖顯示了基準測試結果。我們還測試了各種矩陣大小,其中 x 軸表示乘法次數(對數刻度),y 軸表示平均時間(毫秒)。

Triton kernel compared to torch.compile
Triton 核心與 torch.compile 的比較

我們還嘗試了使用 BitBlas,這是一個旨在執行混合精度矩陣運算的軟體庫。它透過允許以 INT8、INT4 甚至 INT2 等低精度格式進行計算,而不是傳統的 FP32 或 FP16 格式,從而幫助最佳化這些操作。

基準測試結果令人鼓舞,如折線圖所示,BitBlas 在低精度方面優於我們的自定義核心和 Torch 的 matmul 函式。

Bitblas benchmark
Bitblas 基準測試

然而,在模型載入過程中,BitBlas 需要編譯針對權重矩陣形狀定製的核心並將其儲存在本地資料庫中,這可能會增加初始載入時間。

結論

總而言之,隨著大型語言模型(LLM)的不斷擴充套件,透過量化降低其計算需求至關重要。本文探討了 1.58 位量化方法,該方法使用三元權重。雖然以 1.58 位預訓練模型是資源密集型的,但我們已經證明,透過一些技巧,可以對現有模型進行此精度級別的微調,從而在不犧牲準確性的前提下實現高效效能。透過專用核心最佳化推理速度,BitNet 為使 LLM 更加實用和可擴充套件開闢了新的可能性。

致謝

我們衷心感謝 Leandro von Werra、Thomas Wolf 和 Marc Sun 在整個專案中提供的寶貴幫助和見解。我們還要感謝 Omar Sanseviero 和 Pedro Cuenca 在完善這篇部落格文章方面的貢獻,幫助我們清晰有效地向人工智慧社群傳達我們的發現。此外,我們還要感謝 GeneralAI 團隊在 BitNet 專案上的開創性工作。他們的研究是我們努力的基礎,我們特別感謝他們在論文中提供了清晰精確的圖表。

額外資源

  1. H. Wang 等人,《BitNet: Scaling 1-bit Transformers for Large Language Models》。arxiv 論文
  2. S. Ma 等人,《1 位元 LLM 時代:所有大型語言模型均為 1.58 位元》。arxiv 論文
  3. S. Ma 等人,《1 位元 LLM 時代:訓練技巧、程式碼和常見問題解答》。連結
  4. RJ. Honicky,《所有大型語言模型真的都是 1.58 位嗎?》。部落格文章
  5. L. Mao,《CUDA 矩陣乘法最佳化》。部落格文章
  6. 《教程:針對 Kepler 的 OpenCL SGEMM 調優》。連結
  7. 《CUDAMODE》。githubyoutube
  8. Wen-mei W. Hwu, David B. Kirk, Izzat El Hajj, 《大規模並行處理器程式設計:實踐方法》

社群

好文章,謝謝分享。

我想知道為什麼微調是在指令模型而不是基礎模型上進行的?

註冊登入發表評論

© . This site is unofficial and not affiliated with Hugging Face, Inc.