Skip to content

Commit

Permalink
Update PyBullet example (#2049)
Browse files Browse the repository at this point in the history
  • Loading branch information
araffin authored Nov 29, 2024
1 parent 9836692 commit 897d01d
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 41 deletions.
77 changes: 37 additions & 40 deletions docs/guide/examples.rst
Original file line number Diff line number Diff line change
Expand Up @@ -397,14 +397,14 @@ PyBullet: Normalizing input features
------------------------------------

Normalizing input features may be essential to successful training of an RL agent
(by default, images are scaled but not other types of input),
for instance when training on `PyBullet <https://github.com/bulletphysics/bullet3/>`__ environments. For that, a wrapper exists and
will compute a running average and standard deviation of input features (it can do the same for rewards).
(by default, images are scaled, but other types of input are not),
for instance when training on `PyBullet <https://github.com/bulletphysics/bullet3/>`__ environments.
For this, there is a wrapper ``VecNormalize`` that will compute a running average and standard deviation of the input features (it can do the same for rewards).


.. note::

you need to install pybullet with ``pip install pybullet``
you need to install pybullet envs with ``pip install pybullet_envs_gymnasium``


.. image:: ../_static/img/colab-badge.svg
Expand All @@ -413,44 +413,41 @@ will compute a running average and standard deviation of input features (it can

.. code-block:: python
import os
import gymnasium as gym
import pybullet_envs
from pathlib import Path
from stable_baselines3.common.vec_env import DummyVecEnv, VecNormalize
from stable_baselines3 import PPO
import pybullet_envs_gymnasium
from stable_baselines3.common.vec_env import VecNormalize
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3 import PPO
# Alternatively, you can use the MuJoCo equivalent "HalfCheetah-v4"
vec_env = make_vec_env("HalfCheetahBulletEnv-v0", n_envs=1)
# Automatically normalize the input features and reward
vec_env = VecNormalize(vec_env, norm_obs=True, norm_reward=True, clip_obs=10.0)
model = PPO("MlpPolicy", vec_env)
model.learn(total_timesteps=2000)
# Don't forget to save the VecNormalize statistics when saving the agent
log_dir = Path("/tmp/")
model.save(log_dir / "ppo_halfcheetah")
stats_path = log_dir / "vec_normalize.pkl"
vec_env.save(stats_path)
# To demonstrate loading
del model, vec_env
# Load the saved statistics
vec_env = make_vec_env("HalfCheetahBulletEnv-v0", n_envs=1)
vec_env = VecNormalize.load(stats_path, vec_env)
# do not update them at test time
vec_env.training = False
# reward normalization is not needed at test time
vec_env.norm_reward = False
# Note: pybullet is not compatible yet with Gymnasium
# you might need to use `import rl_zoo3.gym_patches`
# and use gym (not Gymnasium) to instantiate the env
# Alternatively, you can use the MuJoCo equivalent "HalfCheetah-v4"
vec_env = DummyVecEnv([lambda: gym.make("HalfCheetahBulletEnv-v0")])
# Automatically normalize the input features and reward
vec_env = VecNormalize(vec_env, norm_obs=True, norm_reward=True,
clip_obs=10.)
model = PPO("MlpPolicy", vec_env)
model.learn(total_timesteps=2000)
# Don't forget to save the VecNormalize statistics when saving the agent
log_dir = "/tmp/"
model.save(log_dir + "ppo_halfcheetah")
stats_path = os.path.join(log_dir, "vec_normalize.pkl")
vec_env.save(stats_path)
# To demonstrate loading
del model, vec_env
# Load the saved statistics
vec_env = DummyVecEnv([lambda: gym.make("HalfCheetahBulletEnv-v0")])
vec_env = VecNormalize.load(stats_path, vec_env)
# do not update them at test time
vec_env.training = False
# reward normalization is not needed at test time
vec_env.norm_reward = False
# Load the agent
model = PPO.load(log_dir + "ppo_halfcheetah", env=vec_env)
# Load the agent
model = PPO.load(log_dir / "ppo_halfcheetah", env=vec_env)
Hindsight Experience Replay (HER)
Expand Down
3 changes: 2 additions & 1 deletion docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ Others:

Documentation:
^^^^^^^^^^^^^^
Added Decisions and Dragons to resources. (@jmacglashan)
- Added Decisions and Dragons to resources. (@jmacglashan)
- Updated PyBullet example, now compatible with Gymnasium

Release 2.4.0 (2024-11-18)
--------------------------
Expand Down

0 comments on commit 897d01d

Please sign in to comment.