Safetensors 文件

Torch 共享張量

您正在檢視的是需要從原始碼安裝。如果您想使用常規的 pip 安裝,請檢視最新的穩定版本 (v0.5.0-rc.0)。
Hugging Face's logo
加入 Hugging Face 社群

並獲得增強的文件體驗

開始使用

Torch 共享張量

總結

使用特定的函式,這在大多數情況下應該能滿足您的需求。但這並非沒有副作用。

from safetensors.torch import load_model, save_model

save_model(model, "model.safetensors")
# Instead of save_file(model.state_dict(), "model.safetensors")

load_model(model, "model.safetensors")
# Instead of model.load_state_dict(load_file("model.safetensors"))

什麼是共享張量?

Pytorch 使用共享張量進行一些計算。這對於普遍減少記憶體使用非常有幫助。

一個非常經典的用例是在 Transformers 中,embeddingslm_head 共享。透過使用相同的矩陣,模型使用的引數更少,並且梯度能更好地流向 embeddings(因為 embeddings 位於模型的起始部分,梯度不容易流到那裡,而 lm_head 位於模型的末端,梯度在那裡非常強。由於它們是同一個張量,兩者都能受益)。

from torch import nn

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.a = nn.Linear(100, 100)
        self.b = self.a

    def forward(self, x):
        return self.b(self.a(x))


model = Model()
print(model.state_dict())
# odict_keys(['a.weight', 'a.bias', 'b.weight', 'b.bias'])
torch.save(model.state_dict(), "model.bin")
# This file is now 41k instead of ~80k, because A and B are the same weight hence only 1 is saved on disk with both `a` and `b` pointing to the same buffer

為什麼共享張量不儲存在 safetensors 中?

原因有多種:

  • 並非所有框架都支援它們,例如 tensorflow 就不支援。所以如果有人在 torch 中儲存了共享張量,就無法以類似的方式在其他框架中載入它們,因此我們無法保持相同的 Dict[str, Tensor] API。

  • 這使得延遲載入變得非常快。延遲載入是指只加載檔案中的部分張量或張量的部分。在沒有張量共享的情況下,這很容易實現,但有了張量共享就不一樣了。

    with safe_open("model.safetensors", framework="pt") as f:
        a = f.get_tensor("a")
        b = f.get_tensor("b")

    現在,使用給定的程式碼,事後“重新共享”緩衝區是不可能的。一旦我們給出了 a 張量,當您請求 b 時,我們無法返回相同的記憶體。(在這個特定的例子中,我們可以跟蹤已給出的緩衝區,但通常情況並非如此,因為您可能在請求 b 之前對 a 進行了任意操作,比如將其傳送到另一個裝置)。

  • 它可能導致檔案比必要的大得多。如果您要儲存的共享張量只是一個更大張量的一部分,那麼用 PyTorch 儲存會導致儲存整個緩衝區,而不是隻儲存需要的部分。

    a = torch.zeros((100, 100))
    b = a[:1, :]
    torch.save({"b": b}, "model.bin")
    # File is 41k instead of the expected 400 bytes
    # In practice it could happen that you save several 10GB instead of 1GB.

儘管提到了所有這些原因,但目前還沒有定論。共享張量不會導致不安全或潛在的拒絕服務攻擊,因此如果當前的解決方案不盡人意,這個決定可能會被重新考慮。

它是如何工作的?

設計相當簡單。我們將查詢所有共享張量,然後查詢所有覆蓋整個緩衝區的張量(可能有多個這樣的張量)。這樣我們就得到了多個可以儲存的名稱,我們只需選擇第一個。

load_model 期間,我們的載入方式有點像 load_state_dict,但我們會檢查模型本身,以查詢共享緩衝區,並忽略那些因緩衝區共享而被覆蓋的“缺失鍵”(它們實際上已正確載入,因為有一個緩衝區在底層載入了)。所有其他錯誤都將按原樣引發。

注意事項:這意味著我們正在丟棄檔案中的一些鍵。也就是說,如果您檢查儲存在磁碟上的鍵,您會看到一些“缺失的張量”,或者如果您正在使用 load_state_dict。除非我們開始直接在格式中支援共享張量,否則沒有真正的解決辦法。

< > 在 GitHub 上更新

© . This site is unofficial and not affiliated with Hugging Face, Inc.