Safetensors 文件
Torch 共享張量
並獲得增強的文件體驗
開始使用
Torch 共享張量
總結
使用特定的函式,這在大多數情況下應該能滿足您的需求。但這並非沒有副作用。
from safetensors.torch import load_model, save_model
save_model(model, "model.safetensors")
# Instead of save_file(model.state_dict(), "model.safetensors")
load_model(model, "model.safetensors")
# Instead of model.load_state_dict(load_file("model.safetensors"))
什麼是共享張量?
Pytorch 使用共享張量進行一些計算。這對於普遍減少記憶體使用非常有幫助。
一個非常經典的用例是在 Transformers 中,embeddings
與 lm_head
共享。透過使用相同的矩陣,模型使用的引數更少,並且梯度能更好地流向 embeddings
(因為 embeddings
位於模型的起始部分,梯度不容易流到那裡,而 lm_head
位於模型的末端,梯度在那裡非常強。由於它們是同一個張量,兩者都能受益)。
from torch import nn
class Model(nn.Module):
def __init__(self):
super().__init__()
self.a = nn.Linear(100, 100)
self.b = self.a
def forward(self, x):
return self.b(self.a(x))
model = Model()
print(model.state_dict())
# odict_keys(['a.weight', 'a.bias', 'b.weight', 'b.bias'])
torch.save(model.state_dict(), "model.bin")
# This file is now 41k instead of ~80k, because A and B are the same weight hence only 1 is saved on disk with both `a` and `b` pointing to the same buffer
為什麼共享張量不儲存在 safetensors 中?
原因有多種:
並非所有框架都支援它們,例如
tensorflow
就不支援。所以如果有人在 torch 中儲存了共享張量,就無法以類似的方式在其他框架中載入它們,因此我們無法保持相同的Dict[str, Tensor]
API。這使得延遲載入變得非常快。延遲載入是指只加載檔案中的部分張量或張量的部分。在沒有張量共享的情況下,這很容易實現,但有了張量共享就不一樣了。
with safe_open("model.safetensors", framework="pt") as f: a = f.get_tensor("a") b = f.get_tensor("b")
現在,使用給定的程式碼,事後“重新共享”緩衝區是不可能的。一旦我們給出了
a
張量,當您請求b
時,我們無法返回相同的記憶體。(在這個特定的例子中,我們可以跟蹤已給出的緩衝區,但通常情況並非如此,因為您可能在請求b
之前對a
進行了任意操作,比如將其傳送到另一個裝置)。它可能導致檔案比必要的大得多。如果您要儲存的共享張量只是一個更大張量的一部分,那麼用 PyTorch 儲存會導致儲存整個緩衝區,而不是隻儲存需要的部分。
a = torch.zeros((100, 100)) b = a[:1, :] torch.save({"b": b}, "model.bin") # File is 41k instead of the expected 400 bytes # In practice it could happen that you save several 10GB instead of 1GB.
儘管提到了所有這些原因,但目前還沒有定論。共享張量不會導致不安全或潛在的拒絕服務攻擊,因此如果當前的解決方案不盡人意,這個決定可能會被重新考慮。
它是如何工作的?
設計相當簡單。我們將查詢所有共享張量,然後查詢所有覆蓋整個緩衝區的張量(可能有多個這樣的張量)。這樣我們就得到了多個可以儲存的名稱,我們只需選擇第一個。
在 load_model
期間,我們的載入方式有點像 load_state_dict
,但我們會檢查模型本身,以查詢共享緩衝區,並忽略那些因緩衝區共享而被覆蓋的“缺失鍵”(它們實際上已正確載入,因為有一個緩衝區在底層載入了)。所有其他錯誤都將按原樣引發。
注意事項:這意味著我們正在丟棄檔案中的一些鍵。也就是說,如果您檢查儲存在磁碟上的鍵,您會看到一些“缺失的張量”,或者如果您正在使用 load_state_dict
。除非我們開始直接在格式中支援共享張量,否則沒有真正的解決辦法。