TRL 文件

廣義知識蒸餾訓練器

Hugging Face's logo
加入 Hugging Face 社群

並獲得增強的文件體驗

開始使用

廣義知識蒸餾訓練器

概述

廣義知識蒸餾 (GKD) 由 Rishabh Agarwal、Nino Vieillard、Yongchao Zhou、Piotr Stanczyk、Sabela Ramos、Matthieu Geist 和 Olivier Bachem 在論文 《On-Policy Distillation of Language Models: Learning from Self-Generated Mistakes》 中提出。

論文摘要如下:

知識蒸餾 (KD) 廣泛用於壓縮教師模型,以降低其推理成本和記憶體佔用,其方法是訓練一個較小的學生模型。然而,當前用於自迴歸序列模型的知識蒸餾方法存在訓練期間看到的輸出序列與學生模型在推理時生成的輸出序列之間的分佈不匹配問題。為了解決這個問題,我們引入了廣義知識蒸餾 (GKD)。GKD 不僅僅依賴於一組固定的輸出序列,而是透過利用教師模型對學生模型自生成的輸出序列的反饋來訓練學生模型。與監督式 KD 方法不同,GKD 還提供了在學生和教師模型之間使用替代損失函式的靈活性,這在學生模型缺乏模仿教師模型分佈的表達能力時非常有用。此外,GKD 促進了蒸餾與強化學習微調 (RLHF) 的無縫整合。我們展示了 GKD 在摘要、翻譯和算術推理任務上蒸餾自迴歸語言模型以及在指令調優中進行與任務無關的蒸餾的有效性。

GKD 的關鍵方面是

  1. 它透過在學生模型自生成的輸出序列上進行訓練,解決了自迴歸序列模型中訓練-推理分佈不匹配的問題。
  2. GKD 允許透過廣義 Jensen-Shannon 散度 (JSD) 靈活選擇學生模型和教師模型之間的不同散度度量,這在學生模型能力不足以完全模仿教師模型時非常有用。

此後訓練方法由 Kashif RasulLewis Tunstall 貢獻。

使用技巧

GKDTrainerSFTTrainer 類的一個包裝器,它接受一個教師模型引數。它需要透過 GKDConfig 設定三個引數,即

  • lmbda:控制學生資料比例,即同策略 (on-policy) 學生生成輸出的比例。當 `lmbda=0.0` 時,損失函式簡化為監督式 JSD,學生模型使用教師模型的詞元級機率進行訓練。當 `lmbda=1.0` 時,損失函式簡化為同策略 JSD,學生模型生成輸出序列,並從教師模型處獲得對這些序列的特定於詞元的反饋。對於 [0, 1] 之間的值,它會根據每個批次的 `lmbda` 值在這兩者之間隨機選擇。
  • seq_kd:控制是否執行序列級 KD (可視為在教師生成的輸出上進行監督式微調)。當 `seq_kd=True` 且 `lmbda=0.0` 時,損失函式簡化為監督式 JSD,教師模型生成輸出序列,學生模型從教師模型處獲得對這些序列的特定於詞元的反饋。
  • beta:控制廣義 Jensen-Shannon 散度中的插值。當 `beta=0.0` 時,損失函式近似於前向 KL 散度,而當 `beta=1.0` 時,損失函式近似於反向 KL 散度。對於 [0, 1] 之間的值,它在這兩者之間進行插值。

作者發現,同策略資料 (高 `lmbda`) 表現更好,而最優的 `beta` 值則因任務和評估方法而異。

在訓練 Gemma 模型 時,請確保設定 `attn_implementation="flash_attention_2"`。否則,由於該架構採用的軟上限技術,您將在 logits 中遇到 NaN。

基本 API 如下

from datasets import Dataset
from trl import GKDConfig, GKDTrainer
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
)

NUM_DUMMY_SAMPLES = 100

tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
# The model to optimise
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
# The teacher model to calculate the KL divergence against
teacher_model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-1.5B-Instruct")

train_dataset = Dataset.from_dict(
    {
        "messages": [
            [
                {"role": "user", "content": "Hi, how are you?"},
                {"role": "assistant", "content": "I'm great thanks"},
            ]
        ]
        * NUM_DUMMY_SAMPLES
    }
)
eval_dataset = Dataset.from_dict(
    {
        "messages": [
            [
                {"role": "user", "content": "What colour is the sky?"},
                {"role": "assistant", "content": "The sky is blue"},
            ]
        ]
        * NUM_DUMMY_SAMPLES
    }
)

