SetFit 文件

知識蒸餾

Hugging Face's logo
加入 Hugging Face 社群

並獲得增強的文件體驗

開始使用

知識蒸餾

如果您有未標記的資料,則可以使用知識蒸餾來提高小型 SetFit 模型的效能。該方法涉及訓練一個更大的模型,並使用未標記的資料將其效能蒸餾到您的小型 SetFit 模型中。因此,您的 SetFit 模型將變得更強大。

此外,您還可以使用知識蒸餾來用更高效的模型替換訓練過的 SetFit 模型,同時減少效能下降。

本指南將向您展示如何進行知識蒸餾。

資料準備

讓我們考慮一個場景,只有少量標記訓練資料(例如 64 個句子)。本指南將使用 ag_news 資料集來模擬此場景。

from datasets import load_dataset
from setfit import sample_dataset

# Load a dataset from the Hugging Face Hub
dataset = load_dataset("ag_news")

# Create a sample few-shot dataset to train with
train_dataset = sample_dataset(dataset["train"], label_column="label", num_samples=16)
# Dataset({
#     features: ['text', 'label'],
#     num_rows: 64
# })

# Dataset for evaluation
eval_dataset = dataset["test"]
# Dataset({
#     features: ['text', 'label'],
#     num_rows: 7600
# })

基線模型

我們可以使用標準的 SetFit 訓練方法來準備模型。

from setfit import SetFitModel, TrainingArguments, Trainer

model = SetFitModel.from_pretrained("sentence-transformers/paraphrase-MiniLM-L3-v2")

args = TrainingArguments(
    batch_size=64,
    num_epochs=5,
)

trainer = Trainer(
    model=model,
    args=args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
)
trainer.train()

metrics = trainer.evaluate()
print(metrics)
***** Running training *****
  Num examples = 48
  Num epochs = 5
  Total optimization steps = 240
  Total train batch size = 64
{'embedding_loss': 0.4173, 'learning_rate': 8.333333333333333e-07, 'epoch': 0.02}                                                                                  
{'embedding_loss': 0.1756, 'learning_rate': 1.7592592592592595e-05, 'epoch': 1.04}                                                                                 
{'embedding_loss': 0.119, 'learning_rate': 1.2962962962962964e-05, 'epoch': 2.08}                                                                                  
{'embedding_loss': 0.0872, 'learning_rate': 8.333333333333334e-06, 'epoch': 3.12}                                                                                  
{'embedding_loss': 0.0542, 'learning_rate': 3.7037037037037037e-06, 'epoch': 4.17}                                                                                 
{'train_runtime': 26.0837, 'train_samples_per_second': 588.873, 'train_steps_per_second': 9.201, 'epoch': 5.0}                                                     
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 240/240 [00:20<00:00, 11.97it/s] 
***** Running evaluation *****
{'accuracy': 0.7818421052631579}

該模型在我們的資料集上達到了 78.18% 的準確率。考慮到微小的訓練資料量,這當然是值得稱讚的,但我們可以使用知識蒸餾從我們的模型中榨取更多的效能。

未標記資料準備

除了標記訓練資料之外,我們可能還有大量未標記的訓練資料(例如 500 個句子)。讓我們準備一下

# Create a dataset of unlabeled examples to perform knowledge distillation
unlabeled_train_dataset = dataset["train"].shuffle(seed=0).select(range(500))
unlabeled_train_dataset = unlabeled_train_dataset.remove_columns("label")
# Dataset({
#     features: ['text'],
#     num_rows: 500
# })

教師模型

然後,我們將準備一個更大的訓練好的 SetFit 模型,作為我們較小學生模型的教師。強大的 sentence-transformers/paraphrase-mpnet-base-v2 句子轉換器模型將用於初始化 SetFit 模型。

from setfit import SetFitModel

teacher_model = SetFitModel.from_pretrained("sentence-transformers/paraphrase-mpnet-base-v2")

我們首先需要在標記資料集上訓練此模型

from setfit import TrainingArguments, Trainer

teacher_args = TrainingArguments(
    batch_size=16,
    num_epochs=2,
)

teacher_trainer = Trainer(
    model=teacher_model,
    args=teacher_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
)

# Train teacher model
teacher_trainer.train()
teacher_metrics = teacher_trainer.evaluate()
print(teacher_metrics)
***** Running training *****
  Num examples = 192
  Num epochs = 2
  Total optimization steps = 384
  Total train batch size = 16
{'embedding_loss': 0.4093, 'learning_rate': 5.128205128205128e-07, 'epoch': 0.01}                                                                                  
{'embedding_loss': 0.1087, 'learning_rate': 1.9362318840579713e-05, 'epoch': 0.26}                                                                                 
{'embedding_loss': 0.001, 'learning_rate': 1.6463768115942028e-05, 'epoch': 0.52}                                                                                  
{'embedding_loss': 0.0006, 'learning_rate': 1.3565217391304348e-05, 'epoch': 0.78}                                                                                 
{'embedding_loss': 0.0003, 'learning_rate': 1.0666666666666667e-05, 'epoch': 1.04}                                                                                 
{'embedding_loss': 0.0004, 'learning_rate': 7.768115942028987e-06, 'epoch': 1.3}                                                                                   
{'embedding_loss': 0.0002, 'learning_rate': 4.869565217391305e-06, 'epoch': 1.56}                                                                                  
{'embedding_loss': 0.0003, 'learning_rate': 1.9710144927536233e-06, 'epoch': 1.82}                                                                                 
{'train_runtime': 84.3703, 'train_samples_per_second': 72.822, 'train_steps_per_second': 4.551, 'epoch': 2.0}                                                      
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 384/384 [01:24<00:00,  4.55it/s] 
***** Running evaluation *****
{'accuracy': 0.8378947368421052}

