TRL 文件

SFT 訓練器

Hugging Face's logo
加入 Hugging Face 社群

並獲得增強的文件體驗

開始使用

SFT 訓練器

All_models-SFT-blue smol_course-Chapter_1-yellow

概覽

TRL 支援用於訓練語言模型的監督式微調 (Supervised Fine-Tuning, SFT) 訓練器。

這種後訓練方法由 Younes Belkada 貢獻。

快速入門

本示例演示瞭如何使用 TRL 中的 SFTTrainer 來訓練語言模型。我們將在 Capybara 資料集上訓練一個 Qwen 3 0.6B 模型,這是一個緊湊、多樣化的多輪對話資料集,用於基準測試推理和泛化能力。

from trl import SFTTrainer, SFTConfig
from datasets import load_dataset

trainer = SFTTrainer(
    model="Qwen/Qwen3-0.6B",
    train_dataset=load_dataset("trl-lib/Capybara", split="train"),
)
trainer.train()

預期的資料集型別和格式

SFT 支援語言建模提示-補全兩種型別的資料集。SFTTrainer 相容標準對話式兩種資料集格式。當提供對話式資料集時,訓練器會自動將聊天模板應用於資料集。

# Standard language modeling
{"text": "The sky is blue."}

# Conversational language modeling
{"messages": [{"role": "user", "content": "What color is the sky?"},
              {"role": "assistant", "content": "It is blue."}]}

# Standard prompt-completion
{"prompt": "The sky is",
 "completion": " blue."}

# Conversational prompt-completion
{"prompt": [{"role": "user", "content": "What color is the sky?"}],
 "completion": [{"role": "assistant", "content": "It is blue."}]}

如果你的資料集不屬於這些格式之一,你可以對其進行預處理,將其轉換為預期格式。以下是使用 FreedomIntelligence/medical-o1-reasoning-SFT 資料集的一個示例。

from datasets import load_dataset

dataset = load_dataset("FreedomIntelligence/medical-o1-reasoning-SFT", "en")

def preprocess_function(example):
    return {
        "prompt": [{"role": "user", "content": example["Question"]}],
        "completion": [
            {"role": "assistant", "content": f"<think>{example['Complex_CoT']}</think>{example['Response']}"}
        ],
    }

dataset = dataset.map(preprocess_function, remove_columns=["Question", "Response", "Complex_CoT"])
print(next(iter(dataset["train"])))
{
    "prompt": [
        {
            "content": "Given the symptoms of sudden weakness in the left arm and leg, recent long-distance travel, and the presence of swollen and tender right lower leg, what specific cardiac abnormality is most likely to be found upon further evaluation that could explain these findings?",
            "role": "user",
        }
    ],
    "completion": [
        {
            "content": "<think>Okay, let's see what's going on here. We've got sudden weakness [...] clicks into place!</think>The specific cardiac abnormality most likely to be found in [...] the presence of a PFO facilitating a paradoxical embolism.",
            "role": "assistant",
        }
    ],
}

深入探究 SFT 方法

監督式微調(SFT)是使語言模型適應目標資料集的最簡單、最常用的方法。該模型以完全監督的方式,使用輸入和輸出序列對進行訓練。目標是最小化目標序列的負對數似然(NLL),並以輸入為條件。

本節將分解 SFT 在實踐中如何工作,涵蓋關鍵步驟:**預處理**、**分詞**和**損失計算**。

預處理和分詞

在訓練期間,根據資料集格式,每個示例預計包含一個**文字欄位**或一個**(提示,補全)**對。有關預期格式的更多詳細資訊,請參閱資料集格式SFTTrainer 使用模型的分詞器對每個輸入進行分詞。如果提示和補全是分開提供的,它們會在分詞前被拼接起來。

計算損失

sft_figure

SFT 中使用的損失是詞元級交叉熵損失,定義為LSFT(θ)=t=1Tlogpθ(yty<t), \mathcal{L}_{\text{SFT}}(\theta) = - \sum_{t=1}^{T} \log p_\theta(y_t \mid y_{<t}),

其中其中 yt y_t 是時間步t t 的目標詞元,模型被訓練來預測給定前面所有詞元的下一個詞元。在實踐中,填充詞元在損失計算中被掩碼掉。

標籤移位和掩碼

