Optimum 文件
最佳化
加入 Hugging Face 社群
並獲得增強的文件體驗
開始使用
最佳化
optimum.fx.optimization
模組提供了一組 torch.fx 圖轉換,以及用於編寫和組合您自己的轉換的類和函式。
轉換指南
在 🤗 Optimum 中,有兩種轉換:可逆轉換和不可逆轉換。
編寫不可逆轉換
最基本的轉換是不可逆轉換。這些轉換無法逆轉,這意味著在將它們應用於圖模組後,無法恢復原始模型。要在 🤗 Optimum 中實現此類轉換,非常簡單:您只需繼承 Transformation 並實現 transform() 方法。
例如,以下轉換將所有乘法更改為加法
>>> import operator
>>> from optimum.fx.optimization import Transformation
>>> class ChangeMulToAdd(Transformation):
... def transform(self, graph_module):
... for node in graph_module.graph.nodes:
... if node.op == "call_function" and node.target == operator.mul:
... node.target = operator.add
... return graph_module
實現後,您的轉換可以作為常規函式使用
>>> from transformers import BertModel
>>> from transformers.utils.fx import symbolic_trace
>>> model = BertModel.from_pretrained("bert-base-uncased")
>>> traced = symbolic_trace(
... model,
... input_names=["input_ids", "attention_mask", "token_type_ids"],
... )
>>> transformation = ChangeMulToAdd()
>>> transformed_model = transformation(traced)
編寫可逆轉換
可逆轉換實現轉換及其逆轉換,允許從轉換後的模型中檢索原始模型。要實現此類轉換,您需要繼承 ReversibleTransformation 並實現 transform() 和 reverse() 方法。
例如,以下轉換是可逆的
>>> import operator
>>> from optimum.fx.optimization import ReversibleTransformation
>>> class MulToMulTimesTwo(ReversibleTransformation):
... def transform(self, graph_module):
... for node in graph_module.graph.nodes:
... if node.op == "call_function" and node.target == operator.mul:
... x, y = node.args
... node.args = (2 * x, y)
... return graph_module
...
... def reverse(self, graph_module):
... for node in graph_module.graph.nodes:
... if node.op == "call_function" and node.target == operator.mul:
... x, y = node.args
... node.args = (x / 2, y)
... return graph_module
組合轉換
由於經常需要鏈式應用多個轉換,因此提供了 compose()。它是一個實用函式,允許您透過鏈式連線多個其他轉換來建立轉換。
>>> from optimum.fx.optimization import compose
>>> composition = compose(MulToMulTimesTwo(), ChangeMulToAdd())