-
Notifications
You must be signed in to change notification settings - Fork 1.7k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Support for
VecMonitor
for gym3-style environments (#311)
* add vectorized monitor * auto format of the code * add documentation and VecExtractDictObs * refactor and add test cases * add test cases and format * avoid circular import and fix doc * fix type * fix type * oops * Update stable_baselines3/common/monitor.py Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org> * Update stable_baselines3/common/monitor.py Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org> * add test cases * update changelog * fix mutable argument * quick fix * Apply suggestions from code review * fix terminal observation for gym3 envs * delete comment * Update doc and bump version * Add warning when already using `Monitor` wrapper * Update vecmonitor tests * Fixes Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
- Loading branch information
Showing
13 changed files
with
424 additions
and
36 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
import numpy as np | ||
|
||
from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvStepReturn, VecEnvWrapper | ||
|
||
|
||
class VecExtractDictObs(VecEnvWrapper): | ||
""" | ||
A vectorized wrapper for extracting dictionary observations. | ||
:param venv: The vectorized environment | ||
:param key: The key of the dictionary observation | ||
""" | ||
|
||
def __init__(self, venv: VecEnv, key: str): | ||
self.key = key | ||
super().__init__(venv=venv, observation_space=venv.observation_space.spaces[self.key]) | ||
|
||
def reset(self) -> np.ndarray: | ||
obs = self.venv.reset() | ||
return obs[self.key] | ||
|
||
def step_wait(self) -> VecEnvStepReturn: | ||
obs, reward, done, info = self.venv.step_wait() | ||
return obs[self.key], reward, done, info |
Oops, something went wrong.