Transformers 文件

超引數搜尋

Hugging Face's logo
加入 Hugging Face 社群

並獲得增強的文件體驗

開始使用

超引數搜尋

超引數搜尋旨在發現能夠產生最佳模型效能的超引數集。Trainer 透過 hyperparameter_search() 支援多種超引數搜尋後端,包括 OptunaSigOptWeights & BiasesRay Tune,以最佳化單個目標或多個目標。

本指南將介紹如何為每個後端設定超引數搜尋。

[!WARNING][SigOpt](https://github.com/sigopt/sigopt-server) 已處於公共歸檔模式,不再積極維護。請嘗試使用 Optuna、Weights & Biases 或 Ray Tune 代替。

pip install optuna/sigopt/wandb/ray[tune]

要使用 hyperparameter_search(),您需要建立一個 model_init 函式。此函式包含基本的模型資訊(引數和配置),因為在每次搜尋試驗執行中都需要重新初始化。

model_init 函式與 optimizers 引數不相容。請繼承 Trainer 並重寫 create_optimizer_and_scheduler() 方法來建立自定義最佳化器和排程器。

以下是一個 model_init 函式的示例。

def model_init(trial):
    return AutoModelForSequenceClassification.from_pretrained(
        model_args.model_name_or_path,
        from_tf=bool(".ckpt" in model_args.model_name_or_path),
        config=config,
        cache_dir=model_args.cache_dir,
        revision=model_args.model_revision,
        token=True if model_args.use_auth_token else None,
    )

model_init 連同訓練所需的其他所有內容傳遞給 Trainer。然後,您可以呼叫 hyperparameter_search() 來開始搜尋。

hyperparameter_search() 接受一個 direction 引數,用於指定是最小化、最大化還是同時最小化和最大化多個目標。您還需要設定正在使用的 後端,一個包含要最佳化的超引數的 物件,要執行的 試驗次數,以及一個用於返回目標值的 compute_objective

如果未定義 compute_objective,則會呼叫預設的 compute_objective,它是評估指標(如 F1)的總和。

from transformers import Trainer

trainer = Trainer(
    model=None,
    args=training_args,
    train_dataset=small_train_dataset,
    eval_dataset=small_eval_dataset,
    compute_metrics=compute_metrics,
    processing_class=tokenizer,
    model_init=model_init,
    data_collator=data_collator,
)
trainer.hyperparameter_search(...)

以下示例演示瞭如何使用不同的後端對學習率和訓練批次大小執行超引數搜尋。

Optuna
Ray Tune
SigOpt
Weights & Biases

Optuna 最佳化類別、整數和浮點數。

def optuna_hp_space(trial):
    return {
        "learning_rate": trial.suggest_float("learning_rate", 1e-6, 1e-4, log=True),
        "per_device_train_batch_size": trial.suggest_categorical("per_device_train_batch_size", [16, 32, 64, 128]),
    }

best_trials = trainer.hyperparameter_search(
    direction=["minimize", "maximize"],
    backend="optuna",
    hp_space=optuna_hp_space,
    n_trials=20,
    compute_objective=compute_objective,
)

分散式資料並行

Trainer 僅支援 Optuna 和 SigOpt 後端的分散式資料並行 (DDP) 超引數搜尋。只有 rank-zero 程序用於生成搜尋試驗,結果引數會傳遞給其他 ranks。

< > 在 GitHub 上更新

© . This site is unofficial and not affiliated with Hugging Face, Inc.