Skip to content

Commit 6996eaa

Browse files
authored
[RLlib] Add necessary fields to Base Envs, and BaseEnv wrapper classes (#20832)
1 parent 8bb9bfe commit 6996eaa

File tree

6 files changed

+192
-21
lines changed

6 files changed

+192
-21
lines changed

rllib/BUILD

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -290,7 +290,7 @@ py_test(
290290
name = "learning_cartpole_simpleq_fake_gpus",
291291
main = "tests/run_regression_tests.py",
292292
tags = ["team:ml", "learning_tests", "learning_tests_cartpole", "learning_tests_discrete", "fake_gpus"],
293-
size = "large",
293+
size = "medium",
294294
srcs = ["tests/run_regression_tests.py"],
295295
data = ["tuned_examples/dqn/cartpole-simpleq-fake-gpus.yaml"],
296296
args = ["--yaml-dir=tuned_examples/dqn"]
@@ -468,7 +468,7 @@ py_test(
468468
py_test(
469469
name = "learning_tests_transformed_actions_pendulum_sac",
470470
main = "tests/run_regression_tests.py",
471-
tags = ["team:ml", "learning_tests", "learning_tests_pendulum", "learning_tests_continuous"],
471+
tags = ["team:ml", "learning_tests", "learning_tests_pendulum", "learning_tests_continuous", "flaky"],
472472
size = "large",
473473
srcs = ["tests/run_regression_tests.py"],
474474
data = ["tuned_examples/sac/pendulum-transformed-actions-sac.yaml"],
@@ -478,7 +478,7 @@ py_test(
478478
py_test(
479479
name = "learning_pendulum_sac_fake_gpus",
480480
main = "tests/run_regression_tests.py",
481-
tags = ["team:ml", "learning_tests", "learning_tests_pendulum", "learning_tests_continuous", "fake_gpus"],
481+
tags = ["team:ml", "learning_tests", "learning_tests_pendulum", "learning_tests_continuous", "fake_gpus", "flaky"],
482482
size = "large",
483483
srcs = ["tests/run_regression_tests.py"],
484484
data = ["tuned_examples/sac/pendulum-sac-fake-gpus.yaml"],

rllib/env/base_env.py

Lines changed: 64 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from typing import Callable, Tuple, Optional, List, Dict, Any, TYPE_CHECKING,\
22
Union
33

4+
import gym
45
import ray
56
from ray.rllib.utils.annotations import Deprecated, override, PublicAPI
67
from ray.rllib.utils.typing import AgentID, EnvID, EnvType, MultiAgentDict, \
@@ -31,12 +32,6 @@ class BaseEnv:
3132
rllib.MultiAgentEnv (is-a gym.Env) => rllib.VectorEnv => rllib.BaseEnv
3233
rllib.ExternalEnv => rllib.BaseEnv
3334
34-
Attributes:
35-
action_space (gym.Space): Action space. This must be defined for
36-
single-agent envs. Multi-agent envs can set this to None.
37-
observation_space (gym.Space): Observation space. This must be defined
38-
for single-agent envs. Multi-agent envs can set this to None.
39-
4035
Examples:
4136
>>> env = MyBaseEnv()
4237
>>> obs, rewards, dones, infos, off_policy_actions = env.poll()
@@ -185,12 +180,18 @@ def try_reset(self, env_id: Optional[EnvID] = None
185180
return None
186181

187182
@PublicAPI
188-
def get_sub_environments(self) -> List[EnvType]:
183+
def get_sub_environments(
184+
self, as_dict: bool = False) -> Union[List[EnvType], dict]:
189185
"""Return a reference to the underlying sub environments, if any.
190186
187+
Args:
188+
as_dict: If True, return a dict mapping from env_id to env.
189+
191190
Returns:
192-
List of the underlying sub environments or [].
191+
List or dictionary of the underlying sub environments or [] / {}.
193192
"""
193+
if as_dict:
194+
return {}
194195
return []
195196

196197
@PublicAPI
@@ -218,6 +219,61 @@ def stop(self) -> None:
218219
def get_unwrapped(self) -> List[EnvType]:
219220
return self.get_sub_environments()
220221

222+
@PublicAPI
223+
@property
224+
def observation_space(self) -> gym.spaces.Dict:
225+
"""Returns the observation space for each environment.
226+
227+
Note: samples from the observation space need to be preprocessed into a
228+
`MultiEnvDict` before being used by a policy.
229+
230+
Returns:
231+
The observation space for each environment.
232+
"""
233+
raise NotImplementedError
234+
235+
@PublicAPI
236+
@property
237+
def action_space(self) -> gym.Space:
238+
"""Returns the action space for each environment.
239+
240+
Note: samples from the action space need to be preprocessed into a
241+
`MultiEnvDict` before being passed to `send_actions`.
242+
243+
Returns:
244+
The observation space for each environment.
245+
"""
246+
raise NotImplementedError
247+
248+
def observation_space_contains(self, x: MultiEnvDict) -> bool:
249+
self._space_contains(self.observation_space, x)
250+
251+
def action_space_contains(self, x: MultiEnvDict) -> bool:
252+
return self._space_contains(self.action_space, x)
253+
254+
@staticmethod
255+
def _space_contains(space: gym.Space, x: MultiEnvDict) -> bool:
256+
"""Check if the given space contains the observations of x.
257+
258+
Args:
259+
space: The space to if x's observations are contained in.
260+
x: The observations to check.
261+
262+
Returns:
263+
True if the observations of x are contained in space.
264+
"""
265+
# this removes the agent_id key and inner dicts
266+
# in MultiEnvDicts
267+
flattened_obs = {
268+
env_id: list(obs.values())
269+
for env_id, obs in x.items()
270+
}
271+
ret = True
272+
for env_id in flattened_obs:
273+
for obs in flattened_obs[env_id]:
274+
ret = ret and space[env_id].contains(obs)
275+
return ret
276+
221277

222278
# Fixed agent identifier when there is only the single agent in the env
223279
_DUMMY_AGENT_ID = "agent0"

rllib/env/external_env.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -337,11 +337,11 @@ def __init__(self,
337337
self.external_env = external_env
338338
self.prep = preprocessor
339339
self.multiagent = issubclass(type(external_env), ExternalMultiAgentEnv)
340-
self.action_space = external_env.action_space
340+
self._action_space = external_env.action_space
341341
if preprocessor:
342-
self.observation_space = preprocessor.observation_space
342+
self._observation_space = preprocessor.observation_space
343343
else:
344-
self.observation_space = external_env.observation_space
344+
self._observation_space = external_env.observation_space
345345
external_env.start()
346346

347347
@override(BaseEnv)
@@ -413,3 +413,15 @@ def fix(d, zero_val):
413413
with_dummy_agent_id(all_dones, "__all__"), \
414414
with_dummy_agent_id(all_infos), \
415415
with_dummy_agent_id(off_policy_actions)
416+
417+
@property
418+
@override(BaseEnv)
419+
@PublicAPI
420+
def observation_space(self) -> gym.spaces.Dict:
421+
return self._observation_space
422+
423+
@property
424+
@override(BaseEnv)
425+
@PublicAPI
426+
def action_space(self) -> gym.Space:
427+
return self._action_space

rllib/env/multi_agent_env.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -336,7 +336,12 @@ def try_reset(self,
336336
return obs
337337

338338
@override(BaseEnv)
339-
def get_sub_environments(self) -> List[EnvType]:
339+
def get_sub_environments(self, as_dict: bool = False) -> List[EnvType]:
340+
if as_dict:
341+
return {
342+
_id: env_state
343+
for _id, env_state in enumerate(self.env_states)
344+
}
340345
return [state.env for state in self.env_states]
341346

342347
@override(BaseEnv)
@@ -346,6 +351,23 @@ def try_render(self, env_id: Optional[EnvID] = None) -> None:
346351
assert isinstance(env_id, int)
347352
return self.envs[env_id].render()
348353

354+
@property
355+
@override(BaseEnv)
356+
@PublicAPI
357+
def observation_space(self) -> gym.spaces.Dict:
358+
space = {
359+
_id: env.observation_space
360+
for _id, env in enumerate(self.envs)
361+
}
362+
return gym.spaces.Dict(space)
363+
364+
@property
365+
@override(BaseEnv)
366+
@PublicAPI
367+
def action_space(self) -> gym.Space:
368+
space = {_id: env.action_space for _id, env in enumerate(self.envs)}
369+
return gym.spaces.Dict(space)
370+
349371

350372
class _MultiAgentEnvState:
351373
def __init__(self, env: MultiAgentEnv):

rllib/env/remote_base_env.py

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import logging
22
from typing import Callable, Dict, List, Optional, Tuple
33

4+
import gym
5+
46
import ray
57
from ray.rllib.env.base_env import BaseEnv, _DUMMY_AGENT_ID, ASYNC_RESET_RETURN
68
from ray.rllib.utils.annotations import override, PublicAPI
@@ -17,6 +19,8 @@ class RemoteBaseEnv(BaseEnv):
1719
from the remote simulator actors. Both single and multi-agent child envs
1820
are supported, and envs can be stepped synchronously or asynchronously.
1921
22+
NOTE: This class implicitly assumes that the remote envs are gym.Env's
23+
2024
You shouldn't need to instantiate this class directly. It's automatically
2125
inserted when you use the `remote_worker_envs=True` option in your
2226
Trainer's config.
@@ -61,6 +65,8 @@ def __init__(self,
6165
# List of ray actor handles (each handle points to one @ray.remote
6266
# sub-environment).
6367
self.actors: Optional[List[ray.actor.ActorHandle]] = None
68+
self._observation_space = None
69+
self._action_space = None
6470
# Dict mapping object refs (return values of @ray.remote calls),
6571
# whose actual values we are waiting for (via ray.wait in
6672
# `self.poll()`) to their corresponding actor handles (the actors
@@ -97,6 +103,10 @@ def make_remote_env(i):
97103
self.actors = [
98104
make_remote_env(i) for i in range(self.num_envs)
99105
]
106+
self._observation_space = ray.get(
107+
self.actors[0].observation_space.remote())
108+
self._action_space = ray.get(
109+
self.actors[0].action_space.remote())
100110

101111
# Lazy initialization. Call `reset()` on all @ray.remote
102112
# sub-environment actors at the beginning.
@@ -199,9 +209,23 @@ def stop(self) -> None:
199209

200210
@override(BaseEnv)
201211
@PublicAPI
202-
def get_sub_environments(self) -> List[EnvType]:
212+
def get_sub_environments(self, as_dict: bool = False) -> List[EnvType]:
213+
if as_dict:
214+
return {env_id: actor for env_id, actor in enumerate(self.actors)}
203215
return self.actors
204216

217+
@property
218+
@override(BaseEnv)
219+
@PublicAPI
220+
def observation_space(self) -> gym.spaces.Dict:
221+
return self._observation_space
222+
223+
@property
224+
@override(BaseEnv)
225+
@PublicAPI
226+
def action_space(self) -> gym.Space:
227+
return self._action_space
228+
205229

206230
@ray.remote(num_cpus=0)
207231
class _RemoteMultiAgentEnv:
@@ -221,6 +245,14 @@ def reset(self):
221245
def step(self, action_dict):
222246
return self.env.step(action_dict)
223247

248+
# defining these 2 functions that way this information can be queried
249+
# with a call to ray.get()
250+
def observation_space(self):
251+
return self.env.observation_space
252+
253+
def action_space(self):
254+
return self.env.action_space
255+
224256

225257
@ray.remote(num_cpus=0)
226258
class _RemoteSingleAgentEnv:
@@ -243,3 +275,11 @@ def step(self, action):
243275
} for x in [obs, rew, done, info]]
244276
done["__all__"] = done[_DUMMY_AGENT_ID]
245277
return obs, rew, done, info
278+
279+
# defining these 2 functions that way this information can be queried
280+
# with a call to ray.get()
281+
def observation_space(self):
282+
return self.env.observation_space
283+
284+
def action_space(self):
285+
return self.env.action_space

rllib/env/vector_env.py

Lines changed: 46 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import logging
22
import gym
33
import numpy as np
4-
from typing import Callable, List, Optional, Tuple
4+
from typing import Callable, List, Optional, Tuple, Union
55

66
from ray.rllib.env.base_env import BaseEnv
77
from ray.rllib.utils.annotations import Deprecated, override, PublicAPI
@@ -265,13 +265,13 @@ class VectorEnvWrapper(BaseEnv):
265265

266266
def __init__(self, vector_env: VectorEnv):
267267
self.vector_env = vector_env
268-
self.action_space = vector_env.action_space
269-
self.observation_space = vector_env.observation_space
270268
self.num_envs = vector_env.num_envs
271269
self.new_obs = None # lazily initialized
272270
self.cur_rewards = [None for _ in range(self.num_envs)]
273271
self.cur_dones = [False for _ in range(self.num_envs)]
274272
self.cur_infos = [None for _ in range(self.num_envs)]
273+
self._observation_space = vector_env.observation_space
274+
self._action_space = vector_env.action_space
275275

276276
@override(BaseEnv)
277277
def poll(self) -> Tuple[MultiEnvDict, MultiEnvDict, MultiEnvDict,
@@ -312,10 +312,51 @@ def try_reset(self, env_id: Optional[EnvID] = None) -> MultiEnvDict:
312312
}
313313

314314
@override(BaseEnv)
315-
def get_sub_environments(self) -> List[EnvType]:
316-
return self.vector_env.get_sub_environments()
315+
def get_sub_environments(
316+
self, as_dict: bool = False) -> Union[List[EnvType], dict]:
317+
if not as_dict:
318+
return self.vector_env.get_sub_environments()
319+
else:
320+
return {
321+
_id: env
322+
for _id, env in enumerate(
323+
self.vector_env.get_sub_environments())
324+
}
317325

318326
@override(BaseEnv)
319327
def try_render(self, env_id: Optional[EnvID] = None) -> None:
320328
assert env_id is None or isinstance(env_id, int)
321329
return self.vector_env.try_render_at(env_id)
330+
331+
@property
332+
@override(BaseEnv)
333+
@PublicAPI
334+
def observation_space(self) -> gym.spaces.Dict:
335+
return self._observation_space
336+
337+
@property
338+
@override(BaseEnv)
339+
@PublicAPI
340+
def action_space(self) -> gym.Space:
341+
return self._action_space
342+
343+
@staticmethod
344+
def _space_contains(space: gym.Space, x: MultiEnvDict) -> bool:
345+
"""Check if the given space contains the observations of x.
346+
347+
Args:
348+
space: The space to if x's observations are contained in.
349+
x: The observations to check.
350+
351+
Note: With vector envs, we can process the raw observations
352+
and ignore the agent ids and env ids, since vector envs'
353+
sub environements are guaranteed to be the same
354+
355+
Returns:
356+
True if the observations of x are contained in space.
357+
"""
358+
for _, multi_agent_dict in x.items():
359+
for _, element in multi_agent_dict.items():
360+
if not space.contains(element):
361+
return False
362+
return True

0 commit comments

Comments
 (0)