宣佈 Hugging Face 和 KerasHub 新整合
Hugging Face Hub 是一個龐大的儲存庫,目前託管著 75 萬多個公共模型,為各種機器學習框架提供了多樣化的預訓練模型。其中,346,268 個模型(截至撰寫本文時)是使用流行的 Transformers 庫構建的。KerasHub 庫最近新增了一個與 Hub 的整合,首批相容 33 個模型。
在第一個版本中,KerasHub 使用者*僅限於*使用 Hugging Face Hub 上可用的基於 KerasHub 的模型。
from keras_hub.models import GemmaCausalLM
gemma_lm = GemmaCausalLM.from_preset(
"hf://google/gemma-2b-keras"
)
他們能夠訓練/微調模型並將其上傳回 Hub(請注意,該模型仍然是 Keras 模型)。
model.save_to_preset("./gemma-2b-finetune")
keras_hub.upload_preset(
"hf://username/gemma-2b-finetune",
"./gemma-2b-finetune"
)
他們錯過了使用 transformers 庫建立的超過 30 萬個模型的龐大集合。圖 1 展示了 Hub 中的 4k Gemma 模型。
![]() |
---|
圖 1:Hugging Face Hub 中的 Gemma 模型(來源:https://huggingface.co/models?other=gemma) |
然而,如果現在我們告訴您,您可以使用 KerasHub 訪問和使用這 30 多萬個模型,這將顯著擴充套件您的模型選擇和功能,您會作何感想?
from keras_hub.models import GemmaCausalLM
gemma_lm = GemmaCausalLM.from_preset(
"hf://google/gemma-2b" # this is not a keras model!
)
我們很高興地宣佈 Hub 社群邁出了重要一步:Transformers 和 KerasHub 現在擁有**共享**的模型儲存格式。這意味著 Hugging Face Hub 上的 transformers 庫模型現在也可以直接載入到 KerasHub 中——立即為 KerasHub 使用者提供了大量微調模型。最初,此整合側重於啟用 Gemma(1 和 2)、Llama 3 和 PaliGemma 模型的使用,並計劃在不久的將來將相容性擴充套件到更廣泛的架構。
使用更廣泛的框架
由於 KerasHub 模型可以無縫使用 **TensorFlow**、**JAX** 或 **PyTorch** 後端,這意味著大量的模型檢查點現在可以透過一行程式碼載入到任何這些框架中。在 Hugging Face 上找到了一個很棒的檢查點,但您希望將其部署到 TFLite 進行服務或將其移植到 JAX 進行研究?現在您可以了!
如何使用
使用此整合需要更新您的 Keras 版本
$ pip install -U -q keras-hub
$ pip install -U keras>=3.3.3
更新後,嘗試整合就像以下程式碼一樣簡單:
from keras_hub.models import Llama3CausalLM
# this model was not fine-tuned with Keras but can still be loaded
causal_lm = Llama3CausalLM.from_preset(
"hf://NousResearch/Hermes-2-Pro-Llama-3-8B"
)
causal_lm.summary()
幕後:工作原理
Transformers 模型以 JSON 格式的配置檔案的形式儲存,一個分詞器(通常也是一個 .JSON 檔案),以及一組 safetensors 權重檔案。實際的模型程式碼包含在 Transformers 庫本身中。這意味著,只要兩個庫都有相關架構的模型程式碼,將 Transformers 檢查點交叉載入到 KerasHub 中就相對簡單。我們所需要做的就是將配置變數、權重名稱和分詞器詞彙從一種格式對映到另一種格式,然後我們就可以從 Transformers 檢查點建立 KerasHub 檢查點,反之亦然。
所有這些都在內部為您處理,因此您可以專注於嘗試模型,而不是轉換它們!
常見用例
生成
語言模型的第一個用例是生成文字。下面是一個示例,演示如何載入 transformer 模型並使用 KerasHub 的 .generate
方法生成新 token。
from keras_hub.models import Llama3CausalLM
# Get the model
causal_lm = Llama3CausalLM.from_preset(
"hf://NousResearch/Hermes-2-Pro-Llama-3-8B"
)
prompts = [
"""<|im_start|>system
You are a sentient, superintelligent artificial general intelligence, here to teach and assist me.<|im_end|>
<|im_start|>user
Write a short story about Goku discovering kirby has teamed up with Majin Buu to destroy the world.<|im_end|>
<|im_start|>assistant""",
]
# Generate from the model
causal_lm.generate(prompts, max_length=200)[0]
更改精度
您可以使用 keras.config
更改模型的精度,如下所示
import keras
keras.config.set_dtype_policy("bfloat16")
from keras_hub.models import Llama3CausalLM
causal_lm = Llama3CausalLM.from_preset(
"hf://NousResearch/Hermes-2-Pro-Llama-3-8B"
)
在 JAX 後端使用檢查點
要使用 JAX 試用模型,您可以利用 Keras 在 JAX 後端執行模型。這可以透過簡單地將 Keras 的後端切換到 JAX 來實現。以下是您在 JAX 環境中使用模型的方法。
import os
os.environ["KERAS_BACKEND"] = "jax"
from keras_hub.models import Llama3CausalLM
causal_lm = Llama3CausalLM.from_preset(
"hf://NousResearch/Hermes-2-Pro-Llama-3-8B"
)
Gemma 2
我們很高興地通知您,Gemma 2 模型也與此整合相容。
from keras_hub.models import GemmaCausalLM
causal_lm = keras_hub.models.GemmaCausalLM.from_preset(
"hf://google/gemma-2-9b" # This is Gemma 2!
)
PaliGemma
您還可以在 KerasHub 管道中使用任何 PaliGemma safetensor 檢查點。
from keras_hub.models import PaliGemmaCausalLM
pali_gemma_lm = PaliGemmaCausalLM.from_preset(
"hf://gokaygokay/sd3-long-captioner" # A finetuned version of PaliGemma
)
接下來呢?
這僅僅是個開始。我們設想將此整合擴充套件到更廣泛的 Hugging Face 模型和架構。請繼續關注更新,並務必探索此次合作帶來的巨大潛力!
我想借此機會感謝 Matthew Carrigan 和 Matthew Watson 在整個過程中給予的幫助。