diff --git a/docs/guide/examples.rst b/docs/guide/examples.rst index 1974d43ba..6a0a5737a 100644 --- a/docs/guide/examples.rst +++ b/docs/guide/examples.rst @@ -642,6 +642,30 @@ A2C policy gradient updates on the model. print(f"Best fitness: {top_candidates[0][1]:.2f}") +SB3 and ProcgenEnv +------------------ + +Some environments like `Procgen `_ already produce a vectorized +environment (see discussion in `issue #314 `_). In order to use it with SB3, you must wrap it in a ``VecMonitor`` wrapper which will also allow +to keep track of the agent progress. + +.. code-block:: python + + from procgen import ProcgenEnv + + from stable_baselines3 import PPO + from stable_baselines3.common.vec_env import VecExtractDictObs, VecMonitor + + # ProcgenEnv is already vectorized + venv = ProcgenEnv(num_envs=2, env_name='starpilot') + # PPO does not currently support Dict observations + # this will be solved in https://github.com/DLR-RM/stable-baselines3/pull/243 + venv = VecExtractDictObs(venv, "rgb") + venv = VecMonitor(venv=venv) + + model = PPO("MlpPolicy", venv, verbose=1) + model.learn(10000) + Record a Video -------------- diff --git a/docs/guide/vec_envs.rst b/docs/guide/vec_envs.rst index 7a52a2754..76713efdf 100644 --- a/docs/guide/vec_envs.rst +++ b/docs/guide/vec_envs.rst @@ -27,14 +27,22 @@ SubprocVecEnv ✔️ ✔️ ✔️ ✔️ ✔️ When using vectorized environments, the environments are automatically reset at the end of each episode. Thus, the observation returned for the i-th environment when ``done[i]`` is true will in fact be the first observation of the next episode, not the last observation of the episode that has just terminated. - You can access the "real" final observation of the terminated episode—that is, the one that accompanied the ``done`` event provided by the underlying environment—using the ``terminal_observation`` keys in the info dicts returned by the vecenv. + You can access the "real" final observation of the terminated episode—that is, the one that accompanied the ``done`` event provided by the underlying environment—using the ``terminal_observation`` keys in the info dicts returned by the ``VecEnv``. + .. warning:: - When using ``SubprocVecEnv``, users must wrap the code in an ``if __name__ == "__main__":`` if using the ``forkserver`` or ``spawn`` start method (default on Windows). - On Linux, the default start method is ``fork`` which is not thread safe and can create deadlocks. + When defining a custom ``VecEnv`` (for instance, using gym3 ``ProcgenEnv``), you should provide ``terminal_observation`` keys in the info dicts returned by the ``VecEnv`` + (cf. note above). + + +.. warning:: + + When using ``SubprocVecEnv``, users must wrap the code in an ``if __name__ == "__main__":`` if using the ``forkserver`` or ``spawn`` start method (default on Windows). + On Linux, the default start method is ``fork`` which is not thread safe and can create deadlocks. + + For more information, see Python's `multiprocessing guidelines `_. - For more information, see Python's `multiprocessing guidelines `_. VecEnv ------ @@ -90,3 +98,15 @@ VecTransposeImage .. autoclass:: VecTransposeImage :members: + +VecMonitor +~~~~~~~~~~~~~~~~~ + +.. autoclass:: VecMonitor + :members: + +VecExtractDictObs +~~~~~~~~~~~~~~~~~ + +.. autoclass:: VecExtractDictObs + :members: diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 3e698954b..a7b780f0e 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -4,7 +4,7 @@ Changelog ========== -Release 1.1.0a1 (WIP) +Release 1.1.0a2 (WIP) --------------------------- Breaking Changes: @@ -12,6 +12,11 @@ Breaking Changes: New Features: ^^^^^^^^^^^^^ +- Added `VecMonitor `_ and + `VecExtractDictObs `_ wrappers + to handle gym3-style vectorized environments (@vwxyzjn) +- Ignored the terminal observation if the it is not provided by the environment + such as the gym3-style vectorized environments. (@vwxyzjn) Bug Fixes: ^^^^^^^^^^ @@ -33,6 +38,8 @@ Documentation: - Clarify channel-first/channel-last recommendation - Update sphinx environment installation instructions (@tom-doerr) - Clarify pip installation in Zsh (@tom-doerr) +- Added example for using ``ProcgenEnv`` + Release 1.0 (2021-03-15) ------------------------ @@ -54,6 +61,7 @@ New Features: ^^^^^^^^^^^^^ - Added support for ``custom_objects`` when loading models + Bug Fixes: ^^^^^^^^^^ - Fixed a bug with ``DQN`` predict method when using ``deterministic=False`` with image space @@ -640,5 +648,5 @@ And all the contributors: @flodorner @KuKuXia @NeoExtended @PartiallyTyped @mmcenta @richardwu @kinalmehta @rolandgvc @tkelestemur @mloo3 @tirafesi @blurLake @koulakis @joeljosephjin @shwang @rk37 @andyshih12 @RaphaelWag @xicocaio @diditforlulz273 @liorcohen5 @ManifoldFR @mloo3 @SwamyDev @wmmc88 @megan-klaiber @thisray -@tfederico @hn2 @LucasAlegre @AptX395 @zampanteymedio @decodyng @ardabbour @lorenz-h @mschweizer @lorepieri8 +@tfederico @hn2 @LucasAlegre @AptX395 @zampanteymedio @decodyng @ardabbour @lorenz-h @mschweizer @lorepieri8 @vwxyzjn @ShangqunYu @PierreExeter @JacopoPan @ltbd78 @tom-doerr diff --git a/stable_baselines3/common/evaluation.py b/stable_baselines3/common/evaluation.py index a35200066..6d1febdd9 100644 --- a/stable_baselines3/common/evaluation.py +++ b/stable_baselines3/common/evaluation.py @@ -5,7 +5,7 @@ import numpy as np from stable_baselines3.common import base_class -from stable_baselines3.common.vec_env import VecEnv +from stable_baselines3.common.vec_env import VecEnv, VecMonitor, is_vecenv_wrapped def evaluate_policy( @@ -57,7 +57,7 @@ def evaluate_policy( if isinstance(env, VecEnv): assert env.num_envs == 1, "You must pass only one environment when using this function" - is_monitor_wrapped = env.env_is_wrapped(Monitor)[0] + is_monitor_wrapped = is_vecenv_wrapped(env, VecMonitor) or env.env_is_wrapped(Monitor)[0] else: is_monitor_wrapped = is_wrapped(env, Monitor) diff --git a/stable_baselines3/common/monitor.py b/stable_baselines3/common/monitor.py index bb50fe40b..74e2b9c0a 100644 --- a/stable_baselines3/common/monitor.py +++ b/stable_baselines3/common/monitor.py @@ -1,11 +1,11 @@ -__all__ = ["Monitor", "get_monitor_files", "load_results"] +__all__ = ["Monitor", "ResultsWriter", "get_monitor_files", "load_results"] import csv import json import os import time from glob import glob -from typing import List, Optional, Tuple, Union +from typing import Dict, List, Optional, Tuple, Union import gym import numpy as np @@ -38,27 +38,20 @@ def __init__( ): super(Monitor, self).__init__(env=env) self.t_start = time.time() - if filename is None: - self.file_handler = None - self.logger = None + if filename is not None: + self.results_writer = ResultsWriter( + filename, + header={"t_start": self.t_start, "env_id": env.spec and env.spec.id}, + extra_keys=reset_keywords + info_keywords, + ) else: - if not filename.endswith(Monitor.EXT): - if os.path.isdir(filename): - filename = os.path.join(filename, Monitor.EXT) - else: - filename = filename + "." + Monitor.EXT - self.file_handler = open(filename, "wt") - self.file_handler.write("#%s\n" % json.dumps({"t_start": self.t_start, "env_id": env.spec and env.spec.id})) - self.logger = csv.DictWriter(self.file_handler, fieldnames=("r", "l", "t") + reset_keywords + info_keywords) - self.logger.writeheader() - self.file_handler.flush() - + self.results_writer = None self.reset_keywords = reset_keywords self.info_keywords = info_keywords self.allow_early_resets = allow_early_resets self.rewards = None self.needs_reset = True - self.episode_rewards = [] + self.episode_returns = [] self.episode_lengths = [] self.episode_times = [] self.total_steps = 0 @@ -81,7 +74,7 @@ def reset(self, **kwargs) -> GymObs: for key in self.reset_keywords: value = kwargs.get(key) if value is None: - raise ValueError("Expected you to pass kwarg {} into reset".format(key)) + raise ValueError(f"Expected you to pass keyword argument {key} into reset") self.current_reset_info[key] = value return self.env.reset(**kwargs) @@ -103,13 +96,12 @@ def step(self, action: Union[np.ndarray, int]) -> GymStepReturn: ep_info = {"r": round(ep_rew, 6), "l": ep_len, "t": round(time.time() - self.t_start, 6)} for key in self.info_keywords: ep_info[key] = info[key] - self.episode_rewards.append(ep_rew) + self.episode_returns.append(ep_rew) self.episode_lengths.append(ep_len) self.episode_times.append(time.time() - self.t_start) ep_info.update(self.current_reset_info) - if self.logger: - self.logger.writerow(ep_info) - self.file_handler.flush() + if self.results_writer: + self.results_writer.write_row(ep_info) info["episode"] = ep_info self.total_steps += 1 return observation, reward, done, info @@ -119,8 +111,8 @@ def close(self) -> None: Closes the environment """ super(Monitor, self).close() - if self.file_handler is not None: - self.file_handler.close() + if self.results_writer is not None: + self.results_writer.close() def get_total_steps(self) -> int: """ @@ -136,7 +128,7 @@ def get_episode_rewards(self) -> List[float]: :return: """ - return self.episode_rewards + return self.episode_returns def get_episode_lengths(self) -> List[int]: """ @@ -163,6 +155,52 @@ class LoadMonitorResultsError(Exception): pass +class ResultsWriter: + """ + A result writer that saves the data from the `Monitor` class + + :param filename: the location to save a log file, can be None for no log + :param header: the header dictionary object of the saved csv + :param reset_keywords: the extra information to log, typically is composed of + ``reset_keywords`` and ``info_keywords`` + """ + + def __init__( + self, + filename: str = "", + header: Dict[str, Union[float, str]] = None, + extra_keys: Tuple[str, ...] = (), + ): + if header is None: + header = {} + if not filename.endswith(Monitor.EXT): + if os.path.isdir(filename): + filename = os.path.join(filename, Monitor.EXT) + else: + filename = filename + "." + Monitor.EXT + self.file_handler = open(filename, "wt") + self.file_handler.write("#%s\n" % json.dumps(header)) + self.logger = csv.DictWriter(self.file_handler, fieldnames=("r", "l", "t") + extra_keys) + self.logger.writeheader() + self.file_handler.flush() + + def write_row(self, epinfo: Dict[str, Union[float, int]]) -> None: + """ + Close the file handler + + :param epinfo: the information on episodic return, length, and time + """ + if self.logger: + self.logger.writerow(epinfo) + self.file_handler.flush() + + def close(self) -> None: + """ + Close the file handler + """ + self.file_handler.close() + + def get_monitor_files(path: str) -> List[str]: """ get all the monitor files in the given path diff --git a/stable_baselines3/common/vec_env/__init__.py b/stable_baselines3/common/vec_env/__init__.py index 42f08da6d..8d143cd49 100644 --- a/stable_baselines3/common/vec_env/__init__.py +++ b/stable_baselines3/common/vec_env/__init__.py @@ -7,7 +7,9 @@ from stable_baselines3.common.vec_env.dummy_vec_env import DummyVecEnv from stable_baselines3.common.vec_env.subproc_vec_env import SubprocVecEnv from stable_baselines3.common.vec_env.vec_check_nan import VecCheckNan +from stable_baselines3.common.vec_env.vec_extract_dict_obs import VecExtractDictObs from stable_baselines3.common.vec_env.vec_frame_stack import VecFrameStack +from stable_baselines3.common.vec_env.vec_monitor import VecMonitor from stable_baselines3.common.vec_env.vec_normalize import VecNormalize from stable_baselines3.common.vec_env.vec_transpose import VecTransposeImage from stable_baselines3.common.vec_env.vec_video_recorder import VecVideoRecorder diff --git a/stable_baselines3/common/vec_env/vec_extract_dict_obs.py b/stable_baselines3/common/vec_env/vec_extract_dict_obs.py new file mode 100644 index 000000000..8582b7a30 --- /dev/null +++ b/stable_baselines3/common/vec_env/vec_extract_dict_obs.py @@ -0,0 +1,24 @@ +import numpy as np + +from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvStepReturn, VecEnvWrapper + + +class VecExtractDictObs(VecEnvWrapper): + """ + A vectorized wrapper for extracting dictionary observations. + + :param venv: The vectorized environment + :param key: The key of the dictionary observation + """ + + def __init__(self, venv: VecEnv, key: str): + self.key = key + super().__init__(venv=venv, observation_space=venv.observation_space.spaces[self.key]) + + def reset(self) -> np.ndarray: + obs = self.venv.reset() + return obs[self.key] + + def step_wait(self) -> VecEnvStepReturn: + obs, reward, done, info = self.venv.step_wait() + return obs[self.key], reward, done, info diff --git a/stable_baselines3/common/vec_env/vec_monitor.py b/stable_baselines3/common/vec_env/vec_monitor.py new file mode 100644 index 000000000..61e0748ff --- /dev/null +++ b/stable_baselines3/common/vec_env/vec_monitor.py @@ -0,0 +1,98 @@ +import time +import warnings +from typing import Optional, Tuple + +import numpy as np + +from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvObs, VecEnvStepReturn, VecEnvWrapper + + +class VecMonitor(VecEnvWrapper): + """ + A vectorized monitor wrapper for *vectorized* Gym environments, + it is used to record the episode reward, length, time and other data. + + Some environments like `openai/procgen `_ + or `gym3 `_ directly initialize the + vectorized environments, without giving us a chance to use the ``Monitor`` + wrapper. So this class simply does the job of the ``Monitor`` wrapper on + a vectorized level. + + :param venv: The vectorized environment + :param filename: the location to save a log file, can be None for no log + :param info_keywords: extra information to log, from the information return of env.step() + """ + + def __init__( + self, + venv: VecEnv, + filename: Optional[str] = None, + info_keywords: Tuple[str, ...] = (), + ): + # Avoid circular import + from stable_baselines3.common.monitor import Monitor, ResultsWriter + + # This check is not valid for special `VecEnv` + # like the ones created by Procgen, that does follow completely + # the `VecEnv` interface + try: + is_wrapped_with_monitor = venv.env_is_wrapped(Monitor)[0] + except AttributeError: + is_wrapped_with_monitor = False + + if is_wrapped_with_monitor: + warnings.warn( + "The environment is already wrapped with a `Monitor` wrapper" + "but you are wrapping it with a `VecMonitor` wrapper, the `Monitor` statistics will be" + "overwritten by the `VecMonitor` ones.", + UserWarning, + ) + + VecEnvWrapper.__init__(self, venv) + self.episode_returns = None + self.episode_lengths = None + self.episode_count = 0 + self.t_start = time.time() + + env_id = None + if hasattr(venv, "spec") and venv.spec is not None: + env_id = venv.spec.id + + if filename: + self.results_writer = ResultsWriter( + filename, header={"t_start": self.t_start, "env_id": env_id}, extra_keys=info_keywords + ) + else: + self.results_writer = None + self.info_keywords = info_keywords + + def reset(self) -> VecEnvObs: + obs = self.venv.reset() + self.episode_returns = np.zeros(self.num_envs, dtype=np.float32) + self.episode_lengths = np.zeros(self.num_envs, dtype=np.int32) + return obs + + def step_wait(self) -> VecEnvStepReturn: + obs, rewards, dones, infos = self.venv.step_wait() + self.episode_returns += rewards + self.episode_lengths += 1 + new_infos = list(infos[:]) + for i in range(len(dones)): + if dones[i]: + info = infos[i].copy() + episode_return = self.episode_returns[i] + episode_length = self.episode_lengths[i] + episode_info = {"r": episode_return, "l": episode_length, "t": round(time.time() - self.t_start, 6)} + info["episode"] = episode_info + self.episode_count += 1 + self.episode_returns[i] = 0 + self.episode_lengths[i] = 0 + if self.results_writer: + self.results_writer.write_row(episode_info) + new_infos[i] = info + return obs, rewards, dones, new_infos + + def close(self) -> None: + if self.results_writer: + self.results_writer.close() + return self.venv.close() diff --git a/stable_baselines3/common/vec_env/vec_normalize.py b/stable_baselines3/common/vec_env/vec_normalize.py index 55ed2c54e..f1feeeef3 100644 --- a/stable_baselines3/common/vec_env/vec_normalize.py +++ b/stable_baselines3/common/vec_env/vec_normalize.py @@ -131,7 +131,8 @@ def step_wait(self) -> VecEnvStepReturn: for idx, done in enumerate(dones): if not done: continue - infos[idx]["terminal_observation"] = self.normalize_obs(infos[idx]["terminal_observation"]) + if "terminal_observation" in infos[idx]: + infos[idx]["terminal_observation"] = self.normalize_obs(infos[idx]["terminal_observation"]) self.ret[dones] = 0 return obs, rewards, dones, infos diff --git a/stable_baselines3/common/vec_env/vec_transpose.py b/stable_baselines3/common/vec_env/vec_transpose.py index 04cfde4cc..f40b276ac 100644 --- a/stable_baselines3/common/vec_env/vec_transpose.py +++ b/stable_baselines3/common/vec_env/vec_transpose.py @@ -51,7 +51,8 @@ def step_wait(self) -> VecEnvStepReturn: for idx, done in enumerate(dones): if not done: continue - infos[idx]["terminal_observation"] = self.transpose_image(infos[idx]["terminal_observation"]) + if "terminal_observation" in infos[idx]: + infos[idx]["terminal_observation"] = self.transpose_image(infos[idx]["terminal_observation"]) return self.transpose_image(observations), rewards, dones, infos diff --git a/stable_baselines3/version.txt b/stable_baselines3/version.txt index 6b34d9c0e..c733209cf 100644 --- a/stable_baselines3/version.txt +++ b/stable_baselines3/version.txt @@ -1 +1 @@ -1.1.0a1 +1.1.0a2 diff --git a/tests/test_vec_extract_dict_obs.py b/tests/test_vec_extract_dict_obs.py new file mode 100644 index 000000000..15074425e --- /dev/null +++ b/tests/test_vec_extract_dict_obs.py @@ -0,0 +1,52 @@ +import numpy as np +from gym import spaces + +from stable_baselines3 import PPO +from stable_baselines3.common.vec_env import VecExtractDictObs, VecMonitor + + +class DictObsVecEnv: + """Custom Environment that produces observation in a dictionary like the procgen env""" + + metadata = {"render.modes": ["human"]} + + def __init__(self): + self.num_envs = 4 + self.action_space = spaces.Discrete(2) + self.observation_space = spaces.Dict({"rgb": spaces.Box(low=0.0, high=255.0, shape=(86, 86), dtype=np.float32)}) + + def step_async(self, actions): + self.actions = actions + + def step_wait(self): + return ( + {"rgb": np.zeros((self.num_envs, 86, 86))}, + np.zeros((self.num_envs,)), + np.zeros((self.num_envs,), dtype=bool), + [{} for _ in range(self.num_envs)], + ) + + def reset(self): + return {"rgb": np.zeros((self.num_envs, 86, 86))} + + def render(self, mode="human", close=False): + pass + + +def test_extract_dict_obs(): + """Test VecExtractDictObs""" + + env = DictObsVecEnv() + env = VecExtractDictObs(env, "rgb") + assert env.reset().shape == (4, 86, 86) + + +def test_vec_with_ppo(): + """ + Test the `VecExtractDictObs` with PPO + """ + env = DictObsVecEnv() + env = VecExtractDictObs(env, "rgb") + monitor_env = VecMonitor(env) + model = PPO("MlpPolicy", monitor_env, verbose=1, n_steps=64, device="cpu") + model.learn(total_timesteps=250) diff --git a/tests/test_vec_monitor.py b/tests/test_vec_monitor.py new file mode 100644 index 000000000..585e6906c --- /dev/null +++ b/tests/test_vec_monitor.py @@ -0,0 +1,120 @@ +import json +import os +import uuid + +import gym +import pandas +import pytest + +from stable_baselines3 import PPO +from stable_baselines3.common.evaluation import evaluate_policy +from stable_baselines3.common.monitor import Monitor, get_monitor_files, load_results +from stable_baselines3.common.vec_env import DummyVecEnv, VecMonitor, VecNormalize + + +def test_vec_monitor(tmp_path): + """ + Test the `VecMonitor` wrapper + """ + env = DummyVecEnv([lambda: gym.make("CartPole-v1")]) + env.seed(0) + monitor_file = os.path.join(str(tmp_path), f"stable_baselines-test-{uuid.uuid4()}.monitor.csv") + monitor_env = VecMonitor(env, monitor_file) + monitor_env.reset() + total_steps = 1000 + ep_len, ep_reward = 0, 0 + for _ in range(total_steps): + _, rewards, dones, infos = monitor_env.step([monitor_env.action_space.sample()]) + ep_len += 1 + ep_reward += rewards[0] + if dones[0]: + assert ep_reward == infos[0]["episode"]["r"] + assert ep_len == infos[0]["episode"]["l"] + ep_len, ep_reward = 0, 0 + + monitor_env.close() + + with open(monitor_file, "rt") as file_handler: + first_line = file_handler.readline() + assert first_line.startswith("#") + metadata = json.loads(first_line[1:]) + assert set(metadata.keys()) == {"t_start", "env_id"}, "Incorrect keys in monitor metadata" + + last_logline = pandas.read_csv(file_handler, index_col=None) + assert set(last_logline.keys()) == {"l", "t", "r"}, "Incorrect keys in monitor logline" + os.remove(monitor_file) + + +def test_vec_monitor_load_results(tmp_path): + """ + test load_results on log files produced by the monitor wrapper + """ + tmp_path = str(tmp_path) + env1 = DummyVecEnv([lambda: gym.make("CartPole-v1")]) + env1.seed(0) + monitor_file1 = os.path.join(str(tmp_path), f"stable_baselines-test-{uuid.uuid4()}.monitor.csv") + monitor_env1 = VecMonitor(env1, monitor_file1) + + monitor_files = get_monitor_files(tmp_path) + assert len(monitor_files) == 1 + assert monitor_file1 in monitor_files + + monitor_env1.reset() + episode_count1 = 0 + for _ in range(1000): + _, _, dones, _ = monitor_env1.step([monitor_env1.action_space.sample()]) + if dones[0]: + episode_count1 += 1 + monitor_env1.reset() + + results_size1 = len(load_results(os.path.join(tmp_path)).index) + assert results_size1 == episode_count1 + + env2 = DummyVecEnv([lambda: gym.make("CartPole-v1")]) + env2.seed(0) + monitor_file2 = os.path.join(str(tmp_path), f"stable_baselines-test-{uuid.uuid4()}.monitor.csv") + monitor_env2 = VecMonitor(env2, monitor_file2) + monitor_files = get_monitor_files(tmp_path) + assert len(monitor_files) == 2 + assert monitor_file1 in monitor_files + assert monitor_file2 in monitor_files + + monitor_env2.reset() + episode_count2 = 0 + for _ in range(1000): + _, _, dones, _ = monitor_env2.step([monitor_env2.action_space.sample()]) + if dones[0]: + episode_count2 += 1 + monitor_env2.reset() + + results_size2 = len(load_results(os.path.join(tmp_path)).index) + + assert results_size2 == (results_size1 + episode_count2) + + os.remove(monitor_file1) + os.remove(monitor_file2) + + +def test_vec_monitor_ppo(recwarn): + """ + Test the `VecMonitor` with PPO + """ + env = DummyVecEnv([lambda: gym.make("CartPole-v1")]) + env.seed(0) + monitor_env = VecMonitor(env) + model = PPO("MlpPolicy", monitor_env, verbose=1, n_steps=64, device="cpu") + model.learn(total_timesteps=250) + + # No warnings because using `VecMonitor` + evaluate_policy(model, monitor_env) + assert len(recwarn) == 0 + + +def test_vec_monitor_warn(): + env = DummyVecEnv([lambda: Monitor(gym.make("CartPole-v1"))]) + # We should warn the user when the env is already wrapped with a Monitor wrapper + with pytest.warns(UserWarning): + VecMonitor(env) + + with pytest.warns(UserWarning): + VecMonitor(VecNormalize(env))