Bitsandbytes 文件
8 位最佳化器
並獲得增強的文件體驗
開始使用
8 位最佳化器
使用 8 位最佳化器,微調大型模型可以減少 75% 的 GPU 記憶體使用,且與使用標準 32 位最佳化器訓練相比不會損失任何精度。減少記憶體需求意味著 8 位最佳化器比標準最佳化器快 4 倍,並且無需進行超引數調整。
本指南將向您展示如何使用 8 位最佳化器。
8 位最佳化器可在多種任務中減少記憶體使用並加速最佳化。然而,由於 8 位最佳化器減少的記憶體與引數數量成正比,因此對於使用大量啟用記憶體的模型(如卷積網路)來說,它們並不能帶來真正的益處。8 位最佳化器對於在記憶體極其受限的 GPU 上訓練或微調具有大量引數的模型最為有利。
8 位最佳化器是常規最佳化器的直接替代品,這意味著它們也接受與常規最佳化器相同的引數。對於 NLP 模型,建議使用 StableEmbedding 類來提高穩定性和結果。
import bitsandbytes as bnb
- adam = torch.optim.Adam(...)
+ adam = bnb.optim.Adam8bit(...)
# recommended for NLP models
- before: torch.nn.Embedding(...)
+ bnb.nn.StableEmbedding(...)
預設情況下,即使您使用 8 位最佳化器初始化引數,所有元素數量少於 4096 的引數張量仍將保持 32 位。這樣做是因為小張量節省的記憶體不多,並且通常包含高度可變的引數(偏置)或需要高精度的引數(批次歸一化、層歸一化)。
您可以使用 min_8bit_size
引數更改此值。例如,如果您希望僅在最小大小為 16384 個值時才將引數最佳化為 8 位(建議使用 4096 的倍數)
import bitsandbytes as bnb
adam = bnb.optim.Adam8bit(model.parameters(), min_8bit_size=16384)
您可以配置的其他引數包括學習率 (lr
)、衰減率 (betas
)、最佳化器狀態的位數 (optim_bits
) 和百分位裁剪 (percentile_clipping
),後者可以增加穩定性。例如,要初始化一個帶有第 5 百分位裁剪的 32 位 Adam 最佳化器
import bitsandbytes as bnb
adam = bnb.optim.Adam(model.parameters(), lr=0.001, betas=(0.9, 0.995), optim_bits=32, percentile_clipping=5)
最佳化不穩定引數
要使用 32 位 Adam 最佳化一些不穩定引數,同時使用 8 位 Adam 最佳化其他引數,請使用 GlobalOptimManager 類為特定層覆蓋特定的超引數。您需要:
- 在引數仍在 CPU 上時註冊它們。
import torch
import bitsandbytes as bnb
mng = bnb.optim.GlobalOptimManager.get_instance()
model = MyModel()
mng.register_parameters(model.parameters())
- 使用新的期望超引數覆蓋配置。例如,讓我們覆蓋
model.fc1.weight
層以使用 32 位 Adam。
有關您可以覆蓋的其他超引數的更多資訊,請查閱最佳化器 API 文件。
model = model.cuda()
# use 8-bit optimizer states for all parameters
adam = bnb.optim.Adam(model.parameters(), lr=0.001, optim_bits=8)
# override the parameter model.fc1.weight now uses 32-bit Adam
mng.override_config(model.fc1.weight, "optim_bits", 32)
您還可以透過將多個層作為列表傳遞,並將新的超引數作為字典傳遞來一次性覆蓋多個層。例如,讓我們覆蓋 model.special.weight
和 model.also_special.weight
層以使用稀疏最佳化以及更低的學習率和衰減率。
mng.override_config([model.special.weight, model.also_special.weight],
key_value_dict ={'is_sparse': True, 'lr': 1e-5, 'betas'=(0.9, 0.98)})
對於特定層,我們建議在每個模組中進行區域性覆蓋。將模組、引數及其屬性名稱傳遞給 GlobalOptimManager
class MyModule(torch.nn.Module):
def __init__(d_in, d_out):
super(MyModule, self).__init__()
self.linear = torch.nn.Linear(d_in, d_out)
# optimization will happen in 32-bit and
# learning rate will be set to 0.0001 independent of the main learning rate
config = {'optim_bits': 32, 'lr' : 0.0001}
GlobalOptimManager.get_instance().register_module_override(self, 'weight', config)
後續步驟
有關 8 位最佳化器的更多概念細節和解釋,請參閱 8 位最佳化器 指南。
< > 在 GitHub 上更新