diff --git a/README.md b/README.md index 509633cc5..836e10e37 100644 --- a/README.md +++ b/README.md @@ -28,6 +28,7 @@ you can take a look at the issues [#48](https://github.com/DLR-RM/stable-baselin | Custom environments | :heavy_check_mark: | | Custom policies | :heavy_check_mark: | | Common interface | :heavy_check_mark: | +| `Dict` observation space support | :heavy_check_mark: | | Ipython / Notebook friendly | :heavy_check_mark: | | Tensorboard support | :heavy_check_mark: | | PEP8 code style | :heavy_check_mark: | diff --git a/docs/Makefile b/docs/Makefile index 47f98cd72..5a4c19e51 100644 --- a/docs/Makefile +++ b/docs/Makefile @@ -2,6 +2,7 @@ # # You can set these variables from the command line. +# For debug: SPHINXOPTS = -nWT --keep-going -vvv SPHINXOPTS = -W # make warnings fatal SPHINXBUILD = sphinx-build SPHINXPROJ = StableBaselines @@ -17,4 +18,4 @@ help: # Catch-all target: route all unknown targets to Sphinx using the new # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). %: Makefile - @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) \ No newline at end of file + @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) diff --git a/docs/common/envs.rst b/docs/common/envs.rst new file mode 100644 index 000000000..4c20ce4c7 --- /dev/null +++ b/docs/common/envs.rst @@ -0,0 +1,24 @@ +.. _envs: + +.. automodule:: stable_baselines3.common.envs + + + +Custom Environments +=================== + +Those environments were created for testing purposes. + + +BitFlippingEnv +-------------- + +.. autoclass:: BitFlippingEnv + :members: + + +SimpleMultiObsEnv +----------------- + +.. autoclass:: SimpleMultiObsEnv + :members: \ No newline at end of file diff --git a/docs/guide/algos.rst b/docs/guide/algos.rst index 7cd5f3e12..78ccb9a89 100644 --- a/docs/guide/algos.rst +++ b/docs/guide/algos.rst @@ -19,7 +19,9 @@ TD3 ✔️ ❌ ❌ ❌ .. note:: - Non-array spaces such as ``Dict`` or ``Tuple`` are not currently supported by any algorithm. + ``Tuple`` observation spaces are not supported by any environment + however single-level ``Dict`` spaces are (cf. :ref:`Examples `). + Actions ``gym.spaces``: @@ -41,6 +43,15 @@ Actions ``gym.spaces``: See `Issue #339 `_ for more info. +.. note:: + + When using off-policy algorithms, `Time Limits `_ (aka timeouts) are handled + properly (cf. `issue #284 `_). + You can revert to SB3 < 2.1.0 behavior by passing ``handle_timeout_termination=False`` + via the ``replay_buffer_kwargs`` argument. + + + Reproducibility --------------- diff --git a/docs/guide/custom_env.rst b/docs/guide/custom_env.rst index cbcad96b8..2b1e4b988 100644 --- a/docs/guide/custom_env.rst +++ b/docs/guide/custom_env.rst @@ -85,5 +85,5 @@ that will allow you to create the RL agent in one line (and use ``gym.make()`` t In the project, for testing purposes, we use a custom environment named ``IdentityEnv`` -defined `in this file `_. -An example of how to use it can be found `here `_. +defined `in this file `_. +An example of how to use it can be found `here `_. diff --git a/docs/guide/custom_policy.rst b/docs/guide/custom_policy.rst index 03829ffd7..3da932b33 100644 --- a/docs/guide/custom_policy.rst +++ b/docs/guide/custom_policy.rst @@ -3,8 +3,8 @@ Custom Policy Network ===================== -Stable Baselines3 provides policy networks for images (CnnPolicies) -and other type of input features (MlpPolicies). +Stable Baselines3 provides policy networks for images (CnnPolicies), +other type of input features (MlpPolicies) and multiple different inputs (MultiInputPolicies). .. warning:: @@ -149,6 +149,70 @@ that derives from ``BaseFeaturesExtractor`` and then pass it to the model when t model.learn(1000) +Multiple Inputs and Dictionary Observations +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Stable Baselines3 supports handling of multiple inputs by using ``Dict`` Gym space. This can be done using +``MultiInputPolicy``, which by default uses the ``CombinedExtractor`` feature extractor to turn multiple +inputs into a single vector, handled by the ``net_arch`` network. + +By default, ``CombinedExtractor`` processes multiple inputs as follows: + +1. If input is an image (automatically detected, see ``common.preprocessing.is_image_space``), process image with Nature Atari CNN network and + output a latent vector of size ``256``. +2. If input is not an image, flatten it (no layers). +3. Concatenate all previous vectors into one long vector and pass it to policy. + +Much like above, you can define custom feature extractors. The following example assumes the environment has two keys in the +observation space dictionary: "image" is a (1,H,W) image (channel first), and "vector" is a (D,) dimensional vector. We process "image" with a simple +downsampling and "vector" with a single linear layer. + +.. code-block:: python + + import gym + import torch as th + from torch import nn + + from stable_baselines3.common.torch_layers import BaseFeaturesExtractor + + class CustomCombinedExtractor(BaseFeaturesExtractor): + def __init__(self, observation_space: gym.spaces.Dict): + # We do not know features-dim here before going over all the items, + # so put something dummy for now. PyTorch requires calling + # nn.Module.__init__ before adding modules + super(CustomCombinedExtractor, self).__init__(observation_space, features_dim=1) + + extractors = {} + + total_concat_size = 0 + # We need to know size of the output of this extractor, + # so go over all the spaces and compute output feature sizes + for key, subspace in observation_space.spaces.items(): + if key == "image": + # We will just downsample one channel of the image by 4x4 and flatten. + # Assume the image is single-channel (subspace.shape[0] == 0) + extractors[key] = nn.Sequential(nn.MaxPool2d(4), nn.Flatten()) + total_concat_size += subspace.shape[1] // 4 * subspace.shape[2] // 4 + elif key == "vector": + # Run through a simple MLP + extractors[key] = nn.Linear(subspace.shape[0], 16) + total_concat_size += 16 + + self.extractors = nn.ModuleDict(extractors) + + # Update the features dim manually + self._features_dim = total_concat_size + + def forward(self, observations) -> th.Tensor: + encoded_tensor_list = [] + + # self.extractors contain nn.Modules that do all the processing. + for key, extractor in self.extractors.items(): + encoded_tensor_list.append(extractor(observations[key])) + # Return a (B, self._features_dim) PyTorch tensor, where B is batch dimension. + return th.cat(encoded_tensor_list, dim=1) + + On-Policy Algorithms ^^^^^^^^^^^^^^^^^^^^ diff --git a/docs/guide/examples.rst b/docs/guide/examples.rst index 6a0a5737a..35df576d4 100644 --- a/docs/guide/examples.rst +++ b/docs/guide/examples.rst @@ -149,6 +149,27 @@ Multiprocessing: Unleashing the Power of Vectorized Environments env.render() +Dict Observations +----------------- + +You can use environments with dictionary observation spaces. This is useful in the case where one can't directly +concatenate observations such as an image from a camera combined with a vector of servo sensor data (e.g., rotation angles). +Stable Baselines3 provides ``SimpleMultiObsEnv`` as an example of this kind of of setting. +The environment is a simple grid world but the observations for each cell come in the form of dictionaries. +These dictionaries are randomly initilaized on the creation of the environment and contain a vector observation and an image observation. + +.. code-block:: python + + from stable_baselines3 import PPO + from stable_baselines3.common.envs import SimpleMultiObsEnv + + + # Stable Baselines provides SimpleMultiObsEnv as an example environment with Dict observations + env = SimpleMultiObsEnv(random_start=False) + + model = PPO("MultiInputPolicy", env, verbose=1) + model.learn(total_timesteps=1e5) + Using Callback: Monitoring Training ----------------------------------- @@ -375,7 +396,7 @@ The parking env is a goal-conditioned continuous control task, in which the vehi import highway_env import numpy as np - from stable_baselines3 import HER, SAC, DDPG, TD3 + from stable_baselines3 import HerReplayBuffer, SAC, DDPG, TD3 from stable_baselines3.common.noise import NormalActionNoise env = gym.make("parking-v0") @@ -384,21 +405,23 @@ The parking env is a goal-conditioned continuous control task, in which the vehi n_sampled_goal = 4 # SAC hyperparams: - model = HER( - "MlpPolicy", + model = SAC( + "MultiInputPolicy", env, - SAC, - n_sampled_goal=n_sampled_goal, - goal_selection_strategy="future", - # IMPORTANT: because the env is not wrapped with a TimeLimit wrapper - # we have to manually specify the max number of steps per episode - max_episode_length=100, + replay_buffer_class=HerReplayBuffer, + replay_buffer_kwargs=dict( + n_sampled_goal=n_sampled_goal, + goal_selection_strategy="future", + # IMPORTANT: because the env is not wrapped with a TimeLimit wrapper + # we have to manually specify the max number of steps per episode + max_episode_length=100, + online_sampling=True, + ) verbose=1, buffer_size=int(1e6), learning_rate=1e-3, gamma=0.95, batch_size=256, - online_sampling=True, policy_kwargs=dict(net_arch=[256, 256, 256]), ) @@ -408,7 +431,7 @@ The parking env is a goal-conditioned continuous control task, in which the vehi # Load saved model # Because it needs access to `env.compute_reward()` # HER must be loaded with the env - model = HER.load("her_sac_highway", env=env) + model = SAC.load("her_sac_highway", env=env) obs = env.reset() @@ -658,12 +681,14 @@ to keep track of the agent progress. # ProcgenEnv is already vectorized venv = ProcgenEnv(num_envs=2, env_name='starpilot') - # PPO does not currently support Dict observations - # this will be solved in https://github.com/DLR-RM/stable-baselines3/pull/243 - venv = VecExtractDictObs(venv, "rgb") + + # To use only part of the observation: + # venv = VecExtractDictObs(venv, "rgb") + + # Wrap with a VecMonitor to collect stats and avoid errors venv = VecMonitor(venv=venv) - model = PPO("MlpPolicy", venv, verbose=1) + model = PPO("MultiInputPolicy", venv, verbose=1) model.learn(10000) diff --git a/docs/guide/vec_envs.rst b/docs/guide/vec_envs.rst index 76713efdf..7958fe0e3 100644 --- a/docs/guide/vec_envs.rst +++ b/docs/guide/vec_envs.rst @@ -71,6 +71,17 @@ VecFrameStack .. autoclass:: VecFrameStack :members: +StackedObservations +~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: stable_baselines3.common.vec_env.stacked_observations.StackedObservations + :members: + +StackedDictObservations +~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: stable_baselines3.common.vec_env.stacked_observations.StackedDictObservations + :members: VecNormalize ~~~~~~~~~~~~ diff --git a/docs/index.rst b/docs/index.rst index 4f41153e7..d55a35c89 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -75,6 +75,7 @@ Main Features common/atari_wrappers common/env_util + common/envs common/distributions common/evaluation common/env_checker diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 285cbc5d0..e4cab0f06 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -4,23 +4,45 @@ Changelog ========== -Release 1.1.0a5 (WIP) +Release 1.1.0a6 (WIP) --------------------------- +**Dict observation support, timeout handling and refactored HER** + Breaking Changes: ^^^^^^^^^^^^^^^^^ +- All customs environments (e.g. the ``BitFlippingEnv`` or ``IdentityEnv``) were moved to ``stable_baselines3.common.envs`` folder +- Refactored ``HER`` which is now the ``HerReplayBuffer`` class that can be passed to any off-policy algorithm +- Handle timeout termination properly for off-policy algorithms (when using ``TimeLimit``) - Renamed ``_last_dones`` and ``dones`` to ``_last_episode_starts`` and ``episode_starts`` in ``RolloutBuffer``. +- Removed ``ObsDictWrapper`` as ``Dict`` observation spaces are now supported + +.. code-block:: python + + her_kwargs = dict(n_sampled_goal=2, goal_selection_strategy="future", online_sampling=True) + # SB3 < 1.1.0 + # model = HER("MlpPolicy", env, model_class=SAC, **her_kwargs) + # SB3 >= 1.1.0: + model = SAC("MultiInputPolicy", env, replay_buffer_class=HerReplayBuffer, replay_buffer_kwargs=her_kwargs) + - Updated the KL Divergence estimator in the PPO algorithm to be positive definite and have lower variance (@09tangriro) - Updated the KL Divergence check in the PPO algorithm to be before the gradient update step rather than after end of epoch (@09tangriro) New Features: ^^^^^^^^^^^^^ +- Added support for single-level ``Dict`` observation space (@JadenTravnik) +- Added ``DictRolloutBuffer`` ``DictReplayBuffer`` to support dictionary observations (@JadenTravnik) +- Added ``StackedObservations`` and ``StackedDictObservations`` that are used within ``VecFrameStack`` +- Added simple 4x4 room Dict test environments +- ``HerReplayBuffer`` now supports ``VecNormalize`` when ``online_sampling=False`` - Added `VecMonitor `_ and `VecExtractDictObs `_ wrappers to handle gym3-style vectorized environments (@vwxyzjn) - Ignored the terminal observation if the it is not provided by the environment such as the gym3-style vectorized environments. (@vwxyzjn) - Add policy_base as input to the OnPolicyAlgorithm for more flexibility (@09tangriro) +- Added support for image observation when using ``HER`` +- Added ``replay_buffer_class`` and ``replay_buffer_kwargs`` arguments to off-policy algorithms Bug Fixes: ^^^^^^^^^^ @@ -34,9 +56,9 @@ Deprecations: Others: ^^^^^^^ - Added ``flake8-bugbear`` to tests dependencies to find likely bugs +- Updated ``env_checker`` to reflect support of dict observation spaces - Added Code of Conduct - Added tests for GAE and lambda return computation -- Updated docker image with newest black version Documentation: ^^^^^^^^^^^^^^ @@ -71,6 +93,7 @@ New Features: - Added support for ``custom_objects`` when loading models + Bug Fixes: ^^^^^^^^^^ - Fixed a bug with ``DQN`` predict method when using ``deterministic=False`` with image space @@ -81,10 +104,14 @@ Documentation: - Added new project using SB3: rl_reach (@PierreExeter) - Added note about slow-down when switching to PyTorch - Add a note on continual learning and resetting environment + +Others: +^^^^^^^ - Updated RL-Zoo to reflect the fact that is it more than a collection of trained agents - Added images to illustrate the training loop and custom policies (created with https://excalidraw.com/) - Updated the custom policy section + Pre-Release 0.11.1 (2021-02-27) ------------------------------- @@ -132,6 +159,7 @@ New Features: - Added new wrappers to log images and matplotlib figures to tensorboard. (@zampanteymedio) - Add support for text records to ``Logger``. (@lorenz-h) + Bug Fixes: ^^^^^^^^^^ - Fixed bug where code added VecTranspose on channel-first image environments (thanks @qxcv) @@ -657,5 +685,5 @@ And all the contributors: @flodorner @KuKuXia @NeoExtended @PartiallyTyped @mmcenta @richardwu @kinalmehta @rolandgvc @tkelestemur @mloo3 @tirafesi @blurLake @koulakis @joeljosephjin @shwang @rk37 @andyshih12 @RaphaelWag @xicocaio @diditforlulz273 @liorcohen5 @ManifoldFR @mloo3 @SwamyDev @wmmc88 @megan-klaiber @thisray -@tfederico @hn2 @LucasAlegre @AptX395 @zampanteymedio @decodyng @ardabbour @lorenz-h @mschweizer @lorepieri8 @vwxyzjn +@tfederico @hn2 @LucasAlegre @AptX395 @zampanteymedio @JadenTravnik @decodyng @ardabbour @lorenz-h @mschweizer @lorepieri8 @vwxyzjn @ShangqunYu @PierreExeter @JacopoPan @ltbd78 @tom-doerr @Atlis @liusida @09tangriro diff --git a/docs/modules/a2c.rst b/docs/modules/a2c.rst index 54c8a4273..6e56d89c3 100644 --- a/docs/modules/a2c.rst +++ b/docs/modules/a2c.rst @@ -40,6 +40,7 @@ Discrete ✔️ ✔️ Box ✔️ ✔️ MultiDiscrete ✔️ ✔️ MultiBinary ✔️ ✔️ +Dict ❌ ✔️ ============= ====== =========== @@ -163,3 +164,10 @@ A2C Policies .. autoclass:: stable_baselines3.common.policies.ActorCriticCnnPolicy :members: :noindex: + +.. autoclass:: MultiInputPolicy + :members: + +.. autoclass:: stable_baselines3.common.policies.MultiInputActorCriticPolicy + :members: + :noindex: diff --git a/docs/modules/ddpg.rst b/docs/modules/ddpg.rst index 75087b13f..40bef6507 100644 --- a/docs/modules/ddpg.rst +++ b/docs/modules/ddpg.rst @@ -23,6 +23,7 @@ trick for DQN with the deterministic policy gradient, to obtain an algorithm for MlpPolicy CnnPolicy + MultiInputPolicy Notes @@ -49,6 +50,7 @@ Discrete ❌ ✔️ Box ✔️ ✔️ MultiDiscrete ❌ ✔️ MultiBinary ❌ ✔️ +Dict ❌ ✔️ ============= ====== =========== @@ -168,3 +170,7 @@ DDPG Policies .. autoclass:: CnnPolicy :members: :noindex: + +.. autoclass:: MultiInputPolicy + :members: + :noindex: diff --git a/docs/modules/dqn.rst b/docs/modules/dqn.rst index f35788fff..2ada427e7 100644 --- a/docs/modules/dqn.rst +++ b/docs/modules/dqn.rst @@ -17,6 +17,7 @@ and make use of different tricks to stabilize the learning with neural networks: MlpPolicy CnnPolicy + MultiInputPolicy Notes @@ -44,6 +45,7 @@ Discrete ✔ ✔ Box ❌ ✔ MultiDiscrete ❌ ✔ MultiBinary ❌ ✔ +Dict ❌ ✔️ ============= ====== =========== @@ -53,7 +55,6 @@ Example .. code-block:: python import gym - import numpy as np from stable_baselines3 import DQN diff --git a/docs/modules/her.rst b/docs/modules/her.rst index fc0c696dc..047809ae0 100644 --- a/docs/modules/her.rst +++ b/docs/modules/her.rst @@ -13,6 +13,12 @@ HER uses the fact that even if a desired goal was not achieved, other goal may h It creates "virtual" transitions by relabeling transitions (changing the desired goal) from past episodes. +.. warning:: + + Starting from Stable Baselines3 v1.1.0, ``HER`` is no longer a separate algorithm + but a replay buffer class ``HerReplayBuffer`` that must be passed to an off-policy algorithm + when using ``MultiInputPolicy`` (to have Dict observation support). + .. warning:: @@ -27,11 +33,6 @@ It creates "virtual" transitions by relabeling transitions (changing the desired Otherwise, you can directly pass ``max_episode_length`` to the model constructor -.. warning:: - - ``HER`` supports ``VecNormalize`` wrapper but only when ``online_sampling=True`` - - .. warning:: Because it needs access to ``env.compute_reward()`` @@ -59,11 +60,10 @@ Example .. code-block:: python - from stable_baselines3 import HER, DDPG, DQN, SAC, TD3 + from stable_baselines3 import HerReplayBuffer, DDPG, DQN, SAC, TD3 from stable_baselines3.her.goal_selection_strategy import GoalSelectionStrategy - from stable_baselines3.common.bit_flipping_env import BitFlippingEnv + from stable_baselines3.common.envs import BitFlippingEnv from stable_baselines3.common.vec_env import DummyVecEnv - from stable_baselines3.common.vec_env.obs_dict_wrapper import ObsDictWrapper model_class = DQN # works also with SAC, DDPG and TD3 N_BITS = 15 @@ -79,15 +79,27 @@ Example max_episode_length = N_BITS # Initialize the model - model = HER('MlpPolicy', env, model_class, n_sampled_goal=4, goal_selection_strategy=goal_selection_strategy, online_sampling=online_sampling, - verbose=1, max_episode_length=max_episode_length) + model = model_class( + "MultiInputPolicy", + env, + replay_buffer_class=HerReplayBuffer, + # Parameters for HER + replay_buffer_kwargs=dict( + n_sampled_goal=4, + goal_selection_strategy=goal_selection_strategy, + online_sampling=online_sampling, + max_episode_length=max_episode_length, + ), + verbose=1, + ) + # Train the model model.learn(1000) model.save("./her_bit_env") # Because it needs access to `env.compute_reward()` # HER must be loaded with the env - model = HER.load('./her_bit_env', env=env) + model = model_class.load('./her_bit_env', env=env) obs = env.reset() for _ in range(100): @@ -123,43 +135,31 @@ Run the benchmark: .. code-block:: bash - python train.py --algo her --env parking-v0 --eval-episodes 10 --eval-freq 10000 + python train.py --algo tqc --env parking-v0 --eval-episodes 10 --eval-freq 10000 Plot the results: .. code-block:: bash - python scripts/all_plots.py -a her -e parking-v0 -f logs/ --no-million + python scripts/all_plots.py -a tqc -e parking-v0 -f logs/ --no-million Parameters ---------- -.. autoclass:: HER - :members: - -Goal Selection Strategies -------------------------- +HER Replay Buffer +----------------- -.. autoclass:: GoalSelectionStrategy +.. autoclass:: HerReplayBuffer :members: :inherited-members: - :undoc-members: -Obs Dict Wrapper ----------------- +Goal Selection Strategies +------------------------- -.. autoclass:: ObsDictWrapper +.. autoclass:: GoalSelectionStrategy :members: :inherited-members: :undoc-members: - - -HER Replay Buffer ------------------ - -.. autoclass:: HerReplayBuffer - :members: - :inherited-members: diff --git a/docs/modules/ppo.rst b/docs/modules/ppo.rst index 09e2e636f..abc50afca 100644 --- a/docs/modules/ppo.rst +++ b/docs/modules/ppo.rst @@ -42,6 +42,7 @@ Discrete ✔️ ✔️ Box ✔️ ✔️ MultiDiscrete ✔️ ✔️ MultiBinary ✔️ ✔️ +Dict ❌ ✔️ ============= ====== =========== Example @@ -164,3 +165,10 @@ PPO Policies .. autoclass:: stable_baselines3.common.policies.ActorCriticCnnPolicy :members: :noindex: + +.. autoclass:: MultiInputPolicy + :members: + +.. autoclass:: stable_baselines3.common.policies.MultiInputActorCriticPolicy + :members: + :noindex: diff --git a/docs/modules/sac.rst b/docs/modules/sac.rst index 63d4245d3..818d8c240 100644 --- a/docs/modules/sac.rst +++ b/docs/modules/sac.rst @@ -19,6 +19,7 @@ A key feature of SAC, and a major difference with common RL algorithms, is that MlpPolicy CnnPolicy + MultiInputPolicy Notes @@ -56,6 +57,7 @@ Discrete ❌ ✔️ Box ✔️ ✔️ MultiDiscrete ❌ ✔️ MultiBinary ❌ ✔️ +Dict ❌ ✔️ ============= ====== =========== @@ -169,3 +171,6 @@ SAC Policies .. autoclass:: CnnPolicy :members: + +.. autoclass:: MultiInputPolicy + :members: diff --git a/docs/modules/td3.rst b/docs/modules/td3.rst index 7a1a0f338..13fb235af 100644 --- a/docs/modules/td3.rst +++ b/docs/modules/td3.rst @@ -19,6 +19,7 @@ We recommend reading `OpenAI Spinning guide on TD3 Optional[GymEnv]: @@ -129,10 +128,10 @@ def __init__( self.learning_rate = learning_rate self.tensorboard_log = tensorboard_log self.lr_schedule = None # type: Optional[Schedule] - self._last_obs = None # type: Optional[np.ndarray] + self._last_obs = None # type: Optional[Union[np.ndarray, Dict[str, np.ndarray]]] self._last_episode_starts = None # type: Optional[np.ndarray] # When using VecNormalize: - self._last_original_obs = None # type: Optional[np.ndarray] + self._last_original_obs = None # type: Optional[Union[np.ndarray, Dict[str, np.ndarray]]] self._episode_num = 0 # Used for gSDE only self.use_sde = use_sde @@ -195,18 +194,33 @@ def _wrap_env(env: GymEnv, verbose: int = 0, monitor_wrapper: bool = True) -> Ve print("Wrapping the env in a DummyVecEnv.") env = DummyVecEnv([lambda: env]) - if ( - is_image_space(env.observation_space) - and not is_vecenv_wrapped(env, VecTransposeImage) - and not is_image_space_channels_first(env.observation_space) - ): - if verbose >= 1: - print("Wrapping the env in a VecTransposeImage.") - env = VecTransposeImage(env) + # Make sure that dict-spaces are not nested (not supported) + check_for_nested_spaces(env.observation_space) + + if isinstance(env.observation_space, gym.spaces.Dict): + for space in env.observation_space.spaces.values(): + if isinstance(space, gym.spaces.Dict): + raise ValueError("Nested observation spaces are not supported (Dict spaces inside Dict space).") + + if not is_vecenv_wrapped(env, VecTransposeImage): + wrap_with_vectranspose = False + if isinstance(env.observation_space, gym.spaces.Dict): + # If even one of the keys is a image-space in need of transpose, apply transpose + # If the image spaces are not consistent (for instance one is channel first, + # the other channel last), VecTransposeImage will throw an error + for space in env.observation_space.spaces.values(): + wrap_with_vectranspose = wrap_with_vectranspose or ( + is_image_space(space) and not is_image_space_channels_first(space) + ) + else: + wrap_with_vectranspose = is_image_space(env.observation_space) and not is_image_space_channels_first( + env.observation_space + ) - # check if wrapper for dict support is needed when using HER - if isinstance(env.observation_space, gym.spaces.dict.Dict): - env = ObsDictWrapper(env) + if wrap_with_vectranspose: + if verbose >= 1: + print("Wrapping the env in a VecTransposeImage.") + env = VecTransposeImage(env) return env @@ -275,6 +289,7 @@ def _excluded_save_params(self) -> List[str]: "replay_buffer", "rollout_buffer", "_vec_normalize_env", + "_episode_storage", ] def _get_torch_save_params(self) -> Tuple[List[str], List[str]]: diff --git a/stable_baselines3/common/buffers.py b/stable_baselines3/common/buffers.py index a6cba8c0d..253787d64 100644 --- a/stable_baselines3/common/buffers.py +++ b/stable_baselines3/common/buffers.py @@ -1,21 +1,26 @@ import warnings from abc import ABC, abstractmethod -from typing import Dict, Generator, Optional, Union +from typing import Any, Dict, Generator, List, Optional, Union import numpy as np import torch as th from gym import spaces +from stable_baselines3.common.preprocessing import get_action_dim, get_obs_shape +from stable_baselines3.common.type_aliases import ( + DictReplayBufferSamples, + DictRolloutBufferSamples, + ReplayBufferSamples, + RolloutBufferSamples, +) +from stable_baselines3.common.vec_env import VecNormalize + try: # Check memory used by replay buffer when possible import psutil except ImportError: psutil = None -from stable_baselines3.common.preprocessing import get_action_dim, get_obs_shape -from stable_baselines3.common.type_aliases import ReplayBufferSamples, RolloutBufferSamples -from stable_baselines3.common.vec_env import VecNormalize - class BaseBuffer(ABC): """ @@ -42,6 +47,7 @@ def __init__( self.observation_space = observation_space self.action_space = action_space self.obs_shape = get_obs_shape(observation_space) + self.action_dim = get_action_dim(action_space) self.pos = 0 self.full = False @@ -130,7 +136,8 @@ def to_torch(self, array: np.ndarray, copy: bool = True) -> th.Tensor: @staticmethod def _normalize_obs( - obs: Union[np.ndarray, Dict[str, np.ndarray]], env: Optional[VecNormalize] = None + obs: Union[np.ndarray, Dict[str, np.ndarray]], + env: Optional[VecNormalize] = None, ) -> Union[np.ndarray, Dict[str, np.ndarray]]: if env is not None: return env.normalize_obs(obs) @@ -157,6 +164,9 @@ class ReplayBuffer(BaseBuffer): at a cost of more complexity. See https://github.com/DLR-RM/stable-baselines3/issues/37#issuecomment-637501195 and https://github.com/DLR-RM/stable-baselines3/pull/28#issuecomment-637559274 + :param handle_timeout_termination: Handle timeout termination (due to timelimit) + separately and treat the task as infinite horizon task. + https://github.com/DLR-RM/stable-baselines3/issues/284 """ def __init__( @@ -167,6 +177,7 @@ def __init__( device: Union[th.device, str] = "cpu", n_envs: int = 1, optimize_memory_usage: bool = False, + handle_timeout_termination: bool = True, ): super(ReplayBuffer, self).__init__(buffer_size, observation_space, action_space, device, n_envs=n_envs) @@ -177,18 +188,27 @@ def __init__( mem_available = psutil.virtual_memory().available self.optimize_memory_usage = optimize_memory_usage + self.observations = np.zeros((self.buffer_size, self.n_envs) + self.obs_shape, dtype=observation_space.dtype) + if optimize_memory_usage: # `observations` contains also the next observation self.next_observations = None else: self.next_observations = np.zeros((self.buffer_size, self.n_envs) + self.obs_shape, dtype=observation_space.dtype) + self.actions = np.zeros((self.buffer_size, self.n_envs, self.action_dim), dtype=action_space.dtype) + self.rewards = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32) self.dones = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32) + # Handle timeouts termination properly if needed + # see https://github.com/DLR-RM/stable-baselines3/issues/284 + self.handle_timeout_termination = handle_timeout_termination + self.timeouts = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32) if psutil is not None: total_memory_usage = self.observations.nbytes + self.actions.nbytes + self.rewards.nbytes + self.dones.nbytes + if self.next_observations is not None: total_memory_usage += self.next_observations.nbytes @@ -201,9 +221,18 @@ def __init__( f"replay buffer {total_memory_usage:.2f}GB > {mem_available:.2f}GB" ) - def add(self, obs: np.ndarray, next_obs: np.ndarray, action: np.ndarray, reward: np.ndarray, done: np.ndarray) -> None: + def add( + self, + obs: np.ndarray, + next_obs: np.ndarray, + action: np.ndarray, + reward: np.ndarray, + done: np.ndarray, + infos: List[Dict[str, Any]], + ) -> None: # Copy to avoid modification by reference self.observations[self.pos] = np.array(obs).copy() + if self.optimize_memory_usage: self.observations[(self.pos + 1) % self.buffer_size] = np.array(next_obs).copy() else: @@ -213,6 +242,9 @@ def add(self, obs: np.ndarray, next_obs: np.ndarray, action: np.ndarray, reward: self.rewards[self.pos] = np.array(reward).copy() self.dones[self.pos] = np.array(done).copy() + if self.handle_timeout_termination: + self.timeouts[self.pos] = np.array([info.get("TimeLimit.truncated", False) for info in infos]) + self.pos += 1 if self.pos == self.buffer_size: self.full = True @@ -241,6 +273,7 @@ def sample(self, batch_size: int, env: Optional[VecNormalize] = None) -> ReplayB return self._get_samples(batch_inds, env=env) def _get_samples(self, batch_inds: np.ndarray, env: Optional[VecNormalize] = None) -> ReplayBufferSamples: + if self.optimize_memory_usage: next_obs = self._normalize_obs(self.observations[(batch_inds + 1) % self.buffer_size, 0, :], env) else: @@ -250,7 +283,9 @@ def _get_samples(self, batch_inds: np.ndarray, env: Optional[VecNormalize] = Non self._normalize_obs(self.observations[batch_inds, 0, :], env), self.actions[batch_inds, 0, :], next_obs, - self.dones[batch_inds], + # Only use dones that are not due to timeouts + # deactivated by default (timeouts is initialized as an array of False) + self.dones[batch_inds] * (1 - self.timeouts[batch_inds]), self._normalize_reward(self.rewards[batch_inds], env), ) return ReplayBufferSamples(*tuple(map(self.to_torch, data))) @@ -299,6 +334,7 @@ def __init__( self.reset() def reset(self) -> None: + self.observations = np.zeros((self.buffer_size, self.n_envs) + self.obs_shape, dtype=np.float32) self.actions = np.zeros((self.buffer_size, self.n_envs, self.action_dim), dtype=np.float32) self.rewards = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32) @@ -391,7 +427,17 @@ def get(self, batch_size: Optional[int] = None) -> Generator[RolloutBufferSample indices = np.random.permutation(self.buffer_size * self.n_envs) # Prepare the data if not self.generator_ready: - for tensor in ["observations", "actions", "values", "log_probs", "advantages", "returns"]: + + _tensor_names = [ + "observations", + "actions", + "values", + "log_probs", + "advantages", + "returns", + ] + + for tensor in _tensor_names: self.__dict__[tensor] = self.swap_and_flatten(self.__dict__[tensor]) self.generator_ready = True @@ -414,3 +460,280 @@ def _get_samples(self, batch_inds: np.ndarray, env: Optional[VecNormalize] = Non self.returns[batch_inds].flatten(), ) return RolloutBufferSamples(*tuple(map(self.to_torch, data))) + + +class DictReplayBuffer(ReplayBuffer): + """ + Dict Replay buffer used in off-policy algorithms like SAC/TD3. + Extends the ReplayBuffer to use dictionary observations + + :param buffer_size: Max number of element in the buffer + :param observation_space: Observation space + :param action_space: Action space + :param device: + :param n_envs: Number of parallel environments + :param optimize_memory_usage: Enable a memory efficient variant + Disabled for now (see https://github.com/DLR-RM/stable-baselines3/pull/243#discussion_r531535702) + :param handle_timeout_termination: Handle timeout termination (due to timelimit) + separately and treat the task as infinite horizon task. + https://github.com/DLR-RM/stable-baselines3/issues/284 + """ + + def __init__( + self, + buffer_size: int, + observation_space: spaces.Space, + action_space: spaces.Space, + device: Union[th.device, str] = "cpu", + n_envs: int = 1, + optimize_memory_usage: bool = False, + handle_timeout_termination: bool = True, + ): + super(ReplayBuffer, self).__init__(buffer_size, observation_space, action_space, device, n_envs=n_envs) + + assert isinstance(self.obs_shape, dict), "DictReplayBuffer must be used with Dict obs space only" + assert n_envs == 1, "Replay buffer only support single environment for now" + + # Check that the replay buffer can fit into the memory + if psutil is not None: + mem_available = psutil.virtual_memory().available + + assert optimize_memory_usage is False, "DictReplayBuffer does not support optimize_memory_usage" + # disabling as this adds quite a bit of complexity + # https://github.com/DLR-RM/stable-baselines3/pull/243#discussion_r531535702 + self.optimize_memory_usage = optimize_memory_usage + + self.observations = { + key: np.zeros((self.buffer_size, self.n_envs) + _obs_shape) for key, _obs_shape in self.obs_shape.items() + } + self.next_observations = { + key: np.zeros((self.buffer_size, self.n_envs) + _obs_shape) for key, _obs_shape in self.obs_shape.items() + } + + # only 1 env is supported + self.actions = np.zeros((self.buffer_size, self.action_dim), dtype=action_space.dtype) + self.rewards = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32) + self.dones = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32) + + # Handle timeouts termination properly if needed + # see https://github.com/DLR-RM/stable-baselines3/issues/284 + self.handle_timeout_termination = handle_timeout_termination + self.timeouts = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32) + + if psutil is not None: + obs_nbytes = 0 + for _, obs in self.observations.items(): + obs_nbytes += obs.nbytes + + total_memory_usage = obs_nbytes + self.actions.nbytes + self.rewards.nbytes + self.dones.nbytes + if self.next_observations is not None: + next_obs_nbytes = 0 + for _, obs in self.observations.items(): + next_obs_nbytes += obs.nbytes + total_memory_usage += next_obs_nbytes + + if total_memory_usage > mem_available: + # Convert to GB + total_memory_usage /= 1e9 + mem_available /= 1e9 + warnings.warn( + "This system does not have apparently enough memory to store the complete " + f"replay buffer {total_memory_usage:.2f}GB > {mem_available:.2f}GB" + ) + + def add( + self, + obs: Dict[str, np.ndarray], + next_obs: Dict[str, np.ndarray], + action: np.ndarray, + reward: np.ndarray, + done: np.ndarray, + infos: List[Dict[str, Any]], + ) -> None: + # Copy to avoid modification by reference + for key in self.observations.keys(): + self.observations[key][self.pos] = np.array(obs[key]).copy() + + for key in self.next_observations.keys(): + self.next_observations[key][self.pos] = np.array(next_obs[key]).copy() + + self.actions[self.pos] = np.array(action).copy() + self.rewards[self.pos] = np.array(reward).copy() + self.dones[self.pos] = np.array(done).copy() + + if self.handle_timeout_termination: + self.timeouts[self.pos] = np.array([info.get("TimeLimit.truncated", False) for info in infos]) + + self.pos += 1 + if self.pos == self.buffer_size: + self.full = True + self.pos = 0 + + def sample(self, batch_size: int, env: Optional[VecNormalize] = None) -> DictReplayBufferSamples: + """ + Sample elements from the replay buffer. + + :param batch_size: Number of element to sample + :param env: associated gym VecEnv + to normalize the observations/rewards when sampling + :return: + """ + return super(ReplayBuffer, self).sample(batch_size=batch_size, env=env) + + def _get_samples(self, batch_inds: np.ndarray, env: Optional[VecNormalize] = None) -> DictReplayBufferSamples: + + # Normalize if needed and remove extra dimension (we are using only one env for now) + obs_ = self._normalize_obs({key: obs[batch_inds, 0, :] for key, obs in self.observations.items()}) + next_obs_ = self._normalize_obs({key: obs[batch_inds, 0, :] for key, obs in self.next_observations.items()}) + + # Convert to torch tensor + observations = {key: self.to_torch(obs) for key, obs in obs_.items()} + next_observations = {key: self.to_torch(obs) for key, obs in next_obs_.items()} + + return DictReplayBufferSamples( + observations=observations, + actions=self.to_torch(self.actions[batch_inds]), + next_observations=next_observations, + # Only use dones that are not due to timeouts + # deactivated by default (timeouts is initialized as an array of False) + dones=self.to_torch(self.dones[batch_inds] * (1 - self.timeouts[batch_inds])), + rewards=self.to_torch(self._normalize_reward(self.rewards[batch_inds], env)), + ) + + +class DictRolloutBuffer(RolloutBuffer): + """ + Dict Rollout buffer used in on-policy algorithms like A2C/PPO. + Extends the RolloutBuffer to use dictionary observations + + It corresponds to ``buffer_size`` transitions collected + using the current policy. + This experience will be discarded after the policy update. + In order to use PPO objective, we also store the current value of each state + and the log probability of each taken action. + + The term rollout here refers to the model-free notion and should not + be used with the concept of rollout used in model-based RL or planning. + Hence, it is only involved in policy and value function training but not action selection. + + :param buffer_size: Max number of element in the buffer + :param observation_space: Observation space + :param action_space: Action space + :param device: + :param gae_lambda: Factor for trade-off of bias vs variance for Generalized Advantage Estimator + Equivalent to classic advantage when set to 1. + :param gamma: Discount factor + :param n_envs: Number of parallel environments + """ + + def __init__( + self, + buffer_size: int, + observation_space: spaces.Space, + action_space: spaces.Space, + device: Union[th.device, str] = "cpu", + gae_lambda: float = 1, + gamma: float = 0.99, + n_envs: int = 1, + ): + + super(RolloutBuffer, self).__init__(buffer_size, observation_space, action_space, device, n_envs=n_envs) + + assert isinstance(self.obs_shape, dict), "DictRolloutBuffer must be used with Dict obs space only" + + self.gae_lambda = gae_lambda + self.gamma = gamma + self.observations, self.actions, self.rewards, self.advantages = None, None, None, None + self.returns, self.episode_starts, self.values, self.log_probs = None, None, None, None + self.generator_ready = False + self.reset() + + def reset(self) -> None: + assert isinstance(self.obs_shape, dict), "DictRolloutBuffer must be used with Dict obs space only" + self.observations = {} + for key, obs_input_shape in self.obs_shape.items(): + self.observations[key] = np.zeros((self.buffer_size, self.n_envs) + obs_input_shape, dtype=np.float32) + self.actions = np.zeros((self.buffer_size, self.n_envs, self.action_dim), dtype=np.float32) + self.rewards = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32) + self.returns = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32) + self.episode_starts = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32) + self.values = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32) + self.log_probs = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32) + self.advantages = np.zeros((self.buffer_size, self.n_envs), dtype=np.float32) + self.generator_ready = False + super(RolloutBuffer, self).reset() + + def add( + self, + obs: Dict[str, np.ndarray], + action: np.ndarray, + reward: np.ndarray, + episode_start: np.ndarray, + value: th.Tensor, + log_prob: th.Tensor, + ) -> None: + """ + :param obs: Observation + :param action: Action + :param reward: + :param episode_start: Start of episode signal. + :param value: estimated value of the current state + following the current policy. + :param log_prob: log probability of the action + following the current policy. + """ + if len(log_prob.shape) == 0: + # Reshape 0-d tensor to avoid error + log_prob = log_prob.reshape(-1, 1) + + for key in self.observations.keys(): + obs_ = np.array(obs[key]).copy() + # Reshape needed when using multiple envs with discrete observations + # as numpy cannot broadcast (n_discrete,) to (n_discrete, 1) + if isinstance(self.observation_space.spaces[key], spaces.Discrete): + obs_ = obs_.reshape((self.n_envs,) + self.obs_shape[key]) + self.observations[key][self.pos] = obs_ + + self.actions[self.pos] = np.array(action).copy() + self.rewards[self.pos] = np.array(reward).copy() + self.episode_starts[self.pos] = np.array(episode_start).copy() + self.values[self.pos] = value.clone().cpu().numpy().flatten() + self.log_probs[self.pos] = log_prob.clone().cpu().numpy() + self.pos += 1 + if self.pos == self.buffer_size: + self.full = True + + def get(self, batch_size: Optional[int] = None) -> Generator[DictRolloutBufferSamples, None, None]: + assert self.full, "" + indices = np.random.permutation(self.buffer_size * self.n_envs) + # Prepare the data + if not self.generator_ready: + + for key, obs in self.observations.items(): + self.observations[key] = self.swap_and_flatten(obs) + + _tensor_names = ["actions", "values", "log_probs", "advantages", "returns"] + + for tensor in _tensor_names: + self.__dict__[tensor] = self.swap_and_flatten(self.__dict__[tensor]) + self.generator_ready = True + + # Return everything, don't create minibatches + if batch_size is None: + batch_size = self.buffer_size * self.n_envs + + start_idx = 0 + while start_idx < self.buffer_size * self.n_envs: + yield self._get_samples(indices[start_idx : start_idx + batch_size]) + start_idx += batch_size + + def _get_samples(self, batch_inds: np.ndarray, env: Optional[VecNormalize] = None) -> DictRolloutBufferSamples: + + return DictRolloutBufferSamples( + observations={key: self.to_torch(obs[batch_inds]) for (key, obs) in self.observations.items()}, + actions=self.to_torch(self.actions[batch_inds]), + old_values=self.to_torch(self.values[batch_inds].flatten()), + old_log_prob=self.to_torch(self.log_probs[batch_inds].flatten()), + advantages=self.to_torch(self.advantages[batch_inds].flatten()), + returns=self.to_torch(self.returns[batch_inds].flatten()), + ) diff --git a/stable_baselines3/common/env_checker.py b/stable_baselines3/common/env_checker.py index 262722fd1..6bd097dae 100644 --- a/stable_baselines3/common/env_checker.py +++ b/stable_baselines3/common/env_checker.py @@ -5,6 +5,7 @@ import numpy as np from gym import spaces +from stable_baselines3.common.preprocessing import is_image_space_channels_first from stable_baselines3.common.vec_env import DummyVecEnv, VecCheckNan @@ -16,14 +17,14 @@ def _is_numpy_array_space(space: spaces.Space) -> bool: return not isinstance(space, (spaces.Dict, spaces.Tuple)) -def _check_image_input(observation_space: spaces.Box) -> None: +def _check_image_input(observation_space: spaces.Box, key: str = "") -> None: """ Check that the input will be compatible with Stable-Baselines when the observation is apparently an image. """ if observation_space.dtype != np.uint8: warnings.warn( - "It seems that your observation is an image but the `dtype` " + f"It seems that your observation {key} is an image but the `dtype` " "of your observation_space is not `np.uint8`. " "If your observation is not an image, we recommend you to flatten the observation " "to have only a 1D vector" @@ -31,38 +32,49 @@ def _check_image_input(observation_space: spaces.Box) -> None: if np.any(observation_space.low != 0) or np.any(observation_space.high != 255): warnings.warn( - "It seems that your observation space is an image but the " + f"It seems that your observation space {key} is an image but the " "upper and lower bounds are not in [0, 255]. " "Because the CNN policy normalize automatically the observation " "you may encounter issue if the values are not in that range." ) - if observation_space.shape[0] < 36 or observation_space.shape[1] < 36: + non_channel_idx = 0 + # Check only if width/height of the image is big enough + if is_image_space_channels_first(observation_space): + non_channel_idx = -1 + + if observation_space.shape[non_channel_idx] < 36 or observation_space.shape[1] < 36: warnings.warn( - "The minimal resolution for an image is 36x36 for the default CnnPolicy. " - "You might need to use a custom `cnn_extractor` " - "cf https://stable-baselines3.readthedocs.io/en/master/guide/custom_policy.html" + "The minimal resolution for an image is 36x36 for the default `CnnPolicy`. " + "You might need to use a custom feature extractor " + "cf. https://stable-baselines3.readthedocs.io/en/master/guide/custom_policy.html" ) def _check_unsupported_spaces(env: gym.Env, observation_space: spaces.Space, action_space: spaces.Space) -> None: """Emit warnings when the observation space or action space used is not supported by Stable-Baselines.""" - if isinstance(observation_space, spaces.Dict) and not isinstance(env, gym.GoalEnv): - warnings.warn( - "The observation space is a Dict but the environment is not a gym.GoalEnv " - "(cf https://github.com/openai/gym/blob/master/gym/core.py), " - "this is currently not supported by Stable Baselines " - "(cf https://github.com/hill-a/stable-baselines/issues/133), " - "you will need to use a custom policy. " - ) + if isinstance(observation_space, spaces.Dict): + nested_dict = False + for space in observation_space.spaces.values(): + if isinstance(space, spaces.Dict): + nested_dict = True + if nested_dict: + warnings.warn( + "Nested observation spaces are not supported by Stable Baselines3 " + "(Dict spaces inside Dict space). " + "You should flatten it to have only one level of keys." + "For example, `dict(space1=dict(space2=Box(), space3=Box()), spaces4=Discrete())` " + "is not supported but `dict(space2=Box(), spaces3=Box(), spaces4=Discrete())` is." + ) if isinstance(observation_space, spaces.Tuple): warnings.warn( "The observation space is a Tuple," - "this is currently not supported by Stable Baselines " - "(cf https://github.com/hill-a/stable-baselines/issues/133), " - "you will need to flatten the observation and maybe use a custom policy. " + "this is currently not supported by Stable Baselines3. " + "However, you can convert it to a Dict observation space " + "(cf. https://github.com/openai/gym/blob/master/gym/spaces/dict.py). " + "which is supported by SB3." ) if not _is_numpy_array_space(action_space): @@ -89,19 +101,37 @@ def _check_obs(obs: Union[tuple, dict, np.ndarray, int], observation_space: spac if not isinstance(observation_space, spaces.Tuple): assert not isinstance( obs, tuple - ), "The observation returned by the `{}()` method should be a single value, not a tuple".format(method_name) + ), f"The observation returned by the `{method_name}()` method should be a single value, not a tuple" # The check for a GoalEnv is done by the base class if isinstance(observation_space, spaces.Discrete): - assert isinstance(obs, int), "The observation returned by `{}()` method must be an int".format(method_name) + assert isinstance(obs, int), f"The observation returned by `{method_name}()` method must be an int" elif _is_numpy_array_space(observation_space): - assert isinstance(obs, np.ndarray), "The observation returned by `{}()` method must be a numpy array".format( - method_name - ) + assert isinstance(obs, np.ndarray), f"The observation returned by `{method_name}()` method must be a numpy array" assert observation_space.contains( obs - ), "The observation returned by the `{}()` method does not match the given observation space".format(method_name) + ), f"The observation returned by the `{method_name}()` method does not match the given observation space" + + +def _check_box_obs(observation_space: spaces.Box, key: str = "") -> None: + """ + Check that the observation space is correctly formatted + when dealing with a ``Box()`` space. In particular, it checks: + - that the dimensions are big enough when it is an image, and that the type matches + - that the observation has an expected shape (warn the user if not) + """ + # If image, check the low and high values, the type and the number of channels + # and the shape (minimal value) + if len(observation_space.shape) == 3: + _check_image_input(observation_space) + + if len(observation_space.shape) not in [1, 3]: + warnings.warn( + f"Your observation {key} has an unconventional shape (neither an image, nor a 1D vector). " + "We recommend you to flatten the observation " + "to have only a 1D vector or use a custom policy to properly process the data." + ) def _check_returned_values(env: gym.Env, observation_space: spaces.Space, action_space: spaces.Space) -> None: @@ -111,7 +141,15 @@ def _check_returned_values(env: gym.Env, observation_space: spaces.Space, action # because env inherits from gym.Env, we assume that `reset()` and `step()` methods exists obs = env.reset() - _check_obs(obs, observation_space, "reset") + if isinstance(observation_space, spaces.Dict): + assert isinstance(obs, dict), "The observation returned by `reset()` must be a dictionary" + for key in observation_space.spaces.keys(): + try: + _check_obs(obs[key], observation_space.spaces[key], "reset") + except AssertionError as e: + raise AssertionError(f"Error while checking key={key}: " + str(e)) + else: + _check_obs(obs, observation_space, "reset") # Sample a random action action = action_space.sample() @@ -122,7 +160,16 @@ def _check_returned_values(env: gym.Env, observation_space: spaces.Space, action # Unpack obs, reward, done, info = data - _check_obs(obs, observation_space, "step") + if isinstance(observation_space, spaces.Dict): + assert isinstance(obs, dict), "The observation returned by `step()` must be a dictionary" + for key in observation_space.spaces.keys(): + try: + _check_obs(obs[key], observation_space.spaces[key], "step") + except AssertionError as e: + raise AssertionError(f"Error while checking key={key}: " + str(e)) + + else: + _check_obs(obs, observation_space, "step") # We also allow int because the reward will be cast to float assert isinstance(reward, (float, int)), "The reward returned by `step()` must be a float" @@ -149,7 +196,8 @@ def _check_spaces(env: gym.Env) -> None: assert isinstance(env.action_space, spaces.Space), "The action space must inherit from gym.spaces" + gym_spaces -def _check_render(env: gym.Env, warn: bool = True, headless: bool = False) -> None: +# Check render cannot be covered by CI +def _check_render(env: gym.Env, warn: bool = True, headless: bool = False) -> None: # pragma: no cover """ Check the declared render modes and the `render()`/`close()` method of the environment. @@ -210,17 +258,10 @@ def check_env(env: gym.Env, warn: bool = True, skip_render_check: bool = True) - if warn: _check_unsupported_spaces(env, observation_space, action_space) - # If image, check the low and high values, the type and the number of channels - # and the shape (minimal value) - if isinstance(observation_space, spaces.Box) and len(observation_space.shape) == 3: - _check_image_input(observation_space) - - if isinstance(observation_space, spaces.Box) and len(observation_space.shape) not in [1, 3]: - warnings.warn( - "Your observation has an unconventional shape (neither an image, nor a 1D vector). " - "We recommend you to flatten the observation " - "to have only a 1D vector" - ) + obs_spaces = observation_space.spaces if isinstance(observation_space, spaces.Dict) else {"": observation_space} + for key, space in obs_spaces.items(): + if isinstance(space, spaces.Box): + _check_box_obs(space, key) # Check for the action space, it may lead to hard-to-debug issues if isinstance(action_space, spaces.Box) and ( @@ -238,7 +279,7 @@ def check_env(env: gym.Env, warn: bool = True, skip_render_check: bool = True) - # ==== Check the render method and the declared render modes ==== if not skip_render_check: - _check_render(env, warn=warn) + _check_render(env, warn=warn) # pragma: no cover # The check only works with numpy arrays if _is_numpy_array_space(observation_space) and _is_numpy_array_space(action_space): diff --git a/stable_baselines3/common/envs/__init__.py b/stable_baselines3/common/envs/__init__.py new file mode 100644 index 000000000..23bd5750f --- /dev/null +++ b/stable_baselines3/common/envs/__init__.py @@ -0,0 +1,9 @@ +from stable_baselines3.common.envs.bit_flipping_env import BitFlippingEnv +from stable_baselines3.common.envs.identity_env import ( + FakeImageEnv, + IdentityEnv, + IdentityEnvBox, + IdentityEnvMultiBinary, + IdentityEnvMultiDiscrete, +) +from stable_baselines3.common.envs.multi_input_envs import SimpleMultiObsEnv diff --git a/stable_baselines3/common/bit_flipping_env.py b/stable_baselines3/common/envs/bit_flipping_env.py similarity index 57% rename from stable_baselines3/common/bit_flipping_env.py rename to stable_baselines3/common/envs/bit_flipping_env.py index 62f07100f..f5c2fb4d3 100644 --- a/stable_baselines3/common/bit_flipping_env.py +++ b/stable_baselines3/common/envs/bit_flipping_env.py @@ -20,15 +20,25 @@ class BitFlippingEnv(GoalEnv): by default, it uses the discrete one :param max_steps: Max number of steps, by default, equal to n_bits :param discrete_obs_space: Whether to use the discrete observation - version or not, by default, it uses the MultiBinary one + version or not, by default, it uses the ``MultiBinary`` one + :param image_obs_space: Use image as input instead of the ``MultiBinary`` one. + :param channel_first: Whether to use channel-first or last image. """ spec = EnvSpec("BitFlippingEnv-v0") def __init__( - self, n_bits: int = 10, continuous: bool = False, max_steps: Optional[int] = None, discrete_obs_space: bool = False + self, + n_bits: int = 10, + continuous: bool = False, + max_steps: Optional[int] = None, + discrete_obs_space: bool = False, + image_obs_space: bool = False, + channel_first: bool = True, ): super(BitFlippingEnv, self).__init__() + # Shape of the observation when using image space + self.image_shape = (1, 36, 36) if channel_first else (36, 36, 1) # The achieved goal is determined by the current state # here, it is a special where they are equal if discrete_obs_space: @@ -36,9 +46,35 @@ def __init__( # representation of the observation self.observation_space = spaces.Dict( { - "observation": spaces.Discrete(2 ** n_bits - 1), - "achieved_goal": spaces.Discrete(2 ** n_bits - 1), - "desired_goal": spaces.Discrete(2 ** n_bits - 1), + "observation": spaces.Discrete(2 ** n_bits), + "achieved_goal": spaces.Discrete(2 ** n_bits), + "desired_goal": spaces.Discrete(2 ** n_bits), + } + ) + elif image_obs_space: + # When using image as input, + # one image contains the bits 0 -> 0, 1 -> 255 + # and the rest is filled with zeros + self.observation_space = spaces.Dict( + { + "observation": spaces.Box( + low=0, + high=255, + shape=self.image_shape, + dtype=np.uint8, + ), + "achieved_goal": spaces.Box( + low=0, + high=255, + shape=self.image_shape, + dtype=np.uint8, + ), + "desired_goal": spaces.Box( + low=0, + high=255, + shape=self.image_shape, + dtype=np.uint8, + ), } ) else: @@ -58,6 +94,7 @@ def __init__( self.action_space = spaces.Discrete(n_bits) self.continuous = continuous self.discrete_obs_space = discrete_obs_space + self.image_obs_space = image_obs_space self.state = None self.desired_goal = np.ones((n_bits,)) if max_steps is None: @@ -79,13 +116,38 @@ def convert_if_needed(self, state: np.ndarray) -> Union[int, np.ndarray]: # The internal state is the binary representation of the # observed one return int(sum([state[i] * 2 ** i for i in range(len(state))])) + + if self.image_obs_space: + size = np.prod(self.image_shape) + image = np.concatenate((state * 255, np.zeros(size - len(state), dtype=np.uint8))) + return image.reshape(self.image_shape).astype(np.uint8) + return state + + def convert_to_bit_vector(self, state: Union[int, np.ndarray], batch_size: int) -> np.ndarray: + """ + Convert to bit vector if needed. + + :param state: + :param batch_size: + :return: + """ + # Convert back to bit vector + if isinstance(state, int): + state = np.array(state).reshape(batch_size, -1) + # Convert to binary representation + state = (((state[:, :] & (1 << np.arange(len(self.state))))) > 0).astype(int) + elif self.image_obs_space: + state = state.reshape(batch_size, -1)[:, : len(self.state)] / 255 + else: + state = np.array(state).reshape(batch_size, -1) + return state def _get_obs(self) -> Dict[str, Union[int, np.ndarray]]: """ Helper to create the observation. - :return: + :return: The current observation. """ return OrderedDict( [ @@ -117,8 +179,19 @@ def step(self, action: Union[np.ndarray, int]) -> GymStepReturn: def compute_reward( self, achieved_goal: Union[int, np.ndarray], desired_goal: Union[int, np.ndarray], _info: Optional[Dict[str, Any]] ) -> np.float32: + # As we are using a vectorized version, we need to keep track of the `batch_size` + if isinstance(achieved_goal, int): + batch_size = 1 + elif self.image_obs_space: + batch_size = achieved_goal.shape[0] if len(achieved_goal.shape) > 3 else 1 + else: + batch_size = achieved_goal.shape[0] if len(achieved_goal.shape) > 1 else 1 + + desired_goal = self.convert_to_bit_vector(desired_goal, batch_size) + achieved_goal = self.convert_to_bit_vector(achieved_goal, batch_size) + # Deceptive reward: it is positive only when the goal is achieved - # vectorized version + # Here we are using a vectorized version distance = np.linalg.norm(achieved_goal - desired_goal, axis=-1) return -(distance > 0).astype(np.float32) diff --git a/stable_baselines3/common/identity_env.py b/stable_baselines3/common/envs/identity_env.py similarity index 100% rename from stable_baselines3/common/identity_env.py rename to stable_baselines3/common/envs/identity_env.py diff --git a/stable_baselines3/common/envs/multi_input_envs.py b/stable_baselines3/common/envs/multi_input_envs.py new file mode 100644 index 000000000..dd49f475e --- /dev/null +++ b/stable_baselines3/common/envs/multi_input_envs.py @@ -0,0 +1,180 @@ +from typing import Dict, Union + +import gym +import numpy as np + +from stable_baselines3.common.type_aliases import GymStepReturn + + +class SimpleMultiObsEnv(gym.Env): + """ + Base class for GridWorld-based MultiObs Environments 4x4 grid world. + + .. code-block:: text + + ____________ + | 0 1 2 3| + | 4|¯5¯¯6¯| 7| + | 8|_9_10_|11| + |12 13 14 15| + ¯¯¯¯¯¯¯¯¯¯¯¯¯¯ + + start is 0 + states 5, 6, 9, and 10 are blocked + goal is 15 + actions are = [left, down, right, up] + + simple linear state env of 15 states but encoded with a vector and an image observation: + each column is represented by a random vector and each row is + represented by a random image, both sampled once at creation time. + + :param num_col: Number of columns in the grid + :param num_row: Number of rows in the grid + :param random_start: If true, agent starts in random position + :param channel_last: If true, the image will be channel last, else it will be channel first + """ + + def __init__( + self, + num_col: int = 4, + num_row: int = 4, + random_start: bool = True, + discrete_actions: bool = True, + channel_last: bool = True, + ): + super(SimpleMultiObsEnv, self).__init__() + + self.vector_size = 5 + if channel_last: + self.img_size = [64, 64, 1] + else: + self.img_size = [1, 64, 64] + + self.random_start = random_start + self.discrete_actions = discrete_actions + if discrete_actions: + self.action_space = gym.spaces.Discrete(4) + else: + self.action_space = gym.spaces.Box(0, 1, (4,)) + + self.observation_space = gym.spaces.Dict( + spaces={ + "vec": gym.spaces.Box(0, 1, (self.vector_size,)), + "img": gym.spaces.Box(0, 255, self.img_size, dtype=np.uint8), + } + ) + self.count = 0 + # Timeout + self.max_count = 100 + self.log = "" + self.state = 0 + self.action2str = ["left", "down", "right", "up"] + self.init_possible_transitions() + + self.num_col = num_col + self.state_mapping = [] + self.init_state_mapping(num_col, num_row) + + self.max_state = len(self.state_mapping) - 1 + + def init_state_mapping(self, num_col: int, num_row: int) -> None: + """ + Initializes the state_mapping array which holds the observation values for each state + + :param num_col: Number of columns. + :param num_row: Number of rows. + """ + # Each column is represented by a random vector + col_vecs = np.random.random((num_col, self.vector_size)) + # Each row is represented by a random image + row_imgs = np.random.randint(0, 255, (num_row, 64, 64), dtype=np.int32) + + for i in range(num_col): + for j in range(num_row): + self.state_mapping.append({"vec": col_vecs[i], "img": row_imgs[j].reshape(self.img_size)}) + + def get_state_mapping(self) -> Dict[str, np.ndarray]: + """ + Uses the state to get the observation mapping. + + :return: observation dict {'vec': ..., 'img': ...} + """ + return self.state_mapping[self.state] + + def init_possible_transitions(self) -> None: + """ + Initializes the transitions of the environment + The environment exploits the cardinal directions of the grid by noting that + they correspond to simple addition and subtraction from the cell id within the grid + + - up => means moving up a row => means subtracting the length of a column + - down => means moving down a row => means adding the length of a column + - left => means moving left by one => means subtracting 1 + - right => means moving right by one => means adding 1 + + Thus one only needs to specify in which states each action is possible + in order to define the transitions of the environment + """ + self.left_possible = [1, 2, 3, 13, 14, 15] + self.down_possible = [0, 4, 8, 3, 7, 11] + self.right_possible = [0, 1, 2, 12, 13, 14] + self.up_possible = [4, 8, 12, 7, 11, 15] + + def step(self, action: Union[int, float, np.ndarray]) -> GymStepReturn: + """ + Run one timestep of the environment's dynamics. When end of + episode is reached, you are responsible for calling `reset()` + to reset this environment's state. + Accepts an action and returns a tuple (observation, reward, done, info). + + :param action: + :return: tuple (observation, reward, done, info). + """ + if not self.discrete_actions: + action = np.argmax(action) + else: + action = int(action) + + self.count += 1 + + prev_state = self.state + + reward = -0.1 + # define state transition + if self.state in self.left_possible and action == 0: # left + self.state -= 1 + elif self.state in self.down_possible and action == 1: # down + self.state += self.num_col + elif self.state in self.right_possible and action == 2: # right + self.state += 1 + elif self.state in self.up_possible and action == 3: # up + self.state -= self.num_col + + got_to_end = self.state == self.max_state + reward = 1 if got_to_end else reward + done = self.count > self.max_count or got_to_end + + self.log = f"Went {self.action2str[action]} in state {prev_state}, got to state {self.state}" + + return self.get_state_mapping(), reward, done, {"got_to_end": got_to_end} + + def render(self, mode: str = "human") -> None: + """ + Prints the log of the environment. + + :param mode: + """ + print(self.log) + + def reset(self) -> Dict[str, np.ndarray]: + """ + Resets the environment state and step count and returns reset observation. + + :return: observation dict {'vec': ..., 'img': ...} + """ + self.count = 0 + if not self.random_start: + self.state = 0 + else: + self.state = np.random.randint(0, self.max_state) + return self.state_mapping[self.state] diff --git a/stable_baselines3/common/off_policy_algorithm.py b/stable_baselines3/common/off_policy_algorithm.py index f140fa4fb..46e6d56ff 100644 --- a/stable_baselines3/common/off_policy_algorithm.py +++ b/stable_baselines3/common/off_policy_algorithm.py @@ -10,7 +10,7 @@ from stable_baselines3.common import logger from stable_baselines3.common.base_class import BaseAlgorithm -from stable_baselines3.common.buffers import ReplayBuffer +from stable_baselines3.common.buffers import DictReplayBuffer, ReplayBuffer from stable_baselines3.common.callbacks import BaseCallback from stable_baselines3.common.noise import ActionNoise from stable_baselines3.common.policies import BasePolicy @@ -18,6 +18,7 @@ from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, RolloutReturn, Schedule, TrainFreq, TrainFrequencyUnit from stable_baselines3.common.utils import safe_mean, should_collect_more_steps from stable_baselines3.common.vec_env import VecEnv +from stable_baselines3.her.her_replay_buffer import HerReplayBuffer class OffPolicyAlgorithm(BaseAlgorithm): @@ -42,6 +43,9 @@ class OffPolicyAlgorithm(BaseAlgorithm): during the rollout. :param action_noise: the action noise type (None by default), this can help for hard exploration problem. Cf common.noise for the different action noise type. + :param replay_buffer_class: Replay buffer class to use (for instance ``HerReplayBuffer``). + If ``None``, it will be automatically selected. + :param replay_buffer_kwargs: Keyword arguments to pass to the replay buffer on creation. :param optimize_memory_usage: Enable a memory efficient variant of the replay buffer at a cost of more complexity. See https://github.com/DLR-RM/stable-baselines3/issues/37#issuecomment-637501195 @@ -76,7 +80,7 @@ def __init__( env: Union[GymEnv, str], policy_base: Type[BasePolicy], learning_rate: Union[float, Schedule], - buffer_size: int = 1000000, + buffer_size: int = 1000000, # 1e6 learning_starts: int = 100, batch_size: int = 256, tau: float = 0.005, @@ -84,6 +88,8 @@ def __init__( train_freq: Union[int, Tuple[int, str]] = (1, "step"), gradient_steps: int = 1, action_noise: Optional[ActionNoise] = None, + replay_buffer_class: Optional[ReplayBuffer] = None, + replay_buffer_kwargs: Optional[Dict[str, Any]] = None, optimize_memory_usage: bool = False, policy_kwargs: Dict[str, Any] = None, tensorboard_log: Optional[str] = None, @@ -126,6 +132,11 @@ def __init__( self.gradient_steps = gradient_steps self.action_noise = action_noise self.optimize_memory_usage = optimize_memory_usage + self.replay_buffer_class = replay_buffer_class + if replay_buffer_kwargs is None: + replay_buffer_kwargs = {} + self.replay_buffer_kwargs = replay_buffer_kwargs + self._episode_storage = None # Remove terminations (dones) that are due to time limit # see https://github.com/hill-a/stable-baselines/issues/863 @@ -167,13 +178,47 @@ def _convert_train_freq(self) -> None: def _setup_model(self) -> None: self._setup_lr_schedule() self.set_random_seed(self.seed) - self.replay_buffer = ReplayBuffer( - self.buffer_size, - self.observation_space, - self.action_space, - self.device, - optimize_memory_usage=self.optimize_memory_usage, - ) + + # Use DictReplayBuffer if needed + if self.replay_buffer_class is None: + if isinstance(self.observation_space, gym.spaces.Dict): + self.replay_buffer_class = DictReplayBuffer + else: + self.replay_buffer_class = ReplayBuffer + + elif self.replay_buffer_class == HerReplayBuffer: + assert self.env is not None, "You must pass an environment when using `HerReplayBuffer`" + + # If using offline sampling, we need a classic replay buffer too + if self.replay_buffer_kwargs.get("online_sampling", True): + replay_buffer = None + else: + replay_buffer = DictReplayBuffer( + self.buffer_size, + self.observation_space, + self.action_space, + self.device, + optimize_memory_usage=self.optimize_memory_usage, + ) + + self.replay_buffer = HerReplayBuffer( + self.env, + self.buffer_size, + self.device, + replay_buffer=replay_buffer, + **self.replay_buffer_kwargs, + ) + + if self.replay_buffer is None: + self.replay_buffer = self.replay_buffer_class( + self.buffer_size, + self.observation_space, + self.action_space, + self.device, + optimize_memory_usage=self.optimize_memory_usage, + **self.replay_buffer_kwargs, + ) + self.policy = self.policy_class( # pytype:disable=not-instantiable self.observation_space, self.action_space, @@ -195,15 +240,35 @@ def save_replay_buffer(self, path: Union[str, pathlib.Path, io.BufferedIOBase]) assert self.replay_buffer is not None, "The replay buffer is not defined" save_to_pkl(path, self.replay_buffer, self.verbose) - def load_replay_buffer(self, path: Union[str, pathlib.Path, io.BufferedIOBase]) -> None: + def load_replay_buffer( + self, + path: Union[str, pathlib.Path, io.BufferedIOBase], + truncate_last_traj: bool = True, + ) -> None: """ Load a replay buffer from a pickle file. :param path: Path to the pickled replay buffer. + :param truncate_last_traj: When using ``HerReplayBuffer`` with online sampling: + If set to ``True``, we assume that the last trajectory in the replay buffer was finished + (and truncate it). + If set to ``False``, we assume that we continue the same trajectory (same episode). """ self.replay_buffer = load_from_pkl(path, self.verbose) assert isinstance(self.replay_buffer, ReplayBuffer), "The replay buffer must inherit from ReplayBuffer class" + # Backward compatibility with SB3 < 2.1.0 replay buffer + # Keep old behavior: do not handle timeout termination separately + if not hasattr(self.replay_buffer, "handle_timeout_termination"): # pragma: no cover + self.replay_buffer.handle_timeout_termination = False + self.replay_buffer.timeouts = np.zeros_like(self.replay_buffer.dones) + + if isinstance(self.replay_buffer, HerReplayBuffer): + assert self.env is not None, "You must pass an environment at load time when using `HerReplayBuffer`" + self.replay_buffer.set_env(self.get_env()) + if truncate_last_traj: + self.replay_buffer.truncate_last_trajectory() + def _setup_learn( self, total_timesteps: int, @@ -221,11 +286,19 @@ def _setup_learn( # Prevent continuity issue by truncating trajectory # when using memory efficient replay buffer # see https://github.com/DLR-RM/stable-baselines3/issues/46 + + # Special case when using HerReplayBuffer, + # the classic replay buffer is inside it when using offline sampling + if isinstance(self.replay_buffer, HerReplayBuffer): + replay_buffer = self.replay_buffer.replay_buffer + else: + replay_buffer = self.replay_buffer + truncate_last_traj = ( self.optimize_memory_usage and reset_num_timesteps - and self.replay_buffer is not None - and (self.replay_buffer.full or self.replay_buffer.pos > 0) + and replay_buffer is not None + and (replay_buffer.full or replay_buffer.pos > 0) ) if truncate_last_traj: @@ -236,11 +309,18 @@ def _setup_learn( "to avoid that issue." ) # Go to the previous index - pos = (self.replay_buffer.pos - 1) % self.replay_buffer.buffer_size - self.replay_buffer.dones[pos] = True + pos = (replay_buffer.pos - 1) % replay_buffer.buffer_size + replay_buffer.dones[pos] = True return super()._setup_learn( - total_timesteps, eval_env, callback, eval_freq, n_eval_episodes, log_path, reset_num_timesteps, tb_log_name + total_timesteps, + eval_env, + callback, + eval_freq, + n_eval_episodes, + log_path, + reset_num_timesteps, + tb_log_name, ) def learn( @@ -257,7 +337,14 @@ def learn( ) -> "OffPolicyAlgorithm": total_timesteps, callback = self._setup_learn( - total_timesteps, eval_env, callback, eval_freq, n_eval_episodes, eval_log_path, reset_num_timesteps, tb_log_name + total_timesteps, + eval_env, + callback, + eval_freq, + n_eval_episodes, + eval_log_path, + reset_num_timesteps, + tb_log_name, ) callback.on_training_start(locals(), globals()) @@ -341,13 +428,14 @@ def _dump_logs(self) -> None: """ Write log. """ - fps = int(self.num_timesteps / (time.time() - self.start_time)) + time_elapsed = time.time() - self.start_time + fps = int(self.num_timesteps / (time_elapsed + 1e-8)) logger.record("time/episodes", self._episode_num, exclude="tensorboard") if len(self.ep_info_buffer) > 0 and len(self.ep_info_buffer[0]) > 0: logger.record("rollout/ep_rew_mean", safe_mean([ep_info["r"] for ep_info in self.ep_info_buffer])) logger.record("rollout/ep_len_mean", safe_mean([ep_info["l"] for ep_info in self.ep_info_buffer])) logger.record("time/fps", fps) - logger.record("time/time_elapsed", int(time.time() - self.start_time), exclude="tensorboard") + logger.record("time/time_elapsed", int(time_elapsed), exclude="tensorboard") logger.record("time/total timesteps", self.num_timesteps, exclude="tensorboard") if self.use_sde: logger.record("train/std", (self.actor.get_std()).mean().item()) @@ -386,7 +474,7 @@ def _store_transition( :param reward: reward for the current transition :param done: Termination signal :param infos: List of additional information about the transition. - It contains the terminal observations. + It may contain the terminal observations and information about timeout. """ # Store only the unnormalized version if self._vec_normalize_env is not None: @@ -406,7 +494,14 @@ def _store_transition( else: next_obs = new_obs_ - replay_buffer.add(self._last_original_obs, next_obs, buffer_action, reward_, done) + replay_buffer.add( + self._last_original_obs, + next_obs, + buffer_action, + reward_, + done, + infos, + ) self._last_obs = new_obs # Save the unnormalized observation diff --git a/stable_baselines3/common/on_policy_algorithm.py b/stable_baselines3/common/on_policy_algorithm.py index 016954db8..924788dbd 100644 --- a/stable_baselines3/common/on_policy_algorithm.py +++ b/stable_baselines3/common/on_policy_algorithm.py @@ -7,11 +7,11 @@ from stable_baselines3.common import logger from stable_baselines3.common.base_class import BaseAlgorithm -from stable_baselines3.common.buffers import RolloutBuffer +from stable_baselines3.common.buffers import DictRolloutBuffer, RolloutBuffer from stable_baselines3.common.callbacks import BaseCallback from stable_baselines3.common.policies import ActorCriticPolicy, BasePolicy from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule -from stable_baselines3.common.utils import safe_mean +from stable_baselines3.common.utils import obs_as_tensor, safe_mean from stable_baselines3.common.vec_env import VecEnv @@ -107,7 +107,9 @@ def _setup_model(self) -> None: self._setup_lr_schedule() self.set_random_seed(self.seed) - self.rollout_buffer = RolloutBuffer( + buffer_cls = DictRolloutBuffer if isinstance(self.observation_space, gym.spaces.Dict) else RolloutBuffer + + self.rollout_buffer = buffer_cls( self.n_steps, self.observation_space, self.action_space, @@ -126,7 +128,11 @@ def _setup_model(self) -> None: self.policy = self.policy.to(self.device) def collect_rollouts( - self, env: VecEnv, callback: BaseCallback, rollout_buffer: RolloutBuffer, n_rollout_steps: int + self, + env: VecEnv, + callback: BaseCallback, + rollout_buffer: RolloutBuffer, + n_rollout_steps: int, ) -> bool: """ Collect experiences using the current policy and fill a ``RolloutBuffer``. @@ -156,8 +162,8 @@ def collect_rollouts( self.policy.reset_noise(env.num_envs) with th.no_grad(): - # Convert to pytorch tensor - obs_tensor = th.as_tensor(self._last_obs).to(self.device) + # Convert to pytorch tensor or to TensorDict + obs_tensor = obs_as_tensor(self._last_obs, self.device) actions, values, log_probs = self.policy.forward(obs_tensor) actions = actions.cpu().numpy() @@ -188,7 +194,7 @@ def collect_rollouts( with th.no_grad(): # Compute value for the last timestep - obs_tensor = th.as_tensor(new_obs).to(self.device) + obs_tensor = obs_as_tensor(new_obs, self.device) _, values, _ = self.policy.forward(obs_tensor) rollout_buffer.compute_returns_and_advantage(last_values=values, dones=dones) diff --git a/stable_baselines3/common/policies.py b/stable_baselines3/common/policies.py index 234e7243c..8b6f3e649 100644 --- a/stable_baselines3/common/policies.py +++ b/stable_baselines3/common/policies.py @@ -1,9 +1,10 @@ """Policies: abstract base class and concrete implementations.""" import collections +import copy from abc import ABC, abstractmethod from functools import partial -from typing import Any, Dict, List, Optional, Tuple, Type, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union import gym import numpy as np @@ -19,11 +20,17 @@ StateDependentNoiseDistribution, make_proba_distribution, ) -from stable_baselines3.common.preprocessing import get_action_dim, maybe_transpose, preprocess_obs -from stable_baselines3.common.torch_layers import BaseFeaturesExtractor, FlattenExtractor, MlpExtractor, NatureCNN, create_mlp +from stable_baselines3.common.preprocessing import get_action_dim, is_image_space, maybe_transpose, preprocess_obs +from stable_baselines3.common.torch_layers import ( + BaseFeaturesExtractor, + CombinedExtractor, + FlattenExtractor, + MlpExtractor, + NatureCNN, + create_mlp, +) from stable_baselines3.common.type_aliases import Schedule -from stable_baselines3.common.utils import get_device, is_vectorized_observation -from stable_baselines3.common.vec_env.obs_dict_wrapper import ObsDictWrapper +from stable_baselines3.common.utils import get_device, is_vectorized_observation, obs_as_tensor class BaseModel(nn.Module, ABC): @@ -84,7 +91,9 @@ def forward(self, *args, **kwargs): pass def _update_features_extractor( - self, net_kwargs: Dict[str, Any], features_extractor: Optional[BaseFeaturesExtractor] = None + self, + net_kwargs: Dict[str, Any], + features_extractor: Optional[BaseFeaturesExtractor] = None, ) -> Dict[str, Any]: """ Update the network keyword arguments and create a new features extractor object if needed. @@ -237,7 +246,7 @@ def _predict(self, observation: th.Tensor, deterministic: bool = False) -> th.Te def predict( self, - observation: np.ndarray, + observation: Union[np.ndarray, Dict[str, np.ndarray]], state: Optional[np.ndarray] = None, mask: Optional[np.ndarray] = None, deterministic: bool = False, @@ -258,20 +267,37 @@ def predict( # state = self.initial_state # if mask is None: # mask = [False for _ in range(self.n_envs)] + + vectorized_env = False if isinstance(observation, dict): - observation = ObsDictWrapper.convert_dict(observation) + # need to copy the dict as the dict in VecFrameStack will become a torch tensor + observation = copy.deepcopy(observation) + for key, obs in observation.items(): + obs_space = self.observation_space.spaces[key] + if is_image_space(obs_space): + obs_ = maybe_transpose(obs, obs_space) + else: + obs_ = np.array(obs) + vectorized_env = vectorized_env or is_vectorized_observation(obs_, obs_space) + # Add batch dimension if needed + observation[key] = obs_.reshape((-1,) + self.observation_space[key].shape) + + elif is_image_space(self.observation_space): + # Handle the different cases for images + # as PyTorch use channel first format + observation = maybe_transpose(observation, self.observation_space) + else: observation = np.array(observation) - # Handle the different cases for images - # as PyTorch use channel first format - observation = maybe_transpose(observation, self.observation_space) - - vectorized_env = is_vectorized_observation(observation, self.observation_space) + if not isinstance(observation, dict): + # Dict obs need to be handled separately + vectorized_env = is_vectorized_observation(observation, self.observation_space) + # Add batch dimension if needed + observation = observation.reshape((-1,) + self.observation_space.shape) - observation = observation.reshape((-1,) + self.observation_space.shape) + observation = obs_as_tensor(observation, self.device) - observation = th.as_tensor(observation).to(self.device) with th.no_grad(): actions = self._predict(observation, deterministic=deterministic) # Convert to numpy @@ -388,10 +414,10 @@ def __init__( # Default network architecture, from stable-baselines if net_arch is None: - if features_extractor_class == FlattenExtractor: - net_arch = [dict(pi=[64, 64], vf=[64, 64])] - else: + if features_extractor_class == NatureCNN: net_arch = [] + else: + net_arch = [dict(pi=[64, 64], vf=[64, 64])] self.net_arch = net_arch self.activation_fn = activation_fn @@ -465,7 +491,10 @@ def _build_mlp_extractor(self) -> None: # net_arch here is an empty list and mlp_extractor does not # really contain any layers (acts like an identity module). self.mlp_extractor = MlpExtractor( - self.features_dim, net_arch=self.net_arch, activation_fn=self.activation_fn, device=self.device + self.features_dim, + net_arch=self.net_arch, + activation_fn=self.activation_fn, + device=self.device, ) def _build(self, lr_schedule: Schedule) -> None: @@ -688,6 +717,81 @@ def __init__( ) +class MultiInputActorCriticPolicy(ActorCriticPolicy): + """ + MultiInputActorClass policy class for actor-critic algorithms (has both policy and value prediction). + Used by A2C, PPO and the likes. + + :param observation_space: Observation space (Tuple) + :param action_space: Action space + :param lr_schedule: Learning rate schedule (could be constant) + :param net_arch: The specification of the policy and value networks. + :param activation_fn: Activation function + :param ortho_init: Whether to use or not orthogonal initialization + :param use_sde: Whether to use State Dependent Exploration or not + :param log_std_init: Initial value for the log standard deviation + :param full_std: Whether to use (n_features x n_actions) parameters + for the std instead of only (n_features,) when using gSDE + :param sde_net_arch: Network architecture for extracting features + when using gSDE. If None, the latent features from the policy will be used. + Pass an empty list to use the states as features. + :param use_expln: Use ``expln()`` function instead of ``exp()`` to ensure + a positive standard deviation (cf paper). It allows to keep variance + above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough. + :param squash_output: Whether to squash the output using a tanh function, + this allows to ensure boundaries when using gSDE. + :param features_extractor_class: Uses the CombinedExtractor + :param features_extractor_kwargs: Keyword arguments + to pass to the feature extractor. + :param normalize_images: Whether to normalize images or not, + dividing by 255.0 (True by default) + :param optimizer_class: The optimizer to use, + ``th.optim.Adam`` by default + :param optimizer_kwargs: Additional keyword arguments, + excluding the learning rate, to pass to the optimizer + """ + + def __init__( + self, + observation_space: gym.spaces.Dict, + action_space: gym.spaces.Space, + lr_schedule: Callable, + net_arch: Optional[List[Union[int, Dict[str, List[int]]]]] = None, + activation_fn: Type[nn.Module] = nn.Tanh, + ortho_init: bool = True, + use_sde: bool = False, + log_std_init: float = 0.0, + full_std: bool = True, + sde_net_arch: Optional[List[int]] = None, + use_expln: bool = False, + squash_output: bool = False, + features_extractor_class: Type[BaseFeaturesExtractor] = CombinedExtractor, + features_extractor_kwargs: Optional[Dict[str, Any]] = None, + normalize_images: bool = True, + optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam, + optimizer_kwargs: Optional[Dict[str, Any]] = None, + ): + super(MultiInputActorCriticPolicy, self).__init__( + observation_space, + action_space, + lr_schedule, + net_arch, + activation_fn, + ortho_init, + use_sde, + log_std_init, + full_std, + sde_net_arch, + use_expln, + squash_output, + features_extractor_class, + features_extractor_kwargs, + normalize_images, + optimizer_class, + optimizer_kwargs, + ) + + class ContinuousCritic(BaseModel): """ Critic network(s) for DDPG/SAC/TD3. diff --git a/stable_baselines3/common/preprocessing.py b/stable_baselines3/common/preprocessing.py index 63052f350..7aaeb12ba 100644 --- a/stable_baselines3/common/preprocessing.py +++ b/stable_baselines3/common/preprocessing.py @@ -1,5 +1,5 @@ import warnings -from typing import Tuple +from typing import Dict, Tuple, Union import numpy as np import torch as th @@ -24,7 +24,11 @@ def is_image_space_channels_first(observation_space: spaces.Box) -> bool: return smallest_dimension == 0 -def is_image_space(observation_space: spaces.Space, channels_last: bool = True, check_channels: bool = False) -> bool: +def is_image_space( + observation_space: spaces.Space, + channels_last: bool = True, + check_channels: bool = False, +) -> bool: """ Check if a observation space has the shape, limits and dtype of a valid image. @@ -81,7 +85,11 @@ def maybe_transpose(observation: np.ndarray, observation_space: spaces.Space) -> return observation -def preprocess_obs(obs: th.Tensor, observation_space: spaces.Space, normalize_images: bool = True) -> th.Tensor: +def preprocess_obs( + obs: th.Tensor, + observation_space: spaces.Space, + normalize_images: bool = True, +) -> Union[th.Tensor, Dict[str, th.Tensor]]: """ Preprocess observation to be to a neural network. For images, it normalizes the values by dividing them by 255 (to have values in [0, 1]) @@ -115,11 +123,20 @@ def preprocess_obs(obs: th.Tensor, observation_space: spaces.Space, normalize_im elif isinstance(observation_space, spaces.MultiBinary): return obs.float() + elif isinstance(observation_space, spaces.Dict): + # Do not modify by reference the original observation + preprocessed_obs = {} + for key, _obs in obs.items(): + preprocessed_obs[key] = preprocess_obs(_obs, observation_space[key], normalize_images=normalize_images) + return preprocessed_obs + else: raise NotImplementedError(f"Preprocessing not implemented for {observation_space}") -def get_obs_shape(observation_space: spaces.Space) -> Tuple[int, ...]: +def get_obs_shape( + observation_space: spaces.Space, +) -> Union[Tuple[int, ...], Dict[str, Tuple[int, ...]]]: """ Get the shape of the observation (useful for the buffers). @@ -137,6 +154,9 @@ def get_obs_shape(observation_space: spaces.Space) -> Tuple[int, ...]: elif isinstance(observation_space, spaces.MultiBinary): # Number of binary features return (int(observation_space.n),) + elif isinstance(observation_space, spaces.Dict): + return {key: get_obs_shape(subspace) for (key, subspace) in observation_space.spaces.items()} + else: raise NotImplementedError(f"{observation_space} observation space is not supported") @@ -146,6 +166,8 @@ def get_flattened_obs_dim(observation_space: spaces.Space) -> int: Get the dimension of the observation space when flattened. It does not apply to image observation space. + Used by the ``FlattenExtractor`` to compute the input shape. + :param observation_space: :return: """ @@ -178,3 +200,20 @@ def get_action_dim(action_space: spaces.Space) -> int: return int(action_space.n) else: raise NotImplementedError(f"{action_space} action space is not supported") + + +def check_for_nested_spaces(obs_space: spaces.Space): + """ + Make sure the observation space does not have nested spaces (Dicts/Tuples inside Dicts/Tuples). + If so, raise an Exception informing that there is no support for this. + + :param obs_space: an observation space + :return: + """ + if isinstance(obs_space, (spaces.Dict, spaces.Tuple)): + sub_spaces = obs_space.spaces.values() if isinstance(obs_space, spaces.Dict) else obs_space.spaces + for sub_space in sub_spaces: + if isinstance(sub_space, (spaces.Dict, spaces.Tuple)): + raise NotImplementedError( + "Nested observation spaces are not supported (Tuple/Dict space inside Tuple/Dict space)." + ) diff --git a/stable_baselines3/common/torch_layers.py b/stable_baselines3/common/torch_layers.py index 1f0c28dd4..16088f3c4 100644 --- a/stable_baselines3/common/torch_layers.py +++ b/stable_baselines3/common/torch_layers.py @@ -6,6 +6,7 @@ from torch import nn from stable_baselines3.common.preprocessing import get_flattened_obs_dim, is_image_space +from stable_baselines3.common.type_aliases import TensorDict from stable_baselines3.common.utils import get_device @@ -66,7 +67,7 @@ def __init__(self, observation_space: gym.spaces.Box, features_dim: int = 512): assert is_image_space(observation_space), ( "You should use NatureCNN " f"only with images not with {observation_space}\n" - "(you are probably using `CnnPolicy` instead of `MlpPolicy`)\n" + "(you are probably using `CnnPolicy` instead of `MlpPolicy` or `MultiInputPolicy`)\n" "If you are using a custom environment,\n" "please check it using our env checker:\n" "https://stable-baselines3.readthedocs.io/en/master/common/env_checker.html" @@ -93,7 +94,11 @@ def forward(self, observations: th.Tensor) -> th.Tensor: def create_mlp( - input_dim: int, output_dim: int, net_arch: List[int], activation_fn: Type[nn.Module] = nn.ReLU, squash_output: bool = False + input_dim: int, + output_dim: int, + net_arch: List[int], + activation_fn: Type[nn.Module] = nn.ReLU, + squash_output: bool = False, ) -> List[nn.Module]: """ Create a multi layer perceptron (MLP), which is @@ -172,11 +177,10 @@ def __init__( # Iterate through the shared layers and build the shared parts of the network for layer in net_arch: if isinstance(layer, int): # Check that this is a shared layer - layer_size = layer # TODO: give layer a meaningful name - shared_net.append(nn.Linear(last_layer_dim_shared, layer_size)) + shared_net.append(nn.Linear(last_layer_dim_shared, layer)) # add linear of size layer shared_net.append(activation_fn()) - last_layer_dim_shared = layer_size + last_layer_dim_shared = layer else: assert isinstance(layer, dict), "Error: the net_arch list can only contain ints and dicts" if "pi" in layer: @@ -224,6 +228,47 @@ def forward(self, features: th.Tensor) -> Tuple[th.Tensor, th.Tensor]: return self.policy_net(shared_latent), self.value_net(shared_latent) +class CombinedExtractor(BaseFeaturesExtractor): + """ + Combined feature extractor for Dict observation spaces. + Builds a feature extractor for each key of the space. Input from each space + is fed through a separate submodule (CNN or MLP, depending on input shape), + the output features are concatenated and fed through additional MLP network ("combined"). + + :param observation_space: + :param cnn_output_dim: Number of features to output from each CNN submodule(s). Defaults to + 256 to avoid exploding network sizes. + """ + + def __init__(self, observation_space: gym.spaces.Dict, cnn_output_dim: int = 256): + # TODO we do not know features-dim here before going over all the items, so put something there. This is dirty! + super(CombinedExtractor, self).__init__(observation_space, features_dim=1) + + extractors = {} + + total_concat_size = 0 + for key, subspace in observation_space.spaces.items(): + if is_image_space(subspace): + extractors[key] = NatureCNN(subspace, features_dim=cnn_output_dim) + total_concat_size += cnn_output_dim + else: + # The observation key is a vector, flatten it if needed + extractors[key] = nn.Flatten() + total_concat_size += get_flattened_obs_dim(subspace) + + self.extractors = nn.ModuleDict(extractors) + + # Update the features dim manually + self._features_dim = total_concat_size + + def forward(self, observations: TensorDict) -> th.Tensor: + encoded_tensor_list = [] + + for key, extractor in self.extractors.items(): + encoded_tensor_list.append(extractor(observations[key])) + return th.cat(encoded_tensor_list, dim=1) + + def get_actor_critic_arch(net_arch: Union[List[int], Dict[str, List[int]]]) -> Tuple[List[int], List[int]]: """ Get the actor and critic network architectures for off-policy actor-critic algorithms (SAC, TD3, DDPG). diff --git a/stable_baselines3/common/type_aliases.py b/stable_baselines3/common/type_aliases.py index 39df57d12..45db9eb43 100644 --- a/stable_baselines3/common/type_aliases.py +++ b/stable_baselines3/common/type_aliases.py @@ -12,9 +12,10 @@ GymEnv = Union[gym.Env, vec_env.VecEnv] GymObs = Union[Tuple, Dict[str, Any], np.ndarray, int] GymStepReturn = Tuple[GymObs, float, bool, Dict] -TensorDict = Dict[str, th.Tensor] +TensorDict = Dict[Union[str, int], th.Tensor] OptimizerStateDict = Dict[str, Any] MaybeCallback = Union[None, Callable, List[callbacks.BaseCallback], callbacks.BaseCallback] + # A schedule takes the remaining progress as input # and ouputs a scalar (e.g. learning rate, clip range, ...) Schedule = Callable[[float], float] @@ -29,6 +30,15 @@ class RolloutBufferSamples(NamedTuple): returns: th.Tensor +class DictRolloutBufferSamples(RolloutBufferSamples): + observations: TensorDict + actions: th.Tensor + old_values: th.Tensor + old_log_prob: th.Tensor + advantages: th.Tensor + returns: th.Tensor + + class ReplayBufferSamples(NamedTuple): observations: th.Tensor actions: th.Tensor @@ -37,6 +47,14 @@ class ReplayBufferSamples(NamedTuple): rewards: th.Tensor +class DictReplayBufferSamples(ReplayBufferSamples): + observations: TensorDict + actions: th.Tensor + next_observations: th.Tensor + dones: th.Tensor + rewards: th.Tensor + + class RolloutReturn(NamedTuple): episode_reward: float episode_timesteps: int diff --git a/stable_baselines3/common/utils.py b/stable_baselines3/common/utils.py index 87ef92a54..9b1a6f32c 100644 --- a/stable_baselines3/common/utils.py +++ b/stable_baselines3/common/utils.py @@ -3,7 +3,7 @@ import random from collections import deque from itertools import zip_longest -from typing import Iterable, Optional, Union +from typing import Dict, Iterable, Optional, Union import gym import numpy as np @@ -16,7 +16,7 @@ SummaryWriter = None from stable_baselines3.common import logger -from stable_baselines3.common.type_aliases import GymEnv, Schedule, TrainFreq, TrainFrequencyUnit +from stable_baselines3.common.type_aliases import GymEnv, Schedule, TensorDict, TrainFreq, TrainFrequencyUnit def set_random_seed(seed: int, using_cuda: bool = False) -> None: @@ -168,7 +168,10 @@ def get_latest_run_id(log_path: Optional[str] = None, log_name: str = "") -> int def configure_logger( - verbose: int = 0, tensorboard_log: Optional[str] = None, tb_log_name: str = "", reset_num_timesteps: bool = True + verbose: int = 0, + tensorboard_log: Optional[str] = None, + tb_log_name: str = "", + reset_num_timesteps: bool = True, ) -> None: """ Configure the logger's outputs. @@ -209,60 +212,140 @@ def check_for_correct_spaces(env: GymEnv, observation_space: gym.spaces.Space, a raise ValueError(f"Action spaces do not match: {action_space} != {env.action_space}") -def is_vectorized_observation(observation: np.ndarray, observation_space: gym.spaces.Space) -> bool: +def is_vectorized_box_observation(observation: np.ndarray, observation_space: gym.spaces.Box) -> bool: """ - For every observation type, detects and validates the shape, + For box observation type, detects and validates the shape, then returns whether or not the observation is vectorized. :param observation: the input observation to validate :param observation_space: the observation space :return: whether the given observation is vectorized or not """ - if isinstance(observation_space, gym.spaces.Box): - if observation.shape == observation_space.shape: - return False - elif observation.shape[1:] == observation_space.shape: - return True - else: - raise ValueError( - f"Error: Unexpected observation shape {observation.shape} for " - + f"Box environment, please use {observation_space.shape} " - + "or (n_env, {}) for the observation shape.".format(", ".join(map(str, observation_space.shape))) - ) - elif isinstance(observation_space, gym.spaces.Discrete): - if observation.shape == (): # A numpy array of a number, has shape empty tuple '()' - return False - elif len(observation.shape) == 1: - return True - else: - raise ValueError( - f"Error: Unexpected observation shape {observation.shape} for " - + "Discrete environment, please use (1,) or (n_env, 1) for the observation shape." - ) + if observation.shape == observation_space.shape: + return False + elif observation.shape[1:] == observation_space.shape: + return True + else: + raise ValueError( + f"Error: Unexpected observation shape {observation.shape} for " + + f"Box environment, please use {observation_space.shape} " + + "or (n_env, {}) for the observation shape.".format(", ".join(map(str, observation_space.shape))) + ) - elif isinstance(observation_space, gym.spaces.MultiDiscrete): - if observation.shape == (len(observation_space.nvec),): - return False - elif len(observation.shape) == 2 and observation.shape[1] == len(observation_space.nvec): - return True - else: - raise ValueError( - f"Error: Unexpected observation shape {observation.shape} for MultiDiscrete " - + f"environment, please use ({len(observation_space.nvec)},) or " - + f"(n_env, {len(observation_space.nvec)}) for the observation shape." - ) - elif isinstance(observation_space, gym.spaces.MultiBinary): - if observation.shape == (observation_space.n,): + +def is_vectorized_discrete_observation(observation: np.ndarray, observation_space: gym.spaces.Discrete) -> bool: + """ + For discrete observation type, detects and validates the shape, + then returns whether or not the observation is vectorized. + + :param observation: the input observation to validate + :param observation_space: the observation space + :return: whether the given observation is vectorized or not + """ + if observation.shape == (): # A numpy array of a number, has shape empty tuple '()' + return False + elif len(observation.shape) == 1: + return True + else: + raise ValueError( + f"Error: Unexpected observation shape {observation.shape} for " + + "Discrete environment, please use (1,) or (n_env, 1) for the observation shape." + ) + + +def is_vectorized_multidiscrete_observation(observation: np.ndarray, observation_space: gym.spaces.MultiDiscrete) -> bool: + """ + For multidiscrete observation type, detects and validates the shape, + then returns whether or not the observation is vectorized. + + :param observation: the input observation to validate + :param observation_space: the observation space + :return: whether the given observation is vectorized or not + """ + if observation.shape == (len(observation_space.nvec),): + return False + elif len(observation.shape) == 2 and observation.shape[1] == len(observation_space.nvec): + return True + else: + raise ValueError( + f"Error: Unexpected observation shape {observation.shape} for MultiDiscrete " + + f"environment, please use ({len(observation_space.nvec)},) or " + + f"(n_env, {len(observation_space.nvec)}) for the observation shape." + ) + + +def is_vectorized_multibinary_observation(observation: np.ndarray, observation_space: gym.spaces.MultiBinary) -> bool: + """ + For multibinary observation type, detects and validates the shape, + then returns whether or not the observation is vectorized. + + :param observation: the input observation to validate + :param observation_space: the observation space + :return: whether the given observation is vectorized or not + """ + if observation.shape == (observation_space.n,): + return False + elif len(observation.shape) == 2 and observation.shape[1] == observation_space.n: + return True + else: + raise ValueError( + f"Error: Unexpected observation shape {observation.shape} for MultiBinary " + + f"environment, please use ({observation_space.n},) or " + + f"(n_env, {observation_space.n}) for the observation shape." + ) + + +def is_vectorized_dict_observation(observation: np.ndarray, observation_space: gym.spaces.Dict) -> bool: + """ + For dict observation type, detects and validates the shape, + then returns whether or not the observation is vectorized. + + :param observation: the input observation to validate + :param observation_space: the observation space + :return: whether the given observation is vectorized or not + """ + for key, subspace in observation_space.spaces.items(): + if observation[key].shape == subspace.shape: return False - elif len(observation.shape) == 2 and observation.shape[1] == observation_space.n: - return True - else: - raise ValueError( - f"Error: Unexpected observation shape {observation.shape} for MultiBinary " - + f"environment, please use ({observation_space.n},) or " - + f"(n_env, {observation_space.n}) for the observation shape." - ) + + all_good = True + + for key, subspace in observation_space.spaces.items(): + if observation[key].shape[1:] != subspace.shape: + all_good = False + break + + if all_good: + return True else: + raise ValueError( + f"Error: Unexpected observation shape {observation.shape} for " + + f"Tuple environment, please use {(obs.shape for obs in observation_space.spaces)} " + ) + + +def is_vectorized_observation(observation: np.ndarray, observation_space: gym.spaces.Space) -> bool: + """ + For every observation type, detects and validates the shape, + then returns whether or not the observation is vectorized. + + :param observation: the input observation to validate + :param observation_space: the observation space + :return: whether the given observation is vectorized or not + """ + + is_vec_obs_func_dict = { + gym.spaces.Box: is_vectorized_box_observation, + gym.spaces.Discrete: is_vectorized_discrete_observation, + gym.spaces.MultiDiscrete: is_vectorized_multidiscrete_observation, + gym.spaces.MultiBinary: is_vectorized_multibinary_observation, + gym.spaces.Dict: is_vectorized_dict_observation, + } + + try: + is_vec_obs_func = is_vec_obs_func_dict[type(observation_space)] + return is_vec_obs_func(observation, observation_space) + except KeyError: raise ValueError( "Error: Cannot determine if the observation is vectorized " + f" with the space type {observation_space}." ) @@ -297,7 +380,11 @@ def zip_strict(*iterables: Iterable) -> Iterable: yield combo -def polyak_update(params: Iterable[th.nn.Parameter], target_params: Iterable[th.nn.Parameter], tau: float) -> None: +def polyak_update( + params: Iterable[th.nn.Parameter], + target_params: Iterable[th.nn.Parameter], + tau: float, +) -> None: """ Perform a Polyak average update on ``target_params`` using ``params``: target parameters are slowly updated towards the main parameters. @@ -320,6 +407,24 @@ def polyak_update(params: Iterable[th.nn.Parameter], target_params: Iterable[th. th.add(target_param.data, param.data, alpha=tau, out=target_param.data) +def obs_as_tensor( + obs: Union[np.ndarray, Dict[Union[str, int], np.ndarray]], device: th.device +) -> Union[th.Tensor, TensorDict]: + """ + Moves the observation to the given device. + + :param obs: + :param device: PyTorch device + :return: PyTorch tensor of the observation on a desired device. + """ + if isinstance(obs, np.ndarray): + return th.as_tensor(obs).to(device) + elif isinstance(obs, dict): + return {key: th.as_tensor(_obs).to(device) for (key, _obs) in obs.items()} + else: + raise Exception(f"Unrecognized type of observation {type(obs)}") + + def should_collect_more_steps( train_freq: TrainFreq, num_collected_steps: int, diff --git a/stable_baselines3/common/vec_env/__init__.py b/stable_baselines3/common/vec_env/__init__.py index 8d143cd49..37ebc364d 100644 --- a/stable_baselines3/common/vec_env/__init__.py +++ b/stable_baselines3/common/vec_env/__init__.py @@ -5,6 +5,7 @@ from stable_baselines3.common.vec_env.base_vec_env import CloudpickleWrapper, VecEnv, VecEnvWrapper from stable_baselines3.common.vec_env.dummy_vec_env import DummyVecEnv +from stable_baselines3.common.vec_env.stacked_observations import StackedDictObservations, StackedObservations from stable_baselines3.common.vec_env.subproc_vec_env import SubprocVecEnv from stable_baselines3.common.vec_env.vec_check_nan import VecCheckNan from stable_baselines3.common.vec_env.vec_extract_dict_obs import VecExtractDictObs diff --git a/stable_baselines3/common/vec_env/obs_dict_wrapper.py b/stable_baselines3/common/vec_env/obs_dict_wrapper.py deleted file mode 100644 index d07ad2402..000000000 --- a/stable_baselines3/common/vec_env/obs_dict_wrapper.py +++ /dev/null @@ -1,68 +0,0 @@ -from typing import Dict - -import numpy as np -from gym import spaces - -from stable_baselines3.common.vec_env import VecEnv, VecEnvWrapper - - -class ObsDictWrapper(VecEnvWrapper): - """ - Wrapper for a VecEnv which overrides the observation space for Hindsight Experience Replay to support dict observations. - - :param env: The vectorized environment to wrap. - """ - - def __init__(self, venv: VecEnv): - super(ObsDictWrapper, self).__init__(venv, venv.observation_space, venv.action_space) - - self.venv = venv - - self.spaces = list(venv.observation_space.spaces.values()) - - # get dimensions of observation and goal - if isinstance(self.spaces[0], spaces.Discrete): - self.obs_dim = 1 - self.goal_dim = 1 - else: - self.obs_dim = venv.observation_space.spaces["observation"].shape[0] - self.goal_dim = venv.observation_space.spaces["achieved_goal"].shape[0] - - # new observation space with concatenated observation and (desired) goal - # for the different types of spaces - if isinstance(self.spaces[0], spaces.Box): - low_values = np.concatenate( - [venv.observation_space.spaces["observation"].low, venv.observation_space.spaces["desired_goal"].low] - ) - high_values = np.concatenate( - [venv.observation_space.spaces["observation"].high, venv.observation_space.spaces["desired_goal"].high] - ) - self.observation_space = spaces.Box(low_values, high_values, dtype=np.float32) - elif isinstance(self.spaces[0], spaces.MultiBinary): - total_dim = self.obs_dim + self.goal_dim - self.observation_space = spaces.MultiBinary(total_dim) - elif isinstance(self.spaces[0], spaces.Discrete): - dimensions = [venv.observation_space.spaces["observation"].n, venv.observation_space.spaces["desired_goal"].n] - self.observation_space = spaces.MultiDiscrete(dimensions) - else: - raise NotImplementedError(f"{type(self.spaces[0])} space is not supported") - - def reset(self): - return self.venv.reset() - - def step_wait(self): - return self.venv.step_wait() - - @staticmethod - def convert_dict( - observation_dict: Dict[str, np.ndarray], observation_key: str = "observation", goal_key: str = "desired_goal" - ) -> np.ndarray: - """ - Concatenate observation and (desired) goal of observation dict. - - :param observation_dict: Dictionary with observation. - :param observation_key: Key of observation in dictionary. - :param goal_key: Key of (desired) goal in dictionary. - :return: Concatenated observation. - """ - return np.concatenate([observation_dict[observation_key], observation_dict[goal_key]], axis=-1) diff --git a/stable_baselines3/common/vec_env/stacked_observations.py b/stable_baselines3/common/vec_env/stacked_observations.py new file mode 100644 index 000000000..513d84a22 --- /dev/null +++ b/stable_baselines3/common/vec_env/stacked_observations.py @@ -0,0 +1,265 @@ +import warnings +from typing import Any, Dict, List, Optional, Tuple, Union + +import numpy as np +from gym import spaces + +from stable_baselines3.common.preprocessing import is_image_space, is_image_space_channels_first + + +class StackedObservations(object): + """ + Frame stacking wrapper for data. + + Dimension to stack over is either first (channels-first) or + last (channels-last), which is detected automatically using + ``common.preprocessing.is_image_space_channels_first`` if + observation is an image space. + + :param num_envs: number of environments + :param n_stack: Number of frames to stack + :param channels_order: If "first", stack on first image dimension. If "last", stack on last dimension. + If None, automatically detect channel to stack over in case of image observation or default to "last" (default). + """ + + def __init__( + self, + num_envs: int, + n_stack: int, + observation_space: spaces.Space, + channels_order: Optional[str] = None, + ): + + self.n_stack = n_stack + ( + self.channels_first, + self.stack_dimension, + self.stackedobs, + self.repeat_axis, + ) = self.compute_stacking(num_envs, n_stack, observation_space, channels_order) + super().__init__() + + @staticmethod + def compute_stacking( + num_envs: int, + n_stack: int, + observation_space: spaces.Box, + channels_order: Optional[str] = None, + ) -> Tuple[bool, int, np.ndarray, int]: + """ + Calculates the parameters in order to stack observations + + :param num_envs: Number of environments in the stack + :param n_stack: The number of observations to stack + :param observation_space: The observation space + :param channels_order: The order of the channels + :return: tuple of channels_first, stack_dimension, stackedobs, repeat_axis + """ + channels_first = False + if channels_order is None: + # Detect channel location automatically for images + if is_image_space(observation_space): + channels_first = is_image_space_channels_first(observation_space) + else: + # Default behavior for non-image space, stack on the last axis + channels_first = False + else: + assert channels_order in { + "last", + "first", + }, "`channels_order` must be one of following: 'last', 'first'" + + channels_first = channels_order == "first" + + # This includes the vec-env dimension (first) + stack_dimension = 1 if channels_first else -1 + repeat_axis = 0 if channels_first else -1 + low = np.repeat(observation_space.low, n_stack, axis=repeat_axis) + stackedobs = np.zeros((num_envs,) + low.shape, low.dtype) + return channels_first, stack_dimension, stackedobs, repeat_axis + + def stack_observation_space(self, observation_space: spaces.Box) -> spaces.Box: + """ + Given an observation space, returns a new observation space with stacked observations + + :return: New observation space with stacked dimensions + """ + low = np.repeat(observation_space.low, self.n_stack, axis=self.repeat_axis) + high = np.repeat(observation_space.high, self.n_stack, axis=self.repeat_axis) + return spaces.Box(low=low, high=high, dtype=observation_space.dtype) + + def reset(self, observation: np.ndarray) -> np.ndarray: + """ + Resets the stackedobs, adds the reset observation to the stack, and returns the stack + + :param observation: Reset observation + :return: The stacked reset observation + """ + self.stackedobs[...] = 0 + if self.channels_first: + self.stackedobs[:, -observation.shape[self.stack_dimension] :, ...] = observation + else: + self.stackedobs[..., -observation.shape[self.stack_dimension] :] = observation + return self.stackedobs + + def update( + self, + observations: np.ndarray, + dones: np.ndarray, + infos: List[Dict[str, Any]], + ) -> Tuple[np.ndarray, List[Dict[str, Any]]]: + """ + Adds the observations to the stack and uses the dones to update the infos. + + :param observations: numpy array of observations + :param dones: numpy array of done info + :param infos: numpy array of info dicts + :return: tuple of the stacked observations and the updated infos + """ + stack_ax_size = observations.shape[self.stack_dimension] + self.stackedobs = np.roll(self.stackedobs, shift=-stack_ax_size, axis=self.stack_dimension) + for i, done in enumerate(dones): + if done: + if "terminal_observation" in infos[i]: + old_terminal = infos[i]["terminal_observation"] + if self.channels_first: + new_terminal = np.concatenate( + (self.stackedobs[i, :-stack_ax_size, ...], old_terminal), + axis=self.stack_dimension, + ) + else: + new_terminal = np.concatenate( + (self.stackedobs[i, ..., :-stack_ax_size], old_terminal), + axis=self.stack_dimension, + ) + infos[i]["terminal_observation"] = new_terminal + else: + warnings.warn("VecFrameStack wrapping a VecEnv without terminal_observation info") + self.stackedobs[i] = 0 + if self.channels_first: + self.stackedobs[:, -observations.shape[self.stack_dimension] :, ...] = observations + else: + self.stackedobs[..., -observations.shape[self.stack_dimension] :] = observations + return self.stackedobs, infos + + +class StackedDictObservations(StackedObservations): + """ + Frame stacking wrapper for dictionary data. + + Dimension to stack over is either first (channels-first) or + last (channels-last), which is detected automatically using + ``common.preprocessing.is_image_space_channels_first`` if + observation is an image space. + + :param num_envs: number of environments + :param n_stack: Number of frames to stack + :param channels_order: If "first", stack on first image dimension. If "last", stack on last dimension. + If None, automatically detect channel to stack over in case of image observation or default to "last" (default). + """ + + def __init__( + self, + num_envs: int, + n_stack: int, + observation_space: spaces.Dict, + channels_order: Optional[Union[str, Dict[str, str]]] = None, + ): + self.n_stack = n_stack + self.channels_first = {} + self.stack_dimension = {} + self.stackedobs = {} + self.repeat_axis = {} + + for key, subspace in observation_space.spaces.items(): + assert isinstance(subspace, spaces.Box), "StackedDictObservations only works with nested gym.spaces.Box" + if isinstance(channels_order, str) or channels_order is None: + subspace_channel_order = channels_order + else: + subspace_channel_order = channels_order[key] + ( + self.channels_first[key], + self.stack_dimension[key], + self.stackedobs[key], + self.repeat_axis[key], + ) = self.compute_stacking(num_envs, n_stack, subspace, subspace_channel_order) + + def stack_observation_space(self, observation_space: spaces.Dict) -> spaces.Dict: + """ + Returns the stacked verson of a Dict observation space + + :param observation_space: Dict observation space to stack + :return: stacked observation space + """ + spaces_dict = {} + for key, subspace in observation_space.spaces.items(): + low = np.repeat(subspace.low, self.n_stack, axis=self.repeat_axis[key]) + high = np.repeat(subspace.high, self.n_stack, axis=self.repeat_axis[key]) + spaces_dict[key] = spaces.Box(low=low, high=high, dtype=subspace.dtype) + return spaces.Dict(spaces=spaces_dict) + + def reset(self, observation: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]: + """ + Resets the stacked observations, adds the reset observation to the stack, and returns the stack + + :param observation: Reset observation + :return: Stacked reset observations + """ + for key, obs in observation.items(): + self.stackedobs[key][...] = 0 + if self.channels_first[key]: + self.stackedobs[key][:, -obs.shape[self.stack_dimension[key]] :, ...] = obs + else: + self.stackedobs[key][..., -obs.shape[self.stack_dimension[key]] :] = obs + return self.stackedobs + + def update( + self, + observations: Dict[str, np.ndarray], + dones: np.ndarray, + infos: List[Dict[str, Any]], + ) -> Tuple[Dict[str, np.ndarray], List[Dict[str, Any]]]: + """ + Adds the observations to the stack and uses the dones to update the infos. + + :param observations: Dict of numpy arrays of observations + :param dones: numpy array of dones + :param infos: dict of infos + :return: tuple of the stacked observations and the updated infos + """ + for key in self.stackedobs.keys(): + stack_ax_size = observations[key].shape[self.stack_dimension[key]] + self.stackedobs[key] = np.roll( + self.stackedobs[key], + shift=-stack_ax_size, + axis=self.stack_dimension[key], + ) + + for i, done in enumerate(dones): + if done: + if "terminal_observation" in infos[i]: + old_terminal = infos[i]["terminal_observation"][key] + if self.channels_first[key]: + new_terminal = np.vstack( + ( + self.stackedobs[key][i, :-stack_ax_size, ...], + old_terminal, + ) + ) + else: + new_terminal = np.concatenate( + ( + self.stackedobs[key][i, ..., :-stack_ax_size], + old_terminal, + ), + axis=self.stack_dimension[key], + ) + infos[i]["terminal_observation"][key] = new_terminal + else: + warnings.warn("VecFrameStack wrapping a VecEnv without terminal_observation info") + self.stackedobs[key][i] = 0 + if self.channels_first[key]: + self.stackedobs[key][:, -stack_ax_size:, ...] = observations[key] + else: + self.stackedobs[key][..., -stack_ax_size:] = observations[key] + return self.stackedobs, infos diff --git a/stable_baselines3/common/vec_env/util.py b/stable_baselines3/common/vec_env/util.py index 8bc7ee44a..859f1ec95 100644 --- a/stable_baselines3/common/vec_env/util.py +++ b/stable_baselines3/common/vec_env/util.py @@ -7,6 +7,7 @@ import gym import numpy as np +from stable_baselines3.common.preprocessing import check_for_nested_spaces from stable_baselines3.common.vec_env.base_vec_env import VecEnvObs @@ -21,22 +22,22 @@ def copy_obs_dict(obs: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]: return OrderedDict([(k, np.copy(v)) for k, v in obs.items()]) -def dict_to_obs(space: gym.spaces.Space, obs_dict: Dict[Any, np.ndarray]) -> VecEnvObs: +def dict_to_obs(obs_space: gym.spaces.Space, obs_dict: Dict[Any, np.ndarray]) -> VecEnvObs: """ Convert an internal representation raw_obs into the appropriate type specified by space. - :param space: an observation space. + :param obs_space: an observation space. :param obs_dict: a dict of numpy arrays. :return: returns an observation of the same type as space. If space is Dict, function is identity; if space is Tuple, converts dict to Tuple; otherwise, space is unstructured and returns the value raw_obs[None]. """ - if isinstance(space, gym.spaces.Dict): + if isinstance(obs_space, gym.spaces.Dict): return obs_dict - elif isinstance(space, gym.spaces.Tuple): - assert len(obs_dict) == len(space.spaces), "size of observation does not match size of observation space" - return tuple((obs_dict[i] for i in range(len(space.spaces)))) + elif isinstance(obs_space, gym.spaces.Tuple): + assert len(obs_dict) == len(obs_space.spaces), "size of observation does not match size of observation space" + return tuple((obs_dict[i] for i in range(len(obs_space.spaces)))) else: assert set(obs_dict.keys()) == {None}, "multiple observation keys for unstructured observation space" return obs_dict[None] @@ -56,6 +57,7 @@ def obs_space_info(obs_space: gym.spaces.Space) -> Tuple[List[str], Dict[Any, Tu shapes: a dict mapping keys to shapes. dtypes: a dict mapping keys to dtypes. """ + check_for_nested_spaces(obs_space) if isinstance(obs_space, gym.spaces.Dict): assert isinstance(obs_space.spaces, OrderedDict), "Dict space must have ordered subspaces" subspaces = obs_space.spaces diff --git a/stable_baselines3/common/vec_env/vec_frame_stack.py b/stable_baselines3/common/vec_env/vec_frame_stack.py index ff9a79652..e06d5125e 100644 --- a/stable_baselines3/common/vec_env/vec_frame_stack.py +++ b/stable_baselines3/common/vec_env/vec_frame_stack.py @@ -1,96 +1,64 @@ -import warnings -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple, Union import numpy as np from gym import spaces -from stable_baselines3.common.preprocessing import is_image_space, is_image_space_channels_first from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvWrapper +from stable_baselines3.common.vec_env.stacked_observations import StackedDictObservations, StackedObservations class VecFrameStack(VecEnvWrapper): """ Frame stacking wrapper for vectorized environment. Designed for image observations. - Dimension to stack over is either first (channels-first) or - last (channels-last), which is detected automatically using - ``common.preprocessing.is_image_space_channels_first`` if - observation is an image space. + Uses the StackedObservations class, or StackedDictObservations depending on the observations space :param venv: the vectorized environment to wrap :param n_stack: Number of frames to stack :param channels_order: If "first", stack on first image dimension. If "last", stack on last dimension. If None, automatically detect channel to stack over in case of image observation or default to "last" (default). + Alternatively channels_order can be a dictionary which can be used with environments with Dict observation spaces """ - def __init__(self, venv: VecEnv, n_stack: int, channels_order: Optional[str] = None): + def __init__(self, venv: VecEnv, n_stack: int, channels_order: Optional[Union[str, Dict[str, str]]] = None): self.venv = venv self.n_stack = n_stack wrapped_obs_space = venv.observation_space - assert isinstance(wrapped_obs_space, spaces.Box), "VecFrameStack only work with gym.spaces.Box observation space" - - if channels_order is None: - # Detect channel location automatically for images - if is_image_space(wrapped_obs_space): - self.channels_first = is_image_space_channels_first(wrapped_obs_space) - else: - # Default behavior for non-image space, stack on the last axis - self.channels_first = False - else: - assert channels_order in {"last", "first"}, "`channels_order` must be one of following: 'last', 'first'" - self.channels_first = channels_order == "first" + if isinstance(wrapped_obs_space, spaces.Box): + assert not isinstance( + channels_order, dict + ), f"Expected None or string for channels_order but received {channels_order}" + self.stackedobs = StackedObservations(venv.num_envs, n_stack, wrapped_obs_space, channels_order) + + elif isinstance(wrapped_obs_space, spaces.Dict): + self.stackedobs = StackedDictObservations(venv.num_envs, n_stack, wrapped_obs_space, channels_order) - # This includes the vec-env dimension (first) - self.stack_dimension = 1 if self.channels_first else -1 - repeat_axis = 0 if self.channels_first else -1 - low = np.repeat(wrapped_obs_space.low, self.n_stack, axis=repeat_axis) - high = np.repeat(wrapped_obs_space.high, self.n_stack, axis=repeat_axis) - self.stackedobs = np.zeros((venv.num_envs,) + low.shape, low.dtype) - observation_space = spaces.Box(low=low, high=high, dtype=venv.observation_space.dtype) + else: + raise Exception("VecFrameStack only works with gym.spaces.Box and gym.spaces.Dict observation spaces") + + observation_space = self.stackedobs.stack_observation_space(wrapped_obs_space) VecEnvWrapper.__init__(self, venv, observation_space=observation_space) - def step_wait(self) -> Tuple[np.ndarray, np.ndarray, np.ndarray, List[Dict[str, Any]]]: + def step_wait( + self, + ) -> Tuple[Union[np.ndarray, Dict[str, np.ndarray]], np.ndarray, np.ndarray, List[Dict[str, Any]],]: + observations, rewards, dones, infos = self.venv.step_wait() - # Let pytype know that observation is not a dict - assert isinstance(observations, np.ndarray) - stack_ax_size = observations.shape[self.stack_dimension] - self.stackedobs = np.roll(self.stackedobs, shift=-stack_ax_size, axis=self.stack_dimension) - for i, done in enumerate(dones): - if done: - if "terminal_observation" in infos[i]: - old_terminal = infos[i]["terminal_observation"] - if self.channels_first: - new_terminal = np.concatenate( - (self.stackedobs[i, :-stack_ax_size, ...], old_terminal), axis=self.stack_dimension - ) - else: - new_terminal = np.concatenate( - (self.stackedobs[i, ..., :-stack_ax_size], old_terminal), axis=self.stack_dimension - ) - infos[i]["terminal_observation"] = new_terminal - else: - warnings.warn("VecFrameStack wrapping a VecEnv without terminal_observation info") - self.stackedobs[i] = 0 - if self.channels_first: - self.stackedobs[:, -observations.shape[self.stack_dimension] :, ...] = observations - else: - self.stackedobs[..., -observations.shape[self.stack_dimension] :] = observations - return self.stackedobs, rewards, dones, infos + observations, infos = self.stackedobs.update(observations, dones, infos) + + return observations, rewards, dones, infos - def reset(self) -> np.ndarray: + def reset(self) -> Union[np.ndarray, Dict[str, np.ndarray]]: """ Reset all environments """ - obs: np.ndarray = self.venv.reset() # pytype:disable=annotation-type-mismatch - self.stackedobs[...] = 0 - if self.channels_first: - self.stackedobs[:, -obs.shape[self.stack_dimension] :, ...] = obs - else: - self.stackedobs[..., -obs.shape[self.stack_dimension] :] = obs - return self.stackedobs + observation = self.venv.reset() # pytype:disable=annotation-type-mismatch + + observation = self.stackedobs.reset(observation) + return observation def close(self) -> None: self.venv.close() diff --git a/stable_baselines3/common/vec_env/vec_transpose.py b/stable_baselines3/common/vec_env/vec_transpose.py index f40b276ac..399fb310e 100644 --- a/stable_baselines3/common/vec_env/vec_transpose.py +++ b/stable_baselines3/common/vec_env/vec_transpose.py @@ -1,7 +1,10 @@ +from copy import deepcopy +from typing import Dict, Union + import numpy as np from gym import spaces -from stable_baselines3.common.preprocessing import is_image_space +from stable_baselines3.common.preprocessing import is_image_space, is_image_space_channels_first from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvStepReturn, VecEnvWrapper @@ -14,22 +17,38 @@ class VecTransposeImage(VecEnvWrapper): """ def __init__(self, venv: VecEnv): - assert is_image_space(venv.observation_space), "The observation space must be an image" + assert is_image_space(venv.observation_space) or isinstance( + venv.observation_space, spaces.dict.Dict + ), "The observation space must be an image or dictionary observation space" - observation_space = self.transpose_space(venv.observation_space) + if isinstance(venv.observation_space, spaces.dict.Dict): + self.image_space_keys = [] + observation_space = deepcopy(venv.observation_space) + for key, space in observation_space.spaces.items(): + if is_image_space(space): + # Keep track of which keys should be transposed later + self.image_space_keys.append(key) + observation_space.spaces[key] = self.transpose_space(space, key) + else: + observation_space = self.transpose_space(venv.observation_space) super(VecTransposeImage, self).__init__(venv, observation_space=observation_space) @staticmethod - def transpose_space(observation_space: spaces.Box) -> spaces.Box: + def transpose_space(observation_space: spaces.Box, key: str = "") -> spaces.Box: """ Transpose an observation space (re-order channels). :param observation_space: + :param key: In case of dictionary space, the key of the observation space. :return: """ + # Sanity checks assert is_image_space(observation_space), "The observation space must be an image" - width, height, channels = observation_space.shape - new_shape = (channels, width, height) + assert not is_image_space_channels_first( + observation_space + ), f"The observation space {key} must follow the channel last convention" + height, width, channels = observation_space.shape + new_shape = (channels, height, width) return spaces.Box(low=0, high=255, shape=new_shape, dtype=observation_space.dtype) @staticmethod @@ -44,6 +63,22 @@ def transpose_image(image: np.ndarray) -> np.ndarray: return np.transpose(image, (2, 0, 1)) return np.transpose(image, (0, 3, 1, 2)) + def transpose_observations(self, observations: Union[np.ndarray, Dict]) -> Union[np.ndarray, Dict]: + """ + Transpose (if needed) and return new observations. + + :param observations: + :return: Transposed observations + """ + if isinstance(observations, dict): + # Avoid modifying the original object in place + observations = deepcopy(observations) + for k in self.image_space_keys: + observations[k] = self.transpose_image(observations[k]) + else: + observations = self.transpose_image(observations) + return observations + def step_wait(self) -> VecEnvStepReturn: observations, rewards, dones, infos = self.venv.step_wait() @@ -52,15 +87,15 @@ def step_wait(self) -> VecEnvStepReturn: if not done: continue if "terminal_observation" in infos[idx]: - infos[idx]["terminal_observation"] = self.transpose_image(infos[idx]["terminal_observation"]) + infos[idx]["terminal_observation"] = self.transpose_observations(infos[idx]["terminal_observation"]) - return self.transpose_image(observations), rewards, dones, infos + return self.transpose_observations(observations), rewards, dones, infos - def reset(self) -> np.ndarray: + def reset(self) -> Union[np.ndarray, Dict]: """ Reset all environments """ - return self.transpose_image(self.venv.reset()) + return self.transpose_observations(self.venv.reset()) def close(self) -> None: self.venv.close() diff --git a/stable_baselines3/ddpg/__init__.py b/stable_baselines3/ddpg/__init__.py index 0b164b2de..262e7f1af 100644 --- a/stable_baselines3/ddpg/__init__.py +++ b/stable_baselines3/ddpg/__init__.py @@ -1,2 +1,2 @@ from stable_baselines3.ddpg.ddpg import DDPG -from stable_baselines3.ddpg.policies import CnnPolicy, MlpPolicy +from stable_baselines3.ddpg.policies import CnnPolicy, MlpPolicy, MultiInputPolicy diff --git a/stable_baselines3/ddpg/ddpg.py b/stable_baselines3/ddpg/ddpg.py index ea0865149..a7de09e39 100644 --- a/stable_baselines3/ddpg/ddpg.py +++ b/stable_baselines3/ddpg/ddpg.py @@ -2,6 +2,7 @@ import torch as th +from stable_baselines3.common.buffers import ReplayBuffer from stable_baselines3.common.noise import ActionNoise from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule @@ -36,6 +37,9 @@ class DDPG(TD3): during the rollout. :param action_noise: the action noise type (None by default), this can help for hard exploration problem. Cf common.noise for the different action noise type. + :param replay_buffer_class: Replay buffer class to use (for instance ``HerReplayBuffer``). + If ``None``, it will be automatically selected. + :param replay_buffer_kwargs: Keyword arguments to pass to the replay buffer on creation. :param optimize_memory_usage: Enable a memory efficient variant of the replay buffer at a cost of more complexity. See https://github.com/DLR-RM/stable-baselines3/issues/37#issuecomment-637501195 @@ -54,7 +58,7 @@ def __init__( policy: Union[str, Type[TD3Policy]], env: Union[GymEnv, str], learning_rate: Union[float, Schedule] = 1e-3, - buffer_size: int = 1000000, + buffer_size: int = 1000000, # 1e6 learning_starts: int = 100, batch_size: int = 100, tau: float = 0.005, @@ -62,6 +66,8 @@ def __init__( train_freq: Union[int, Tuple[int, str]] = (1, "episode"), gradient_steps: int = -1, action_noise: Optional[ActionNoise] = None, + replay_buffer_class: Optional[ReplayBuffer] = None, + replay_buffer_kwargs: Optional[Dict[str, Any]] = None, optimize_memory_usage: bool = False, tensorboard_log: Optional[str] = None, create_eval_env: bool = False, @@ -84,6 +90,8 @@ def __init__( train_freq=train_freq, gradient_steps=gradient_steps, action_noise=action_noise, + replay_buffer_class=replay_buffer_class, + replay_buffer_kwargs=replay_buffer_kwargs, policy_kwargs=policy_kwargs, tensorboard_log=tensorboard_log, verbose=verbose, diff --git a/stable_baselines3/ddpg/policies.py b/stable_baselines3/ddpg/policies.py index 64b166f44..945b7e775 100644 --- a/stable_baselines3/ddpg/policies.py +++ b/stable_baselines3/ddpg/policies.py @@ -1,2 +1,2 @@ # DDPG can be view as a special case of TD3 -from stable_baselines3.td3.policies import CnnPolicy, MlpPolicy # noqa:F401 +from stable_baselines3.td3.policies import CnnPolicy, MlpPolicy, MultiInputPolicy # noqa:F401 diff --git a/stable_baselines3/dqn/__init__.py b/stable_baselines3/dqn/__init__.py index 4ae42872c..f36f96e8a 100644 --- a/stable_baselines3/dqn/__init__.py +++ b/stable_baselines3/dqn/__init__.py @@ -1,2 +1,2 @@ from stable_baselines3.dqn.dqn import DQN -from stable_baselines3.dqn.policies import CnnPolicy, MlpPolicy +from stable_baselines3.dqn.policies import CnnPolicy, MlpPolicy, MultiInputPolicy diff --git a/stable_baselines3/dqn/dqn.py b/stable_baselines3/dqn/dqn.py index 6cea64d39..6dffaa721 100644 --- a/stable_baselines3/dqn/dqn.py +++ b/stable_baselines3/dqn/dqn.py @@ -6,6 +6,7 @@ from torch.nn import functional as F from stable_baselines3.common import logger +from stable_baselines3.common.buffers import ReplayBuffer from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm from stable_baselines3.common.preprocessing import maybe_transpose from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule @@ -35,6 +36,9 @@ class DQN(OffPolicyAlgorithm): :param gradient_steps: How many gradient steps to do after each rollout (see ``train_freq``) Set to ``-1`` means to do as many gradient steps as steps done in the environment during the rollout. + :param replay_buffer_class: Replay buffer class to use (for instance ``HerReplayBuffer``). + If ``None``, it will be automatically selected. + :param replay_buffer_kwargs: Keyword arguments to pass to the replay buffer on creation. :param optimize_memory_usage: Enable a memory efficient variant of the replay buffer at a cost of more complexity. See https://github.com/DLR-RM/stable-baselines3/issues/37#issuecomment-637501195 @@ -60,13 +64,15 @@ def __init__( policy: Union[str, Type[DQNPolicy]], env: Union[GymEnv, str], learning_rate: Union[float, Schedule] = 1e-4, - buffer_size: int = 1000000, + buffer_size: int = 1000000, # 1e6 learning_starts: int = 50000, batch_size: Optional[int] = 32, tau: float = 1.0, gamma: float = 0.99, train_freq: Union[int, Tuple[int, str]] = 4, gradient_steps: int = 1, + replay_buffer_class: Optional[ReplayBuffer] = None, + replay_buffer_kwargs: Optional[Dict[str, Any]] = None, optimize_memory_usage: bool = False, target_update_interval: int = 10000, exploration_fraction: float = 0.1, @@ -95,6 +101,8 @@ def __init__( train_freq, gradient_steps, action_noise=None, # No action noise + replay_buffer_class=replay_buffer_class, + replay_buffer_kwargs=replay_buffer_kwargs, policy_kwargs=policy_kwargs, tensorboard_log=tensorboard_log, verbose=verbose, @@ -124,7 +132,9 @@ def _setup_model(self) -> None: super(DQN, self)._setup_model() self._create_aliases() self.exploration_schedule = get_linear_fn( - self.exploration_initial_eps, self.exploration_final_eps, self.exploration_fraction + self.exploration_initial_eps, + self.exploration_final_eps, + self.exploration_fraction, ) def _create_aliases(self) -> None: @@ -203,7 +213,10 @@ def predict( """ if not deterministic and np.random.rand() < self.exploration_rate: if is_vectorized_observation(maybe_transpose(observation, self.observation_space), self.observation_space): - n_batch = observation.shape[0] + if isinstance(self.observation_space, gym.spaces.Dict): + n_batch = observation[list(observation.keys())[0]].shape[0] + else: + n_batch = observation.shape[0] action = np.array([self.action_space.sample() for _ in range(n_batch)]) else: action = np.array(self.action_space.sample()) diff --git a/stable_baselines3/dqn/policies.py b/stable_baselines3/dqn/policies.py index 4ecf39a5f..d39d9f2b9 100644 --- a/stable_baselines3/dqn/policies.py +++ b/stable_baselines3/dqn/policies.py @@ -5,7 +5,13 @@ from torch import nn from stable_baselines3.common.policies import BasePolicy, register_policy -from stable_baselines3.common.torch_layers import BaseFeaturesExtractor, FlattenExtractor, NatureCNN, create_mlp +from stable_baselines3.common.torch_layers import ( + BaseFeaturesExtractor, + CombinedExtractor, + FlattenExtractor, + NatureCNN, + create_mlp, +) from stable_baselines3.common.type_aliases import Schedule @@ -122,10 +128,10 @@ def __init__( ) if net_arch is None: - if features_extractor_class == FlattenExtractor: - net_arch = [64, 64] - else: + if features_extractor_class == NatureCNN: net_arch = [] + else: + net_arch = [64, 64] self.net_arch = net_arch self.activation_fn = activation_fn @@ -233,5 +239,51 @@ def __init__( ) +class MultiInputPolicy(DQNPolicy): + """ + Policy class for DQN when using dict observations as input. + + :param observation_space: Observation space + :param action_space: Action space + :param lr_schedule: Learning rate schedule (could be constant) + :param net_arch: The specification of the policy and value networks. + :param activation_fn: Activation function + :param features_extractor_class: Features extractor to use. + :param normalize_images: Whether to normalize images or not, + dividing by 255.0 (True by default) + :param optimizer_class: The optimizer to use, + ``th.optim.Adam`` by default + :param optimizer_kwargs: Additional keyword arguments, + excluding the learning rate, to pass to the optimizer + """ + + def __init__( + self, + observation_space: gym.spaces.Dict, + action_space: gym.spaces.Space, + lr_schedule: Schedule, + net_arch: Optional[List[int]] = None, + activation_fn: Type[nn.Module] = nn.ReLU, + features_extractor_class: Type[BaseFeaturesExtractor] = CombinedExtractor, + features_extractor_kwargs: Optional[Dict[str, Any]] = None, + normalize_images: bool = True, + optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam, + optimizer_kwargs: Optional[Dict[str, Any]] = None, + ): + super(MultiInputPolicy, self).__init__( + observation_space, + action_space, + lr_schedule, + net_arch, + activation_fn, + features_extractor_class, + features_extractor_kwargs, + normalize_images, + optimizer_class, + optimizer_kwargs, + ) + + register_policy("MlpPolicy", MlpPolicy) register_policy("CnnPolicy", CnnPolicy) +register_policy("MultiInputPolicy", MultiInputPolicy) diff --git a/stable_baselines3/her/__init__.py b/stable_baselines3/her/__init__.py index 24f347305..1f58921b4 100644 --- a/stable_baselines3/her/__init__.py +++ b/stable_baselines3/her/__init__.py @@ -1,4 +1,2 @@ -from stable_baselines3.common.vec_env.obs_dict_wrapper import ObsDictWrapper from stable_baselines3.her.goal_selection_strategy import GoalSelectionStrategy -from stable_baselines3.her.her import HER from stable_baselines3.her.her_replay_buffer import HerReplayBuffer diff --git a/stable_baselines3/her/her.py b/stable_baselines3/her/her.py deleted file mode 100644 index 43984ded3..000000000 --- a/stable_baselines3/her/her.py +++ /dev/null @@ -1,582 +0,0 @@ -import io -import pathlib -import warnings -from typing import Any, Dict, Iterable, List, Optional, Tuple, Type, Union - -import numpy as np -import torch as th - -from stable_baselines3.common.base_class import BaseAlgorithm -from stable_baselines3.common.callbacks import BaseCallback -from stable_baselines3.common.noise import ActionNoise -from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm -from stable_baselines3.common.policies import BasePolicy -from stable_baselines3.common.save_util import load_from_zip_file, recursive_setattr -from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, RolloutReturn, TrainFreq -from stable_baselines3.common.utils import check_for_correct_spaces, should_collect_more_steps -from stable_baselines3.common.vec_env import VecEnv -from stable_baselines3.common.vec_env.obs_dict_wrapper import ObsDictWrapper -from stable_baselines3.her.goal_selection_strategy import KEY_TO_GOAL_STRATEGY, GoalSelectionStrategy -from stable_baselines3.her.her_replay_buffer import HerReplayBuffer - - -def get_time_limit(env: VecEnv, current_max_episode_length: Optional[int]) -> int: - """ - Get time limit from environment. - - :param env: Environment from which we want to get the time limit. - :param current_max_episode_length: Current value for max_episode_length. - :return: max episode length - """ - # try to get the attribute from environment - if current_max_episode_length is None: - try: - current_max_episode_length = env.get_attr("spec")[0].max_episode_steps - # Raise the error because the attribute is present but is None - if current_max_episode_length is None: - raise AttributeError - # if not available check if a valid value was passed as an argument - except AttributeError: - raise ValueError( - "The max episode length could not be inferred.\n" - "You must specify a `max_episode_steps` when registering the environment,\n" - "use a `gym.wrappers.TimeLimit` wrapper " - "or pass `max_episode_length` to the model constructor" - ) - return current_max_episode_length - - -# TODO: rewrite HER class as soon as dict obs are supported -class HER(BaseAlgorithm): - """ - Hindsight Experience Replay (HER) - Paper: https://arxiv.org/abs/1707.01495 - - .. warning:: - - For performance reasons, the maximum number of steps per episodes must be specified. - In most cases, it will be inferred if you specify ``max_episode_steps`` when registering the environment - or if you use a ``gym.wrappers.TimeLimit`` (and ``env.spec`` is not None). - Otherwise, you can directly pass ``max_episode_length`` to the model constructor - - - For additional offline algorithm specific arguments please have a look at the corresponding documentation. - - :param policy: The policy model to use. - :param env: The environment to learn from (if registered in Gym, can be str) - :param model_class: Off policy model which will be used with hindsight experience replay. (SAC, TD3, DDPG, DQN) - :param n_sampled_goal: Number of sampled goals for replay. (offline sampling) - :param goal_selection_strategy: Strategy for sampling goals for replay. - One of ['episode', 'final', 'future', 'random'] - :param online_sampling: Sample HER transitions online. - :param learning_rate: learning rate for the optimizer, - it can be a function of the current progress remaining (from 1 to 0) - :param max_episode_length: The maximum length of an episode. If not specified, - it will be automatically inferred if the environment uses a ``gym.wrappers.TimeLimit`` wrapper. - """ - - def __init__( - self, - policy: Union[str, Type[BasePolicy]], - env: Union[GymEnv, str], - model_class: Type[OffPolicyAlgorithm], - n_sampled_goal: int = 4, - goal_selection_strategy: Union[GoalSelectionStrategy, str] = "future", - online_sampling: bool = False, - max_episode_length: Optional[int] = None, - *args, - **kwargs, - ): - - # we will use the policy and learning rate from the model - super(HER, self).__init__(policy=BasePolicy, env=env, policy_base=BasePolicy, learning_rate=0.0) - del self.policy, self.learning_rate - - if self.get_vec_normalize_env() is not None: - assert online_sampling, "You must pass `online_sampling=True` if you want to use `VecNormalize` with `HER`" - - _init_setup_model = kwargs.get("_init_setup_model", True) - if "_init_setup_model" in kwargs: - del kwargs["_init_setup_model"] - # model initialization - self.model_class = model_class - self.model = model_class( - policy=policy, - env=self.env, - _init_setup_model=False, # pytype: disable=wrong-keyword-args - *args, - **kwargs, # pytype: disable=wrong-keyword-args - ) - - # Make HER use self.model.action_noise - del self.action_noise - self.verbose = self.model.verbose - self.tensorboard_log = self.model.tensorboard_log - - # convert goal_selection_strategy into GoalSelectionStrategy if string - if isinstance(goal_selection_strategy, str): - self.goal_selection_strategy = KEY_TO_GOAL_STRATEGY[goal_selection_strategy.lower()] - else: - self.goal_selection_strategy = goal_selection_strategy - - # check if goal_selection_strategy is valid - assert isinstance( - self.goal_selection_strategy, GoalSelectionStrategy - ), f"Invalid goal selection strategy, please use one of {list(GoalSelectionStrategy)}" - - self.n_sampled_goal = n_sampled_goal - # if we sample her transitions online use custom replay buffer - self.online_sampling = online_sampling - # compute ratio between HER replays and regular replays in percent for online HER sampling - self.her_ratio = 1 - (1.0 / (self.n_sampled_goal + 1)) - # maximum steps in episode - self.max_episode_length = get_time_limit(self.env, max_episode_length) - # storage for transitions of current episode for offline sampling - # for online sampling, it replaces the "classic" replay buffer completely - her_buffer_size = self.buffer_size if online_sampling else self.max_episode_length - - assert self.env is not None, "Because it needs access to `env.compute_reward()` HER you must provide the env." - - self._episode_storage = HerReplayBuffer( - self.env, - her_buffer_size, - self.max_episode_length, - self.goal_selection_strategy, - self.env.observation_space, - self.env.action_space, - self.device, - self.n_envs, - self.her_ratio, # pytype: disable=wrong-arg-types - ) - - # counter for steps in episode - self.episode_steps = 0 - - if _init_setup_model: - self._setup_model() - - def _setup_model(self) -> None: - self.model._setup_model() - # assign episode storage to replay buffer when using online HER sampling - if self.online_sampling: - self.model.replay_buffer = self._episode_storage - - def predict( - self, - observation: np.ndarray, - state: Optional[np.ndarray] = None, - mask: Optional[np.ndarray] = None, - deterministic: bool = False, - ) -> Tuple[np.ndarray, Optional[np.ndarray]]: - - return self.model.predict(observation, state, mask, deterministic) - - def learn( - self, - total_timesteps: int, - callback: MaybeCallback = None, - log_interval: int = 4, - eval_env: Optional[GymEnv] = None, - eval_freq: int = -1, - n_eval_episodes: int = 5, - tb_log_name: str = "HER", - eval_log_path: Optional[str] = None, - reset_num_timesteps: bool = True, - ) -> BaseAlgorithm: - - total_timesteps, callback = self._setup_learn( - total_timesteps, eval_env, callback, eval_freq, n_eval_episodes, eval_log_path, reset_num_timesteps, tb_log_name - ) - self.model.start_time = self.start_time - self.model.ep_info_buffer = self.ep_info_buffer - self.model.ep_success_buffer = self.ep_success_buffer - self.model.num_timesteps = self.num_timesteps - self.model._episode_num = self._episode_num - self.model._last_obs = self._last_obs - self.model._total_timesteps = self._total_timesteps - - callback.on_training_start(locals(), globals()) - - while self.num_timesteps < total_timesteps: - rollout = self.collect_rollouts( - self.env, - train_freq=self.train_freq, - action_noise=self.action_noise, - callback=callback, - learning_starts=self.learning_starts, - log_interval=log_interval, - ) - - if rollout.continue_training is False: - break - - if self.num_timesteps > 0 and self.num_timesteps > self.learning_starts and self.replay_buffer.size() > 0: - # If no `gradient_steps` is specified, - # do as many gradients steps as steps performed during the rollout - gradient_steps = self.gradient_steps if self.gradient_steps > 0 else rollout.episode_timesteps - self.train(batch_size=self.batch_size, gradient_steps=gradient_steps) - - callback.on_training_end() - - return self - - def collect_rollouts( - self, - env: VecEnv, - callback: BaseCallback, - train_freq: TrainFreq, - action_noise: Optional[ActionNoise] = None, - learning_starts: int = 0, - log_interval: Optional[int] = None, - ) -> RolloutReturn: - """ - Collect experiences and store them into a ReplayBuffer. - - :param env: The training environment - :param callback: Callback that will be called at each step - (and at the beginning and end of the rollout) - :param train_freq: How much experience to collect - by doing rollouts of current policy. - Either ``TrainFreq(, TrainFrequencyUnit.STEP)`` - or ``TrainFreq(, TrainFrequencyUnit.EPISODE)`` - with ```` being an integer greater than 0. - :param action_noise: Action noise that will be used for exploration - Required for deterministic policy (e.g. TD3). This can also be used - in addition to the stochastic policy for SAC. - :param learning_starts: Number of steps before learning for the warm-up phase. - :param log_interval: Log data every ``log_interval`` episodes - :return: - """ - - episode_rewards, total_timesteps = [], [] - num_collected_steps, num_collected_episodes = 0, 0 - - assert isinstance(env, VecEnv), "You must pass a VecEnv" - assert env.num_envs == 1, "OffPolicyAlgorithm only support single environment" - assert train_freq.frequency > 0, "Should at least collect one step or episode." - - if self.model.use_sde: - self.actor.reset_noise() - - callback.on_rollout_start() - continue_training = True - - while should_collect_more_steps(train_freq, num_collected_steps, num_collected_episodes): - done = False - episode_reward, episode_timesteps = 0.0, 0 - - while not done: - # concatenate observation and (desired) goal - observation = self._last_obs - self._last_obs = ObsDictWrapper.convert_dict(observation) - - if ( - self.model.use_sde - and self.model.sde_sample_freq > 0 - and num_collected_steps % self.model.sde_sample_freq == 0 - ): - # Sample a new noise matrix - self.actor.reset_noise() - - # Select action randomly or according to policy - self.model._last_obs = self._last_obs - action, buffer_action = self._sample_action(learning_starts, action_noise) - - # Perform action - new_obs, reward, done, infos = env.step(action) - - self.num_timesteps += 1 - self.model.num_timesteps = self.num_timesteps - episode_timesteps += 1 - num_collected_steps += 1 - - # Only stop training if return value is False, not when it is None. - if callback.on_step() is False: - return RolloutReturn(0.0, num_collected_steps, num_collected_episodes, continue_training=False) - - episode_reward += reward - - # Retrieve reward and episode length if using Monitor wrapper - self._update_info_buffer(infos, done) - self.model.ep_info_buffer = self.ep_info_buffer - self.model.ep_success_buffer = self.ep_success_buffer - - # == Store transition in the replay buffer and/or in the episode storage == - - if self._vec_normalize_env is not None: - # Store only the unnormalized version - new_obs_ = self._vec_normalize_env.get_original_obs() - reward_ = self._vec_normalize_env.get_original_reward() - else: - # Avoid changing the original ones - self._last_original_obs, new_obs_, reward_ = observation, new_obs, reward - self.model._last_original_obs = self._last_original_obs - - # As the VecEnv resets automatically, new_obs is already the - # first observation of the next episode - if done and infos[0].get("terminal_observation") is not None: - next_obs = infos[0]["terminal_observation"] - # VecNormalize normalizes the terminal observation - if self._vec_normalize_env is not None: - next_obs = self._vec_normalize_env.unnormalize_obs(next_obs) - else: - next_obs = new_obs_ - - if self.online_sampling: - self.replay_buffer.add(self._last_original_obs, next_obs, buffer_action, reward_, done, infos) - else: - # concatenate observation with (desired) goal - flattened_obs = ObsDictWrapper.convert_dict(self._last_original_obs) - flattened_next_obs = ObsDictWrapper.convert_dict(next_obs) - # add to replay buffer - self.replay_buffer.add(flattened_obs, flattened_next_obs, buffer_action, reward_, done) - # add current transition to episode storage - self._episode_storage.add(self._last_original_obs, next_obs, buffer_action, reward_, done, infos) - - self._last_obs = new_obs - self.model._last_obs = self._last_obs - - # Save the unnormalized new observation - if self._vec_normalize_env is not None: - self._last_original_obs = new_obs_ - self.model._last_original_obs = self._last_original_obs - - self.model._update_current_progress_remaining(self.num_timesteps, self._total_timesteps) - - # For DQN, check if the target network should be updated - # and update the exploration schedule - # For SAC/TD3, the update is done as the same time as the gradient update - # see https://github.com/hill-a/stable-baselines/issues/900 - self.model._on_step() - - self.episode_steps += 1 - - if not should_collect_more_steps(train_freq, num_collected_steps, num_collected_episodes): - break - - if done or self.episode_steps >= self.max_episode_length: - if self.online_sampling: - self.replay_buffer.store_episode() - else: - self._episode_storage.store_episode() - # sample virtual transitions and store them in replay buffer - self._sample_her_transitions() - # clear storage for current episode - self._episode_storage.reset() - - num_collected_episodes += 1 - self._episode_num += 1 - self.model._episode_num = self._episode_num - episode_rewards.append(episode_reward) - total_timesteps.append(episode_timesteps) - - if action_noise is not None: - action_noise.reset() - - # Log training infos - if log_interval is not None and self._episode_num % log_interval == 0: - self._dump_logs() - - self.episode_steps = 0 - - mean_reward = np.mean(episode_rewards) if num_collected_episodes > 0 else 0.0 - - callback.on_rollout_end() - - return RolloutReturn(mean_reward, num_collected_steps, num_collected_episodes, continue_training) - - def _sample_her_transitions(self) -> None: - """ - Sample additional goals and store new transitions in replay buffer - when using offline sampling. - """ - - # Sample goals and get new observations - # maybe_vec_env=None as we should store unnormalized transitions, - # they will be normalized at sampling time - observations, next_observations, actions, rewards = self._episode_storage.sample_offline( - n_sampled_goal=self.n_sampled_goal - ) - - # store data in replay buffer - dones = np.zeros((len(observations)), dtype=bool) - self.replay_buffer.extend(observations, next_observations, actions, rewards, dones) - - def __getattr__(self, item: str) -> Any: - """ - Find attribute from model class if this class does not have it. - """ - if hasattr(self.model, item): - return getattr(self.model, item) - else: - raise AttributeError(f"{self} has no attribute {item}") - - def _get_torch_save_params(self) -> Tuple[List[str], List[str]]: - return self.model._get_torch_save_params() - - def save( - self, - path: Union[str, pathlib.Path, io.BufferedIOBase], - exclude: Optional[Iterable[str]] = None, - include: Optional[Iterable[str]] = None, - ) -> None: - """ - Save all the attributes of the object and the model parameters in a zip-file. - - :param path: path to the file where the rl agent should be saved - :param exclude: name of parameters that should be excluded in addition to the default one - :param include: name of parameters that might be excluded but should be included anyway - """ - - # add HER parameters to model - self.model.n_sampled_goal = self.n_sampled_goal - self.model.goal_selection_strategy = self.goal_selection_strategy - self.model.online_sampling = self.online_sampling - self.model.model_class = self.model_class - self.model.max_episode_length = self.max_episode_length - - self.model.save(path, exclude, include) - - @classmethod - def load( - cls, - path: Union[str, pathlib.Path, io.BufferedIOBase], - env: Optional[GymEnv] = None, - device: Union[th.device, str] = "auto", - custom_objects: Optional[Dict[str, Any]] = None, - **kwargs, - ) -> "BaseAlgorithm": - """ - Load the model from a zip-file - - :param path: path to the file (or a file-like) where to - load the agent from - :param env: the new environment to run the loaded model on - (can be None if you only need prediction from a trained model) has priority over any saved environment - :param device: Device on which the code should run. - :param custom_objects: Dictionary of objects to replace - upon loading. If a variable is present in this dictionary as a - key, it will not be deserialized and the corresponding item - will be used instead. Similar to custom_objects in - ``keras.models.load_model``. Useful when you have an object in - file that can not be deserialized. - :param kwargs: extra arguments to change the model when loading - """ - data, params, pytorch_variables = load_from_zip_file(path, device=device, custom_objects=custom_objects) - - # Remove stored device information and replace with ours - if "policy_kwargs" in data: - if "device" in data["policy_kwargs"]: - del data["policy_kwargs"]["device"] - - if "policy_kwargs" in kwargs and kwargs["policy_kwargs"] != data["policy_kwargs"]: - raise ValueError( - f"The specified policy kwargs do not equal the stored policy kwargs." - f"Stored kwargs: {data['policy_kwargs']}, specified kwargs: {kwargs['policy_kwargs']}" - ) - - # check if observation space and action space are part of the saved parameters - if "observation_space" not in data or "action_space" not in data: - raise KeyError("The observation_space and action_space were not given, can't verify new environments") - - # check if given env is valid - if env is not None: - # Wrap first if needed - env = cls._wrap_env(env, data["verbose"]) - # Check if given env is valid - check_for_correct_spaces(env, data["observation_space"], data["action_space"]) - else: - # Use stored env, if one exists. If not, continue as is (can be used for predict) - if "env" in data: - env = data["env"] - - if "use_sde" in data and data["use_sde"]: - kwargs["use_sde"] = True - - # Keys that cannot be changed - for key in {"model_class", "online_sampling", "max_episode_length"}: - if key in kwargs: - del kwargs[key] - - # Keys that can be changed - for key in {"n_sampled_goal", "goal_selection_strategy"}: - if key in kwargs: - data[key] = kwargs[key] # pytype: disable=unsupported-operands - del kwargs[key] - - # noinspection PyArgumentList - her_model = cls( - policy=data["policy_class"], - env=env, - model_class=data["model_class"], - n_sampled_goal=data["n_sampled_goal"], - goal_selection_strategy=data["goal_selection_strategy"], - online_sampling=data["online_sampling"], - max_episode_length=data["max_episode_length"], - policy_kwargs=data["policy_kwargs"], - _init_setup_model=False, # pytype: disable=not-instantiable,wrong-keyword-args - **kwargs, - ) - - # load parameters - her_model.model.__dict__.update(data) - her_model.model.__dict__.update(kwargs) - her_model._setup_model() - - her_model._total_timesteps = her_model.model._total_timesteps - her_model.num_timesteps = her_model.model.num_timesteps - her_model._episode_num = her_model.model._episode_num - - # put state_dicts back in place - her_model.model.set_parameters(params, exact_match=True, device=device) - - # put other pytorch variables back in place - if pytorch_variables is not None: - for name in pytorch_variables: - recursive_setattr(her_model.model, name, pytorch_variables[name]) - - # Sample gSDE exploration matrix, so it uses the right device - # see issue #44 - if her_model.model.use_sde: - her_model.model.policy.reset_noise() # pytype: disable=attribute-error - return her_model - - def load_replay_buffer( - self, path: Union[str, pathlib.Path, io.BufferedIOBase], truncate_last_trajectory: bool = True - ) -> None: - """ - Load a replay buffer from a pickle file and set environment for replay buffer (only online sampling). - - :param path: Path to the pickled replay buffer. - :param truncate_last_trajectory: Only for online sampling. - If set to ``True`` we assume that the last trajectory in the replay buffer was finished. - If it is set to ``False`` we assume that we continue the same trajectory (same episode). - """ - self.model.load_replay_buffer(path=path) - - if self.online_sampling: - # set environment - self.replay_buffer.set_env(self.env) - # If we are at the start of an episode, no need to truncate - current_idx = self.replay_buffer.current_idx - - # truncate interrupted episode - if truncate_last_trajectory and current_idx > 0: - warnings.warn( - "The last trajectory in the replay buffer will be truncated.\n" - "If you are in the same episode as when the replay buffer was saved,\n" - "you should use `truncate_last_trajectory=False` to avoid that issue." - ) - # get current episode and transition index - pos = self.replay_buffer.pos - # set episode length for current episode - self.replay_buffer.episode_lengths[pos] = current_idx - # set done = True for current episode - # current_idx was already incremented - self.replay_buffer.buffer["done"][pos][current_idx - 1] = np.array([True], dtype=np.float32) - # reset current transition index - self.replay_buffer.current_idx = 0 - # increment episode counter - self.replay_buffer.pos = (self.replay_buffer.pos + 1) % self.replay_buffer.max_episode_stored - # update "full" indicator - self.replay_buffer.full = self.replay_buffer.full or self.replay_buffer.pos == 0 diff --git a/stable_baselines3/her/her_replay_buffer.py b/stable_baselines3/her/her_replay_buffer.py index edca50ff9..6a790d691 100644 --- a/stable_baselines3/her/her_replay_buffer.py +++ b/stable_baselines3/her/her_replay_buffer.py @@ -1,75 +1,148 @@ +import warnings from collections import deque from typing import Any, Dict, List, Optional, Tuple, Union import numpy as np import torch as th -from gym import spaces -from stable_baselines3.common.buffers import ReplayBuffer -from stable_baselines3.common.type_aliases import ReplayBufferSamples, RolloutBufferSamples -from stable_baselines3.common.vec_env import VecNormalize -from stable_baselines3.common.vec_env.obs_dict_wrapper import ObsDictWrapper -from stable_baselines3.her.goal_selection_strategy import GoalSelectionStrategy +from stable_baselines3.common.buffers import DictReplayBuffer +from stable_baselines3.common.preprocessing import get_obs_shape +from stable_baselines3.common.type_aliases import DictReplayBufferSamples +from stable_baselines3.common.vec_env import VecEnv, VecNormalize +from stable_baselines3.her.goal_selection_strategy import KEY_TO_GOAL_STRATEGY, GoalSelectionStrategy -class HerReplayBuffer(ReplayBuffer): +def get_time_limit(env: VecEnv, current_max_episode_length: Optional[int]) -> int: """ + Get time limit from environment. + + :param env: Environment from which we want to get the time limit. + :param current_max_episode_length: Current value for max_episode_length. + :return: max episode length + """ + # try to get the attribute from environment + if current_max_episode_length is None: + try: + current_max_episode_length = env.get_attr("spec")[0].max_episode_steps + # Raise the error because the attribute is present but is None + if current_max_episode_length is None: + raise AttributeError + # if not available check if a valid value was passed as an argument + except AttributeError: + raise ValueError( + "The max episode length could not be inferred.\n" + "You must specify a `max_episode_steps` when registering the environment,\n" + "use a `gym.wrappers.TimeLimit` wrapper " + "or pass `max_episode_length` to the model constructor" + ) + return current_max_episode_length + + +class HerReplayBuffer(DictReplayBuffer): + """ + Hindsight Experience Replay (HER) buffer. + Paper: https://arxiv.org/abs/1707.01495 + + .. warning:: + + For performance reasons, the maximum number of steps per episodes must be specified. + In most cases, it will be inferred if you specify ``max_episode_steps`` when registering the environment + or if you use a ``gym.wrappers.TimeLimit`` (and ``env.spec`` is not None). + Otherwise, you can directly pass ``max_episode_length`` to the replay buffer constructor. + + Replay buffer for sampling HER (Hindsight Experience Replay) transitions. In the online sampling case, these new transitions will not be saved in the replay buffer and will only be created at sampling time. :param env: The training environment :param buffer_size: The size of the buffer measured in transitions. - :param max_episode_length: The length of an episode. (time horizon) + :param max_episode_length: The maximum length of an episode. If not specified, + it will be automatically inferred if the environment uses a ``gym.wrappers.TimeLimit`` wrapper. :param goal_selection_strategy: Strategy for sampling goals for replay. One of ['episode', 'final', 'future'] - :param observation_space: Observation space - :param action_space: Action space :param device: PyTorch device - :param n_envs: Number of parallel environments - :her_ratio: The ratio between HER transitions and regular transitions in percent - (between 0 and 1, for online sampling) - The default value ``her_ratio=0.8`` corresponds to 4 virtual transitions - for one real transition (4 / (4 + 1) = 0.8) + :param n_sampled_goal: Number of virtual transitions to create per real transition, + by sampling new goals. + :param handle_timeout_termination: Handle timeout termination (due to timelimit) + separately and treat the task as infinite horizon task. + https://github.com/DLR-RM/stable-baselines3/issues/284 """ def __init__( self, - env: ObsDictWrapper, + env: VecEnv, buffer_size: int, - max_episode_length: int, - goal_selection_strategy: GoalSelectionStrategy, - observation_space: spaces.Space, - action_space: spaces.Space, device: Union[th.device, str] = "cpu", - n_envs: int = 1, - her_ratio: float = 0.8, + replay_buffer: Optional[DictReplayBuffer] = None, + max_episode_length: Optional[int] = None, + n_sampled_goal: int = 4, + goal_selection_strategy: Union[GoalSelectionStrategy, str] = "future", + online_sampling: bool = True, + handle_timeout_termination: bool = True, ): - super(HerReplayBuffer, self).__init__(buffer_size, observation_space, action_space, device, n_envs) + super(HerReplayBuffer, self).__init__(buffer_size, env.observation_space, env.action_space, device, env.num_envs) + + # convert goal_selection_strategy into GoalSelectionStrategy if string + if isinstance(goal_selection_strategy, str): + self.goal_selection_strategy = KEY_TO_GOAL_STRATEGY[goal_selection_strategy.lower()] + else: + self.goal_selection_strategy = goal_selection_strategy + + # check if goal_selection_strategy is valid + assert isinstance( + self.goal_selection_strategy, GoalSelectionStrategy + ), f"Invalid goal selection strategy, please use one of {list(GoalSelectionStrategy)}" + + self.n_sampled_goal = n_sampled_goal + # if we sample her transitions online use custom replay buffer + self.online_sampling = online_sampling + # compute ratio between HER replays and regular replays in percent for online HER sampling + self.her_ratio = 1 - (1.0 / (self.n_sampled_goal + 1)) + # maximum steps in episode + self.max_episode_length = get_time_limit(env, max_episode_length) + # storage for transitions of current episode for offline sampling + # for online sampling, it replaces the "classic" replay buffer completely + her_buffer_size = buffer_size if online_sampling else self.max_episode_length self.env = env - self.buffer_size = buffer_size - self.max_episode_length = max_episode_length + self.buffer_size = her_buffer_size + + if online_sampling: + replay_buffer = None + self.replay_buffer = replay_buffer + self.online_sampling = online_sampling + + # Handle timeouts termination properly if needed + # see https://github.com/DLR-RM/stable-baselines3/issues/284 + self.handle_timeout_termination = handle_timeout_termination # buffer with episodes # number of episodes which can be stored until buffer size is reached self.max_episode_stored = self.buffer_size // self.max_episode_length self.current_idx = 0 + # Counter to prevent overflow + self.episode_steps = 0 + + # Get shape of observation and goal (usually the same) + self.obs_shape = get_obs_shape(self.env.observation_space.spaces["observation"]) + self.goal_shape = get_obs_shape(self.env.observation_space.spaces["achieved_goal"]) # input dimensions for buffer initialization input_shape = { - "observation": (self.env.num_envs, self.env.obs_dim), - "achieved_goal": (self.env.num_envs, self.env.goal_dim), - "desired_goal": (self.env.num_envs, self.env.goal_dim), + "observation": (self.env.num_envs,) + self.obs_shape, + "achieved_goal": (self.env.num_envs,) + self.goal_shape, + "desired_goal": (self.env.num_envs,) + self.goal_shape, "action": (self.action_dim,), "reward": (1,), - "next_obs": (self.env.num_envs, self.env.obs_dim), - "next_achieved_goal": (self.env.num_envs, self.env.goal_dim), - "next_desired_goal": (self.env.num_envs, self.env.goal_dim), + "next_obs": (self.env.num_envs,) + self.obs_shape, + "next_achieved_goal": (self.env.num_envs,) + self.goal_shape, + "next_desired_goal": (self.env.num_envs,) + self.goal_shape, "done": (1,), } - self.buffer = { + self._observation_keys = ["observation", "achieved_goal", "desired_goal"] + self._buffer = { key: np.zeros((self.max_episode_stored, self.max_episode_length, *dim), dtype=np.float32) for key, dim in input_shape.items() } @@ -78,15 +151,13 @@ def __init__( # episode length storage, needed for episodes which has less steps than the maximum length self.episode_lengths = np.zeros(self.max_episode_stored, dtype=np.int64) - self.goal_selection_strategy = goal_selection_strategy - # percentage of her indices - self.her_ratio = her_ratio - def __getstate__(self) -> Dict[str, Any]: """ Gets state for pickling. - Excludes self.env, as in general Env's may not be pickleable.""" + Excludes self.env, as in general Env's may not be pickleable. + Note: when using offline sampling, this will also save the offline replay buffer. + """ state = self.__dict__.copy() # these attributes are not pickleable del state["env"] @@ -104,7 +175,7 @@ def __setstate__(self, state: Dict[str, Any]) -> None: assert "env" not in state self.env = None - def set_env(self, env: ObsDictWrapper) -> None: + def set_env(self, env: VecEnv) -> None: """ Sets the environment. @@ -115,9 +186,7 @@ def set_env(self, env: ObsDictWrapper) -> None: self.env = env - def _get_samples( - self, batch_inds: np.ndarray, env: Optional[VecNormalize] = None - ) -> Union[ReplayBufferSamples, RolloutBufferSamples]: + def _get_samples(self, batch_inds: np.ndarray, env: Optional[VecNormalize] = None) -> DictReplayBufferSamples: """ Abstract method from base class. """ @@ -127,7 +196,7 @@ def sample( self, batch_size: int, env: Optional[VecNormalize], - ) -> Union[ReplayBufferSamples, Tuple[np.ndarray, ...]]: + ) -> DictReplayBufferSamples: """ Sample function for online sampling of HER transition, this replaces the "regular" replay buffer ``sample()`` @@ -138,12 +207,14 @@ def sample( to normalize the observations/rewards when sampling :return: Samples. """ - return self._sample_transitions(batch_size, maybe_vec_env=env, online_sampling=True) + if self.replay_buffer is not None: + return self.replay_buffer.sample(batch_size, env) + return self._sample_transitions(batch_size, maybe_vec_env=env, online_sampling=True) # pytype: disable=bad-return-type - def sample_offline( + def _sample_offline( self, n_sampled_goal: Optional[int] = None, - ) -> Union[ReplayBufferSamples, Tuple[np.ndarray, ...]]: + ) -> Tuple[Dict[str, np.ndarray], Dict[str, np.ndarray], np.ndarray, np.ndarray]: """ Sample function for offline sampling of HER transition, in that case, only one episode is used and transitions @@ -152,9 +223,13 @@ def sample_offline( :param n_sampled_goal: Number of sampled goals for replay :return: at most(n_sampled_goal * episode_length) HER transitions. """ - # env=None as we should store unnormalized transitions, they will be normalized at sampling time + # `maybe_vec_env=None` as we should store unnormalized transitions, + # they will be normalized at sampling time return self._sample_transitions( - batch_size=None, maybe_vec_env=None, online_sampling=False, n_sampled_goal=n_sampled_goal + batch_size=None, + maybe_vec_env=None, + online_sampling=False, + n_sampled_goal=n_sampled_goal, ) def sample_goals( @@ -191,7 +266,7 @@ def sample_goals( else: raise ValueError(f"Strategy {self.goal_selection_strategy} for sampling goals not supported!") - return self.buffer["achieved_goal"][her_episode_indices, transitions_indices] + return self._buffer["achieved_goal"][her_episode_indices, transitions_indices] def _sample_transitions( self, @@ -199,7 +274,7 @@ def _sample_transitions( maybe_vec_env: Optional[VecNormalize], online_sampling: bool, n_sampled_goal: Optional[int] = None, - ) -> Union[ReplayBufferSamples, Tuple[np.ndarray, ...]]: + ) -> Union[DictReplayBufferSamples, Tuple[Dict[str, np.ndarray], Dict[str, np.ndarray], np.ndarray, np.ndarray]]: """ :param batch_size: Number of element to sample (only used for online sampling) :param env: associated gym VecEnv to normalize the observations/rewards @@ -248,7 +323,7 @@ def _sample_transitions( if her_indices.size == 0: # Episode of one timestep, not enough for using the "future" strategy # no virtual transitions are created in that case - return np.zeros(0), np.zeros(0), np.zeros(0), np.zeros(0) + return {}, {}, np.zeros(0), np.zeros(0) else: # Repeat every transition index n_sampled_goals times # to sample n_sampled_goal per timestep in the episode (only one is stored). @@ -258,7 +333,7 @@ def _sample_transitions( her_indices = np.arange(len(episode_indices)) # get selected transitions - transitions = {key: self.buffer[key][episode_indices, transitions_indices].copy() for key in self.buffer.keys()} + transitions = {key: self._buffer[key][episode_indices, transitions_indices].copy() for key in self._buffer.keys()} # sample new desired goals and relabel the transitions new_goals = self.sample_goals(episode_indices, her_indices, transitions_indices) @@ -272,37 +347,48 @@ def _sample_transitions( ] ) - # Vectorized computation of the new reward - transitions["reward"][her_indices, 0] = self.env.env_method( - "compute_reward", - # the new state depends on the previous state and action - # s_{t+1} = f(s_t, a_t) - # so the next_achieved_goal depends also on the previous state and action - # because we are in a GoalEnv: - # r_t = reward(s_t, a_t) = reward(next_achieved_goal, desired_goal) - # therefore we have to use "next_achieved_goal" and not "achieved_goal" - transitions["next_achieved_goal"][her_indices, 0], - # here we use the new desired goal - transitions["desired_goal"][her_indices, 0], - transitions["info"][her_indices, 0], - ) + # Edge case: episode of one timesteps with the future strategy + # no virtual transition can be created + if len(her_indices) > 0: + # Vectorized computation of the new reward + transitions["reward"][her_indices, 0] = self.env.env_method( + "compute_reward", + # the new state depends on the previous state and action + # s_{t+1} = f(s_t, a_t) + # so the next_achieved_goal depends also on the previous state and action + # because we are in a GoalEnv: + # r_t = reward(s_t, a_t) = reward(next_achieved_goal, desired_goal) + # therefore we have to use "next_achieved_goal" and not "achieved_goal" + transitions["next_achieved_goal"][her_indices, 0], + # here we use the new desired goal + transitions["desired_goal"][her_indices, 0], + transitions["info"][her_indices, 0], + ) # concatenate observation with (desired) goal - observations = ObsDictWrapper.convert_dict(self._normalize_obs(transitions, maybe_vec_env)) - # HACK to make normalize obs work with the next observation - transitions["observation"] = transitions["next_obs"] - next_observations = ObsDictWrapper.convert_dict(self._normalize_obs(transitions, maybe_vec_env)) + observations = self._normalize_obs(transitions, maybe_vec_env) + + # HACK to make normalize obs and `add()` work with the next observation + next_observations = { + "observation": transitions["next_obs"], + "achieved_goal": transitions["next_achieved_goal"], + # The desired goal for the next observation must be the same as the previous one + "desired_goal": transitions["desired_goal"], + } + next_observations = self._normalize_obs(next_observations, maybe_vec_env) if online_sampling: - data = ( - observations[:, 0], - transitions["action"], - next_observations[:, 0], - transitions["done"], - self._normalize_reward(transitions["reward"], maybe_vec_env), - ) + next_obs = {key: self.to_torch(next_observations[key][:, 0, :]) for key in self._observation_keys} - return ReplayBufferSamples(*tuple(map(self.to_torch, data))) + normalized_obs = {key: self.to_torch(observations[key][:, 0, :]) for key in self._observation_keys} + + return DictReplayBufferSamples( + observations=normalized_obs, + actions=self.to_torch(transitions["action"]), + next_observations=next_obs, + dones=self.to_torch(transitions["done"]), + rewards=self.to_torch(self._normalize_reward(transitions["reward"], maybe_vec_env)), + ) else: return observations, next_observations, transitions["action"], transitions["reward"] @@ -313,28 +399,58 @@ def add( action: np.ndarray, reward: np.ndarray, done: np.ndarray, - infos: List[dict], + infos: List[Dict[str, Any]], ) -> None: if self.current_idx == 0 and self.full: # Clear info buffer self.info_buffer[self.pos] = deque(maxlen=self.max_episode_length) - self.buffer["observation"][self.pos][self.current_idx] = obs["observation"] - self.buffer["achieved_goal"][self.pos][self.current_idx] = obs["achieved_goal"] - self.buffer["desired_goal"][self.pos][self.current_idx] = obs["desired_goal"] - self.buffer["action"][self.pos][self.current_idx] = action - self.buffer["done"][self.pos][self.current_idx] = done - self.buffer["reward"][self.pos][self.current_idx] = reward - self.buffer["next_obs"][self.pos][self.current_idx] = next_obs["observation"] - self.buffer["next_achieved_goal"][self.pos][self.current_idx] = next_obs["achieved_goal"] - self.buffer["next_desired_goal"][self.pos][self.current_idx] = next_obs["desired_goal"] + # Remove termination signals due to timeout + if self.handle_timeout_termination: + done_ = done * (1 - np.array([info.get("TimeLimit.truncated", False) for info in infos])) + else: + done_ = done + + self._buffer["observation"][self.pos][self.current_idx] = obs["observation"] + self._buffer["achieved_goal"][self.pos][self.current_idx] = obs["achieved_goal"] + self._buffer["desired_goal"][self.pos][self.current_idx] = obs["desired_goal"] + self._buffer["action"][self.pos][self.current_idx] = action + self._buffer["done"][self.pos][self.current_idx] = done_ + self._buffer["reward"][self.pos][self.current_idx] = reward + self._buffer["next_obs"][self.pos][self.current_idx] = next_obs["observation"] + self._buffer["next_achieved_goal"][self.pos][self.current_idx] = next_obs["achieved_goal"] + self._buffer["next_desired_goal"][self.pos][self.current_idx] = next_obs["desired_goal"] + + # When doing offline sampling + # Add real transition to normal replay buffer + if self.replay_buffer is not None: + self.replay_buffer.add( + obs, + next_obs, + action, + reward, + done, + infos, + ) self.info_buffer[self.pos].append(infos) # update current pointer self.current_idx += 1 + self.episode_steps += 1 + + if done or self.episode_steps >= self.max_episode_length: + self.store_episode() + if not self.online_sampling: + # sample virtual transitions and store them in replay buffer + self._sample_her_transitions() + # clear storage for current episode + self.reset() + + self.episode_steps = 0 + def store_episode(self) -> None: """ Increment episode counter @@ -354,6 +470,28 @@ def store_episode(self) -> None: # reset transition pointer self.current_idx = 0 + def _sample_her_transitions(self) -> None: + """ + Sample additional goals and store new transitions in replay buffer + when using offline sampling. + """ + + # Sample goals to create virtual transitions for the last episode. + observations, next_observations, actions, rewards = self._sample_offline(n_sampled_goal=self.n_sampled_goal) + + # Store virtual transitions in the replay buffer, if available + if len(observations) > 0: + for i in range(len(observations["observation"])): + self.replay_buffer.add( + {key: obs[i] for key, obs in observations.items()}, + {key: next_obs[i] for key, next_obs in next_observations.items()}, + actions[i], + rewards[i], + # We consider the transition as non-terminal + done=[False], + infos=[{}], + ) + @property def n_episodes_stored(self) -> int: if self.full: @@ -374,3 +512,34 @@ def reset(self) -> None: self.current_idx = 0 self.full = False self.episode_lengths = np.zeros(self.max_episode_stored, dtype=np.int64) + + def truncate_last_trajectory(self) -> None: + """ + Only for online sampling, called when loading the replay buffer. + If called, we assume that the last trajectory in the replay buffer was finished + (and truncate it). + If not called, we assume that we continue the same trajectory (same episode). + """ + # If we are at the start of an episode, no need to truncate + current_idx = self.current_idx + + # truncate interrupted episode + if current_idx > 0: + warnings.warn( + "The last trajectory in the replay buffer will be truncated.\n" + "If you are in the same episode as when the replay buffer was saved,\n" + "you should use `truncate_last_trajectory=False` to avoid that issue." + ) + # get current episode and transition index + pos = self.pos + # set episode length for current episode + self.episode_lengths[pos] = current_idx + # set done = True for current episode + # current_idx was already incremented + self._buffer["done"][pos][current_idx - 1] = np.array([True], dtype=np.float32) + # reset current transition index + self.current_idx = 0 + # increment episode counter + self.pos = (self.pos + 1) % self.max_episode_stored + # update "full" indicator + self.full = self.full or self.pos == 0 diff --git a/stable_baselines3/ppo/__init__.py b/stable_baselines3/ppo/__init__.py index c5b80937c..e5c23fc9c 100644 --- a/stable_baselines3/ppo/__init__.py +++ b/stable_baselines3/ppo/__init__.py @@ -1,2 +1,2 @@ -from stable_baselines3.ppo.policies import CnnPolicy, MlpPolicy +from stable_baselines3.ppo.policies import CnnPolicy, MlpPolicy, MultiInputPolicy from stable_baselines3.ppo.ppo import PPO diff --git a/stable_baselines3/ppo/policies.py b/stable_baselines3/ppo/policies.py index 7d21de8bf..7427cfc4a 100644 --- a/stable_baselines3/ppo/policies.py +++ b/stable_baselines3/ppo/policies.py @@ -1,9 +1,16 @@ # This file is here just to define MlpPolicy/CnnPolicy # that work for PPO -from stable_baselines3.common.policies import ActorCriticCnnPolicy, ActorCriticPolicy, register_policy +from stable_baselines3.common.policies import ( + ActorCriticCnnPolicy, + ActorCriticPolicy, + MultiInputActorCriticPolicy, + register_policy, +) MlpPolicy = ActorCriticPolicy CnnPolicy = ActorCriticCnnPolicy +MultiInputPolicy = MultiInputActorCriticPolicy register_policy("MlpPolicy", ActorCriticPolicy) register_policy("CnnPolicy", ActorCriticCnnPolicy) +register_policy("MultiInputPolicy", MultiInputPolicy) diff --git a/stable_baselines3/sac/__init__.py b/stable_baselines3/sac/__init__.py index 5b0e89900..5a84dde19 100644 --- a/stable_baselines3/sac/__init__.py +++ b/stable_baselines3/sac/__init__.py @@ -1,2 +1,2 @@ -from stable_baselines3.sac.policies import CnnPolicy, MlpPolicy +from stable_baselines3.sac.policies import CnnPolicy, MlpPolicy, MultiInputPolicy from stable_baselines3.sac.sac import SAC diff --git a/stable_baselines3/sac/policies.py b/stable_baselines3/sac/policies.py index 6f73f0325..86d218517 100644 --- a/stable_baselines3/sac/policies.py +++ b/stable_baselines3/sac/policies.py @@ -9,6 +9,7 @@ from stable_baselines3.common.preprocessing import get_action_dim from stable_baselines3.common.torch_layers import ( BaseFeaturesExtractor, + CombinedExtractor, FlattenExtractor, NatureCNN, create_mlp, @@ -255,10 +256,10 @@ def __init__( ) if net_arch is None: - if features_extractor_class == FlattenExtractor: - net_arch = [256, 256] - else: + if features_extractor_class == NatureCNN: net_arch = [] + else: + net_arch = [256, 256] actor_arch, critic_arch = get_actor_critic_arch(net_arch) @@ -435,5 +436,77 @@ def __init__( ) +class MultiInputPolicy(SACPolicy): + """ + Policy class (with both actor and critic) for SAC. + + :param observation_space: Observation space + :param action_space: Action space + :param lr_schedule: Learning rate schedule (could be constant) + :param net_arch: The specification of the policy and value networks. + :param activation_fn: Activation function + :param use_sde: Whether to use State Dependent Exploration or not + :param log_std_init: Initial value for the log standard deviation + :param sde_net_arch: Network architecture for extracting features + when using gSDE. If None, the latent features from the policy will be used. + Pass an empty list to use the states as features. + :param use_expln: Use ``expln()`` function instead of ``exp()`` when using gSDE to ensure + a positive standard deviation (cf paper). It allows to keep variance + above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough. + :param clip_mean: Clip the mean output when using gSDE to avoid numerical instability. + :param features_extractor_class: Features extractor to use. + :param normalize_images: Whether to normalize images or not, + dividing by 255.0 (True by default) + :param optimizer_class: The optimizer to use, + ``th.optim.Adam`` by default + :param optimizer_kwargs: Additional keyword arguments, + excluding the learning rate, to pass to the optimizer + :param n_critics: Number of critic networks to create. + :param share_features_extractor: Whether to share or not the features extractor + between the actor and the critic (this saves computation time) + """ + + def __init__( + self, + observation_space: gym.spaces.Space, + action_space: gym.spaces.Space, + lr_schedule: Schedule, + net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None, + activation_fn: Type[nn.Module] = nn.ReLU, + use_sde: bool = False, + log_std_init: float = -3, + sde_net_arch: Optional[List[int]] = None, + use_expln: bool = False, + clip_mean: float = 2.0, + features_extractor_class: Type[BaseFeaturesExtractor] = CombinedExtractor, + features_extractor_kwargs: Optional[Dict[str, Any]] = None, + normalize_images: bool = True, + optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam, + optimizer_kwargs: Optional[Dict[str, Any]] = None, + n_critics: int = 2, + share_features_extractor: bool = True, + ): + super(MultiInputPolicy, self).__init__( + observation_space, + action_space, + lr_schedule, + net_arch, + activation_fn, + use_sde, + log_std_init, + sde_net_arch, + use_expln, + clip_mean, + features_extractor_class, + features_extractor_kwargs, + normalize_images, + optimizer_class, + optimizer_kwargs, + n_critics, + share_features_extractor, + ) + + register_policy("MlpPolicy", MlpPolicy) register_policy("CnnPolicy", CnnPolicy) +register_policy("MultiInputPolicy", MultiInputPolicy) diff --git a/stable_baselines3/sac/sac.py b/stable_baselines3/sac/sac.py index 73b220c79..a22095575 100644 --- a/stable_baselines3/sac/sac.py +++ b/stable_baselines3/sac/sac.py @@ -6,6 +6,7 @@ from torch.nn import functional as F from stable_baselines3.common import logger +from stable_baselines3.common.buffers import ReplayBuffer from stable_baselines3.common.noise import ActionNoise from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule @@ -44,6 +45,9 @@ class SAC(OffPolicyAlgorithm): during the rollout. :param action_noise: the action noise type (None by default), this can help for hard exploration problem. Cf common.noise for the different action noise type. + :param replay_buffer_class: Replay buffer class to use (for instance ``HerReplayBuffer``). + If ``None``, it will be automatically selected. + :param replay_buffer_kwargs: Keyword arguments to pass to the replay buffer on creation. :param optimize_memory_usage: Enable a memory efficient variant of the replay buffer at a cost of more complexity. See https://github.com/DLR-RM/stable-baselines3/issues/37#issuecomment-637501195 @@ -74,7 +78,7 @@ def __init__( policy: Union[str, Type[SACPolicy]], env: Union[GymEnv, str], learning_rate: Union[float, Schedule] = 3e-4, - buffer_size: int = 1000000, + buffer_size: int = 1000000, # 1e6 learning_starts: int = 100, batch_size: int = 256, tau: float = 0.005, @@ -82,6 +86,8 @@ def __init__( train_freq: Union[int, Tuple[int, str]] = 1, gradient_steps: int = 1, action_noise: Optional[ActionNoise] = None, + replay_buffer_class: Optional[ReplayBuffer] = None, + replay_buffer_kwargs: Optional[Dict[str, Any]] = None, optimize_memory_usage: bool = False, ent_coef: Union[str, float] = "auto", target_update_interval: int = 1, @@ -111,6 +117,8 @@ def __init__( train_freq, gradient_steps, action_noise, + replay_buffer_class=replay_buffer_class, + replay_buffer_kwargs=replay_buffer_kwargs, policy_kwargs=policy_kwargs, tensorboard_log=tensorboard_log, verbose=verbose, diff --git a/stable_baselines3/td3/__init__.py b/stable_baselines3/td3/__init__.py index ed054f0d9..0b903cd2b 100644 --- a/stable_baselines3/td3/__init__.py +++ b/stable_baselines3/td3/__init__.py @@ -1,2 +1,2 @@ -from stable_baselines3.td3.policies import CnnPolicy, MlpPolicy +from stable_baselines3.td3.policies import CnnPolicy, MlpPolicy, MultiInputPolicy from stable_baselines3.td3.td3 import TD3 diff --git a/stable_baselines3/td3/policies.py b/stable_baselines3/td3/policies.py index e90dd6943..1288d7899 100644 --- a/stable_baselines3/td3/policies.py +++ b/stable_baselines3/td3/policies.py @@ -8,6 +8,7 @@ from stable_baselines3.common.preprocessing import get_action_dim from stable_baselines3.common.torch_layers import ( BaseFeaturesExtractor, + CombinedExtractor, FlattenExtractor, NatureCNN, create_mlp, @@ -132,10 +133,10 @@ def __init__( # Default network architecture, from the original paper if net_arch is None: - if features_extractor_class == FlattenExtractor: - net_arch = [400, 300] - else: + if features_extractor_class == NatureCNN: net_arch = [] + else: + net_arch = [400, 300] actor_arch, critic_arch = get_actor_critic_arch(net_arch) @@ -282,5 +283,60 @@ def __init__( ) +class MultiInputPolicy(TD3Policy): + """ + Policy class (with both actor and critic) for TD3 to be used with Dict observation spaces. + + :param observation_space: Observation space + :param action_space: Action space + :param lr_schedule: Learning rate schedule (could be constant) + :param net_arch: The specification of the policy and value networks. + :param activation_fn: Activation function + :param features_extractor_class: Features extractor to use. + :param features_extractor_kwargs: Keyword arguments + to pass to the features extractor. + :param normalize_images: Whether to normalize images or not, + dividing by 255.0 (True by default) + :param optimizer_class: The optimizer to use, + ``th.optim.Adam`` by default + :param optimizer_kwargs: Additional keyword arguments, + excluding the learning rate, to pass to the optimizer + :param n_critics: Number of critic networks to create. + :param share_features_extractor: Whether to share or not the features extractor + between the actor and the critic (this saves computation time) + """ + + def __init__( + self, + observation_space: gym.spaces.Dict, + action_space: gym.spaces.Space, + lr_schedule: Schedule, + net_arch: Optional[Union[List[int], Dict[str, List[int]]]] = None, + activation_fn: Type[nn.Module] = nn.ReLU, + features_extractor_class: Type[BaseFeaturesExtractor] = CombinedExtractor, + features_extractor_kwargs: Optional[Dict[str, Any]] = None, + normalize_images: bool = True, + optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam, + optimizer_kwargs: Optional[Dict[str, Any]] = None, + n_critics: int = 2, + share_features_extractor: bool = True, + ): + super(MultiInputPolicy, self).__init__( + observation_space, + action_space, + lr_schedule, + net_arch, + activation_fn, + features_extractor_class, + features_extractor_kwargs, + normalize_images, + optimizer_class, + optimizer_kwargs, + n_critics, + share_features_extractor, + ) + + register_policy("MlpPolicy", MlpPolicy) register_policy("CnnPolicy", CnnPolicy) +register_policy("MultiInputPolicy", MultiInputPolicy) diff --git a/stable_baselines3/td3/td3.py b/stable_baselines3/td3/td3.py index 6ea068152..2b165c0f5 100644 --- a/stable_baselines3/td3/td3.py +++ b/stable_baselines3/td3/td3.py @@ -6,6 +6,7 @@ from torch.nn import functional as F from stable_baselines3.common import logger +from stable_baselines3.common.buffers import ReplayBuffer from stable_baselines3.common.noise import ActionNoise from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule @@ -39,6 +40,9 @@ class TD3(OffPolicyAlgorithm): during the rollout. :param action_noise: the action noise type (None by default), this can help for hard exploration problem. Cf common.noise for the different action noise type. + :param replay_buffer_class: Replay buffer class to use (for instance ``HerReplayBuffer``). + If ``None``, it will be automatically selected. + :param replay_buffer_kwargs: Keyword arguments to pass to the replay buffer on creation. :param optimize_memory_usage: Enable a memory efficient variant of the replay buffer at a cost of more complexity. See https://github.com/DLR-RM/stable-baselines3/issues/37#issuecomment-637501195 @@ -70,6 +74,8 @@ def __init__( train_freq: Union[int, Tuple[int, str]] = (1, "episode"), gradient_steps: int = -1, action_noise: Optional[ActionNoise] = None, + replay_buffer_class: Optional[ReplayBuffer] = None, + replay_buffer_kwargs: Optional[Dict[str, Any]] = None, optimize_memory_usage: bool = False, policy_delay: int = 2, target_policy_noise: float = 0.2, @@ -96,6 +102,8 @@ def __init__( train_freq, gradient_steps, action_noise=action_noise, + replay_buffer_class=replay_buffer_class, + replay_buffer_kwargs=replay_buffer_kwargs, policy_kwargs=policy_kwargs, tensorboard_log=tensorboard_log, verbose=verbose, diff --git a/stable_baselines3/version.txt b/stable_baselines3/version.txt index c84ce1899..1406d2fc7 100644 --- a/stable_baselines3/version.txt +++ b/stable_baselines3/version.txt @@ -1 +1 @@ -1.1.0a5 +1.1.0a6 diff --git a/tests/test_callbacks.py b/tests/test_callbacks.py index d86a4d62b..48a6f34bd 100644 --- a/tests/test_callbacks.py +++ b/tests/test_callbacks.py @@ -5,8 +5,7 @@ import numpy as np import pytest -from stable_baselines3 import A2C, DDPG, DQN, HER, PPO, SAC, TD3 -from stable_baselines3.common.bit_flipping_env import BitFlippingEnv +from stable_baselines3 import A2C, DDPG, DQN, PPO, SAC, TD3, HerReplayBuffer from stable_baselines3.common.callbacks import ( CallbackList, CheckpointCallback, @@ -16,8 +15,8 @@ StopTrainingOnRewardThreshold, ) from stable_baselines3.common.env_util import make_vec_env +from stable_baselines3.common.envs import BitFlippingEnv from stable_baselines3.common.vec_env import DummyVecEnv -from stable_baselines3.common.vec_env.obs_dict_wrapper import ObsDictWrapper @pytest.mark.parametrize("model_class", [A2C, PPO, SAC, TD3, DQN, DDPG]) @@ -108,12 +107,19 @@ def test_eval_success_logging(tmp_path): env = BitFlippingEnv(n_bits=n_bits) eval_env = DummyVecEnv([lambda: BitFlippingEnv(n_bits=n_bits)]) eval_callback = EvalCallback( - ObsDictWrapper(eval_env), + eval_env, eval_freq=250, log_path=tmp_path, warn=False, ) - model = HER("MlpPolicy", env, DQN, learning_starts=100, seed=0, max_episode_length=n_bits) + model = DQN( + "MultiInputPolicy", + env, + replay_buffer_class=HerReplayBuffer, + learning_starts=100, + seed=0, + replay_buffer_kwargs=dict(max_episode_length=n_bits), + ) model.learn(500, callback=eval_callback) assert len(eval_callback._is_success_buffer) > 0 # More than 50% success rate diff --git a/tests/test_cnn.py b/tests/test_cnn.py index c7d5cc4e1..5cac58edb 100644 --- a/tests/test_cnn.py +++ b/tests/test_cnn.py @@ -7,7 +7,7 @@ from gym import spaces from stable_baselines3 import A2C, DQN, PPO, SAC, TD3 -from stable_baselines3.common.identity_env import FakeImageEnv +from stable_baselines3.common.envs import FakeImageEnv from stable_baselines3.common.preprocessing import is_image_space, is_image_space_channels_first from stable_baselines3.common.utils import zip_strict from stable_baselines3.common.vec_env import VecTransposeImage, is_vecenv_wrapped diff --git a/tests/test_dict_env.py b/tests/test_dict_env.py new file mode 100644 index 000000000..b165180d5 --- /dev/null +++ b/tests/test_dict_env.py @@ -0,0 +1,309 @@ +import gym +import numpy as np +import pytest +from gym import spaces + +from stable_baselines3 import A2C, DDPG, DQN, PPO, SAC, TD3 +from stable_baselines3.common.env_util import make_vec_env +from stable_baselines3.common.envs import BitFlippingEnv, SimpleMultiObsEnv +from stable_baselines3.common.evaluation import evaluate_policy +from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv, VecFrameStack, VecNormalize + + +class DummyDictEnv(gym.Env): + """Custom Environment for testing purposes only""" + + metadata = {"render.modes": ["human"]} + + def __init__( + self, + use_discrete_actions=False, + channel_last=False, + nested_dict_obs=False, + vec_only=False, + ): + super().__init__() + if use_discrete_actions: + self.action_space = spaces.Discrete(3) + else: + self.action_space = spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float32) + N_CHANNELS = 1 + HEIGHT = 64 + WIDTH = 64 + + if channel_last: + obs_shape = (HEIGHT, WIDTH, N_CHANNELS) + else: + obs_shape = (N_CHANNELS, HEIGHT, WIDTH) + + self.observation_space = spaces.Dict( + { + # Image obs + "img": spaces.Box(low=0, high=255, shape=obs_shape, dtype=np.uint8), + # Vector obs + "vec": spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float32), + # Discrete obs + "discrete": spaces.Discrete(4), + } + ) + + # For checking consistency with normal MlpPolicy + if vec_only: + self.observation_space = spaces.Dict( + { + # Vector obs + "vec": spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float32), + } + ) + + if nested_dict_obs: + # Add dictionary observation inside observation space + self.observation_space.spaces["nested-dict"] = spaces.Dict({"nested-dict-discrete": spaces.Discrete(4)}) + + def seed(self, seed=None): + if seed is not None: + self.observation_space.seed(seed) + + def step(self, action): + reward = 0.0 + done = False + return self.observation_space.sample(), reward, done, {} + + def compute_reward(self, achieved_goal, desired_goal, info): + return np.zeros((len(achieved_goal),)) + + def reset(self): + return self.observation_space.sample() + + def render(self, mode="human"): + pass + + +@pytest.mark.parametrize("model_class", [PPO, A2C]) +def test_goal_env(model_class): + env = BitFlippingEnv(n_bits=4) + # check that goal env works for PPO/A2C that cannot use HER replay buffer + model = model_class("MultiInputPolicy", env, n_steps=64).learn(250) + evaluate_policy(model, model.get_env()) + + +@pytest.mark.parametrize("model_class", [PPO, A2C, DQN, DDPG, SAC, TD3]) +def test_consistency(model_class): + """ + Make sure that dict obs with vector only vs using flatten obs is equivalent. + This ensures notable that the network architectures are the same. + """ + use_discrete_actions = model_class == DQN + dict_env = DummyDictEnv(use_discrete_actions=use_discrete_actions, vec_only=True) + dict_env = gym.wrappers.TimeLimit(dict_env, 100) + env = gym.wrappers.FlattenObservation(dict_env) + dict_env.seed(10) + obs = dict_env.reset() + + kwargs = {} + n_steps = 256 + + if model_class in {A2C, PPO}: + kwargs = dict( + n_steps=128, + ) + else: + # Avoid memory error when using replay buffer + # Reduce the size of the features and make learning faster + kwargs = dict( + buffer_size=250, + train_freq=8, + gradient_steps=1, + ) + if model_class == DQN: + kwargs["learning_starts"] = 0 + + dict_model = model_class("MultiInputPolicy", dict_env, gamma=0.5, seed=1, **kwargs) + action_before_learning_1, _ = dict_model.predict(obs, deterministic=True) + dict_model.learn(total_timesteps=n_steps) + + normal_model = model_class("MlpPolicy", env, gamma=0.5, seed=1, **kwargs) + action_before_learning_2, _ = normal_model.predict(obs["vec"], deterministic=True) + normal_model.learn(total_timesteps=n_steps) + + action_1, _ = dict_model.predict(obs, deterministic=True) + action_2, _ = normal_model.predict(obs["vec"], deterministic=True) + + assert np.allclose(action_before_learning_1, action_before_learning_2) + assert np.allclose(action_1, action_2) + + +@pytest.mark.parametrize("model_class", [PPO, A2C, DQN, DDPG, SAC, TD3]) +@pytest.mark.parametrize("channel_last", [False, True]) +def test_dict_spaces(model_class, channel_last): + """ + Additional tests for PPO/A2C/SAC/DDPG/TD3/DQN to check observation space support + with mixed observation. + """ + use_discrete_actions = model_class not in [SAC, TD3, DDPG] + env = DummyDictEnv(use_discrete_actions=use_discrete_actions, channel_last=channel_last) + env = gym.wrappers.TimeLimit(env, 100) + + kwargs = {} + n_steps = 256 + + if model_class in {A2C, PPO}: + kwargs = dict( + n_steps=128, + policy_kwargs=dict( + net_arch=[32], + features_extractor_kwargs=dict(cnn_output_dim=32), + ), + ) + else: + # Avoid memory error when using replay buffer + # Reduce the size of the features and make learning faster + kwargs = dict( + buffer_size=250, + policy_kwargs=dict( + net_arch=[32], + features_extractor_kwargs=dict(cnn_output_dim=32), + ), + train_freq=8, + gradient_steps=1, + ) + if model_class == DQN: + kwargs["learning_starts"] = 0 + + model = model_class("MultiInputPolicy", env, gamma=0.5, seed=1, **kwargs) + + model.learn(total_timesteps=n_steps) + + evaluate_policy(model, env, n_eval_episodes=5, warn=False) + + +@pytest.mark.parametrize("model_class", [PPO, A2C]) +def test_multiprocessing(model_class): + use_discrete_actions = model_class not in [SAC, TD3, DDPG] + + def make_env(): + env = DummyDictEnv(use_discrete_actions=use_discrete_actions, channel_last=False) + env = gym.wrappers.TimeLimit(env, 100) + return env + + env = make_vec_env(make_env, n_envs=2, vec_env_cls=SubprocVecEnv) + + kwargs = {} + n_steps = 256 + + if model_class in {A2C, PPO}: + kwargs = dict( + n_steps=128, + policy_kwargs=dict( + net_arch=[32], + features_extractor_kwargs=dict(cnn_output_dim=32), + ), + ) + + model = model_class("MultiInputPolicy", env, gamma=0.5, seed=1, **kwargs) + + model.learn(total_timesteps=n_steps) + + +@pytest.mark.parametrize("model_class", [PPO, A2C, DQN, DDPG, SAC, TD3]) +@pytest.mark.parametrize("channel_last", [False, True]) +def test_dict_vec_framestack(model_class, channel_last): + """ + Additional tests for PPO/A2C/SAC/DDPG/TD3/DQN to check observation space support + for Dictionary spaces and VecEnvWrapper using MultiInputPolicy. + """ + use_discrete_actions = model_class not in [SAC, TD3, DDPG] + channels_order = {"vec": None, "img": "last" if channel_last else "first"} + env = DummyVecEnv( + [lambda: SimpleMultiObsEnv(random_start=True, discrete_actions=use_discrete_actions, channel_last=channel_last)] + ) + + env = VecFrameStack(env, n_stack=3, channels_order=channels_order) + + kwargs = {} + n_steps = 256 + + if model_class in {A2C, PPO}: + kwargs = dict( + n_steps=128, + policy_kwargs=dict( + net_arch=[32], + features_extractor_kwargs=dict(cnn_output_dim=32), + ), + ) + else: + # Avoid memory error when using replay buffer + # Reduce the size of the features and make learning faster + kwargs = dict( + buffer_size=250, + policy_kwargs=dict( + net_arch=[32], + features_extractor_kwargs=dict(cnn_output_dim=32), + ), + train_freq=8, + gradient_steps=1, + ) + if model_class == DQN: + kwargs["learning_starts"] = 0 + + model = model_class("MultiInputPolicy", env, gamma=0.5, seed=1, **kwargs) + + model.learn(total_timesteps=n_steps) + + evaluate_policy(model, env, n_eval_episodes=5, warn=False) + + +@pytest.mark.parametrize("model_class", [PPO, A2C, DQN, DDPG, SAC, TD3]) +def test_vec_normalize(model_class): + """ + Additional tests for PPO/A2C/SAC/DDPG/TD3/DQN to check observation space support + for GoalEnv and VecNormalize using MultiInputPolicy. + """ + env = DummyVecEnv([lambda: BitFlippingEnv(n_bits=4, continuous=not (model_class == DQN))]) + env = VecNormalize(env) + + kwargs = {} + n_steps = 256 + + if model_class in {A2C, PPO}: + kwargs = dict( + n_steps=128, + policy_kwargs=dict( + net_arch=[32], + ), + ) + else: + # Avoid memory error when using replay buffer + # Reduce the size of the features and make learning faster + kwargs = dict( + buffer_size=250, + policy_kwargs=dict( + net_arch=[32], + ), + train_freq=8, + gradient_steps=1, + ) + if model_class == DQN: + kwargs["learning_starts"] = 0 + + model = model_class("MultiInputPolicy", env, gamma=0.5, seed=1, **kwargs) + + model.learn(total_timesteps=n_steps) + + evaluate_policy(model, env, n_eval_episodes=5, warn=False) + + +def test_dict_nested(): + """ + Make sure we throw an appropiate error with nested Dict observation spaces + """ + # Test without manual wrapping to vec-env + env = DummyDictEnv(nested_dict_obs=True) + + with pytest.raises(NotImplementedError): + _ = PPO("MultiInputPolicy", env, seed=1) + + # Test with manual vec-env wrapping + + with pytest.raises(NotImplementedError): + env = DummyVecEnv([lambda: DummyDictEnv(nested_dict_obs=True)]) diff --git a/tests/test_envs.py b/tests/test_envs.py index e5f6bf3a9..645c17e3f 100644 --- a/tests/test_envs.py +++ b/tests/test_envs.py @@ -1,19 +1,30 @@ +import types + import gym import numpy as np import pytest from gym import spaces -from stable_baselines3.common.bit_flipping_env import BitFlippingEnv from stable_baselines3.common.env_checker import check_env -from stable_baselines3.common.identity_env import ( +from stable_baselines3.common.envs import ( + BitFlippingEnv, FakeImageEnv, IdentityEnv, IdentityEnvBox, IdentityEnvMultiBinary, IdentityEnvMultiDiscrete, + SimpleMultiObsEnv, ) -ENV_CLASSES = [BitFlippingEnv, IdentityEnv, IdentityEnvBox, IdentityEnvMultiBinary, IdentityEnvMultiDiscrete, FakeImageEnv] +ENV_CLASSES = [ + BitFlippingEnv, + IdentityEnv, + IdentityEnvBox, + IdentityEnvMultiBinary, + IdentityEnvMultiDiscrete, + FakeImageEnv, + SimpleMultiObsEnv, +] @pytest.mark.parametrize("env_id", ["CartPole-v0", "Pendulum-v0"]) @@ -39,7 +50,29 @@ def test_env(env_id): @pytest.mark.parametrize("env_class", ENV_CLASSES) def test_custom_envs(env_class): env = env_class() - check_env(env) + with pytest.warns(None) as record: + check_env(env) + # No warnings for custom envs + assert len(record) == 0 + + +@pytest.mark.parametrize( + "kwargs", + [ + dict(continuous=True), + dict(discrete_obs_space=True), + dict(image_obs_space=True, channel_first=True), + dict(image_obs_space=True, channel_first=False), + ], +) +def test_bit_flipping(kwargs): + # Additional tests for BitFlippingEnv + env = BitFlippingEnv(**kwargs) + with pytest.warns(None) as record: + check_env(env) + + # No warnings for custom envs + assert len(record) == 0 def test_high_dimension_action_space(): @@ -72,8 +105,10 @@ def patched_step(_action): spaces.Box(low=-1, high=1, shape=(64, 3), dtype=np.float32), # Tuple space is not supported by SB spaces.Tuple([spaces.Discrete(5), spaces.Discrete(10)]), - # Dict space is not supported by SB when env is not a GoalEnv - spaces.Dict({"position": spaces.Discrete(5)}), + # Nested dict space is not supported by SB3 + spaces.Dict({"position": spaces.Dict({"abs": spaces.Discrete(5), "rel": spaces.Discrete(2)})}), + # Small image inside a dict + spaces.Dict({"img": spaces.Box(low=0, high=255, shape=(32, 32, 3), dtype=np.uint8)}), ], ) def test_non_default_spaces(new_obs_space): @@ -119,6 +154,19 @@ def test_common_failures_reset(): # Return not only the observation check_reset_assert_error(env, (env.observation_space.sample(), False)) + env = SimpleMultiObsEnv() + obs = env.reset() + + def wrong_reset(self): + return {"img": obs["img"], "vec": obs["img"]} + + env.reset = types.MethodType(wrong_reset, env) + with pytest.raises(AssertionError) as excinfo: + check_env(env) + + # Check that the key is explicitly mentioned + assert "vec" in str(excinfo.value) + def check_step_assert_error(env, new_step_return=()): """ @@ -156,3 +204,16 @@ def test_common_failures_step(): # Done is not a boolean check_step_assert_error(env, (env.observation_space.sample(), 0.0, 3.0, {})) check_step_assert_error(env, (env.observation_space.sample(), 0.0, 1, {})) + + env = SimpleMultiObsEnv() + obs = env.reset() + + def wrong_step(self, action): + return {"img": obs["vec"], "vec": obs["vec"]}, 0.0, False, {} + + env.step = types.MethodType(wrong_step, env) + with pytest.raises(AssertionError) as excinfo: + check_env(env) + + # Check that the key is explicitly mentioned + assert "img" in str(excinfo.value) diff --git a/tests/test_her.py b/tests/test_her.py index 5d76d1735..0f6d75f6f 100644 --- a/tests/test_her.py +++ b/tests/test_her.py @@ -8,38 +8,57 @@ import pytest import torch as th -from stable_baselines3 import DDPG, DQN, HER, SAC, TD3 -from stable_baselines3.common.bit_flipping_env import BitFlippingEnv +from stable_baselines3 import DDPG, DQN, SAC, TD3, HerReplayBuffer +from stable_baselines3.common.envs import BitFlippingEnv +from stable_baselines3.common.evaluation import evaluate_policy +from stable_baselines3.common.monitor import Monitor from stable_baselines3.common.noise import NormalActionNoise from stable_baselines3.common.vec_env import DummyVecEnv -from stable_baselines3.common.vec_env.obs_dict_wrapper import ObsDictWrapper from stable_baselines3.her.goal_selection_strategy import GoalSelectionStrategy -from stable_baselines3.her.her import get_time_limit +from stable_baselines3.her.her_replay_buffer import get_time_limit + + +def test_import_error(): + with pytest.raises(ImportError) as excinfo: + from stable_baselines3 import HER + + HER("MlpPolicy") + assert "documentation" in str(excinfo.value) @pytest.mark.parametrize("model_class", [SAC, TD3, DDPG, DQN]) @pytest.mark.parametrize("online_sampling", [True, False]) -def test_her(model_class, online_sampling): +@pytest.mark.parametrize("image_obs_space", [True, False]) +def test_her(model_class, online_sampling, image_obs_space): """ Test Hindsight Experience Replay. """ n_bits = 4 - env = BitFlippingEnv(n_bits=n_bits, continuous=not (model_class == DQN)) + env = BitFlippingEnv( + n_bits=n_bits, + continuous=not (model_class == DQN), + image_obs_space=image_obs_space, + ) - model = HER( - "MlpPolicy", + model = model_class( + "MultiInputPolicy", env, - model_class, - goal_selection_strategy="future", - online_sampling=online_sampling, - gradient_steps=1, + replay_buffer_class=HerReplayBuffer, + replay_buffer_kwargs=dict( + n_sampled_goal=2, + goal_selection_strategy="future", + online_sampling=online_sampling, + max_episode_length=n_bits, + ), train_freq=4, - max_episode_length=n_bits, + gradient_steps=1, policy_kwargs=dict(net_arch=[64]), learning_starts=100, + buffer_size=int(2e4), ) - model.learn(total_timesteps=300) + model.learn(total_timesteps=150) + evaluate_policy(model, Monitor(env)) @pytest.mark.parametrize( @@ -62,21 +81,25 @@ def test_goal_selection_strategy(goal_selection_strategy, online_sampling): normal_action_noise = NormalActionNoise(np.zeros(1), 0.1 * np.ones(1)) - model = HER( - "MlpPolicy", + model = SAC( + "MultiInputPolicy", env, - SAC, - goal_selection_strategy=goal_selection_strategy, - online_sampling=online_sampling, - gradient_steps=1, + replay_buffer_class=HerReplayBuffer, + replay_buffer_kwargs=dict( + goal_selection_strategy=goal_selection_strategy, + online_sampling=online_sampling, + max_episode_length=10, + n_sampled_goal=2, + ), train_freq=4, - max_episode_length=10, + gradient_steps=1, policy_kwargs=dict(net_arch=[64]), learning_starts=100, + buffer_size=int(1e5), action_noise=normal_action_noise, ) assert model.action_noise is not None - model.learn(total_timesteps=300) + model.learn(total_timesteps=150) @pytest.mark.parametrize("model_class", [SAC, TD3, DDPG, DQN]) @@ -95,37 +118,39 @@ def test_save_load(tmp_path, model_class, use_sde, online_sampling): kwargs = dict(use_sde=True) if use_sde else {} # create model - model = HER( - "MlpPolicy", + model = model_class( + "MultiInputPolicy", env, - model_class, - n_sampled_goal=5, - goal_selection_strategy="future", - online_sampling=online_sampling, + replay_buffer_class=HerReplayBuffer, + replay_buffer_kwargs=dict( + n_sampled_goal=2, + goal_selection_strategy="future", + online_sampling=online_sampling, + max_episode_length=n_bits, + ), verbose=0, tau=0.05, batch_size=128, learning_rate=0.001, policy_kwargs=dict(net_arch=[64]), - buffer_size=int(1e6), + buffer_size=int(1e5), gamma=0.98, gradient_steps=1, train_freq=4, learning_starts=100, - max_episode_length=n_bits, **kwargs ) - model.learn(total_timesteps=300) + model.learn(total_timesteps=150) - env.reset() + obs = env.reset() - observations_list = [] + observations = {key: [] for key in obs.keys()} for _ in range(10): obs = env.step(env.action_space.sample())[0] - observation = ObsDictWrapper.convert_dict(obs) - observations_list.append(observation) - observations = np.array(observations_list) + for key in obs.keys(): + observations[key].append(obs[key]) + observations = {key: np.array(obs) for key, obs in observations.items()} # Get dictionary of current parameters params = deepcopy(model.policy.state_dict()) @@ -153,14 +178,14 @@ def test_save_load(tmp_path, model_class, use_sde, online_sampling): # test custom_objects # Load with custom objects custom_objects = dict(learning_rate=2e-5, dummy=1.0) - model_ = HER.load(str(tmp_path / "test_save.zip"), env=env, custom_objects=custom_objects, verbose=2) + model_ = model_class.load(str(tmp_path / "test_save.zip"), env=env, custom_objects=custom_objects, verbose=2) assert model_.verbose == 2 # Check that the custom object was taken into account assert model_.learning_rate == custom_objects["learning_rate"] # Check that only parameters that are here already are replaced assert not hasattr(model_, "dummy") - model = HER.load(str(tmp_path / "test_save.zip"), env=env) + model = model_class.load(str(tmp_path / "test_save.zip"), env=env) # check if params are still the same after load new_params = model.policy.state_dict() @@ -174,18 +199,19 @@ def test_save_load(tmp_path, model_class, use_sde, online_sampling): assert np.allclose(selected_actions, new_selected_actions, 1e-4) # check if learn still works - model.learn(total_timesteps=300) + model.learn(total_timesteps=150) # Test that the change of parameters works - model = HER.load(str(tmp_path / "test_save.zip"), env=env, verbose=3, learning_rate=2.0) - assert model.model.learning_rate == 2.0 + model = model_class.load(str(tmp_path / "test_save.zip"), env=env, verbose=3, learning_rate=2.0) + assert model.learning_rate == 2.0 assert model.verbose == 3 # clear file from os os.remove(tmp_path / "test_save.zip") -@pytest.mark.parametrize("online_sampling, truncate_last_trajectory", [(False, False), (True, True), (True, False)]) +@pytest.mark.parametrize("online_sampling", [False, True]) +@pytest.mark.parametrize("truncate_last_trajectory", [False, True]) def test_save_load_replay_buffer(tmp_path, recwarn, online_sampling, truncate_last_trajectory): """ Test if 'save_replay_buffer' and 'load_replay_buffer' works correctly @@ -194,26 +220,32 @@ def test_save_load_replay_buffer(tmp_path, recwarn, online_sampling, truncate_la warnings.filterwarnings(action="ignore", category=DeprecationWarning) warnings.filterwarnings(action="ignore", category=UserWarning, module="gym") - path = pathlib.Path(tmp_path / "logs/replay_buffer.pkl") + path = pathlib.Path(tmp_path / "replay_buffer.pkl") path.parent.mkdir(exist_ok=True, parents=True) # to not raise a warning env = BitFlippingEnv(n_bits=4, continuous=True) - model = HER( - "MlpPolicy", + model = SAC( + "MultiInputPolicy", env, - SAC, - goal_selection_strategy="future", - online_sampling=online_sampling, + replay_buffer_class=HerReplayBuffer, + replay_buffer_kwargs=dict( + n_sampled_goal=2, + goal_selection_strategy="future", + online_sampling=online_sampling, + max_episode_length=4, + ), gradient_steps=1, train_freq=4, - max_episode_length=4, buffer_size=int(2e4), policy_kwargs=dict(net_arch=[64]), - seed=0, + seed=1, ) model.learn(200) - old_replay_buffer = deepcopy(model.replay_buffer) + if online_sampling: + old_replay_buffer = deepcopy(model.replay_buffer) + else: + old_replay_buffer = deepcopy(model.replay_buffer.replay_buffer) model.save_replay_buffer(path) - del model.model.replay_buffer + del model.replay_buffer with pytest.raises(AttributeError): model.replay_buffer @@ -221,7 +253,7 @@ def test_save_load_replay_buffer(tmp_path, recwarn, online_sampling, truncate_la # Check that there is no warning assert len(recwarn) == 0 - model.load_replay_buffer(path, truncate_last_trajectory) + model.load_replay_buffer(path, truncate_last_traj=truncate_last_trajectory) if truncate_last_trajectory: assert len(recwarn) == 1 @@ -233,29 +265,33 @@ def test_save_load_replay_buffer(tmp_path, recwarn, online_sampling, truncate_la if online_sampling: n_episodes_stored = model.replay_buffer.n_episodes_stored assert np.allclose( - old_replay_buffer.buffer["observation"][:n_episodes_stored], - model.replay_buffer.buffer["observation"][:n_episodes_stored], + old_replay_buffer._buffer["observation"][:n_episodes_stored], + model.replay_buffer._buffer["observation"][:n_episodes_stored], ) assert np.allclose( - old_replay_buffer.buffer["next_obs"][:n_episodes_stored], - model.replay_buffer.buffer["next_obs"][:n_episodes_stored], + old_replay_buffer._buffer["next_obs"][:n_episodes_stored], + model.replay_buffer._buffer["next_obs"][:n_episodes_stored], ) assert np.allclose( - old_replay_buffer.buffer["action"][:n_episodes_stored], model.replay_buffer.buffer["action"][:n_episodes_stored] + old_replay_buffer._buffer["action"][:n_episodes_stored], + model.replay_buffer._buffer["action"][:n_episodes_stored], ) assert np.allclose( - old_replay_buffer.buffer["reward"][:n_episodes_stored], model.replay_buffer.buffer["reward"][:n_episodes_stored] + old_replay_buffer._buffer["reward"][:n_episodes_stored], + model.replay_buffer._buffer["reward"][:n_episodes_stored], ) # we might change the last done of the last trajectory so we don't compare it assert np.allclose( - old_replay_buffer.buffer["done"][: n_episodes_stored - 1], - model.replay_buffer.buffer["done"][: n_episodes_stored - 1], + old_replay_buffer._buffer["done"][: n_episodes_stored - 1], + model.replay_buffer._buffer["done"][: n_episodes_stored - 1], ) else: - assert np.allclose(old_replay_buffer.observations, model.replay_buffer.observations) - assert np.allclose(old_replay_buffer.actions, model.replay_buffer.actions) - assert np.allclose(old_replay_buffer.rewards, model.replay_buffer.rewards) - assert np.allclose(old_replay_buffer.dones, model.replay_buffer.dones) + replay_buffer = model.replay_buffer.replay_buffer + assert np.allclose(old_replay_buffer.observations["observation"], replay_buffer.observations["observation"]) + assert np.allclose(old_replay_buffer.observations["desired_goal"], replay_buffer.observations["desired_goal"]) + assert np.allclose(old_replay_buffer.actions, replay_buffer.actions) + assert np.allclose(old_replay_buffer.rewards, replay_buffer.rewards) + assert np.allclose(old_replay_buffer.dones, replay_buffer.dones) # test if continuing training works properly reset_num_timesteps = False if truncate_last_trajectory is False else True @@ -271,19 +307,23 @@ def test_full_replay_buffer(): env = BitFlippingEnv(n_bits=n_bits, continuous=True) # use small buffer size to get the buffer full - model = HER( - "MlpPolicy", + model = SAC( + "MultiInputPolicy", env, - SAC, - goal_selection_strategy="future", - online_sampling=True, + replay_buffer_class=HerReplayBuffer, + replay_buffer_kwargs=dict( + n_sampled_goal=2, + goal_selection_strategy="future", + online_sampling=True, + max_episode_length=n_bits, + ), gradient_steps=1, train_freq=4, - max_episode_length=n_bits, policy_kwargs=dict(net_arch=[64]), learning_starts=1, buffer_size=20, verbose=1, + seed=757, ) model.learn(total_timesteps=100) @@ -313,15 +353,15 @@ def test_get_max_episode_length(): get_time_limit(vec_env, current_max_episode_length=None) # Initialize HER and specify max_episode_length, should not raise an issue - HER("MlpPolicy", dict_env, DQN, max_episode_length=5) + DQN("MultiInputPolicy", dict_env, replay_buffer_class=HerReplayBuffer, replay_buffer_kwargs=dict(max_episode_length=5)) with pytest.raises(ValueError): - HER("MlpPolicy", dict_env, DQN) + DQN("MultiInputPolicy", dict_env, replay_buffer_class=HerReplayBuffer) # Wrapped in a timelimit, should be fine # Note: it requires env.spec to be defined env = DummyVecEnv([lambda: gym.wrappers.TimeLimit(BitFlippingEnv(), 10)]) - HER("MlpPolicy", env, DQN) + DQN("MultiInputPolicy", env, replay_buffer_class=HerReplayBuffer, replay_buffer_kwargs=dict(max_episode_length=5)) @pytest.mark.parametrize("online_sampling", [False, True]) @@ -333,22 +373,25 @@ def test_performance_her(online_sampling, n_bits): """ env = BitFlippingEnv(n_bits=n_bits, continuous=False) - model = HER( - "MlpPolicy", + model = DQN( + "MultiInputPolicy", env, - DQN, - n_sampled_goal=5, - goal_selection_strategy="future", - online_sampling=online_sampling, + replay_buffer_class=HerReplayBuffer, + replay_buffer_kwargs=dict( + n_sampled_goal=5, + goal_selection_strategy="future", + online_sampling=online_sampling, + max_episode_length=n_bits, + ), verbose=1, learning_rate=5e-4, - max_episode_length=n_bits, train_freq=1, learning_starts=100, exploration_final_eps=0.02, target_update_interval=500, seed=0, batch_size=32, + buffer_size=int(1e5), ) model.learn(total_timesteps=5000, log_interval=50) diff --git a/tests/test_identity.py b/tests/test_identity.py index fdde0d2d4..6226580ac 100644 --- a/tests/test_identity.py +++ b/tests/test_identity.py @@ -2,8 +2,8 @@ import pytest from stable_baselines3 import A2C, DDPG, DQN, PPO, SAC, TD3 +from stable_baselines3.common.envs import IdentityEnv, IdentityEnvBox, IdentityEnvMultiBinary, IdentityEnvMultiDiscrete from stable_baselines3.common.evaluation import evaluate_policy -from stable_baselines3.common.identity_env import IdentityEnv, IdentityEnvBox, IdentityEnvMultiBinary, IdentityEnvMultiDiscrete from stable_baselines3.common.noise import NormalActionNoise from stable_baselines3.common.vec_env import DummyVecEnv diff --git a/tests/test_save_load.py b/tests/test_save_load.py index 8f71f5203..5e4e9e705 100644 --- a/tests/test_save_load.py +++ b/tests/test_save_load.py @@ -12,7 +12,7 @@ from stable_baselines3 import A2C, DDPG, DQN, PPO, SAC, TD3 from stable_baselines3.common.base_class import BaseAlgorithm -from stable_baselines3.common.identity_env import FakeImageEnv, IdentityEnv, IdentityEnvBox +from stable_baselines3.common.envs import FakeImageEnv, IdentityEnv, IdentityEnvBox from stable_baselines3.common.save_util import load_from_pkl, open_path, save_to_pkl from stable_baselines3.common.utils import get_device from stable_baselines3.common.vec_env import DummyVecEnv @@ -289,6 +289,8 @@ def test_save_load_replay_buffer(tmp_path, model_class): assert np.allclose(old_replay_buffer.actions, model.replay_buffer.actions) assert np.allclose(old_replay_buffer.rewards, model.replay_buffer.rewards) assert np.allclose(old_replay_buffer.dones, model.replay_buffer.dones) + assert np.allclose(old_replay_buffer.timeouts, model.replay_buffer.timeouts) + infos = [[{"TimeLimit.truncated": truncated}] for truncated in old_replay_buffer.timeouts] # test extending replay buffer model.replay_buffer.extend( @@ -297,6 +299,7 @@ def test_save_load_replay_buffer(tmp_path, model_class): old_replay_buffer.actions, old_replay_buffer.rewards, old_replay_buffer.dones, + infos, ) diff --git a/tests/test_vec_normalize.py b/tests/test_vec_normalize.py index 46c0c4450..63d4dbfba 100644 --- a/tests/test_vec_normalize.py +++ b/tests/test_vec_normalize.py @@ -3,7 +3,7 @@ import pytest from gym import spaces -from stable_baselines3 import HER, SAC, TD3 +from stable_baselines3 import SAC, TD3, HerReplayBuffer from stable_baselines3.common.monitor import Monitor from stable_baselines3.common.running_mean_std import RunningMeanStd from stable_baselines3.common.vec_env import ( @@ -217,18 +217,39 @@ def test_normalize_external(): assert np.all(norm_rewards < 1) -@pytest.mark.parametrize("model_class", [SAC, TD3, HER]) -def test_offpolicy_normalization(model_class): - make_env_ = make_dict_env if model_class == HER else make_env +@pytest.mark.parametrize("model_class", [SAC, TD3, HerReplayBuffer]) +@pytest.mark.parametrize("online_sampling", [False, True]) +def test_offpolicy_normalization(model_class, online_sampling): + + if online_sampling and model_class != HerReplayBuffer: + pytest.skip() + + make_env_ = make_dict_env if model_class == HerReplayBuffer else make_env env = DummyVecEnv([make_env_]) env = VecNormalize(env, norm_obs=True, norm_reward=True, clip_obs=10.0, clip_reward=10.0) eval_env = DummyVecEnv([make_env_]) eval_env = VecNormalize(eval_env, training=False, norm_obs=True, norm_reward=False, clip_obs=10.0, clip_reward=10.0) - kwargs = dict(model_class=SAC, max_episode_length=200, online_sampling=True) if model_class == HER else {} - model = model_class("MlpPolicy", env, verbose=1, learning_starts=100, policy_kwargs=dict(net_arch=[64]), **kwargs) - model.learn(total_timesteps=500, eval_env=eval_env, eval_freq=250) + if model_class == HerReplayBuffer: + model = SAC( + "MultiInputPolicy", + env, + verbose=1, + learning_starts=100, + policy_kwargs=dict(net_arch=[64]), + replay_buffer_kwargs=dict( + max_episode_length=100, + online_sampling=online_sampling, + n_sampled_goal=2, + ), + replay_buffer_class=HerReplayBuffer, + seed=2, + ) + else: + model = model_class("MlpPolicy", env, verbose=1, learning_starts=100, policy_kwargs=dict(net_arch=[64])) + + model.learn(total_timesteps=150, eval_env=eval_env, eval_freq=75) # Check getter assert isinstance(model.get_vec_normalize_env(), VecNormalize)