在訓練期間,損失是使用**單詞元移位**計算的:模型被訓練來基於所有先前的詞元預測序列中的每個詞元。具體來說,輸入序列向右移動一個位置以形成目標標籤。填充詞元(如果存在)透過在相應位置應用忽略索引(預設為 -100)在損失計算中被忽略。這確保了損失只關注有意義的、非填充的詞元。

日誌指標

  • global_step:到目前為止已執行的最佳化器步驟總數。
  • epoch:當前的 epoch 數,基於資料集的迭代。
  • num_tokens:到目前為止已處理的詞元總數。
  • loss:在當前日誌記錄間隔內,對非掩碼詞元計算的平均交叉熵損失。
  • mean_token_accuracy:模型的 top-1 預測與真實詞元匹配的非掩碼詞元的比例。
  • learning_rate:當前學習率,如果使用排程器,可能會動態變化。
  • grad_norm:梯度的 L2 範數,在梯度裁剪之前計算。

自定義

模型初始化

你可以直接將 `from_pretrained()` 方法的關鍵字引數傳遞給 SFTConfig。例如,如果你想以不同的精度載入模型,類似於

model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen3-0.6B", torch_dtype=torch.bfloat16)

你可以透過將 `model_init_kwargs={"torch_dtype": torch.bfloat16}` 引數傳遞給 SFTConfig 來實現。

from trl import SFTConfig

training_args = SFTConfig(
    model_init_kwargs={"torch_dtype": torch.bfloat16},
)

請注意,`from_pretrained()` 的所有關鍵字引數都受支援。

打包

SFTTrainer 支援*示例打包*,即在同一個輸入序列中打包多個示例以提高訓練效率。要啟用打包功能,只需在 SFTConfig 建構函式中傳遞 `packing=True`。

training_args = SFTConfig(packing=True)

有關打包的更多詳細資訊,請參閱打包

只在助手訊息上訓練

要只在助手訊息上訓練,請使用一個對話式資料集,並在 SFTConfig 中設定 `assistant_only_loss=True`。此設定確保損失**只**在助手回覆上計算,而忽略使用者或系統訊息。

training_args = SFTConfig(assistant_only_loss=True)

train_on_assistant

此功能僅適用於支援透過 `{% generation %}` 和 `{% endgeneration %}` 關鍵字返回助手詞元掩碼的聊天模板。有關此類模板的示例,請參閱 HugggingFaceTB/SmolLM3-3B

只在補全部分訓練

要只在補全部分訓練,請使用提示-補全資料集。預設情況下,訓練器僅在補全詞元上計算損失,忽略提示詞元。如果你想在完整序列上訓練,請在 SFTConfig 中設定 `completion_only_loss=False`。

train_on_completion

