歡迎 Stable-baselines3 加入 Hugging Face Hub 🤗
在 Hugging Face,我們正致力於為深度強化學習研究人員和愛好者們打造一個良好的生態系統。因此,我們很高興地宣佈,我們將 Stable-Baselines3 整合到了 Hugging Face Hub。
Stable-Baselines3 是最受歡迎的 PyTorch 深度強化學習庫之一,它能讓你在各種環境(Gym、Atari、MuJoco、Procgen 等)中輕鬆訓練和測試你的智慧體。透過這次整合,你現在可以託管你儲存的模型 💾,並從社群中載入強大的模型。
在本文中,我們將向你展示如何操作。
安裝
要將 stable-baselines3 與 Hugging Face Hub 一起使用,你只需安裝這兩個庫即可
pip install huggingface_hub
pip install huggingface_sb3
尋找模型
我們目前正在上傳玩《太空侵略者 (Space Invaders)》、《打磚塊 (Breakout)》、《月球著陸器 (LunarLander)》等遊戲的智慧體模型。除此之外,你可以在這裡找到社群中所有 stable-baselines-3 模型。
當你找到所需的模型時,只需複製倉庫 ID 即可。
從 Hub 下載模型
本次整合最酷的功能是,你現在可以非常輕鬆地將 Hub 上儲存的模型載入到 Stable-baselines3 中。
為此,你只需複製包含已儲存模型的倉庫 ID (repo-id) 以及倉庫中已儲存模型的 zip 檔名。
例如:sb3/demo-hf-CartPole-v1
import gym
from huggingface_sb3 import load_from_hub
from stable_baselines3 import PPO
from stable_baselines3.common.evaluation import evaluate_policy
# Retrieve the model from the hub
## repo_id = id of the model repository from the Hugging Face Hub (repo_id = {organization}/{repo_name})
## filename = name of the model zip file from the repository including the extension .zip
checkpoint = load_from_hub(
repo_id="sb3/demo-hf-CartPole-v1",
filename="ppo-CartPole-v1.zip",
)
model = PPO.load(checkpoint)
# Evaluate the agent and watch it
eval_env = gym.make("CartPole-v1")
mean_reward, std_reward = evaluate_policy(
model, eval_env, render=True, n_eval_episodes=5, deterministic=True, warn=False
)
print(f"mean_reward={mean_reward:.2f} +/- {std_reward}")
將模型分享到 Hub
只需一分鐘,你就可以將儲存的模型上傳到 Hub。
首先,你需要登入 Hugging Face 才能上傳模型。
- 如果你正在使用 Colab/Jupyter Notebooks
from huggingface_hub import notebook_login
notebook_login()
- 否則
huggingface-cli login
然後,在這個例子中,我們訓練一個 PPO 智慧體來玩 CartPole-v1,並將其推送到一個新的倉庫 `ThomasSimonini/demo-hf-CartPole-v1`。
from huggingface_sb3 import push_to_hub
from stable_baselines3 import PPO
# Define a PPO model with MLP policy network
model = PPO("MlpPolicy", "CartPole-v1", verbose=1)
# Train it for 10000 timesteps
model.learn(total_timesteps=10_000)
# Save the model
model.save("ppo-CartPole-v1")
# Push this saved model to the hf repo
# If this repo does not exists it will be created
## repo_id = id of the model repository from the Hugging Face Hub (repo_id = {organization}/{repo_name})
## filename: the name of the file == "name" inside model.save("ppo-CartPole-v1")
push_to_hub(
repo_id="ThomasSimonini/demo-hf-CartPole-v1",
filename="ppo-CartPole-v1.zip",
commit_message="Added Cartpole-v1 model trained with PPO",
)
快來試試並與社群分享你的模型吧!
下一步?
在接下來的幾周和幾個月裡,我們將透過以下方式擴充套件生態系統:
- 整合 RL-baselines3-zoo
- 將 RL-trained-agents 模型上傳到 Hub:這是一個使用 stable-baselines3 預訓練的強化學習智慧體的龐大集合。
- 整合其他深度強化學習庫
- 實現 Decision Transformers 🔥
- 以及更多即將推出的內容 🥳
保持聯絡的最佳方式是加入我們的 discord 伺服器,與我們以及社群進行交流。
如果你想更深入地瞭解,我們編寫了一篇教程,你將學到:
- 如何訓練一個深度強化學習著陸器智慧體,使其正確地在月球上著陸 🌕
- 如何將其上傳到 Hub 🚀
- 如何從 Hub 下載並使用一個玩《太空侵略者》的已儲存模型 👾。
👉 教程
結論
我們很高興看到你使用 Stable-baselines3 進行的工作,並期待在 Hub 中試用你的模型 😍。
我們也很樂意聽到你的反饋 💖。 📧 歡迎隨時聯絡我們。
最後,我們要感謝 SB3 團隊,特別是 Antonin Raffin,感謝他們為該庫的整合提供的寶貴幫助 🤗。
你想將你的庫整合到 Hub 嗎?
這次整合是藉助 huggingface_hub
庫實現的,該庫包含了我們所有的元件以及所有支援的庫的 API。如果你想將你的庫整合到 Hub,我們為你準備了一份指南!