Google TPU 文件
在 Google TPU 上微調 Gemma
並獲得增強的文件體驗
開始使用
在 Google TPU 上微調 Gemma
本教學將教您如何在 Google Cloud 的 TPU 上微調開放式大型語言模型(LLM),例如 Google Gemma。在範例中,我們將運用 Hugging Face Optimum TPU、🤗 Transformers 及 datasets 函式庫。
Google TPU
Google Cloud TPU 是專門設計的 AI 加速器,針對大型 AI 模型的訓練與推論進行了最佳化。它們非常適合多種使用情境,例如聊天機器人、程式碼生成、媒體內容生成、合成語音、視覺服務、推薦引擎、個人化模型等。
使用 TPU 的優點包括:
- 設計用於以符合成本效益的方式擴展各種 AI 工作負載,涵蓋訓練、微調及推論。
- 針對 TensorFlow、PyTorch 和 JAX 進行了最佳化,並提供多種外型規格,包括邊緣裝置、工作站以及雲端架構。
- TPU 可在 Google Cloud 上使用,並已與 Vertex AI 和 Google Kubernetes Engine (GKE) 整合。
環境設定
對於此範例,單一主機的 v5litepod8 TPU 即已足夠。若要使用 Pytorch XLA 設定 TPU 環境,請參考此 Google Cloud 指南。
我們可以使用 ssh 或 gcloud 指令登入遠端 TPU,並啟用 8888 連接埠的轉發(Port-forwarding),例如:
gcloud compute tpus tpu-vm ssh $TPU_NAME \
--zone=$ZONE \
-- -L 8888:localhost:8888一旦我們能存取 TPU VM,就可以複製包含相關筆記本的 optimum-tpu 儲存庫。接著,我們安裝本教學所需的幾個套件,並啟動筆記本。
git clone https://github.com/huggingface/optimum-tpu.git
# Install Optimum tpu
pip install -e . -f https://storage.googleapis.com/libtpu-releases/index.html
# Install TRL and PEFT for training (see later how they are used)
pip install trl peft
# Install Jupyter notebook
pip install -U jupyterlab notebook
# Optionally, install widgets extensions for better rendering
pip install ipywidgets widgetsnbextension
# Change directory and launch Jupyter notebook
cd optimum-tpu/examples/language-modeling
jupyter notebook --port 8888之後應該會看到熟悉的 Jupyter 輸出畫面,顯示可從瀏覽器存取的位址。
http://:8888/tree?token=3ceb24619d0a2f99acf5fba41c51b475b1ddce7cadb2a133由於我們將使用受存取權限制的 gemma 模型,因此需要使用 Hugging Face 權杖(Token)進行登入。
!huggingface-cli login --token YOUR_HF_TOKEN
啟用 FSDPv2
為了微調 LLM,可能有必要將模型切分(Shard)至各個 TPU 上,以避免記憶體問題並提升微調效能。Fully Sharded Data Parallel(完全分片資料並行)是一種已在 PyTorch 上實作的演算法,允許封裝模組以進行分佈。當在 TPU 上使用 PyTorch/XLA 時,FSDPv2 是一種利用 SPMD(單程式多資料)重新演繹著名 FSDP 演算法的公用程式。在 optimum-tpu 中,可以使用專用的輔助工具來運用 FSDPv2。若要啟用它,您可以使用專用函式,且該函式應在執行開始時呼叫。
from optimum.tpu import fsdp_v2
fsdp_v2.use_fsdp_v2()載入並準備資料集
我們將使用 Dolly,這是一個開源資料集,包含 InstructGPT 論文中概述的指令遵循紀錄類別,包括腦力激盪、分類、封閉式問答、生成、資訊擷取、開放式問答及摘要。
我們將從 Hugging Face Hub 載入該資料集。
from datasets import load_dataset
dataset = load_dataset("databricks/databricks-dolly-15k", split="train")我們可以查看其中一個樣本。
dataset[321]我們將獲得類似這樣的結果。
{
"instruction": "When was the 8088 processor released?",
"context": "The 8086 (also called iAPX 86) is a 16-bit microprocessor chip designed by Intel between early 1976 and June 8, 1978, when it was released. The Intel 8088, released July 1, 1979, is a slightly modified chip with an external 8-bit data bus (allowing the use of cheaper and fewer supporting ICs),[note 1] and is notable as the processor used in the original IBM PC design.",
"response": "The Intel 8088 processor was released July 1, 1979.",
"category": "information_extraction",
}我們將定義一個格式化函式,將 instruction(指令)、context(上下文)和 response(回覆)欄位組合成一個完整的提示(Prompt)並進行 Tokenize。我們將使用與我們打算使用的模型相容的 Tokenizer。
from transformers import AutoTokenizer
model_id = "google/gemma-2b"
tokenizer = AutoTokenizer.from_pretrained(model_id)
def preprocess_function(sample):
instruction = f"### Instruction\n{sample['instruction']}"
context = f"### Context\n{sample['context']}" if len(sample["context"]) > 0 else None
response = f"### Answer\n{sample['response']}"
# join all the parts together
prompt = "\n\n".join([i for i in [instruction, context, response] if i is not None])
prompt += tokenizer.eos_token
sample["prompt"] = prompt
return sample現在可以使用此函式來映射(map)資料集,並移除原本的欄位。
data = dataset.map(preprocess_function, remove_columns=list(dataset.features))準備微調模型
我們現在可以載入將用於微調的模型。資料集現在已準備好進行微調。
import torch
from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained(model_id, use_cache=False, torch_dtype=torch.bfloat16)我們將使用 PEFT (Parameter Efficient FineTuning) 和 LoRA (Low-Rank Adaptation) 有效地在準備好的資料集上微調模型。在 LoraConfig 執行個體中,我們將定義要進行微調的 nn.Linear 操作。
from peft import LoraConfig
# Set up PEFT LoRA for fine-tuning.
lora_config = LoraConfig(
r=8,
target_modules=["k_proj", "v_proj"],
task_type="CAUSAL_LM",
)optimum-tpu 的專用函式將協助我們取得引數,以便建立訓練器(Trainer)執行個體。
from transformers import TrainingArguments
from trl import SFTTrainer
# Set up the FSDP arguments
fsdp_training_args = fsdp_v2.get_fsdp_training_args(model)
# Set up the trainer
trainer = SFTTrainer(
model=model,
train_dataset=data,
args=TrainingArguments(
per_device_train_batch_size=64,
num_train_epochs=32,
max_steps=-1,
output_dir="./output",
optim="adafactor",
logging_steps=1,
dataloader_drop_last=True, # Required for FSDPv2.
**fsdp_training_args,
),
peft_config=lora_config,
dataset_text_field="prompt",
max_seq_length=1024,
packing=True,
)一切準備就緒後,微調模型就像呼叫一個函式一樣簡單!
trainer.train()
完成此步驟後,我們已成功在 Dolly 資料集上微調模型。