TRL 文件
Unsloth 整合
加入 Hugging Face 社群
並獲得增強的文件體驗
開始使用
Unsloth 整合
此部分正在建設中。歡迎貢獻!
Unsloth 是一個用於微調和強化學習的開源框架,它能將 LLM(如 Llama、Mistral、Gemma、DeepSeek 等)的訓練速度提升高達 2 倍,同時減少高達 70% 的 VRAM 佔用,併為訓練、評估和部署提供了一個精簡的、與 Hugging Face 相容的工作流程。Unsloth 庫與 SFTTrainer 完全相容。下面列出了在 1 x A100 上的部分基準測試結果
1 A100 40GB | 資料集 | 🤗 | 🤗 + FlashAttention 2 | 🦥 Unsloth | 🦥 節省的 VRAM |
---|---|---|---|---|---|
Code Llama 34b | Slim Orca | 1 倍 | 1.01x | 1.94x | -22.7% |
Llama-2 7b | Slim Orca | 1 倍 | 0.96x | 1.87x | -39.3% |
Mistral 7b | Slim Orca | 1 倍 | 1.17x | 1.88x | -65.9% |
Tiny Llama 1.1b | Alpaca | 1 倍 | 1.55x | 2.74x | -57.8% |
首先,根據官方文件安裝 unsloth
。安裝後,你可以非常簡單地將 unsloth 整合到你的工作流程中;你只需載入一個 FastLanguageModel
,而不是載入 AutoModelForCausalLM
,如下所示
import torch
from trl import SFTConfig, SFTTrainer
from unsloth import FastLanguageModel
max_length = 2048 # Supports automatic RoPE Scaling, so choose any number
# Load model
model, tokenizer = FastLanguageModel.from_pretrained(
model_name="unsloth/mistral-7b",
max_seq_length=max_length,
dtype=None, # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+
load_in_4bit=True, # Use 4bit quantization to reduce memory usage. Can be False
)
# Do model patching and add fast LoRA weights
model = FastLanguageModel.get_peft_model(
model,
r=16,
target_modules=[
"q_proj",
"k_proj",
"v_proj",
"o_proj",
"gate_proj",
"up_proj",
"down_proj",
],
lora_alpha=16,
lora_dropout=0, # Dropout = 0 is currently optimized
bias="none", # Bias = "none" is currently optimized
use_gradient_checkpointing=True,
random_state=3407,
)
training_args = SFTConfig(output_dir="./output", max_length=max_length)
trainer = SFTTrainer(
model=model,
args=training_args,
train_dataset=dataset,
)
trainer.train()
儲存的模型與 Hugging Face 的 transformers 庫完全相容。在他們的官方倉庫中瞭解更多關於 unsloth 的資訊。
< > 在 GitHub 上更新