Transformers 文件

訓練指令碼

Hugging Face's logo
加入 Hugging Face 社群

並獲得增強的文件體驗

開始使用

訓練指令碼

Transformers 為深度學習框架(PyTorch、TensorFlow、Flax)和任務在 transformers/examples 中提供了許多示例訓練指令碼。在 transformers/research projectstransformers/legacy 中還有其他指令碼,但這些指令碼沒有積極維護,並且需要特定版本的 Transformers。

示例指令碼只是示例,你可能需要根據你的用例調整指令碼。為了幫助你,大多數指令碼在資料預處理方面都非常透明,允許你根據需要進行編輯。

對於你想在示例指令碼中實現的任何功能,請在提交拉取請求之前在 論壇issue 中討論。雖然我們歡迎貢獻,但增加更多功能但犧牲可讀性的拉取請求不太可能被接受。

本指南將向你展示如何在 PyTorchTensorFlow 中執行示例摘要訓練指令碼。

設定

在新虛擬環境中從原始碼安裝 Transformers,以執行最新版本的示例指令碼。

git clone https://github.com/huggingface/transformers
cd transformers
pip install .

執行以下命令以從特定或舊版本 Transformers 中檢出指令碼。

git checkout tags/v3.5.1

設定正確版本後,導航到你選擇的示例資料夾並安裝示例特定的依賴項。

pip install -r requirements.txt

執行指令碼

透過包含 `max_train_samples`、`max_eval_samples` 和 `max_predict_samples` 引數來截斷資料集到最大樣本數,從而使用較小的資料集開始。這有助於確保訓練按預期進行,然後再提交整個資料集,這可能需要數小時才能完成。

並非所有示例指令碼都支援 `max_predict_samples` 引數。執行以下命令以檢查指令碼是否支援它。

examples/pytorch/summarization/run_summarization.py -h

以下示例在 CNN/DailyMail 資料集上對 T5-small 進行微調。T5 需要一個額外的 `source_prefix` 引數來提示它進行摘要。

PyTorch
TensorFlow

示例指令碼下載並預處理資料集,然後使用 Trainer 和受支援的模型架構對其進行微調。

如果訓練中斷,從檢查點恢復訓練非常有用,因為你無需從頭開始。有兩種方法可以從檢查點恢復訓練。

  • `--output dir previous_output_dir` 從儲存在 `output_dir` 中的最新檢查點恢復訓練。如果使用此方法,請刪除 `--overwrite_output_dir` 引數。
  • `--resume_from_checkpoint path_to_specific_checkpoint` 從特定的檢查點資料夾恢復訓練。

使用 `--push_to_hub` 引數在 Hub 上分享你的模型。它會建立一個倉庫並將模型上傳到 `--output_dir` 中指定的資料夾名稱。你也可以使用 `--push_to_hub_model_id` 引數來指定倉庫名稱。

python examples/pytorch/summarization/run_summarization.py \
    --model_name_or_path google-t5/t5-small \
    # remove the `max_train_samples`, `max_eval_samples` and `max_predict_samples` if everything works
    --max_train_samples 50 \
    --max_eval_samples 50 \
    --max_predict_samples 50 \
    --do_train \
    --do_eval \
    --dataset_name cnn_dailymail \
    --dataset_config "3.0.0" \
    --source_prefix "summarize: " \
    --output_dir /tmp/tst-summarization \
    --per_device_train_batch_size=4 \
    --per_device_eval_batch_size=4 \
    --push_to_hub \
    --push_to_hub_model_id finetuned-t5-cnn_dailymail \
    # remove if using `output_dir previous_output_dir`
    # --overwrite_output_dir \
    --output_dir previous_output_dir \
    # --resume_from_checkpoint path_to_specific_checkpoint \
    --predict_with_generate \

對於混合精度和分散式訓練,請包含以下引數並使用 torchrun 啟動訓練。

  • 新增 `fp16` 或 `bf16` 引數以啟用混合精度訓練。XPU 裝置僅支援 `bf16`。
  • 新增 `nproc_per_node` 引數以設定要訓練的 GPU 數量。
torchrun \
    --nproc_per_node 8 pytorch/summarization/run_summarization.py \
    --fp16 \
    ...
    ...

PyTorch 透過 PyTorch/XLA 包支援 TPU,這是一種旨在加速效能的硬體。啟動 `xla_spawn.py` 指令碼並使用 `num_cores` 設定要訓練的 TPU 核心數量。

python xla_spawn.py --num_cores 8 pytorch/summarization/run_summarization.py \
    --model_name_or_path google-t5/t5-small \
    ...
    ...

Accelerate

Accelerate 旨在簡化分散式訓練,同時提供對 PyTorch 訓練迴圈的完全可見性。如果你計劃使用 Accelerate 訓練指令碼,請使用指令碼的 `_no_trainer.py` 版本。

從原始碼安裝 Accelerate,以確保你擁有最新版本。

pip install git+https://github.com/huggingface/accelerate

執行 accelerate config 命令,回答有關你的訓練設定的幾個問題。這將建立並儲存一個關於你係統的配置檔案。

accelerate config

你可以使用 accelerate test 確保你的系統已正確配置。

accelerate test

執行 accelerate launch 以開始訓練。

accelerate launch run_summarization_no_trainer.py \
    --model_name_or_path google-t5/t5-small \
    --dataset_name cnn_dailymail \
    --dataset_config "3.0.0" \
    --source_prefix "summarize: " \
    --output_dir ~/tmp/tst-summarization \

自定義資料集

摘要指令碼支援自定義資料集,只要它們是 CSV 或 JSONL 檔案。使用你自己的資料集時,你需要指定以下附加引數。

  • `train_file` 和 `validation_file` 指定訓練和驗證檔案的路徑。
  • `text_column` 是要摘要的輸入文字。
  • `summary_column` 是要輸出的目標文字。

下面顯示了摘要自定義資料集的示例命令。

python examples/pytorch/summarization/run_summarization.py \
    --model_name_or_path google-t5/t5-small \
    --do_train \
    --do_eval \
    --train_file path_to_csv_or_jsonlines_file \
    --validation_file path_to_csv_or_jsonlines_file \
    --text_column text_column_name \
    --summary_column summary_column_name \
    --source_prefix "summarize: " \
    --output_dir /tmp/tst-summarization \
    --overwrite_output_dir \
    --per_device_train_batch_size=4 \
    --per_device_eval_batch_size=4 \
    --predict_with_generate \
< > 在 GitHub 上更新

© . This site is unofficial and not affiliated with Hugging Face, Inc.