11import logging
22from typing import Callable , Dict , List , Optional , Tuple
33
4+ import gym
5+
46import ray
57from ray .rllib .env .base_env import BaseEnv , _DUMMY_AGENT_ID , ASYNC_RESET_RETURN
68from ray .rllib .utils .annotations import override , PublicAPI
@@ -17,6 +19,8 @@ class RemoteBaseEnv(BaseEnv):
1719 from the remote simulator actors. Both single and multi-agent child envs
1820 are supported, and envs can be stepped synchronously or asynchronously.
1921
22+ NOTE: This class implicitly assumes that the remote envs are gym.Env's
23+
2024 You shouldn't need to instantiate this class directly. It's automatically
2125 inserted when you use the `remote_worker_envs=True` option in your
2226 Trainer's config.
@@ -61,6 +65,8 @@ def __init__(self,
6165 # List of ray actor handles (each handle points to one @ray.remote
6266 # sub-environment).
6367 self .actors : Optional [List [ray .actor .ActorHandle ]] = None
68+ self ._observation_space = None
69+ self ._action_space = None
6470 # Dict mapping object refs (return values of @ray.remote calls),
6571 # whose actual values we are waiting for (via ray.wait in
6672 # `self.poll()`) to their corresponding actor handles (the actors
@@ -97,6 +103,10 @@ def make_remote_env(i):
97103 self .actors = [
98104 make_remote_env (i ) for i in range (self .num_envs )
99105 ]
106+ self ._observation_space = ray .get (
107+ self .actors [0 ].observation_space .remote ())
108+ self ._action_space = ray .get (
109+ self .actors [0 ].action_space .remote ())
100110
101111 # Lazy initialization. Call `reset()` on all @ray.remote
102112 # sub-environment actors at the beginning.
@@ -199,9 +209,23 @@ def stop(self) -> None:
199209
200210 @override (BaseEnv )
201211 @PublicAPI
202- def get_sub_environments (self ) -> List [EnvType ]:
212+ def get_sub_environments (self , as_dict : bool = False ) -> List [EnvType ]:
213+ if as_dict :
214+ return {env_id : actor for env_id , actor in enumerate (self .actors )}
203215 return self .actors
204216
217+ @property
218+ @override (BaseEnv )
219+ @PublicAPI
220+ def observation_space (self ) -> gym .spaces .Dict :
221+ return self ._observation_space
222+
223+ @property
224+ @override (BaseEnv )
225+ @PublicAPI
226+ def action_space (self ) -> gym .Space :
227+ return self ._action_space
228+
205229
206230@ray .remote (num_cpus = 0 )
207231class _RemoteMultiAgentEnv :
@@ -221,6 +245,14 @@ def reset(self):
221245 def step (self , action_dict ):
222246 return self .env .step (action_dict )
223247
248+ # defining these 2 functions that way this information can be queried
249+ # with a call to ray.get()
250+ def observation_space (self ):
251+ return self .env .observation_space
252+
253+ def action_space (self ):
254+ return self .env .action_space
255+
224256
225257@ray .remote (num_cpus = 0 )
226258class _RemoteSingleAgentEnv :
@@ -243,3 +275,11 @@ def step(self, action):
243275 } for x in [obs , rew , done , info ]]
244276 done ["__all__" ] = done [_DUMMY_AGENT_ID ]
245277 return obs , rew , done , info
278+
279+ # defining these 2 functions that way this information can be queried
280+ # with a call to ray.get()
281+ def observation_space (self ):
282+ return self .env .observation_space
283+
284+ def action_space (self ):
285+ return self .env .action_space
0 commit comments