Safetensors 文件

Torch 共享張量

您正在檢視的是需要從原始碼安裝. 如果您想透過 regular pip install,可以檢視最新穩定版本 (v0.5.0-rc.0)。
Hugging Face's logo
加入 Hugging Face 社群

並獲得增強的文件體驗

開始使用

Torch 共享張量

簡而言之 (TL;DR)

使用特定的函式,這些函式在大多數情況下都應該對您有用。但這並非沒有副作用。

from safetensors.torch import load_model, save_model

save_model(model, "model.safetensors")
# Instead of save_file(model.state_dict(), "model.safetensors")

load_model(model, "model.safetensors")
# Instead of model.load_state_dict(load_file("model.safetensors"))

什麼是共享張量?

PyTorch 使用共享張量進行一些計算。這在一般情況下對於減少記憶體使用非常有用。

一個非常經典的用例是在 Transformer 中,embeddingslm_head 共享。透過使用相同的矩陣,模型使用的引數更少,並且梯度流向 embeddings 也更加順暢(因為 embeddings 是模型的開頭,梯度不容易流到那裡,而 lm_head 是模型的尾部,梯度在那裡非常好,由於它們是相同的張量,因此它們都受益)。

from torch import nn

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.a = nn.Linear(100, 100)
        self.b = self.a

    def forward(self, x):
        return self.b(self.a(x))


model = Model()
print(model.state_dict())
# odict_keys(['a.weight', 'a.bias', 'b.weight', 'b.bias'])
torch.save(model.state_dict(), "model.bin")
# This file is now 41k instead of ~80k, because A and B are the same weight hence only 1 is saved on disk with both `a` and `b` pointing to the same buffer

為什麼共享張量不儲存在 safetensors 中?

有幾個原因:

  • 並非所有框架都支援它們,例如 tensorflow 不支援。因此,如果有人在 torch 中儲存共享張量,將無法以類似的方式載入它們,因此我們無法保持相同的 Dict[str, Tensor] API。

  • 它使延遲載入變得非常困難。 延遲載入是指僅載入給定檔案中的某些張量或張量部分的能力。在沒有共享張量的情況下,這很容易實現,但使用張量共享

    with safe_open("model.safetensors", framework="pt") as f:
        a = f.get_tensor("a")
        b = f.get_tensor("b")

    現在,使用此特定程式碼,不可能在事後“重新共享”緩衝區。一旦我們提供了 a 張量,我們就無法在您要求 b 時提供相同的記憶體。(在此特定示例中,我們可以跟蹤提供的緩衝區,但在一般情況下並非如此,因為您可以對 a 進行任意操作,例如將其傳送到另一個裝置,然後再請求 b)。

  • 它可能導致檔案比必需的要大得多。如果您正在儲存一個僅佔較大張量一部分的共享張量,那麼使用 pytorch 儲存它會導致儲存整個緩衝區,而不是僅儲存所需內容。

    a = torch.zeros((100, 100))
    b = a[:1, :]
    torch.save({"b": b}, "model.bin")
    # File is 41k instead of the expected 400 bytes
    # In practice it could happen that you save several 10GB instead of 1GB.

現在,儘管提到了所有這些原因,但其中並沒有什麼是一成不變的。共享張量不會導致不安全或拒絕服務,因此如果當前的工作繞方式不令人滿意,可以重新考慮這一決定。

它是如何工作的?

設計相當簡單。我們將查詢所有共享張量,然後查詢覆蓋整個緩衝區的(可能存在多個此類張量)所有張量。這給了我們多個可以儲存的名稱,我們只選擇第一個。

load_model 期間,我們的載入方式類似於 load_state_dict,但我們檢查模型本身,以查詢共享緩衝區,並忽略因緩衝區共享(它們被正確載入,因為有一個緩衝區在後臺載入)而實際上被覆蓋的“丟失鍵”。其他所有錯誤都會按原樣引發。

注意:這意味著我們正在丟棄檔案中的一些鍵。這意味著如果您正在檢查磁碟上儲存的鍵,您會看到一些“丟失的張量”,或者如果您使用的是 load_state_dict。除非我們直接在格式中支援共享張量,否則沒有真正的辦法解決這個問題。

在 GitHub 上更新

© . This site is unofficial and not affiliated with Hugging Face, Inc.