-
Notifications
You must be signed in to change notification settings - Fork 1.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support for VecMonitor
for gym3-style environments
#311
Changes from 4 commits
333482b
7553232
287ca33
621e8cf
d698499
e5d4636
dda2436
696af57
65e7a0f
23d0e69
cc6cbc9
aa0e400
1c7bf32
90601fe
4a85d21
99229e2
a64bf95
59781fb
99022b5
fa09feb
b48e956
c07637a
91a2fcd
cfbadbb
52f803b
cfbb5f0
4500ec9
723224b
153afa7
df8e27a
ca56818
fcc8609
e59d82c
6a58cb7
bdd8b19
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
import time | ||
|
||
import numpy as np | ||
|
||
from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvStepReturn, VecEnvWrapper | ||
|
||
|
||
class VecExtractDictObs(VecEnvWrapper): | ||
""" | ||
A vectorized monitor wrapper for extracting dictionary observations. | ||
araffin marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
:param venv: The vectorized environment | ||
:param key: The key of the dictionary observation | ||
""" | ||
|
||
def __init__(self, venv: VecEnv, key: str): | ||
self.key = key | ||
super().__init__(venv=venv, observation_space=venv.observation_space.spaces[self.key]) | ||
|
||
def reset(self) -> np.ndarray: | ||
obs = self.venv.reset() | ||
return obs[self.key] | ||
|
||
def step_wait(self) -> VecEnvStepReturn: | ||
obs, reward, done, info = self.venv.step_wait() | ||
return obs[self.key], reward, done, info |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
import time | ||
|
||
import numpy as np | ||
|
||
from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvObs, VecEnvStepReturn, VecEnvWrapper | ||
|
||
|
||
class VecMonitor(VecEnvWrapper): | ||
""" | ||
A vectorized monitor wrapper for *vectorized* Gym environments, it is used to record the episode reward, length, time and other data. | ||
|
||
Some environments like [`openai/procgen`](https://github.com/openai/procgen) or [`gym3`](https://github.com/openai/gym3) directly | ||
araffin marked this conversation as resolved.
Show resolved
Hide resolved
|
||
initialize the vectorized environments, without giving us a chance to use the `Monitor` wrapper. So this class simply does the job | ||
of the `Monitor` wrapper on a vectorized level. | ||
|
||
As an example, the following two ways of initializing vectorized envs should be equivalent | ||
|
||
```python | ||
from stable_baselines3.common.monitor import Monitor | ||
from stable_baselines3.common.vec_env import DummyVecEnv | ||
import gym | ||
def make_env(gym_id): | ||
def thunk(): | ||
env = gym.make(gym_id, render_mode='rgb_array') | ||
env = Monitor(env) | ||
return env | ||
return thunk | ||
envs = DummyVecEnv([make_env('procgen-starpilot-v0')]) | ||
``` | ||
|
||
```python | ||
from procgen import ProcgenEnv | ||
from stable_baselines3.common.vec_env import VecExtractDictObs, VecMonitor | ||
venv = ProcgenEnv(num_envs=1, env_name='starpilot') | ||
venv = VecExtractDictObs(venv, "rgb") | ||
venv = VecMonitor(venv=venv) | ||
``` | ||
See [here](https://github.com/openai/train-procgen/blob/1a2ae2194a61f76a733a39339530401c024c3ad8/train_procgen/train.py#L36-L43) for a full example. | ||
|
||
:param venv: The vectorized environment | ||
""" | ||
|
||
def __init__(self, venv: VecEnv): | ||
VecEnvWrapper.__init__(self, venv) | ||
self.eprets = None | ||
araffin marked this conversation as resolved.
Show resolved
Hide resolved
|
||
self.eplens = None | ||
self.epcount = 0 | ||
self.tstart = time.time() | ||
|
||
def reset(self) -> VecEnvObs: | ||
obs = self.venv.reset() | ||
self.eprets = np.zeros(self.num_envs, "f") | ||
araffin marked this conversation as resolved.
Show resolved
Hide resolved
|
||
self.eplens = np.zeros(self.num_envs, "i") | ||
return obs | ||
|
||
def step_wait(self) -> VecEnvStepReturn: | ||
obs, rews, dones, infos = self.venv.step_wait() | ||
araffin marked this conversation as resolved.
Show resolved
Hide resolved
|
||
self.eprets += rews | ||
self.eplens += 1 | ||
newinfos = list(infos[:]) | ||
for i in range(len(dones)): | ||
if dones[i]: | ||
info = infos[i].copy() | ||
ret = self.eprets[i] | ||
eplen = self.eplens[i] | ||
epinfo = {"r": ret, "l": eplen, "t": round(time.time() - self.tstart, 6)} | ||
info["episode"] = epinfo | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. you should probably check at the very beginning if a monitor wrapper is already present or not (cf what we do with There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am not sure how to do the check exactly... There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
self.epcount += 1 | ||
self.eprets[i] = 0 | ||
self.eplens[i] = 0 | ||
newinfos[i] = info | ||
return obs, rews, dones, newinfos |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
as a general remark, as mentioned in the PR template and contributing guide, please update the changelog accordingly too