Skip to content

Commit

Permalink
Support for VecMonitor for gym3-style environments (#311)
Browse files Browse the repository at this point in the history
* add vectorized monitor

* auto format of the code

* add documentation and VecExtractDictObs

* refactor and add test cases

* add test cases and format

* avoid circular import and fix doc

* fix type

* fix type

* oops

* Update stable_baselines3/common/monitor.py

Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>

* Update stable_baselines3/common/monitor.py

Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>

* add test cases

* update changelog

* fix mutable argument

* quick fix

* Apply suggestions from code review

* fix terminal observation for gym3 envs

* delete comment

* Update doc and bump version

* Add warning when already using `Monitor` wrapper

* Update vecmonitor tests

* Fixes

Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
  • Loading branch information
vwxyzjn and araffin authored Apr 13, 2021
1 parent 1ed15bf commit ddbe0e9
Show file tree
Hide file tree
Showing 13 changed files with 424 additions and 36 deletions.
24 changes: 24 additions & 0 deletions docs/guide/examples.rst
Original file line number Diff line number Diff line change
Expand Up @@ -642,6 +642,30 @@ A2C policy gradient updates on the model.
print(f"Best fitness: {top_candidates[0][1]:.2f}")
SB3 and ProcgenEnv
------------------

Some environments like `Procgen <https://github.com/openai/procgen>`_ already produce a vectorized
environment (see discussion in `issue #314 <https://github.com/DLR-RM/stable-baselines3/issues/314>`_). In order to use it with SB3, you must wrap it in a ``VecMonitor`` wrapper which will also allow
to keep track of the agent progress.

.. code-block:: python
from procgen import ProcgenEnv
from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import VecExtractDictObs, VecMonitor
# ProcgenEnv is already vectorized
venv = ProcgenEnv(num_envs=2, env_name='starpilot')
# PPO does not currently support Dict observations
# this will be solved in https://github.com/DLR-RM/stable-baselines3/pull/243
venv = VecExtractDictObs(venv, "rgb")
venv = VecMonitor(venv=venv)
model = PPO("MlpPolicy", venv, verbose=1)
model.learn(10000)
Record a Video
--------------
Expand Down
28 changes: 24 additions & 4 deletions docs/guide/vec_envs.rst
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,22 @@ SubprocVecEnv ✔️ ✔️ ✔️ ✔️ ✔️

When using vectorized environments, the environments are automatically reset at the end of each episode.
Thus, the observation returned for the i-th environment when ``done[i]`` is true will in fact be the first observation of the next episode, not the last observation of the episode that has just terminated.
You can access the "real" final observation of the terminated episode—that is, the one that accompanied the ``done`` event provided by the underlying environment—using the ``terminal_observation`` keys in the info dicts returned by the vecenv.
You can access the "real" final observation of the terminated episode—that is, the one that accompanied the ``done`` event provided by the underlying environment—using the ``terminal_observation`` keys in the info dicts returned by the ``VecEnv``.


.. warning::

When using ``SubprocVecEnv``, users must wrap the code in an ``if __name__ == "__main__":`` if using the ``forkserver`` or ``spawn`` start method (default on Windows).
On Linux, the default start method is ``fork`` which is not thread safe and can create deadlocks.
When defining a custom ``VecEnv`` (for instance, using gym3 ``ProcgenEnv``), you should provide ``terminal_observation`` keys in the info dicts returned by the ``VecEnv``
(cf. note above).


.. warning::

When using ``SubprocVecEnv``, users must wrap the code in an ``if __name__ == "__main__":`` if using the ``forkserver`` or ``spawn`` start method (default on Windows).
On Linux, the default start method is ``fork`` which is not thread safe and can create deadlocks.

For more information, see Python's `multiprocessing guidelines <https://docs.python.org/3/library/multiprocessing.html#the-spawn-and-forkserver-start-methods>`_.

For more information, see Python's `multiprocessing guidelines <https://docs.python.org/3/library/multiprocessing.html#the-spawn-and-forkserver-start-methods>`_.

VecEnv
------
Expand Down Expand Up @@ -90,3 +98,15 @@ VecTransposeImage

.. autoclass:: VecTransposeImage
:members:

VecMonitor
~~~~~~~~~~~~~~~~~

.. autoclass:: VecMonitor
:members:

VecExtractDictObs
~~~~~~~~~~~~~~~~~

.. autoclass:: VecExtractDictObs
:members:
12 changes: 10 additions & 2 deletions docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,19 @@ Changelog
==========


Release 1.1.0a1 (WIP)
Release 1.1.0a2 (WIP)
---------------------------

Breaking Changes:
^^^^^^^^^^^^^^^^^

New Features:
^^^^^^^^^^^^^
- Added `VecMonitor <https://github.com/DLR-RM/stable-baselines3/blob/master/stable_baselines3/common/vec_env/vec_monitor.py>`_ and
`VecExtractDictObs <https://github.com/DLR-RM/stable-baselines3/blob/master/stable_baselines3/common/vec_env/vec_extract_dict_obs.py>`_ wrappers
to handle gym3-style vectorized environments (@vwxyzjn)
- Ignored the terminal observation if the it is not provided by the environment
such as the gym3-style vectorized environments. (@vwxyzjn)

Bug Fixes:
^^^^^^^^^^
Expand All @@ -33,6 +38,8 @@ Documentation:
- Clarify channel-first/channel-last recommendation
- Update sphinx environment installation instructions (@tom-doerr)
- Clarify pip installation in Zsh (@tom-doerr)
- Added example for using ``ProcgenEnv``


Release 1.0 (2021-03-15)
------------------------
Expand All @@ -54,6 +61,7 @@ New Features:
^^^^^^^^^^^^^
- Added support for ``custom_objects`` when loading models


Bug Fixes:
^^^^^^^^^^
- Fixed a bug with ``DQN`` predict method when using ``deterministic=False`` with image space
Expand Down Expand Up @@ -640,5 +648,5 @@ And all the contributors:
@flodorner @KuKuXia @NeoExtended @PartiallyTyped @mmcenta @richardwu @kinalmehta @rolandgvc @tkelestemur @mloo3
@tirafesi @blurLake @koulakis @joeljosephjin @shwang @rk37 @andyshih12 @RaphaelWag @xicocaio
@diditforlulz273 @liorcohen5 @ManifoldFR @mloo3 @SwamyDev @wmmc88 @megan-klaiber @thisray
@tfederico @hn2 @LucasAlegre @AptX395 @zampanteymedio @decodyng @ardabbour @lorenz-h @mschweizer @lorepieri8
@tfederico @hn2 @LucasAlegre @AptX395 @zampanteymedio @decodyng @ardabbour @lorenz-h @mschweizer @lorepieri8 @vwxyzjn
@ShangqunYu @PierreExeter @JacopoPan @ltbd78 @tom-doerr
4 changes: 2 additions & 2 deletions stable_baselines3/common/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import numpy as np

from stable_baselines3.common import base_class
from stable_baselines3.common.vec_env import VecEnv
from stable_baselines3.common.vec_env import VecEnv, VecMonitor, is_vecenv_wrapped


def evaluate_policy(
Expand Down Expand Up @@ -57,7 +57,7 @@ def evaluate_policy(

if isinstance(env, VecEnv):
assert env.num_envs == 1, "You must pass only one environment when using this function"
is_monitor_wrapped = env.env_is_wrapped(Monitor)[0]
is_monitor_wrapped = is_vecenv_wrapped(env, VecMonitor) or env.env_is_wrapped(Monitor)[0]
else:
is_monitor_wrapped = is_wrapped(env, Monitor)

Expand Down
88 changes: 63 additions & 25 deletions stable_baselines3/common/monitor.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
__all__ = ["Monitor", "get_monitor_files", "load_results"]
__all__ = ["Monitor", "ResultsWriter", "get_monitor_files", "load_results"]

import csv
import json
import os
import time
from glob import glob
from typing import List, Optional, Tuple, Union
from typing import Dict, List, Optional, Tuple, Union

import gym
import numpy as np
Expand Down Expand Up @@ -38,27 +38,20 @@ def __init__(
):
super(Monitor, self).__init__(env=env)
self.t_start = time.time()
if filename is None:
self.file_handler = None
self.logger = None
if filename is not None:
self.results_writer = ResultsWriter(
filename,
header={"t_start": self.t_start, "env_id": env.spec and env.spec.id},
extra_keys=reset_keywords + info_keywords,
)
else:
if not filename.endswith(Monitor.EXT):
if os.path.isdir(filename):
filename = os.path.join(filename, Monitor.EXT)
else:
filename = filename + "." + Monitor.EXT
self.file_handler = open(filename, "wt")
self.file_handler.write("#%s\n" % json.dumps({"t_start": self.t_start, "env_id": env.spec and env.spec.id}))
self.logger = csv.DictWriter(self.file_handler, fieldnames=("r", "l", "t") + reset_keywords + info_keywords)
self.logger.writeheader()
self.file_handler.flush()

self.results_writer = None
self.reset_keywords = reset_keywords
self.info_keywords = info_keywords
self.allow_early_resets = allow_early_resets
self.rewards = None
self.needs_reset = True
self.episode_rewards = []
self.episode_returns = []
self.episode_lengths = []
self.episode_times = []
self.total_steps = 0
Expand All @@ -81,7 +74,7 @@ def reset(self, **kwargs) -> GymObs:
for key in self.reset_keywords:
value = kwargs.get(key)
if value is None:
raise ValueError("Expected you to pass kwarg {} into reset".format(key))
raise ValueError(f"Expected you to pass keyword argument {key} into reset")
self.current_reset_info[key] = value
return self.env.reset(**kwargs)

Expand All @@ -103,13 +96,12 @@ def step(self, action: Union[np.ndarray, int]) -> GymStepReturn:
ep_info = {"r": round(ep_rew, 6), "l": ep_len, "t": round(time.time() - self.t_start, 6)}
for key in self.info_keywords:
ep_info[key] = info[key]
self.episode_rewards.append(ep_rew)
self.episode_returns.append(ep_rew)
self.episode_lengths.append(ep_len)
self.episode_times.append(time.time() - self.t_start)
ep_info.update(self.current_reset_info)
if self.logger:
self.logger.writerow(ep_info)
self.file_handler.flush()
if self.results_writer:
self.results_writer.write_row(ep_info)
info["episode"] = ep_info
self.total_steps += 1
return observation, reward, done, info
Expand All @@ -119,8 +111,8 @@ def close(self) -> None:
Closes the environment
"""
super(Monitor, self).close()
if self.file_handler is not None:
self.file_handler.close()
if self.results_writer is not None:
self.results_writer.close()

def get_total_steps(self) -> int:
"""
Expand All @@ -136,7 +128,7 @@ def get_episode_rewards(self) -> List[float]:
:return:
"""
return self.episode_rewards
return self.episode_returns

def get_episode_lengths(self) -> List[int]:
"""
Expand All @@ -163,6 +155,52 @@ class LoadMonitorResultsError(Exception):
pass


class ResultsWriter:
"""
A result writer that saves the data from the `Monitor` class
:param filename: the location to save a log file, can be None for no log
:param header: the header dictionary object of the saved csv
:param reset_keywords: the extra information to log, typically is composed of
``reset_keywords`` and ``info_keywords``
"""

def __init__(
self,
filename: str = "",
header: Dict[str, Union[float, str]] = None,
extra_keys: Tuple[str, ...] = (),
):
if header is None:
header = {}
if not filename.endswith(Monitor.EXT):
if os.path.isdir(filename):
filename = os.path.join(filename, Monitor.EXT)
else:
filename = filename + "." + Monitor.EXT
self.file_handler = open(filename, "wt")
self.file_handler.write("#%s\n" % json.dumps(header))
self.logger = csv.DictWriter(self.file_handler, fieldnames=("r", "l", "t") + extra_keys)
self.logger.writeheader()
self.file_handler.flush()

def write_row(self, epinfo: Dict[str, Union[float, int]]) -> None:
"""
Close the file handler
:param epinfo: the information on episodic return, length, and time
"""
if self.logger:
self.logger.writerow(epinfo)
self.file_handler.flush()

def close(self) -> None:
"""
Close the file handler
"""
self.file_handler.close()


def get_monitor_files(path: str) -> List[str]:
"""
get all the monitor files in the given path
Expand Down
2 changes: 2 additions & 0 deletions stable_baselines3/common/vec_env/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@
from stable_baselines3.common.vec_env.dummy_vec_env import DummyVecEnv
from stable_baselines3.common.vec_env.subproc_vec_env import SubprocVecEnv
from stable_baselines3.common.vec_env.vec_check_nan import VecCheckNan
from stable_baselines3.common.vec_env.vec_extract_dict_obs import VecExtractDictObs
from stable_baselines3.common.vec_env.vec_frame_stack import VecFrameStack
from stable_baselines3.common.vec_env.vec_monitor import VecMonitor
from stable_baselines3.common.vec_env.vec_normalize import VecNormalize
from stable_baselines3.common.vec_env.vec_transpose import VecTransposeImage
from stable_baselines3.common.vec_env.vec_video_recorder import VecVideoRecorder
Expand Down
24 changes: 24 additions & 0 deletions stable_baselines3/common/vec_env/vec_extract_dict_obs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import numpy as np

from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvStepReturn, VecEnvWrapper


class VecExtractDictObs(VecEnvWrapper):
"""
A vectorized wrapper for extracting dictionary observations.
:param venv: The vectorized environment
:param key: The key of the dictionary observation
"""

def __init__(self, venv: VecEnv, key: str):
self.key = key
super().__init__(venv=venv, observation_space=venv.observation_space.spaces[self.key])

def reset(self) -> np.ndarray:
obs = self.venv.reset()
return obs[self.key]

def step_wait(self) -> VecEnvStepReturn:
obs, reward, done, info = self.venv.step_wait()
return obs[self.key], reward, done, info
Loading

0 comments on commit ddbe0e9

Please sign in to comment.