TRL 文件

Unsloth 整合

Hugging Face's logo
加入 Hugging Face 社群

並獲得增強的文件體驗

開始使用

Unsloth 整合

此部分正在建設中。歡迎貢獻!

Unsloth 是一個用於微調和強化學習的開源框架,它能將 LLM(如 Llama、Mistral、Gemma、DeepSeek 等)的訓練速度提升高達 2 倍,同時減少高達 70% 的 VRAM 佔用,併為訓練、評估和部署提供了一個精簡的、與 Hugging Face 相容的工作流程。Unsloth 庫與 SFTTrainer 完全相容。下面列出了在 1 x A100 上的部分基準測試結果

1 A100 40GB 資料集 🤗 🤗 + FlashAttention 2 🦥 Unsloth 🦥 節省的 VRAM
Code Llama 34b Slim Orca 1 倍 1.01x 1.94x -22.7%
Llama-2 7b Slim Orca 1 倍 0.96x 1.87x -39.3%
Mistral 7b Slim Orca 1 倍 1.17x 1.88x -65.9%
Tiny Llama 1.1b Alpaca 1 倍 1.55x 2.74x -57.8%

首先,根據官方文件安裝 unsloth。安裝後,你可以非常簡單地將 unsloth 整合到你的工作流程中;你只需載入一個 FastLanguageModel,而不是載入 AutoModelForCausalLM,如下所示

import torch
from trl import SFTConfig, SFTTrainer
from unsloth import FastLanguageModel

max_length = 2048 # Supports automatic RoPE Scaling, so choose any number

# Load model
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name="unsloth/mistral-7b",
    max_seq_length=max_length,
    dtype=None,  # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+
    load_in_4bit=True,  # Use 4bit quantization to reduce memory usage. Can be False
)

# Do model patching and add fast LoRA weights
model = FastLanguageModel.get_peft_model(
    model,
    r=16,
    target_modules=[
        "q_proj",
        "k_proj",
        "v_proj",
        "o_proj",
        "gate_proj",
        "up_proj",
        "down_proj",
    ],
    lora_alpha=16,
    lora_dropout=0,  # Dropout = 0 is currently optimized
    bias="none",  # Bias = "none" is currently optimized
    use_gradient_checkpointing=True,
    random_state=3407,
)

training_args = SFTConfig(output_dir="./output", max_length=max_length)

trainer = SFTTrainer(
    model=model,
    args=training_args,
    train_dataset=dataset,
)
trainer.train()

儲存的模型與 Hugging Face 的 transformers 庫完全相容。在他們的官方倉庫中瞭解更多關於 unsloth 的資訊。

< > 在 GitHub 上更新

© . This site is unofficial and not affiliated with Hugging Face, Inc.