Optimum 文件

最佳化

您正在檢視的是需要從原始碼安裝。如果你想要常規的 pip 安裝,請檢視最新穩定版本 (v1.27.0)。
Hugging Face's logo
加入 Hugging Face 社群

並獲得增強的文件體驗

開始使用

最佳化

optimum.fx.optimization 模組提供了一組 torch.fx 圖轉換,以及用於編寫和組合您自己的轉換的類和函式。

轉換指南

在 🤗 Optimum 中,有兩種轉換:可逆轉換和不可逆轉換。

編寫不可逆轉換

最基本的轉換是不可逆轉換。這些轉換無法逆轉,這意味著在將它們應用於圖模組後,無法恢復原始模型。要在 🤗 Optimum 中實現此類轉換,非常簡單:您只需繼承 Transformation 並實現 transform() 方法。

例如,以下轉換將所有乘法更改為加法

>>> import operator
>>> from optimum.fx.optimization import Transformation

>>> class ChangeMulToAdd(Transformation):
...     def transform(self, graph_module):
...         for node in graph_module.graph.nodes:
...             if node.op == "call_function" and node.target == operator.mul:
...                 node.target = operator.add
...         return graph_module

實現後,您的轉換可以作為常規函式使用

>>> from transformers import BertModel
>>> from transformers.utils.fx import symbolic_trace

>>> model = BertModel.from_pretrained("bert-base-uncased")
>>> traced = symbolic_trace(
...     model,
...     input_names=["input_ids", "attention_mask", "token_type_ids"],
... )

>>> transformation = ChangeMulToAdd()
>>> transformed_model = transformation(traced)

編寫可逆轉換

可逆轉換實現轉換及其逆轉換,允許從轉換後的模型中檢索原始模型。要實現此類轉換,您需要繼承 ReversibleTransformation 並實現 transform()reverse() 方法。

例如,以下轉換是可逆的

>>> import operator
>>> from optimum.fx.optimization import ReversibleTransformation

>>> class MulToMulTimesTwo(ReversibleTransformation):
...     def transform(self, graph_module):
...         for node in graph_module.graph.nodes:
...             if node.op == "call_function" and node.target == operator.mul:
...                 x, y = node.args
...                 node.args = (2 * x, y)
...         return graph_module
...
...     def reverse(self, graph_module):
...         for node in graph_module.graph.nodes:
...             if node.op == "call_function" and node.target == operator.mul:
...                 x, y = node.args
...                 node.args = (x / 2, y)
...         return graph_module

組合轉換

由於經常需要鏈式應用多個轉換,因此提供了 compose()。它是一個實用函式,允許您透過鏈式連線多個其他轉換來建立轉換。

>>> from optimum.fx.optimization import compose
>>> composition = compose(MulToMulTimesTwo(), ChangeMulToAdd())
< > 在 GitHub 上更新

© . This site is unofficial and not affiliated with Hugging Face, Inc.