Press "Enter" to skip to content

欢迎 Stable-baselines3 加入 Hugging Face Hub 🤗

在Hugging Face,我们为深度强化学习的研究人员和爱好者贡献了生态系统。这就是为什么我们很高兴地宣布我们将Stable-Baselines3集成到Hugging Face Hub中。

Stable-Baselines3是最流行的PyTorch深度强化学习库之一,它可以轻松训练和测试各种环境中的智能体(Gym、Atari、MuJoco、Procgen等)。通过这个集成,您现在可以托管您的保存模型💾并从社区中加载强大的模型。

在本文中,我们将展示如何实现这一点。

安装

要在Hugging Face Hub中使用stable-baselines3,您只需要安装这两个库:

pip install huggingface_hub
pip install huggingface_sb3

查找模型

我们目前正在上传玩Space Invaders、Breakout、LunarLander等游戏的智能体的保存模型。除此之外,您还可以在这里找到社区中的所有stable-baselines-3模型

当您找到所需的模型时,您只需要复制存储库ID:

欢迎 Stable-baselines3 加入 Hugging Face Hub 🤗 四海 第1张

从Hub下载模型

这个集成的最酷功能是您现在可以非常容易地从Hub加载一个保存的模型到Stable-baselines3。

为了做到这一点,您只需要复制包含您保存的模型和存储库中保存的模型zip文件的repo-id。

例如:sb3/demo-hf-CartPole-v1

import gym

from huggingface_sb3 import load_from_hub
from stable_baselines3 import PPO
from stable_baselines3.common.evaluation import evaluate_policy

# 从Hub中获取模型
## repo_id = Hugging Face Hub中模型存储库的ID(repo_id = {organization}/{repo_name})
## filename = 存储库中模型zip文件的名称,包括扩展名.zip
checkpoint = load_from_hub(
    repo_id="sb3/demo-hf-CartPole-v1",
    filename="ppo-CartPole-v1.zip",
)
model = PPO.load(checkpoint)

# 评估智能体并观察
eval_env = gym.make("CartPole-v1")
mean_reward, std_reward = evaluate_policy(
    model, eval_env, render=True, n_eval_episodes=5, deterministic=True, warn=False
)
print(f"mean_reward={mean_reward:.2f} +/- {std_reward}")

将模型共享到Hub

只需一分钟,您就可以将保存的模型放入Hub中。

首先,您需要登录到Hugging Face以上传模型:

  • 如果您使用的是Colab/Jupyter Notebooks:
from huggingface_hub import notebook_login
notebook_login()
  • 否则:
huggingface-cli login

然后,在此示例中,我们训练了一个PPO智能体来玩CartPole-v1,并将其推送到一个新的存储库ThomasSimonini/demo-hf-CartPole-v1`

from huggingface_sb3 import push_to_hub
from stable_baselines3 import PPO

# 使用MLP策略网络定义PPO模型
model = PPO("MlpPolicy", "CartPole-v1", verbose=1)

# 训练10000个时间步长
model.learn(total_timesteps=10_000)

# 保存模型
model.save("ppo-CartPole-v1")

# 将此保存的模型推送到hf存储库
# 如果此存储库不存在,则将创建它
## repo_id = Hugging Face Hub中模型存储库的ID(repo_id = {organization}/{repo_name})
## filename: 文件的名称==model.save("ppo-CartPole-v1")中的"name"
push_to_hub(
    repo_id="ThomasSimonini/demo-hf-CartPole-v1",
    filename="ppo-CartPole-v1.zip",
    commit_message="Added Cartpole-v1 model trained with PPO",
)

试一试并与社区分享您的模型!

下一步是什么?

在未来几周和几个月内,我们将通过以下方式扩展生态系统:

  • 整合 RL-baselines3-zoo
  • 将 RL-trained-agents 模型上传到 Hub:这是一个使用 stable-baselines3 预训练的强化学习智能体的大集合
  • 整合其他深度强化学习库
  • 实现决策 Transformer 🔥
  • 等等,敬请期待 🥳

保持联系的最佳方式是加入我们的 Discord 服务器,与我们和社区交流。

如果您想深入了解,我们写了一个教程,您将学到:

  • 如何训练一个深度强化学习的登月器智能体,以正确着陆在月球上 🌕
  • 如何将其上传到 Hub 🚀

欢迎 Stable-baselines3 加入 Hugging Face Hub 🤗 四海 第2张

  • 如何下载和使用在 Hub 上玩太空侵略者的保存模型 👾。

欢迎 Stable-baselines3 加入 Hugging Face Hub 🤗 四海 第3张

👉 教程

结论

我们很期待看到您在 Stable-baselines3 上的工作,并在 Hub 中尝试您的模型 😍。

我们很乐意听取您的反馈 💖。 📧 随时联系我们。

最后,我们要感谢 SB3 团队,特别是 Antonin Raffin,在整合该库时给予了宝贵的帮助 🤗。

您想将您的库整合到 Hub 中吗?

这种整合得益于 huggingface_hub 库,它具有我们所有支持的库的小部件和 API。如果您想将您的库整合到 Hub 中,我们有一个指南供您参考!

Leave a Reply

Your email address will not be published. Required fields are marked *