只在補全部分訓練與只在助手訊息上訓練相容。在這種情況下,請使用[對話式](dataset_formats#conversational)[提示-補全](dataset_formats#prompt-completion)資料集,並在 [SFTConfig](/docs/trl/v0.21.0/en/sft_trainer#trl.SFTConfig) 中設定 `assistant_only_loss=True`。

使用 PEFT 訓練介面卡

我們支援與 🤗 PEFT 庫的緊密整合,允許任何使用者方便地訓練介面卡並在 Hub 上分享,而不是訓練整個模型。

from datasets import load_dataset
from trl import SFTTrainer
from peft import LoraConfig

dataset = load_dataset("trl-lib/Capybara", split="train")

trainer = SFTTrainer(
    "Qwen/Qwen3-0.6B",
    train_dataset=dataset,
    peft_config=LoraConfig()
)

trainer.train()

你也可以繼續訓練你的 `peft.PeftModel`。為此,首先在 SFTTrainer 外部載入一個 `PeftModel`,然後將其直接傳遞給訓練器,而不需要傳遞 `peft_config` 引數。

from datasets import load_dataset
from trl import SFTTrainer
from peft import AutoPeftModelForCausalLM

model = AutoPeftModelForCausalLM.from_pretrained("trl-lib/Qwen3-4B-LoRA", is_trainable=True)
dataset = load_dataset("trl-lib/Capybara", split="train")

trainer = SFTTrainer(
    model=model,
    train_dataset=dataset,
)

trainer.train()

訓練介面卡時,通常使用更高的學習率(≈1e-4),因為只學習新的引數。

SFTConfig(learning_rate=1e-4, ...)

使用 Liger Kernel 進行訓練

Liger Kernel 是一系列用於 LLM 訓練的 Triton 核心,可將多 GPU 吞吐量提升 20%,記憶體使用量減少 60%(支援高達 4 倍的上下文長度),並與 FlashAttention、PyTorch FSDP 和 DeepSpeed 等工具無縫協作。更多資訊,請參閱Liger Kernel 整合

使用 Unsloth 進行訓練

Unsloth 是一個開源的微調和強化學習框架,可以使 LLMs(如 Llama、Mistral、Gemma、DeepSeek 等)的訓練速度提高 2 倍,VRAM 使用量減少高達 70%,同時為訓練、評估和部署提供了簡化的、與 Hugging Face 相容的工作流程。更多資訊,請參閱Unsloth 整合

指令調優示例

指令調優教導基礎語言模型遵循使用者指令並進行對話。這需要

  1. 聊天模板:定義如何將對話構造成文字序列,包括角色標記(使用者/助手)、特殊詞元和對話輪次邊界。在聊天模板中閱讀更多關於聊天模板的資訊。
  2. 對話資料集:包含指令-響應對

此示例展示瞭如何使用 Capybara 資料集和來自 HuggingFaceTB/SmolLM3-3B 的聊天模板,將 Qwen 3 0.6B Base 模型轉換為一個指令遵循模型。SFT 訓練器會自動處理分詞器更新和特殊詞元配置。

from trl import SFTTrainer, SFTConfig
from datasets import load_dataset

trainer = SFTTrainer(
    model="Qwen/Qwen3-0.6B-Base",
    args=SFTConfig(
        output_dir="Qwen3-0.6B-Instruct",
        chat_template_path="HuggingFaceTB/SmolLM3-3B",
    ),
    train_dataset=load_dataset("trl-lib/Capybara", split="train"),
)
trainer.train()

一些基礎模型,如 Qwen 系列模型,在其分詞器中預定義了聊天模板。在這些情況下,不必應用 `clone_chat_template()`,因為分詞器已經處理了格式化。但是,有必要將 EOS 詞元與聊天模板對齊,以確保模型的響應正確終止。在這些情況下,在 SFTConfig 中指定 `eos_token`;例如,對於 `Qwen/Qwen2.5-1.5B`,應設定 `eos_token="<|im_end|>"`。

訓練完成後,你的模型現在可以使用其新的聊天模板來遵循指令並進行對話。

>>> from transformers import pipeline
>>> pipe = pipeline("text-generation", model="Qwen3-0.6B-Instruct/checkpoint-5000")
>>> prompt = "<|im_start|>user\nWhat is the capital of France? Answer in one word.<|im_end|>\n<|im_start|>assistant\n"
>>> response = pipe(prompt)
>>> response[0]["generated_text"]
'<|im_start|>user\nWhat is the capital of France? Answer in one word.<|im_end|>\n<|im_start|>assistant\nThe capital of France is Paris.'

或者,使用結構化對話格式(推薦)

>>> prompt = [{"role": "user", "content": "What is the capital of France? Answer in one word."}]
>>> response = pipe(prompt)
>>> response[0]["generated_text"]
[{'role': 'user', 'content': 'What is the capital of France? Answer in one word.'}, {'role': 'assistant', 'content': 'The capital of France is Paris.'}]

使用 SFT 進行工具呼叫

SFT 訓練器完全支援對具有*工具呼叫*能力的模型進行微調。在這種情況下,每個資料集示例應包含

  • 對話訊息,包括任何工具呼叫(`tool_calls`)和工具響應(`tool` 角色訊息)
  • `tools` 列中的可用工具列表,通常以 JSON 模式提供

有關預期資料集結構的詳細資訊,請參閱資料集格式 — 工具呼叫部分。

為視覺語言模型擴充套件 SFTTrainer

SFTTrainer 目前尚未原生支援視覺語言資料。但是,我們提供了一個關於如何調整訓練器以支援視覺語言資料的指南。具體來說,您需要使用一個與視覺語言資料相容的自定義資料整理器。本指南概述了進行這些調整的步驟。有關具體示例,請參閱指令碼 examples/scripts/sft_vlm.py,該指令碼演示瞭如何在 HuggingFaceH4/llava-instruct-mix-vsft 資料集上微調 LLaVA 1.5 模型。

準備資料

資料格式是靈活的,只要它與我們稍後將定義的自定義整理器相容即可。一種常見的方法是使用對話資料。鑑於資料包含文字和影像,格式需要相應調整。以下是一個涉及文字和影像的對話資料格式示例

images = ["obama.png"]
messages = [
    {
        "role": "user",
        "content": [
            {"type": "text", "text": "Who is this?"},
            {"type": "image"}
        ]
    },
    {
        "role": "assistant",
        "content": [
            {"type": "text", "text": "Barack Obama"}
        ]
    },
    {
        "role": "user",
        "content": [
            {"type": "text", "text": "What is he famous for?"}
        ]
    },
    {
        "role": "assistant",
        "content": [
            {"type": "text", "text": "He is the 44th President of the United States."}
        ]
    }
]

為了說明如何使用 LLaVA 模型處理此資料格式,您可以使用以下程式碼

from transformers import AutoProcessor

processor = AutoProcessor.from_pretrained("llava-hf/llava-1.5-7b-hf")
print(processor.apply_chat_template(messages, tokenize=False))

輸出將格式化如下

Who is this? ASSISTANT: Barack Obama USER: What is he famous for? ASSISTANT: He is the 44th President of the United States. 

用於處理多模態資料的自定義整理器

SFTTrainer 的預設行為不同,多模態資料的處理是在資料整理過程中動態完成的。為此,您需要定義一個自定義整理器來處理文字和影像。該整理器必須接受一個示例列表作為輸入(有關資料格式的示例,請參見上一節)並返回一批處理過的資料。以下是此類整理器的一個示例

def collate_fn(examples):
    # Get the texts and images, and apply the chat template
    texts = [processor.apply_chat_template(example["messages"], tokenize=False) for example in examples]
    images = [example["images"][0] for example in examples]

    # Tokenize the texts and process the images
    batch = processor(images=images, text=texts, return_tensors="pt", padding=True)

    # The labels are the input_ids, and we mask the padding tokens in the loss computation
    labels = batch["input_ids"].clone()
    labels[labels == processor.tokenizer.pad_token_id] = -100
    batch["labels"] = labels

    return batch

我們可以透過執行以下程式碼來驗證整理器是否按預期工作

from datasets import load_dataset

dataset = load_dataset("HuggingFaceH4/llava-instruct-mix-vsft", split="train")
examples = [dataset[0], dataset[1]]  # Just two examples for the sake of the example
collated_data = collate_fn(examples)
print(collated_data.keys())  # dict_keys(['input_ids', 'attention_mask', 'pixel_values', 'labels'])

訓練視覺語言模型

現在我們已經準備好資料並定義了整理器,我們可以繼續訓練模型了。為了確保資料不被僅作為文字處理,我們需要在 SFTConfig 中設定幾個引數,特別是將 `remove_unused_columns` 和 `skip_prepare_dataset` 設定為 `True` 以避免資料集的預設處理。以下是如何設定 `SFTTrainer` 的示例。

training_args.remove_unused_columns = False
training_args.dataset_kwargs = {"skip_prepare_dataset": True}

trainer = SFTTrainer(
    model=model,
    args=training_args,
    data_collator=collate_fn,
    train_dataset=train_dataset,
    processing_class=processor,
)

有關在 HuggingFaceH4/llava-instruct-mix-vsft 資料集上訓練 LLaVa 1.5 的完整示例,請參閱指令碼 examples/scripts/sft_vlm.py

SFTTrainer

class trl.SFTTrainer

< >

( model: typing.Union[str, torch.nn.modules.module.Module, transformers.modeling_utils.PreTrainedModel] args: typing.Union[trl.trainer.sft_config.SFTConfig, transformers.training_args.TrainingArguments, NoneType] = None data_collator: typing.Optional[transformers.data.data_collator.DataCollator] = None train_dataset: typing.Union[datasets.arrow_dataset.Dataset, datasets.iterable_dataset.IterableDataset, NoneType] = None eval_dataset: typing.Union[datasets.arrow_dataset.Dataset, dict[str, datasets.arrow_dataset.Dataset], NoneType] = None processing_class: typing.Union[transformers.tokenization_utils_base.PreTrainedTokenizerBase, transformers.image_processing_utils.BaseImageProcessor, transformers.feature_extraction_utils.FeatureExtractionMixin, transformers.processing_utils.ProcessorMixin, NoneType] = None compute_loss_func: typing.Optional[typing.Callable] = None compute_metrics: typing.Optional[typing.Callable[[transformers.trainer_utils.EvalPrediction], dict]] = None callbacks: typing.Optional[list[transformers.trainer_callback.TrainerCallback]] = None optimizers: tuple = (None, None) optimizer_cls_and_kwargs: typing.Optional[tuple[type[torch.optim.optimizer.Optimizer], dict[str, typing.Any]]] = None preprocess_logits_for_metrics: typing.Optional[typing.Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None peft_config: typing.Optional[ForwardRef('PeftConfig')] = None formatting_func: typing.Optional[typing.Callable[[dict], str]] = None )

引數

  • model (Union[str, PreTrainedModel]) — 要訓練的模型。可以是:

    • 一個字串,即 huggingface.co 上模型倉庫中預訓練模型的*模型 ID*,或包含使用 `save_pretrained` 儲存的模型權重的*目錄*路徑,例如 `'./my_model_directory/'`。模型使用 `from_pretrained` 和 `args.model_init_kwargs` 中的關鍵字引數載入。
    • 一個 PreTrainedModel 物件。僅支援因果語言模型。
  • args (SFTConfig, *可選*, 預設為 None) — 此訓練器的配置。如果為 None,則使用預設配置。
  • data_collator (DataCollator, *可選*) — 用於從處理過的 `train_dataset` 或 `eval_dataset` 的元素列表中形成批次的函式。將預設為自定義的 `DataCollatorForLanguageModeling`。
  • train_dataset (DatasetIterableDataset) — 用於訓練的資料集。SFT 支援語言建模型別和提示-補全型別。樣本的格式可以是:

    • 標準:每個樣本包含純文字。
    • 對話式:每個樣本包含結構化訊息(例如,角色和內容)。

    訓練器還支援已處理(已分詞)的資料集,只要它們包含一個 `input_ids` 欄位。

  • eval_dataset (DatasetIterableDataset 或 `dict[str, Union[Dataset, IterableDataset]]`) — 用於評估的資料集。它必須滿足與 `train_dataset` 相同的要求。
  • processing_class (PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixinProcessorMixin, *可選*, 預設為 None) — 用於處理資料的處理類。如果為 None,則從模型的名稱使用 from_pretrained 載入處理類。
  • callbacks (TrainerCallback 列表, *可選*, 預設為 None) — 用於自定義訓練迴圈的回撥列表。將新增到詳見此處的預設回撥列表中。

    如果要刪除使用的預設回撥之一,請使用 `remove_callback` 方法。

  • optimizers (tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR], *可選*, 預設為 (None, None)) — 包含要使用的最佳化器和排程器的元組。將預設為模型上的 `AdamW` 例項和由 `args` 控制的 `get_linear_schedule_with_warmup` 提供的排程器。
  • optimizer_cls_and_kwargs (Tuple[Type[torch.optim.Optimizer], Dict[str, Any]], *可選*, 預設為 None) — 包含最佳化器類和要使用的關鍵字引數的元組。覆蓋 `args` 中的 `optim` 和 `optim_args`。與 `optimizers` 引數不相容。

    與 `optimizers` 不同,此引數避免了在初始化 Trainer 之前將模型引數放置在正確裝置上的需要。

  • preprocess_logits_for_metrics (Callable[[torch.Tensor, torch.Tensor], torch.Tensor], *可選*, 預設為 None) — 一個函式,用於在每個評估步驟快取 logits 之前對其進行預處理。必須接受兩個張量,logits 和標籤,並返回處理後所需的 logits。此函式所做的修改將反映在 `compute_metrics` 接收的預測中。

    請注意,如果資料集沒有標籤,則標籤(第二個引數)將為 `None`。

  • peft_config (~peft.PeftConfig, *可選*, 預設為 None) — 用於包裝模型的 PEFT 配置。如果為 `None`,則不包裝模型。
  • formatting_func (Optional[Callable]) — 在分詞前應用於資料集的格式化函式。顯式應用格式化函式會將資料集轉換為語言建模型別。

用於監督式微調(SFT)方法的訓練器。

此類是 `transformers.Trainer` 類的包裝器,並繼承其所有屬性和方法。

示例

from datasets import load_dataset
from trl import SFTTrainer

dataset = load_dataset("roneneldan/TinyStories", split="train[:1%]")

trainer = SFTTrainer(model="Qwen/Qwen2-0.5B-Instruct", train_dataset=dataset)
trainer.train()

train

< >

( resume_from_checkpoint: typing.Union[str, bool, NoneType] = None trial: typing.Union[ForwardRef('optuna.Trial'), dict[str, typing.Any], NoneType] = None ignore_keys_for_eval: typing.Optional[list[str]] = None **kwargs )

引數

  • resume_from_checkpoint (str or bool, optional) — 如果是 str,則為先前 Trainer 例項儲存的檢查點的本地路徑。如果是 bool 且等於 True,則載入先前 Trainer 例項儲存在 args.output_dir 中的最後一個檢查點。如果提供此引數,訓練將從載入的模型/最佳化器/排程器狀態恢復。
  • trial (optuna.Trialdict[str, Any], optional) — 用於超引數搜尋的試驗執行或超引數字典。
  • ignore_keys_for_eval (list[str], optional) — 模型輸出(如果為字典)中的一個鍵列表,在訓練期間收集評估預測時應忽略這些鍵。
  • kwargs (dict[str, Any], optional) — 用於隱藏已棄用引數的附加關鍵字引數。

主訓練入口點。

save_model

< >

( output_dir: typing.Optional[str] = None _internal_call: bool = False )

將儲存模型,以便您可以使用 `from_pretrained()` 重新載入它。

僅從主程序儲存。

push_to_hub

< >

( commit_message: typing.Optional[str] = 'End of training' blocking: bool = True token: typing.Optional[str] = None revision: typing.Optional[str] = None **kwargs )

引數

  • commit_message (str, optional, 預設為 "End of training") — 推送時要提交的訊息。
  • blocking (bool, optional, 預設為 True) — 函式是否僅在 git push 完成後返回。
  • token (str, optional, 預設為 None) — 具有寫入許可權的令牌,用於覆蓋 Trainer 的原始引數。
  • revision (str, optional) — 要提交的 git 修訂版本。預設為“main”分支的頭部。
  • kwargs (dict[str, Any], optional) — 傳遞給 ~Trainer.create_model_card 的附加關鍵字引數。

將 `self.model` 和 `self.processing_class` 上傳到 🤗 模型中心的 `self.args.hub_model_id` 儲存庫。

SFTConfig

class trl.SFTConfig

< >

( output_dir: typing.Optional[str] = None overwrite_output_dir: bool = False do_train: bool = False do_eval: bool = False do_predict: bool = False eval_strategy: typing.Union[transformers.trainer_utils.IntervalStrategy, str] = 'no' prediction_loss_only: bool = False per_device_train_batch_size: int = 8 per_device_eval_batch_size: int = 8 per_gpu_train_batch_size: typing.Optional[int] = None per_gpu_eval_batch_size: typing.Optional[int] = None gradient_accumulation_steps: int = 1 eval_accumulation_steps: typing.Optional[int] = None eval_delay: typing.Optional[float] = 0 torch_empty_cache_steps: typing.Optional[int] = None learning_rate: float = 2e-05 weight_decay: float = 0.0 adam_beta1: float = 0.9 adam_beta2: float = 0.999 adam_epsilon: float = 1e-08 max_grad_norm: float = 1.0 num_train_epochs: float = 3.0 max_steps: int = -1 lr_scheduler_type: typing.Union[transformers.trainer_utils.SchedulerType, str] = 'linear' lr_scheduler_kwargs: typing.Union[dict[str, typing.Any], str, NoneType] = <factory> warmup_ratio: float = 0.0 warmup_steps: int = 0 log_level: str = 'passive' log_level_replica: str = 'warning' log_on_each_node: bool = True logging_dir: typing.Optional[str] = None logging_strategy: typing.Union[transformers.trainer_utils.IntervalStrategy, str] = 'steps' logging_first_step: bool = False logging_steps: float = 10 logging_nan_inf_filter: bool = True save_strategy: typing.Union[transformers.trainer_utils.SaveStrategy, str] = 'steps' save_steps: float = 500 save_total_limit: typing.Optional[int] = None save_safetensors: typing.Optional[bool] = True save_on_each_node: bool = False save_only_model: bool = False restore_callback_states_from_checkpoint: bool = False no_cuda: bool = False use_cpu: bool = False use_mps_device: bool = False seed: int = 42 data_seed: typing.Optional[int] = None jit_mode_eval: bool = False use_ipex: bool = False bf16: typing.Optional[bool] = None fp16: bool = False fp16_opt_level: str = 'O1' half_precision_backend: str = 'auto' bf16_full_eval: bool = False fp16_full_eval: bool = False tf32: typing.Optional[bool] = None local_rank: int = -1 ddp_backend: typing.Optional[str] = None tpu_num_cores: typing.Optional[int] = None tpu_metrics_debug: bool = False debug: typing.Union[str, list[transformers.debug_utils.DebugOption]] = '' dataloader_drop_last: bool = False eval_steps: typing.Optional[float] = None dataloader_num_workers: int = 0 dataloader_prefetch_factor: typing.Optional[int] = None past_index: int = -1 run_name: typing.Optional[str] = None disable_tqdm: typing.Optional[bool] = None remove_unused_columns: typing.Optional[bool] = True label_names: typing.Optional[list[str]] = None load_best_model_at_end: typing.Optional[bool] = False metric_for_best_model: typing.Optional[str] = None greater_is_better: typing.Optional[bool] = None ignore_data_skip: bool = False fsdp: typing.Union[list[transformers.trainer_utils.FSDPOption], str, NoneType] = '' fsdp_min_num_params: int = 0 fsdp_config: typing.Union[dict[str, typing.Any], str, NoneType] = None fsdp_transformer_layer_cls_to_wrap: typing.Optional[str] = None accelerator_config: typing.Union[dict, str, NoneType] = None deepspeed: typing.Union[dict, str, NoneType] = None label_smoothing_factor: float = 0.0 optim: typing.Union[transformers.training_args.OptimizerNames, str] = 'adamw_torch' optim_args: typing.Optional[str] = None adafactor: bool = False group_by_length: bool = False length_column_name: typing.Optional[str] = 'length' report_to: typing.Union[NoneType, str, list[str]] = None ddp_find_unused_parameters: typing.Optional[bool] = None ddp_bucket_cap_mb: typing.Optional[int] = None ddp_broadcast_buffers: typing.Optional[bool] = None dataloader_pin_memory: bool = True dataloader_persistent_workers: bool = False skip_memory_metrics: bool = True use_legacy_prediction_loop: bool = False push_to_hub: bool = False resume_from_checkpoint: typing.Optional[str] = None hub_model_id: typing.Optional[str] = None hub_strategy: typing.Union[transformers.trainer_utils.HubStrategy, str] = 'every_save' hub_token: typing.Optional[str] = None hub_private_repo: typing.Optional[bool] = None hub_always_push: bool = False hub_revision: typing.Optional[str] = None gradient_checkpointing: bool = False gradient_checkpointing_kwargs: typing.Union[dict[str, typing.Any], str, NoneType] = None include_inputs_for_metrics: bool = False include_for_metrics: list = <factory> eval_do_concat_batches: bool = True fp16_backend: str = 'auto' push_to_hub_model_id: typing.Optional[str] = None push_to_hub_organization: typing.Optional[str] = None push_to_hub_token: typing.Optional[str] = None mp_parameters: str = '' auto_find_batch_size: bool = False full_determinism: bool = False torchdynamo: typing.Optional[str] = None ray_scope: typing.Optional[str] = 'last' ddp_timeout: int = 1800 torch_compile: bool = False torch_compile_backend: typing.Optional[str] = None torch_compile_mode: typing.Optional[str] = None include_tokens_per_second: typing.Optional[bool] = False include_num_input_tokens_seen: typing.Optional[bool] = False neftune_noise_alpha: typing.Optional[float] = None optim_target_modules: typing.Union[NoneType, str, list[str]] = None batch_eval_metrics: bool = False eval_on_start: bool = False use_liger_kernel: typing.Optional[bool] = False liger_kernel_config: typing.Optional[dict[str, bool]] = None eval_use_gather_object: typing.Optional[bool] = False average_tokens_across_devices: bool = True model_init_kwargs: typing.Optional[dict[str, typing.Any]] = None chat_template_path: typing.Optional[str] = None dataset_text_field: str = 'text' dataset_kwargs: typing.Optional[dict[str, typing.Any]] = None dataset_num_proc: typing.Optional[int] = None eos_token: typing.Optional[str] = None pad_token: typing.Optional[str] = None max_length: typing.Optional[int] = 1024 packing: bool = False packing_strategy: str = 'bfd' padding_free: bool = False pad_to_multiple_of: typing.Optional[int] = None eval_packing: typing.Optional[bool] = None completion_only_loss: typing.Optional[bool] = None assistant_only_loss: bool = False activation_offloading: bool = False )

控制模型的引數

  • model_init_kwargs (dict[str, Any]None, optional, 預設為 None) — 當 SFTTrainermodel 引數以字串形式提供時,用於 from_pretrained 的關鍵字引數。
  • chat_template_path (strNone, optional, 預設為 None) — 如果指定,則設定模型的聊天模板。這可以是一個分詞器(本地目錄或 Hugging Face Hub 模型)的路徑,也可以是一個 Jinja 模板檔案的直接路徑。使用 Jinja 檔案時,必須確保模板中引用的任何特殊令牌都已新增到分詞器中,並相應地調整模型的嵌入層大小。

控制資料預處理的引數

  • dataset_text_field (str, optional, 預設為 "text") — 資料集中包含文字資料的列名。
  • dataset_kwargs (dict[str, Any]None, optional, 預設為 None) — 資料集準備的可選關鍵字引數字典。唯一支援的鍵是 skip_prepare_dataset
  • dataset_num_proc (intNone, optional, 預設為 None) — 用於處理資料集的程序數。
  • eos_token (strNone, optional, 預設為 None) — 用於指示一輪對話或序列結束的令牌。如果為 None,則預設為 processing_class.eos_token
  • pad_token (intNone, optional, 預設為 None) — 用於填充的令牌。如果為 None,則預設為 processing_class.pad_token,如果該值也為 None,則回退到 processing_class.eos_token
  • max_length (intNone, optional, 預設為 1024) — 標記化序列的最大長度。超過 max_length 的序列將從右側截斷。如果為 None,則不應用截斷。啟用打包時,此值設定序列長度。
  • packing (bool, optional, 預設為 False) — 是否將多個序列分組到固定長度的塊中,以提高計算效率並減少填充。使用 max_length 定義序列長度。
  • packing_strategy (str, optional, 預設為 "bfd") — 打包序列的策略。可以是 "bfd"(最佳擬合遞減,預設值)或 "wrapped"
  • padding_free (bool, optional, 預設為 False) — 是否透過將批次中的所有序列展平為單個連續序列來執行無填充的前向傳播。這透過消除填充開銷來減少記憶體使用。目前,這僅在 FlashAttention 2 或 3 中受支援,因為它們可以高效處理展平的批次結構。當使用 "bfd" 策略啟用打包時,無論此引數的值如何,都會啟用無填充。
  • pad_to_multiple_of (intNone, optional, 預設為 None) — 如果設定,序列將被填充到該值的倍數。
  • eval_packing (boolNone, optional, 預設為 None) — 是否打包評估資料集。如果為 None,則使用與 packing 相同的值。

控制訓練的引數

  • completion_only_loss (boolNone, optional, 預設為 None) — 是否僅對序列的補全部分計算損失。如果設定為 True,則僅對補全部分計算損失,這僅支援提示-補全資料集。如果為 False,則對整個序列計算損失。如果為 None(預設),行為取決於資料集:對於提示-補全資料集,對補全部分計算損失;對於語言建模資料集,對整個序列計算損失。
  • assistant_only_loss (bool, optional, 預設為 False) — 是否僅對序列的助手部分計算損失。如果設定為 True,則僅對助手響應計算損失,這僅支援對話資料集。如果為 False,則對整個序列計算損失。
  • activation_offloading (bool, optional, 預設為 False) — 是否將啟用解除安裝到 CPU。

用於 SFTTrainer 的配置類。

此類僅包含特定於 SFT 訓練的引數。有關訓練引數的完整列表,請參閱 TrainingArguments 文件。請注意,此類中的預設值可能與 TrainingArguments 中的不同。

使用 HfArgumentParser,我們可以將此類別轉換為可在命令列上指定的 argparse 引數。

< > 在 GitHub 上更新

© . This site is unofficial and not affiliated with Hugging Face, Inc.