From fe76e2c0e1cdba1000c3c00ee4ac5326f8bfce13 Mon Sep 17 00:00:00 2001 From: Antonin RAFFIN Date: Fri, 29 Nov 2024 14:48:10 +0100 Subject: [PATCH] Update PyBullet example --- docs/guide/examples.rst | 77 ++++++++++++++++++++--------------------- docs/misc/changelog.rst | 3 +- 2 files changed, 39 insertions(+), 41 deletions(-) diff --git a/docs/guide/examples.rst b/docs/guide/examples.rst index 32158172b..608eadcba 100644 --- a/docs/guide/examples.rst +++ b/docs/guide/examples.rst @@ -397,14 +397,14 @@ PyBullet: Normalizing input features ------------------------------------ Normalizing input features may be essential to successful training of an RL agent -(by default, images are scaled but not other types of input), -for instance when training on `PyBullet `__ environments. For that, a wrapper exists and -will compute a running average and standard deviation of input features (it can do the same for rewards). +(by default, images are scaled, but other types of input are not), +for instance when training on `PyBullet `__ environments. +For this, there is a wrapper ``VecNormalize`` that will compute a running average and standard deviation of the input features (it can do the same for rewards). .. note:: - you need to install pybullet with ``pip install pybullet`` + you need to install pybullet envs with ``pip install pybullet_envs_gymnasium`` .. image:: ../_static/img/colab-badge.svg @@ -413,44 +413,41 @@ will compute a running average and standard deviation of input features (it can .. code-block:: python - import os - import gymnasium as gym - import pybullet_envs + from pathlib import Path - from stable_baselines3.common.vec_env import DummyVecEnv, VecNormalize - from stable_baselines3 import PPO + import pybullet_envs_gymnasium + + from stable_baselines3.common.vec_env import VecNormalize + from stable_baselines3.common.env_util import make_vec_env + from stable_baselines3 import PPO + + # Alternatively, you can use the MuJoCo equivalent "HalfCheetah-v4" + vec_env = make_vec_env("HalfCheetahBulletEnv-v0", n_envs=1) + # Automatically normalize the input features and reward + vec_env = VecNormalize(vec_env, norm_obs=True, norm_reward=True, clip_obs=10.0) + + model = PPO("MlpPolicy", vec_env) + model.learn(total_timesteps=2000) + + # Don't forget to save the VecNormalize statistics when saving the agent + log_dir = Path("/tmp/") + model.save(log_dir / "ppo_halfcheetah") + stats_path = log_dir / "vec_normalize.pkl" + vec_env.save(stats_path) + + # To demonstrate loading + del model, vec_env + + # Load the saved statistics + vec_env = make_vec_env("HalfCheetahBulletEnv-v0", n_envs=1) + vec_env = VecNormalize.load(stats_path, vec_env) + # do not update them at test time + vec_env.training = False + # reward normalization is not needed at test time + vec_env.norm_reward = False - # Note: pybullet is not compatible yet with Gymnasium - # you might need to use `import rl_zoo3.gym_patches` - # and use gym (not Gymnasium) to instantiate the env - # Alternatively, you can use the MuJoCo equivalent "HalfCheetah-v4" - vec_env = DummyVecEnv([lambda: gym.make("HalfCheetahBulletEnv-v0")]) - # Automatically normalize the input features and reward - vec_env = VecNormalize(vec_env, norm_obs=True, norm_reward=True, - clip_obs=10.) - - model = PPO("MlpPolicy", vec_env) - model.learn(total_timesteps=2000) - - # Don't forget to save the VecNormalize statistics when saving the agent - log_dir = "/tmp/" - model.save(log_dir + "ppo_halfcheetah") - stats_path = os.path.join(log_dir, "vec_normalize.pkl") - vec_env.save(stats_path) - - # To demonstrate loading - del model, vec_env - - # Load the saved statistics - vec_env = DummyVecEnv([lambda: gym.make("HalfCheetahBulletEnv-v0")]) - vec_env = VecNormalize.load(stats_path, vec_env) - # do not update them at test time - vec_env.training = False - # reward normalization is not needed at test time - vec_env.norm_reward = False - - # Load the agent - model = PPO.load(log_dir + "ppo_halfcheetah", env=vec_env) + # Load the agent + model = PPO.load(log_dir / "ppo_halfcheetah", env=vec_env) Hindsight Experience Replay (HER) diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 4cbd0ec04..91386170e 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -36,7 +36,8 @@ Others: Documentation: ^^^^^^^^^^^^^^ -Added Decisions and Dragons to resources. (@jmacglashan) +- Added Decisions and Dragons to resources. (@jmacglashan) +- Updated PyBullet example, now compatible with Gymnasium Release 2.4.0 (2024-11-18) --------------------------