training_args = GKDConfig(output_dir="gkd-model", per_device_train_batch_size=1)
trainer = GKDTrainer(
    model=model,
    teacher_model=teacher_model,
    args=training_args,
    processing_class=tokenizer,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
)
trainer.train()

預期資料集型別

資料集應格式化為“訊息”列表,其中每個訊息是包含以下鍵的字典列表

  • role:`system`、`assistant` 或 `user`
  • content:訊息內容

GKDTrainer

class trl.GKDTrainer

< >

( model: typing.Union[transformers.modeling_utils.PreTrainedModel, torch.nn.modules.module.Module, str, NoneType] = None teacher_model: typing.Union[transformers.modeling_utils.PreTrainedModel, torch.nn.modules.module.Module, str] = None args: typing.Optional[trl.trainer.gkd_config.GKDConfig] = None data_collator: typing.Optional[transformers.data.data_collator.DataCollator] = None train_dataset: typing.Optional[datasets.arrow_dataset.Dataset] = None eval_dataset: typing.Union[datasets.arrow_dataset.Dataset, dict[str, datasets.arrow_dataset.Dataset], NoneType] = None processing_class: typing.Union[transformers.tokenization_utils_base.PreTrainedTokenizerBase, transformers.image_processing_utils.BaseImageProcessor, transformers.feature_extraction_utils.FeatureExtractionMixin, transformers.processing_utils.ProcessorMixin, NoneType] = None compute_metrics: typing.Optional[typing.Callable[[transformers.trainer_utils.EvalPrediction], dict]] = None callbacks: typing.Optional[list[transformers.trainer_callback.TrainerCallback]] = None optimizers: tuple = (None, None) preprocess_logits_for_metrics: typing.Optional[typing.Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None peft_config: typing.Optional[ForwardRef('PeftConfig')] = None formatting_func: typing.Optional[typing.Callable] = None )

train

< >

( resume_from_checkpoint: typing.Union[str, bool, NoneType] = None trial: typing.Union[ForwardRef('optuna.Trial'), dict[str, typing.Any], NoneType] = None ignore_keys_for_eval: typing.Optional[list[str]] = None **kwargs )

引數

  • resume_from_checkpoint (strbool, 可選) — 如果是 str,則為 `Trainer` 的先前例項儲存的檢查點的本地路徑。如果是 `bool` 且等於 `True`,則載入 *args.output_dir* 中由 `Trainer` 的先前例項儲存的最後一個檢查點。如果存在,訓練將從此處載入的模型/最佳化器/排程器狀態恢復。
  • trial (optuna.Trialdict[str, Any], 可選) — 用於超引數搜尋的試驗執行或超引數字典。
  • ignore_keys_for_eval (list[str], 可選) — 在訓練期間收集評估預測時,應忽略的模型輸出中的鍵列表(如果輸出是字典)。
  • kwargs (dict[str, Any], 可選) — 用於隱藏已棄用引數的附加關鍵字引數。

主訓練入口點。

save_model

< >

( output_dir: typing.Optional[str] = None _internal_call: bool = False )

將儲存模型,以便您可以使用 `from_pretrained()` 重新載入它。

僅從主程序儲存。

push_to_hub

< >

( commit_message: typing.Optional[str] = 'End of training' blocking: bool = True token: typing.Optional[str] = None revision: typing.Optional[str] = None **kwargs )

引數

  • commit_message (str, 可選, 預設為 "End of training") — 推送時使用的提交資訊。
  • blocking (bool, 可選, 預設為 True) — 函式是否應該在 `git push` 完成後才返回。
  • token (str, 可選, 預設為 None) — 具有寫入許可權的令牌,用於覆蓋 Trainer 的原始引數。
  • revision (str, 可選) — 要提交的 git 修訂版本。預設為“main”分支的頭部。
  • kwargs (dict[str, Any], 可選) — 傳遞給 ~Trainer.create_model_card 的其他關鍵字引數。

將 `self.model` 和 `self.processing_class` 上傳到 🤗 模型中心的 `self.args.hub_model_id` 儲存庫。

GKDConfig

class trl.GKDConfig

< >

( output_dir: typing.Optional[str] = None overwrite_output_dir: bool = False do_train: bool = False do_eval: bool = False do_predict: bool = False eval_strategy: typing.Union[transformers.trainer_utils.IntervalStrategy, str] = 'no' prediction_loss_only: bool = False per_device_train_batch_size: int = 8 per_device_eval_batch_size: int = 8 per_gpu_train_batch_size: typing.Optional[int] = None per_gpu_eval_batch_size: typing.Optional[int] = None gradient_accumulation_steps: int = 1 eval_accumulation_steps: typing.Optional[int] = None eval_delay: typing.Optional[float] = 0 torch_empty_cache_steps: typing.Optional[int] = None learning_rate: float = 2e-05 weight_decay: float = 0.0 adam_beta1: float = 0.9 adam_beta2: float = 0.999 adam_epsilon: float = 1e-08 max_grad_norm: float = 1.0 num_train_epochs: float = 3.0 max_steps: int = -1 lr_scheduler_type: typing.Union[transformers.trainer_utils.SchedulerType, str] = 'linear' lr_scheduler_kwargs: typing.Union[dict[str, typing.Any], str, NoneType] = <factory> warmup_ratio: float = 0.0 warmup_steps: int = 0 log_level: str = 'passive' log_level_replica: str = 'warning' log_on_each_node: bool = True logging_dir: typing.Optional[str] = None logging_strategy: typing.Union[transformers.trainer_utils.IntervalStrategy, str] = 'steps' logging_first_step: bool = False logging_steps: float = 10 logging_nan_inf_filter: bool = True save_strategy: typing.Union[transformers.trainer_utils.SaveStrategy, str] = 'steps' save_steps: float = 500 save_total_limit: typing.Optional[int] = None save_safetensors: typing.Optional[bool] = True save_on_each_node: bool = False save_only_model: bool = False restore_callback_states_from_checkpoint: bool = False no_cuda: bool = False use_cpu: bool = False use_mps_device: bool = False seed: int = 42 data_seed: typing.Optional[int] = None jit_mode_eval: bool = False use_ipex: bool = False bf16: typing.Optional[bool] = None fp16: bool = False fp16_opt_level: str = 'O1' half_precision_backend: str = 'auto' bf16_full_eval: bool = False fp16_full_eval: bool = False tf32: typing.Optional[bool] = None local_rank: int = -1 ddp_backend: typing.Optional[str] = None tpu_num_cores: typing.Optional[int] = None tpu_metrics_debug: bool = False debug: typing.Union[str, list[transformers.debug_utils.DebugOption]] = '' dataloader_drop_last: bool = False eval_steps: typing.Optional[float] = None dataloader_num_workers: int = 0 dataloader_prefetch_factor: typing.Optional[int] = None past_index: int = -1 run_name: typing.Optional[str] = None disable_tqdm: typing.Optional[bool] = None remove_unused_columns: typing.Optional[bool] = True label_names: typing.Optional[list[str]] = None load_best_model_at_end: typing.Optional[bool] = False metric_for_best_model: typing.Optional[str] = None greater_is_better: typing.Optional[bool] = None ignore_data_skip: bool = False fsdp: typing.Union[list[transformers.trainer_utils.FSDPOption], str, NoneType] = '' fsdp_min_num_params: int = 0 fsdp_config: typing.Union[dict[str, typing.Any], str, NoneType] = None fsdp_transformer_layer_cls_to_wrap: typing.Optional[str] = None accelerator_config: typing.Union[dict, str, NoneType] = None deepspeed: typing.Union[dict, str, NoneType] = None label_smoothing_factor: float = 0.0 optim: typing.Union[transformers.training_args.OptimizerNames, str] = 'adamw_torch' optim_args: typing.Optional[str] = None adafactor: bool = False group_by_length: bool = False length_column_name: typing.Optional[str] = 'length' report_to: typing.Union[NoneType, str, list[str]] = None ddp_find_unused_parameters: typing.Optional[bool] = None ddp_bucket_cap_mb: typing.Optional[int] = None ddp_broadcast_buffers: typing.Optional[bool] = None dataloader_pin_memory: bool = True dataloader_persistent_workers: bool = False skip_memory_metrics: bool = True use_legacy_prediction_loop: bool = False push_to_hub: bool = False resume_from_checkpoint: typing.Optional[str] = None hub_model_id: typing.Optional[str] = None hub_strategy: typing.Union[transformers.trainer_utils.HubStrategy, str] = 'every_save' hub_token: typing.Optional[str] = None hub_private_repo: typing.Optional[bool] = None hub_always_push: bool = False hub_revision: typing.Optional[str] = None gradient_checkpointing: bool = False gradient_checkpointing_kwargs: typing.Union[dict[str, typing.Any], str, NoneType] = None include_inputs_for_metrics: bool = False include_for_metrics: list = <factory> eval_do_concat_batches: bool = True fp16_backend: str = 'auto' push_to_hub_model_id: typing.Optional[str] = None push_to_hub_organization: typing.Optional[str] = None push_to_hub_token: typing.Optional[str] = None mp_parameters: str = '' auto_find_batch_size: bool = False full_determinism: bool = False torchdynamo: typing.Optional[str] = None ray_scope: typing.Optional[str] = 'last' ddp_timeout: int = 1800 torch_compile: bool = False torch_compile_backend: typing.Optional[str] = None torch_compile_mode: typing.Optional[str] = None include_tokens_per_second: typing.Optional[bool] = False include_num_input_tokens_seen: typing.Optional[bool] = False neftune_noise_alpha: typing.Optional[float] = None optim_target_modules: typing.Union[NoneType, str, list[str]] = None batch_eval_metrics: bool = False eval_on_start: bool = False use_liger_kernel: typing.Optional[bool] = False liger_kernel_config: typing.Optional[dict[str, bool]] = None eval_use_gather_object: typing.Optional[bool] = False average_tokens_across_devices: bool = True model_init_kwargs: typing.Optional[dict[str, typing.Any]] = None chat_template_path: typing.Optional[str] = None dataset_text_field: str = 'text' dataset_kwargs: typing.Optional[dict[str, typing.Any]] = None dataset_num_proc: typing.Optional[int] = None eos_token: typing.Optional[str] = None pad_token: typing.Optional[str] = None max_length: typing.Optional[int] = 1024 packing: bool = False packing_strategy: str = 'bfd' padding_free: bool = False pad_to_multiple_of: typing.Optional[int] = None eval_packing: typing.Optional[bool] = None completion_only_loss: typing.Optional[bool] = None assistant_only_loss: bool = False activation_offloading: bool = False temperature: float = 0.9 lmbda: float = 0.5 beta: float = 0.5 max_new_tokens: int = 128 teacher_model_name_or_path: typing.Optional[str] = None teacher_model_init_kwargs: typing.Optional[dict[str, typing.Any]] = None disable_dropout: bool = True seq_kd: bool = False )

引數

  • temperature (float, 可選, 預設為 0.9) — 用於取樣的溫度。溫度越高,補全結果的隨機性越強。
  • lmbda (float, 可選, 預設為 0.5) — Lambda 引數,用於控制學生資料部分(即同策略下學生生成輸出的比例)。
  • beta (float, 可選, 預設為 0.5) — 廣義 Jensen-Shannon 散度損失的插值係數,介於 0.01.0 之間。當 beta 為 0.0 時,損失為 KL 散度。當 beta 為 1.0 時,損失為逆 KL 散度。
  • max_new_tokens (int, 可選, 預設為 128) — 每次補全生成的最大詞元數。
  • teacher_model_name_or_path (strNone, 可選, 預設為 None) — 教師模型的模型名稱或路徑。如果為 None,則教師模型將與正在訓練的模型相同。
  • teacher_model_init_kwargs (dict[str, Any]]None, 可選, 預設為 None) — 從字串例項化教師模型時,傳遞給 `AutoModelForCausalLM.from_pretrained` 的關鍵字引數。
  • disable_dropout (bool, 可選, 預設為 True) — 是否在模型中停用 dropout。
  • seq_kd (bool, 可選, 預設為 False) — Seq_kd 引數,用於控制是否執行序列級知識蒸餾(Sequence-Level KD),可視為在教師生成的輸出上進行監督微調。

GKDTrainer 的配置類。

此類僅包含特定於 GKD 訓練的引數。有關訓練引數的完整列表,請參閱 TrainingArgumentsSFTConfig 文件。

< > 在 GitHub 上更新

© . This site is unofficial and not affiliated with Hugging Face, Inc.