這個大型教師模型達到了 83.79% 的準確率,對於這麼少的資料來說,這相當強大,並且明顯強於我們更小(但更高效)的模型獲得的 78.18%。

知識蒸餾

可以使用 DistillationTrainer 將更強大的 teacher_model 的效能蒸餾到較小的模型中。它接受一個教師模型和一個學生模型,以及一個未標記資料集。

請注意,此訓練器使用句子之間的配對作為訓練樣本,因此訓練步驟的數量會隨著未標記樣本的數量呈指數級增長。為避免過擬合,請考慮將 max_steps 設定得相對較低。

from setfit import DistillationTrainer

distillation_args = TrainingArguments(
    batch_size=16,
    max_steps=500,
)

distillation_trainer = DistillationTrainer(
    teacher_model=teacher_model,
    student_model=model,
    args=distillation_args,
    train_dataset=unlabeled_train_dataset,
    eval_dataset=eval_dataset,
)
# Train student with knowledge distillation
distillation_trainer.train()
distillation_metrics = distillation_trainer.evaluate()
print(distillation_metrics)
***** Running training *****
  Num examples = 7829
  Num epochs = 1
  Total optimization steps = 7829
  Total train batch size = 16
{'embedding_loss': 0.5048, 'learning_rate': 2.554278416347382e-08, 'epoch': 0.0}                                                                                   
{'embedding_loss': 0.4514, 'learning_rate': 1.277139208173691e-06, 'epoch': 0.01}                                                                                  
{'embedding_loss': 0.33, 'learning_rate': 2.554278416347382e-06, 'epoch': 0.01}                                                                                    
{'embedding_loss': 0.1218, 'learning_rate': 3.831417624521073e-06, 'epoch': 0.02}                                                                                  
{'embedding_loss': 0.0213, 'learning_rate': 5.108556832694764e-06, 'epoch': 0.03}                                                                                  
{'embedding_loss': 0.016, 'learning_rate': 6.385696040868455e-06, 'epoch': 0.03}                                                                                   
{'embedding_loss': 0.0054, 'learning_rate': 7.662835249042147e-06, 'epoch': 0.04}                                                                                  
{'embedding_loss': 0.0049, 'learning_rate': 8.939974457215838e-06, 'epoch': 0.04}                                                                                  
{'embedding_loss': 0.002, 'learning_rate': 1.0217113665389528e-05, 'epoch': 0.05}                                                                                  
{'embedding_loss': 0.0019, 'learning_rate': 1.1494252873563218e-05, 'epoch': 0.06}                                                                                 
{'embedding_loss': 0.0012, 'learning_rate': 1.277139208173691e-05, 'epoch': 0.06}                                                                                  
{'train_runtime': 22.2725, 'train_samples_per_second': 359.188, 'train_steps_per_second': 22.449, 'epoch': 0.06}                                                   
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 500/500 [00:22<00:00, 22.45it/s] 
***** Running evaluation *****
{'accuracy': 0.8084210526315789}

透過知識蒸餾,我們能夠在幾分鐘的訓練時間內將模型的效能從 78.18% 提高到 80.84%。

端到端

此程式碼片段展示了端到端知識蒸餾策略的完整示例

from datasets import load_dataset
from setfit import sample_dataset

# Load a dataset from the Hugging Face Hub
dataset = load_dataset("ag_news")

# Create a sample few-shot dataset to train with
train_dataset = sample_dataset(dataset["train"], label_column="label", num_samples=16)
# Dataset({
#     features: ['text', 'label'],
#     num_rows: 64
# })

# Dataset for evaluation
eval_dataset = dataset["test"]
# Dataset({
#     features: ['text', 'label'],
#     num_rows: 7600
# })

from setfit import SetFitModel, TrainingArguments, Trainer

model = SetFitModel.from_pretrained("sentence-transformers/paraphrase-MiniLM-L3-v2")

args = TrainingArguments(
    batch_size=64,
    num_epochs=5,
)

trainer = Trainer(
    model=model,
    args=args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
)
trainer.train()

metrics = trainer.evaluate()
print(metrics)

# Create a dataset of unlabeled examples to perform knowledge distillation
unlabeled_train_dataset = dataset["train"].shuffle(seed=0).select(range(500))
unlabeled_train_dataset = unlabeled_train_dataset.remove_columns("label")
# Dataset({
#     features: ['text'],
#     num_rows: 500
# })

from setfit import SetFitModel

teacher_model = SetFitModel.from_pretrained("sentence-transformers/paraphrase-mpnet-base-v2")

from setfit import TrainingArguments, Trainer

teacher_args = TrainingArguments(
    batch_size=16,
    num_epochs=2,
)

teacher_trainer = Trainer(
    model=teacher_model,
    args=teacher_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
)

# Train teacher model
teacher_trainer.train()
teacher_metrics = teacher_trainer.evaluate()
print(teacher_metrics)

from setfit import DistillationTrainer

distillation_args = TrainingArguments(
    batch_size=16,
    max_steps=500,
)

distillation_trainer = DistillationTrainer(
    teacher_model=teacher_model,
    student_model=model,
    args=distillation_args,
    train_dataset=unlabeled_train_dataset,
    eval_dataset=eval_dataset,
)
# Train student with knowledge distillation
distillation_trainer.train()
distillation_metrics = distillation_trainer.evaluate()
print(distillation_metrics)
< > 在 GitHub 上更新

© . This site is unofficial and not affiliated with Hugging Face, Inc.