SetFit 文件
知識蒸餾
並獲得增強的文件體驗
開始使用
知識蒸餾
如果您有未標記的資料,則可以使用知識蒸餾來提高小型 SetFit 模型的效能。該方法涉及訓練一個更大的模型,並使用未標記的資料將其效能蒸餾到您的小型 SetFit 模型中。因此,您的 SetFit 模型將變得更強大。
此外,您還可以使用知識蒸餾來用更高效的模型替換訓練過的 SetFit 模型,同時減少效能下降。
本指南將向您展示如何進行知識蒸餾。
資料準備
讓我們考慮一個場景,只有少量標記訓練資料(例如 64 個句子)。本指南將使用 ag_news 資料集來模擬此場景。
from datasets import load_dataset
from setfit import sample_dataset
# Load a dataset from the Hugging Face Hub
dataset = load_dataset("ag_news")
# Create a sample few-shot dataset to train with
train_dataset = sample_dataset(dataset["train"], label_column="label", num_samples=16)
# Dataset({
# features: ['text', 'label'],
# num_rows: 64
# })
# Dataset for evaluation
eval_dataset = dataset["test"]
# Dataset({
# features: ['text', 'label'],
# num_rows: 7600
# })
基線模型
我們可以使用標準的 SetFit 訓練方法來準備模型。
from setfit import SetFitModel, TrainingArguments, Trainer
model = SetFitModel.from_pretrained("sentence-transformers/paraphrase-MiniLM-L3-v2")
args = TrainingArguments(
batch_size=64,
num_epochs=5,
)
trainer = Trainer(
model=model,
args=args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
)
trainer.train()
metrics = trainer.evaluate()
print(metrics)
***** Running training *****
Num examples = 48
Num epochs = 5
Total optimization steps = 240
Total train batch size = 64
{'embedding_loss': 0.4173, 'learning_rate': 8.333333333333333e-07, 'epoch': 0.02}
{'embedding_loss': 0.1756, 'learning_rate': 1.7592592592592595e-05, 'epoch': 1.04}
{'embedding_loss': 0.119, 'learning_rate': 1.2962962962962964e-05, 'epoch': 2.08}
{'embedding_loss': 0.0872, 'learning_rate': 8.333333333333334e-06, 'epoch': 3.12}
{'embedding_loss': 0.0542, 'learning_rate': 3.7037037037037037e-06, 'epoch': 4.17}
{'train_runtime': 26.0837, 'train_samples_per_second': 588.873, 'train_steps_per_second': 9.201, 'epoch': 5.0}
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 240/240 [00:20<00:00, 11.97it/s]
***** Running evaluation *****
{'accuracy': 0.7818421052631579}
該模型在我們的資料集上達到了 78.18% 的準確率。考慮到微小的訓練資料量,這當然是值得稱讚的,但我們可以使用知識蒸餾從我們的模型中榨取更多的效能。
未標記資料準備
除了標記訓練資料之外,我們可能還有大量未標記的訓練資料(例如 500 個句子)。讓我們準備一下
# Create a dataset of unlabeled examples to perform knowledge distillation
unlabeled_train_dataset = dataset["train"].shuffle(seed=0).select(range(500))
unlabeled_train_dataset = unlabeled_train_dataset.remove_columns("label")
# Dataset({
# features: ['text'],
# num_rows: 500
# })
教師模型
然後,我們將準備一個更大的訓練好的 SetFit 模型,作為我們較小學生模型的教師。強大的 sentence-transformers/paraphrase-mpnet-base-v2
句子轉換器模型將用於初始化 SetFit 模型。
from setfit import SetFitModel
teacher_model = SetFitModel.from_pretrained("sentence-transformers/paraphrase-mpnet-base-v2")
我們首先需要在標記資料集上訓練此模型
from setfit import TrainingArguments, Trainer
teacher_args = TrainingArguments(
batch_size=16,
num_epochs=2,
)
teacher_trainer = Trainer(
model=teacher_model,
args=teacher_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
)
# Train teacher model
teacher_trainer.train()
teacher_metrics = teacher_trainer.evaluate()
print(teacher_metrics)
***** Running training *****
Num examples = 192
Num epochs = 2
Total optimization steps = 384
Total train batch size = 16
{'embedding_loss': 0.4093, 'learning_rate': 5.128205128205128e-07, 'epoch': 0.01}
{'embedding_loss': 0.1087, 'learning_rate': 1.9362318840579713e-05, 'epoch': 0.26}
{'embedding_loss': 0.001, 'learning_rate': 1.6463768115942028e-05, 'epoch': 0.52}
{'embedding_loss': 0.0006, 'learning_rate': 1.3565217391304348e-05, 'epoch': 0.78}
{'embedding_loss': 0.0003, 'learning_rate': 1.0666666666666667e-05, 'epoch': 1.04}
{'embedding_loss': 0.0004, 'learning_rate': 7.768115942028987e-06, 'epoch': 1.3}
{'embedding_loss': 0.0002, 'learning_rate': 4.869565217391305e-06, 'epoch': 1.56}
{'embedding_loss': 0.0003, 'learning_rate': 1.9710144927536233e-06, 'epoch': 1.82}
{'train_runtime': 84.3703, 'train_samples_per_second': 72.822, 'train_steps_per_second': 4.551, 'epoch': 2.0}
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 384/384 [01:24<00:00, 4.55it/s]
***** Running evaluation *****
{'accuracy': 0.8378947368421052}
這個大型教師模型達到了 83.79% 的準確率,對於這麼少的資料來說,這相當強大,並且明顯強於我們更小(但更高效)的模型獲得的 78.18%。
知識蒸餾
可以使用 DistillationTrainer 將更強大的 teacher_model 的效能蒸餾到較小的模型中。它接受一個教師模型和一個學生模型,以及一個未標記資料集。
請注意,此訓練器使用句子之間的配對作為訓練樣本,因此訓練步驟的數量會隨著未標記樣本的數量呈指數級增長。為避免過擬合,請考慮將 max_steps
設定得相對較低。
from setfit import DistillationTrainer
distillation_args = TrainingArguments(
batch_size=16,
max_steps=500,
)
distillation_trainer = DistillationTrainer(
teacher_model=teacher_model,
student_model=model,
args=distillation_args,
train_dataset=unlabeled_train_dataset,
eval_dataset=eval_dataset,
)
# Train student with knowledge distillation
distillation_trainer.train()
distillation_metrics = distillation_trainer.evaluate()
print(distillation_metrics)
***** Running training *****
Num examples = 7829
Num epochs = 1
Total optimization steps = 7829
Total train batch size = 16
{'embedding_loss': 0.5048, 'learning_rate': 2.554278416347382e-08, 'epoch': 0.0}
{'embedding_loss': 0.4514, 'learning_rate': 1.277139208173691e-06, 'epoch': 0.01}
{'embedding_loss': 0.33, 'learning_rate': 2.554278416347382e-06, 'epoch': 0.01}
{'embedding_loss': 0.1218, 'learning_rate': 3.831417624521073e-06, 'epoch': 0.02}
{'embedding_loss': 0.0213, 'learning_rate': 5.108556832694764e-06, 'epoch': 0.03}
{'embedding_loss': 0.016, 'learning_rate': 6.385696040868455e-06, 'epoch': 0.03}
{'embedding_loss': 0.0054, 'learning_rate': 7.662835249042147e-06, 'epoch': 0.04}
{'embedding_loss': 0.0049, 'learning_rate': 8.939974457215838e-06, 'epoch': 0.04}
{'embedding_loss': 0.002, 'learning_rate': 1.0217113665389528e-05, 'epoch': 0.05}
{'embedding_loss': 0.0019, 'learning_rate': 1.1494252873563218e-05, 'epoch': 0.06}
{'embedding_loss': 0.0012, 'learning_rate': 1.277139208173691e-05, 'epoch': 0.06}
{'train_runtime': 22.2725, 'train_samples_per_second': 359.188, 'train_steps_per_second': 22.449, 'epoch': 0.06}
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 500/500 [00:22<00:00, 22.45it/s]
***** Running evaluation *****
{'accuracy': 0.8084210526315789}
透過知識蒸餾,我們能夠在幾分鐘的訓練時間內將模型的效能從 78.18% 提高到 80.84%。
端到端
此程式碼片段展示了端到端知識蒸餾策略的完整示例
from datasets import load_dataset
from setfit import sample_dataset
# Load a dataset from the Hugging Face Hub
dataset = load_dataset("ag_news")
# Create a sample few-shot dataset to train with
train_dataset = sample_dataset(dataset["train"], label_column="label", num_samples=16)
# Dataset({
# features: ['text', 'label'],
# num_rows: 64
# })
# Dataset for evaluation
eval_dataset = dataset["test"]
# Dataset({
# features: ['text', 'label'],
# num_rows: 7600
# })
from setfit import SetFitModel, TrainingArguments, Trainer
model = SetFitModel.from_pretrained("sentence-transformers/paraphrase-MiniLM-L3-v2")
args = TrainingArguments(
batch_size=64,
num_epochs=5,
)
trainer = Trainer(
model=model,
args=args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
)
trainer.train()
metrics = trainer.evaluate()
print(metrics)
# Create a dataset of unlabeled examples to perform knowledge distillation
unlabeled_train_dataset = dataset["train"].shuffle(seed=0).select(range(500))
unlabeled_train_dataset = unlabeled_train_dataset.remove_columns("label")
# Dataset({
# features: ['text'],
# num_rows: 500
# })
from setfit import SetFitModel
teacher_model = SetFitModel.from_pretrained("sentence-transformers/paraphrase-mpnet-base-v2")
from setfit import TrainingArguments, Trainer
teacher_args = TrainingArguments(
batch_size=16,
num_epochs=2,
)
teacher_trainer = Trainer(
model=teacher_model,
args=teacher_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
)
# Train teacher model
teacher_trainer.train()
teacher_metrics = teacher_trainer.evaluate()
print(teacher_metrics)
from setfit import DistillationTrainer
distillation_args = TrainingArguments(
batch_size=16,
max_steps=500,
)
distillation_trainer = DistillationTrainer(
teacher_model=teacher_model,
student_model=model,
args=distillation_args,
train_dataset=unlabeled_train_dataset,
eval_dataset=eval_dataset,
)
# Train student with knowledge distillation
distillation_trainer.train()
distillation_metrics = distillation_trainer.evaluate()
print(distillation_metrics)