Ettin 套件:當前最佳的成對編碼器與解碼器
摘要
如果將 ModernBERT 的訓練方案應用到純解碼器模型上會發生什麼?結果是,我們得到了一個當前最先進的解碼器語言模型,它擊敗了 Llama 3.2 1B 和 SmolLM2!
我們引入了一個全新的開放資料訓練方案,以復現純編碼器模型 ModernBERT(並且實際上超越了它!)。然後,我們將完全相同的方案應用到純解碼器模型上。我們首次在相同的設定下,使用兩種不同的訓練目標(掩碼語言建模 MLM 和因果語言建模 CLM)訓練出了兩個當前最先進的模型。
這篇部落格文章介紹了 Ettin,這是首個包含當前最先進的 成對純編碼器和純解碼器模型(引數量從 1700 萬到 10 億)的套件,它們使用完全相同的資料(2 萬億詞元)、架構和訓練方案進行訓練。Ettin 使得在兩種架構之間進行真正的同類比較成為可能,並在兩個類別中都提供了 當前最佳效能的開放資料模型。我們還進一步探討了從解碼器出發獲得有競爭力的編碼器,以及反向操作的可能性。
如果您有興趣嘗試這些模型,在本文末尾提供了一些樣板程式碼!
編碼器 vs. 解碼器:架構之爭
大型語言模型(LLM)社群已大體上趨向於使用像 GPT、Llama 和 Qwen 這樣的純解碼器模型。它們的生成能力令人印象深刻,但這種關注卻分散了對其他類別模型的注意力,例如像 BERT 這樣的純編碼器模型。
然而,類似於 BERT 的編碼器模型仍然是分類、檢索和嵌入任務生產系統中的主力。它們更快、記憶體效率更高,並且在判別性任務上通常更準確。關鍵區別在於它們的注意力模式:
- 編碼器模型 使用雙向注意力,允許每個詞元“看到”序列中的所有其他詞元(完全可見)。
- 解碼器模型 使用因果注意力,詞元只能“看到”前面的詞元,以實現自迴歸生成。
儘管解碼器模型取得了快速創新,但編碼器模型的開發卻停滯不前——直到最近,像 ModernBERT 這樣的努力才使其現代化。但哪種架構更好呢?以往對編碼器和解碼器的比較使用了不同的資料集、架構和訓練方案,因此很難判斷。
Ettin 以北歐神話中的雙頭巨人命名,透過在相同的資料、相同的模型形態和相同的訓練方案上訓練兩種架構,提供了一次 受控比較。它們唯一的區別在於注意力模式和訓練目標!
訓練方案:適用於兩種架構的現代技術
我們基於 ModernBERT 的訓練方案,它借鑑了純解碼器模型的現代技術並將其引入編碼器訓練中。這為訓練兩種架構提供了堅實的基礎。
模型大小
我們訓練了六種不同大小的模型,引數量從 1700 萬到 10 億不等。這使我們能夠測試規模效應,併為您提供多種模型選擇!無論您需要一個極速的端側模型,還是一個功能強大但速度較慢的模型,我們都能滿足您的需求!
三階段訓練過程
我們採用全面的三階段訓練方法以最大化效能:
第一階段 - 預訓練(1.7 萬億詞元):我們從多樣化的高質量資料來源混合開始,在較短的上下文(1024 詞元)上進行訓練,以建立堅實的基礎知識。
第二階段 - 上下文擴充套件(2500 億詞元):我們使用更高質量的過濾資料將上下文長度增加到 8K 詞元,使模型能夠理解更長的文件和更復雜的關係。
第三階段 - 衰減(1000 億詞元):我們以包括科學論文、教科書和精選內容在內的高階資料來源結束訓練,同時逐漸降低學習率。
現代架構元件
我們的編碼器模型獲得了 ModernBERT 速度上的所有優勢,使其比前幾代編碼器快得多。
資料來源與質量
與 ModernBERT 不同,我們所有的訓練資料都是公開且可復現的。
您可以繼續在新資料上訓練這些模型,或提出新的方案來進一步提升結果!
編碼器結果:擊敗 ModernBERT
我們的編碼器模型在所有任務和模型大小上均 優於 ModernBERT,同時完全使用開放的訓練資料。由於我們提供了多種大小的模型,您現在可以在更小的尺寸上使用 ModernBERT 風格的模型(非常適合端側裝置或快速推理),或者使用 10 億引數規模的編碼器來碾壓競爭對手。
解碼器結果:擊敗 Llama 3.2 和 SmolLM2
將相同的訓練方案應用於解碼器模型同樣取得了令人印象深刻的結果,我們的模型 優於或持平 於 Llama 3.2 和 SmolLM2 等現有基線。
在像 SciQ 這樣的知識密集型任務上,增益尤為明顯,這反映了我們高質量訓練資料混合的優勢。這些結果表明,我們的訓練方案在兩種架構正規化中都能創建出真正強大的模型。
公平對決:編碼器與解碼器的同臺競技
我們首次能夠公平地比較在相同資料和訓練方案下訓練的編碼器和解碼器架構。結果揭示了即使在控制了所有其他因素的情況下,根本性的架構優勢依然存在。
架構特定優勢依然存在
結果顯示出清晰的模式:
編碼器在分類和檢索任務中占主導地位:在 MNLI 分類任務上,一個 1.5 億引數的編碼器(89.2)甚至優於一個 4 億引數的解碼器(88.2)。在檢索任務中,差距雖小但仍然顯著——尤其是在解碼器未經 MNTP 訓練時。
解碼器在生成任務上表現出色:在生成任務上,解碼器保持了一貫的優勢,並且隨著模型規模的增大,效能差距實際上在擴大。
大小並非總是決定因素:一個 4 億引數的編碼器在分類任務上擊敗了一個 10 億引數的解碼器,而一個 4 億引數的解碼器在生成任務上擊敗了一個 10 億引數的編碼器。
跨目標訓練效果不佳
由於缺乏新的編碼器模型,像 LLM2Vec 這樣的工作提出了用 MLM 繼續預訓練解碼器。我們現在可以測試這種策略的有效性了!
我們切換了目標,並用相反的目標繼續訓練我們的模型,增加了 500 億詞元的訓練量。以下是我們的發現:
- 解碼器轉編碼器:在分類/檢索任務上,仍然普遍落後於原生編碼器。
- 編碼器轉解碼器:比原生解碼器差很多,尤其是在較大規模時。這可能是因為編碼器是用 MLM 而不是 MNTP(掩碼下一詞元預測)訓練的,而 LLM2Vec(以及我們的解碼器轉編碼器方案)推薦使用 MNTP。
這表明架構選擇本身至關重要,而不僅僅是訓練目標。
超越效能:理解模型行為
由於訓練資料完全相同,我們可以研究不同目標如何影響學習過程。例如,使用 WinoGender 基準分析性別偏見揭示了:
- 編碼器模型 更頻繁地偏好性別中立的代詞(60%+ 的中性代詞 vs. 解碼器的 30%+)。
- 兩種架構 都顯示出男性偏見,但解碼器略微更嚴重。
- 跨目標訓練 以可衡量的方式影響了偏見模式。
這為系統性研究訓練目標如何影響模型行為(不僅僅是準確性指標)打開了大門。
使用示例
只需幾行程式碼,您就可以使用這些模型!
編碼器
from transformers import AutoTokenizer, AutoModel
# Load encoder for classification/embeddings
tokenizer = AutoTokenizer.from_pretrained("jhu-clsp/ettin-encoder-150m")
model = AutoModel.from_pretrained("jhu-clsp/ettin-encoder-150m")
def predict_masked_token(text):
inputs = tokenizer(text, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
# Get predictions for [MASK] tokens
mask_indices = torch.where(inputs["input_ids"] == tokenizer.mask_token_id)
predictions = outputs.logits[mask_indices]
# Get top 5 predictions
top_tokens = torch.topk(predictions, 5, dim=-1)
return [tokenizer.decode(token) for token in top_tokens.indices[0]]
# Example
masked_text = "The capital of France is [MASK]."
predictions = predict_masked_token(masked_text)
print(f"Predictions: {predictions}")
對於分類和檢索任務,請使用編碼器模型: 您可能也想在這些任務上使用微調過的版本。
解碼器
對於文字生成任務,請使用解碼器模型:
from transformers import AutoTokenizer, AutoModelForCausalLM
# Load decoder for generation
tokenizer = AutoTokenizer.from_pretrained("jhu-clsp/ettin-decoder-150m")
model = AutoModelForCausalLM.from_pretrained("jhu-clsp/ettin-decoder-150m")
# Generate text
prompt = "The future of artificial intelligence is"
inputs = tokenizer(prompt, return_tensors="pt")
outputs = model.generate(inputs.input_ids, max_length=50, temperature=0.7)
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
微調示例
編碼器
點選檢視如何使用 Sentence Transformers 將其微調為密集嵌入模型
import argparse
from datasets import load_dataset
from sentence_transformers import (
SentenceTransformer,
SentenceTransformerTrainer,
SentenceTransformerTrainingArguments,
)
from sentence_transformers.evaluation import TripletEvaluator
from sentence_transformers.losses import CachedMultipleNegativesRankingLoss
from sentence_transformers.training_args import BatchSamplers
def main():
# parse the lr & model name
parser = argparse.ArgumentParser()
parser.add_argument("--lr", type=float, default=8e-5)
parser.add_argument("--model_name", type=str, default="jhu-clsp/ettin-encoder-150m")
args = parser.parse_args()
lr = args.lr
model_name = args.model_name
model_shortname = model_name.split("/")[-1]
# 1. Load a model to finetune
model = SentenceTransformer(model_name)
# 2. Load a dataset to finetune on
dataset = load_dataset(
"sentence-transformers/msmarco-co-condenser-margin-mse-sym-mnrl-mean-v1",
"triplet-hard",
split="train",
)
dataset_dict = dataset.train_test_split(test_size=1_000, seed=12)
train_dataset = dataset_dict["train"].select(range(1_250_000))
eval_dataset = dataset_dict["test"]
# 3. Define a loss function
loss = CachedMultipleNegativesRankingLoss(model, mini_batch_size=16) # Increase mini_batch_size if you have enough VRAM
run_name = f"{model_shortname}-DPR-{lr}"
# 4. (Optional) Specify training arguments
args = SentenceTransformerTrainingArguments(
# Required parameter:
output_dir=f"output/{model_shortname}/{run_name}",
# Optional training parameters:
num_train_epochs=1,
per_device_train_batch_size=512,
per_device_eval_batch_size=512,
warmup_ratio=0.05,
fp16=False, # Set to False if GPU can't handle FP16
bf16=True, # Set to True if GPU supports BF16
batch_sampler=BatchSamplers.NO_DUPLICATES, # (Cached)MultipleNegativesRankingLoss benefits from no duplicates
learning_rate=lr,
# Optional tracking/debugging parameters:
save_strategy="steps",
save_steps=500,
save_total_limit=2,
logging_steps=500,
run_name=run_name, # Used in `wandb`, `tensorboard`, `neptune`, etc. if installed
)
# 5. (Optional) Create an evaluator & evaluate the base model
dev_evaluator = TripletEvaluator(
anchors=eval_dataset["query"],
positives=eval_dataset["positive"],
negatives=eval_dataset["negative"],
name="msmarco-co-condenser-dev",
)
dev_evaluator(model)
# 6. Create a trainer & train
trainer = SentenceTransformerTrainer(
model=model,
args=args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
loss=loss,
evaluator=dev_evaluator,
)
trainer.train()
# 7. (Optional) Evaluate the trained model on the evaluator after training
dev_evaluator(model)
# 8. Save the model
model.save_pretrained(f"output/{model_shortname}/{run_name}/final")
# 9. (Optional) Push it to the Hugging Face Hub
model.push_to_hub(run_name, private=False)
if __name__ == "__main__":
main()
點選檢視如何使用 PyLate 將其微調為多向量嵌入模型
from datasets import load_dataset
from pylate import losses, models, utils
from sentence_transformers import (
SentenceTransformerTrainer,
SentenceTransformerTrainingArguments,
)
def main():
# Load the datasets required for knowledge distillation (train, queries, documents)
train = load_dataset(
path="lightonai/ms-marco-en-bge",
name="train",
)
queries = load_dataset(
path="lightonai/ms-marco-en-bge",
name="queries",
)
documents = load_dataset(
path="lightonai/ms-marco-en-bge",
name="documents",
)
# Set the transformation to load the documents/queries texts using the corresponding ids on the fly
train.set_transform(
utils.KDProcessing(queries=queries, documents=documents).transform,
)
# Define the base model, training parameters, and output directory
num_train_epochs = 1
lr = 8e-5
batch_size = 16
accum_steps = 1
model_name = "jhu-clsp/ettin-encoder-150m"
model_shortname = model_name.split("/")[-1]
# Set the run name for logging and output directory
run_name = f"{model_shortname}-colbert-KD-{lr}"
output_dir = f"output/{model_shortname}/{run_name}"
# Initialize the ColBERT model from the base model
model = models.ColBERT(model_name_or_path=model_name)
# Configure the training arguments (e.g., epochs, batch size, learning rate)
args = SentenceTransformerTrainingArguments(
output_dir=output_dir,
num_train_epochs=num_train_epochs,
per_device_train_batch_size=batch_size,
fp16=False, # Set to False if you get an error that your GPU can't run on FP16
bf16=True, # Set to True if you have a GPU that supports BF16
run_name=run_name,
logging_steps=10,
learning_rate=lr,
gradient_accumulation_steps=accum_steps,
warmup_ratio=0.05,
)
# Use the Distillation loss function for training
train_loss = losses.Distillation(model=model)
# Initialize the trainer
trainer = SentenceTransformerTrainer(
model=model,
args=args,
train_dataset=train,
loss=train_loss,
data_collator=utils.ColBERTCollator(tokenize_fn=model.tokenize),
)
# Start the training process
trainer.train()
model.save_pretrained(f"{output_dir}/final")
if __name__ == "__main__":
main()
點選檢視如何使用 Sentence Transformers 將其微調為稀疏檢索模型
import logging
from datasets import load_dataset
from sentence_transformers import (
SparseEncoder,
SparseEncoderModelCardData,
SparseEncoderTrainer,
SparseEncoderTrainingArguments,
)
from sentence_transformers.sparse_encoder.evaluation import SparseNanoBEIREvaluator
from sentence_transformers.sparse_encoder.losses import SparseMultipleNegativesRankingLoss, SpladeLoss
from sentence_transformers.training_args import BatchSamplers
logging.basicConfig(format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO)
# 1. Load a model to finetune with 2. (Optional) model card data
model = SparseEncoder(
"jhu-clsp/ettin-encoder-150m",
model_card_data=SparseEncoderModelCardData(
language="en",
license="apache-2.0",
)
)
# 3. Load a dataset to finetune on
full_dataset = load_dataset("sentence-transformers/natural-questions", split="train").select(range(100_000))
dataset_dict = full_dataset.train_test_split(test_size=1_000, seed=12)
train_dataset = dataset_dict["train"]
eval_dataset = dataset_dict["test"]
# 4. Define a loss function
loss = SpladeLoss(
model=model,
loss=SparseMultipleNegativesRankingLoss(model=model),
query_regularizer_weight=5e-5,
document_regularizer_weight=3e-5,
)
# 5. (Optional) Specify training arguments
run_name = "splade-distilbert-base-uncased-nq"
args = SparseEncoderTrainingArguments(
# Required parameter:
output_dir=f"models/{run_name}",
# Optional training parameters:
num_train_epochs=1,
per_device_train_batch_size=16,
per_device_eval_batch_size=16,
learning_rate=2e-5,
warmup_ratio=0.1,
fp16=True, # Set to False if you get an error that your GPU can't run on FP16
bf16=False, # Set to True if you have a GPU that supports BF16
batch_sampler=BatchSamplers.NO_DUPLICATES, # MultipleNegativesRankingLoss benefits from no duplicate samples in a batch
# Optional tracking/debugging parameters:
eval_strategy="steps",
eval_steps=1000,
save_strategy="steps",
save_steps=1000,
save_total_limit=2,
logging_steps=200,
run_name=run_name, # Will be used in W&B if `wandb` is installed
)
# 6. (Optional) Create an evaluator & evaluate the base model
dev_evaluator = SparseNanoBEIREvaluator(dataset_names=["msmarco", "nfcorpus", "nq"], batch_size=16)
# 7. Create a trainer & train
trainer = SparseEncoderTrainer(
model=model,
args=args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
loss=loss,
evaluator=dev_evaluator,
)
trainer.train()
# 8. Evaluate the model performance again after training
dev_evaluator(model)
# 9. Save the trained model
model.save_pretrained(f"models/{run_name}/final")
# 10. (Optional) Push it to the Hugging Face Hub
model.push_to_hub(run_name)
點選檢視如何使用 Sentence Transformers 將其微調為重排模型
import logging
import traceback
import torch
from datasets import load_dataset
from sentence_transformers import SentenceTransformer
from sentence_transformers.cross_encoder import (
CrossEncoder,
CrossEncoderModelCardData,
CrossEncoderTrainer,
CrossEncoderTrainingArguments,
)
from sentence_transformers.cross_encoder.evaluation import (
CrossEncoderNanoBEIREvaluator,
CrossEncoderRerankingEvaluator,
)
from sentence_transformers.cross_encoder.losses import BinaryCrossEntropyLoss
from sentence_transformers.evaluation import SequentialEvaluator
from sentence_transformers.util import mine_hard_negatives
# Set the log level to INFO to get more information
logging.basicConfig(format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO)
def main():
model_name = "jhu-clsp/ettin-encoder-150m"
train_batch_size = 64
num_epochs = 1
num_hard_negatives = 5 # How many hard negatives should be mined for each question-answer pair
# 1a. Load a model to finetune with 1b. (Optional) model card data
model = CrossEncoder(
model_name,
model_card_data=CrossEncoderModelCardData(
language="en",
license="apache-2.0",
),
)
print("Model max length:", model.max_length)
print("Model num labels:", model.num_labels)
# 2a. Load the GooAQ dataset: https://huggingface.co/datasets/sentence-transformers/gooaq
logging.info("Read the gooaq training dataset")
full_dataset = load_dataset("sentence-transformers/gooaq", split="train").select(range(100_000))
dataset_dict = full_dataset.train_test_split(test_size=1_000, seed=12)
train_dataset = dataset_dict["train"]
eval_dataset = dataset_dict["test"]
logging.info(train_dataset)
logging.info(eval_dataset)
# 2b. Modify our training dataset to include hard negatives using a very efficient embedding model
embedding_model = SentenceTransformer("sentence-transformers/static-retrieval-mrl-en-v1", device="cpu")
hard_train_dataset = mine_hard_negatives(
train_dataset,
embedding_model,
num_negatives=num_hard_negatives, # How many negatives per question-answer pair
margin=0, # Similarity between query and negative samples should be x lower than query-positive similarity
range_min=0, # Skip the x most similar samples
range_max=100, # Consider only the x most similar samples
sampling_strategy="top", # Sample the top negatives from the range
batch_size=4096, # Use a batch size of 4096 for the embedding model
output_format="labeled-pair", # The output format is (query, passage, label), as required by BinaryCrossEntropyLoss
use_faiss=True,
)
logging.info(hard_train_dataset)
# 2c. (Optionally) Save the hard training dataset to disk
# hard_train_dataset.save_to_disk("gooaq-hard-train")
# Load again with:
# hard_train_dataset = load_from_disk("gooaq-hard-train")
# 3. Define our training loss.
# pos_weight is recommended to be set as the ratio between positives to negatives, a.k.a. `num_hard_negatives`
loss = BinaryCrossEntropyLoss(model=model, pos_weight=torch.tensor(num_hard_negatives))
# 4a. Define evaluators. We use the CrossEncoderNanoBEIREvaluator, which is a light-weight evaluator for English reranking
nano_beir_evaluator = CrossEncoderNanoBEIREvaluator(
dataset_names=["msmarco", "nfcorpus", "nq"],
batch_size=train_batch_size,
)
# 4b. Define a reranking evaluator by mining hard negatives given query-answer pairs
# We include the positive answer in the list of negatives, so the evaluator can use the performance of the
# embedding model as a baseline.
hard_eval_dataset = mine_hard_negatives(
eval_dataset,
embedding_model,
corpus=full_dataset["answer"], # Use the full dataset as the corpus
num_negatives=30, # How many documents to rerank
batch_size=4096,
include_positives=True,
output_format="n-tuple",
use_faiss=True,
)
logging.info(hard_eval_dataset)
reranking_evaluator = CrossEncoderRerankingEvaluator(
samples=[
{
"query": sample["question"],
"positive": [sample["answer"]],
"documents": [sample[column_name] for column_name in hard_eval_dataset.column_names[2:]],
}
for sample in hard_eval_dataset
],
batch_size=train_batch_size,
name="gooaq-dev",
# Realistic setting: only rerank the positives that the retriever found
# Set to True to rerank *all* positives
always_rerank_positives=False,
)
# 4c. Combine the evaluators & run the base model on them
evaluator = SequentialEvaluator([reranking_evaluator, nano_beir_evaluator])
evaluator(model)
# 5. Define the training arguments
short_model_name = model_name if "/" not in model_name else model_name.split("/")[-1]
run_name = f"reranker-{short_model_name}-gooaq-bce"
args = CrossEncoderTrainingArguments(
# Required parameter:
output_dir=f"models/{run_name}",
# Optional training parameters:
num_train_epochs=num_epochs,
per_device_train_batch_size=train_batch_size,
per_device_eval_batch_size=train_batch_size,
learning_rate=2e-5,
warmup_ratio=0.1,
fp16=False, # Set to False if you get an error that your GPU can't run on FP16
bf16=True, # Set to True if you have a GPU that supports BF16
dataloader_num_workers=4,
load_best_model_at_end=True,
metric_for_best_model="eval_gooaq-dev_ndcg@10",
# Optional tracking/debugging parameters:
eval_strategy="steps",
eval_steps=1000,
save_strategy="steps",
save_steps=1000,
save_total_limit=2,
logging_steps=200,
logging_first_step=True,
run_name=run_name, # Will be used in W&B if `wandb` is installed
seed=12,
)
# 6. Create the trainer & start training
trainer = CrossEncoderTrainer(
model=model,
args=args,
train_dataset=hard_train_dataset,
loss=loss,
evaluator=evaluator,
)
trainer.train()
# 7. Evaluate the final model, useful to include these in the model card
evaluator(model)
# 8. Save the final model
final_output_dir = f"models/{run_name}/final"
model.save_pretrained(final_output_dir)
# 9. (Optional) save the model to the Hugging Face Hub!
# It is recommended to run `huggingface-cli login` to log into your Hugging Face account first
try:
model.push_to_hub(run_name)
except Exception:
logging.error(
f"Error uploading model to the Hugging Face Hub:\n{traceback.format_exc()}To upload it manually, you can run "
f"`huggingface-cli login`, followed by loading the model using `model = CrossEncoder({final_output_dir!r})` "
f"and saving it using `model.push_to_hub('{run_name}')`."
)
if __name__ == "__main__":
main()
解碼器
點選展開解碼器訓練程式碼
全量訓練
python trl/scripts/sft.py \
--model_name_or_path jhu-clsp/ettin-decoder-17m \
--dataset_name trl-lib/Capybara \
--learning_rate 2.0e-5 \
--num_train_epochs 1 \
--packing \
--per_device_train_batch_size 2 \
--gradient_accumulation_steps 8 \
--gradient_checkpointing \
--eos_token '<|im_end|>' \
--eval_strategy steps \
--eval_steps 100 \
--output_dir ettin-decoder-17m \
--push_to_hub
LoRA
python trl/scripts/sft.py \
--model_name_or_path jhu-clsp/ettin-decoder-17m \
--dataset_name trl-lib/Capybara \
--learning_rate 2.0e-4 \
--num_train_epochs 1 \
--packing \
--per_device_train_batch_size 2 \
--gradient_accumulation_steps 8 \
--gradient_checkpointing \
--eos_token '<|im_end|>' \
--eval_strategy steps \
--eval_steps 100 \
--use_peft \
--lora_r 32 \
--lora_alpha 16 \
--output_dir ettin-decoder-17m \
--push_to_hub
使用 sft.py
import argparse
from datasets import load_dataset
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
from transformers.models.auto.modeling_auto import MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES
from trl import (
ModelConfig,
ScriptArguments,
SFTConfig,
SFTTrainer,
TrlParser,
clone_chat_template,
get_kbit_device_map,
get_peft_config,
get_quantization_config,
)
def main(script_args, training_args, model_args):
################
# Model init kwargs & Tokenizer
################
quantization_config = get_quantization_config(model_args)
model_kwargs = dict(
revision=model_args.model_revision,
trust_remote_code=model_args.trust_remote_code,
attn_implementation=model_args.attn_implementation,
torch_dtype=model_args.torch_dtype,
use_cache=False if training_args.gradient_checkpointing else True,
device_map=get_kbit_device_map() if quantization_config is not None else None,
quantization_config=quantization_config,
)
# Create model
config = AutoConfig.from_pretrained(model_args.model_name_or_path)
valid_image_text_architectures = MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES.values()
if config.architectures and any(arch in valid_image_text_architectures for arch in config.architectures):
from transformers import AutoModelForImageTextToText
model_kwargs.pop("use_cache", None) # Image models do not support cache
model = AutoModelForImageTextToText.from_pretrained(model_args.model_name_or_path, **model_kwargs)
else:
model = AutoModelForCausalLM.from_pretrained(model_args.model_name_or_path, **model_kwargs)
# Create tokenizer
tokenizer = AutoTokenizer.from_pretrained(
model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code, use_fast=True
)
# Set default chat template if needed
if tokenizer.chat_template is None:
# TODO: source should be passed as an argument
model, tokenizer = clone_chat_template(model, tokenizer, "Qwen/Qwen3-0.6B")
################
# Dataset
################
dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config)
################
# Training
################
trainer = SFTTrainer(
model=model,
args=training_args,
train_dataset=dataset[script_args.dataset_train_split],
eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None,
processing_class=tokenizer,
peft_config=get_peft_config(model_args),
)
trainer.train()
# Save and push to hub
trainer.save_model(training_args.output_dir)
if training_args.push_to_hub:
trainer.push_to_hub(dataset_name=script_args.dataset_name)
def make_parser(subparsers: argparse._SubParsersAction = None):
dataclass_types = (ScriptArguments, SFTConfig, ModelConfig)
if subparsers is not None:
parser = subparsers.add_parser("sft", help="Run the SFT training script", dataclass_types=dataclass_types)
else:
parser = TrlParser(dataclass_types)
return parser
if __name__ == "__main__":
parser = make_parser()
# When using the trl cli, this script may be run with additional arguments, corresponding accelerate arguments.
# To ensure that their parsing does not interfere with the script arguments, parse the arguments with
# `return_remaining_strings=True`, then ignore the remaining strings.
script_args, training_args, model_args, _ = parser.parse_args_and_config(return_remaining_strings=True)
main(script_args, training_args, model_args)
模型家族與連結
完整的 Ettin 套件包括六種不同規模的模型(編碼器和解碼器均有):
標準模型
- ettin-encoder-17m / ettin-decoder-17m (1700萬引數)
- ettin-encoder-32m / ettin-decoder-32m (3200萬引數)
- ettin-encoder-68m / ettin-decoder-68m (6800萬引數)
- ettin-encoder-150m / ettin-decoder-150m (1.5億引數)
- ettin-encoder-400m / ettin-decoder-400m (4億引數)
- ettin-encoder-1b / ettin-decoder-1b (10億引數)
研究資源
- 🤗 Ettin 模型合集
- 📝 論文
- 🗂️ 訓練資料 (2萬億+詞元,完全開放)
- 💻 GitHub 倉庫
- 📊 250+ 個訓練檢查點 用於研究訓練動態或知識學習