Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Batch action_space in VectorEnv #2280

Merged
merged 9 commits into from
Dec 9, 2021
Prev Previous commit
Next Next commit
Check for same action spaces in vectorized environments
  • Loading branch information
tristandeleu committed Dec 7, 2021
commit 8c0cda7d60f78ec239a0835d8d463296b6a825b4
44 changes: 29 additions & 15 deletions gym/vector/async_vector_env.py
Original file line number Diff line number Diff line change
@@ -186,7 +186,7 @@ def __init__(
child_pipe.close()

self._state = AsyncState.DEFAULT
self._check_observation_spaces()
self._check_spaces()

def seed(self, seeds=None):
self._assert_is_running()
@@ -441,18 +441,25 @@ def _poll(self, timeout=None):
return False
return True

def _check_observation_spaces(self):
def _check_spaces(self):
self._assert_is_running()
spaces = (self.single_observation_space, self.single_action_space)
for pipe in self.parent_pipes:
pipe.send(("_check_observation_space", self.single_observation_space))
same_spaces, successes = zip(*[pipe.recv() for pipe in self.parent_pipes])
pipe.send(("_check_spaces", spaces))
results, successes = zip(*[pipe.recv() for pipe in self.parent_pipes])
self._raise_if_errors(successes)
if not all(same_spaces):
same_observation_spaces, same_action_spaces = zip(*results)
if not all(same_observation_spaces):
raise RuntimeError(
"Some environments have an observation space different from "
f"`{self.single_observation_space}`. In order to batch observations, "
"the observation spaces from all environments must be equal."
)
if not all(same_action_spaces):
raise RuntimeError(
"Some environments have an observation space "
"different from `{}`. In order to batch observations, the "
"observation spaces from all environments must be "
"equal.".format(self.single_observation_space)
"Some environments have an action space different from "
f"`{self.single_action_space}`. In order to batch actions, the "
"action spaces from all environments must be equal."
)

def _assert_is_running(self):
@@ -502,13 +509,18 @@ def _worker(index, env_fn, pipe, parent_pipe, shared_memory, error_queue):
elif command == "close":
pipe.send((None, True))
break
elif command == "_check_observation_space":
pipe.send((data == env.observation_space, True))
elif command == "_check_spaces":
pipe.send(
(
(data[0] == env.observation_space, data[1] == env.action_space),
True,
)
)
else:
raise RuntimeError(
"Received unknown command `{0}`. Must "
"be one of {`reset`, `step`, `seed`, `close`, "
"`_check_observation_space`}.".format(command)
"`_check_spaces`}.".format(command)
)
except (KeyboardInterrupt, Exception):
error_queue.put((index,) + sys.exc_info()[:2])
@@ -546,13 +558,15 @@ def _worker_shared_memory(index, env_fn, pipe, parent_pipe, shared_memory, error
elif command == "close":
pipe.send((None, True))
break
elif command == "_check_observation_space":
pipe.send((data == observation_space, True))
elif command == "_check_spaces":
pipe.send(
((data[0] == observation_space, data[1] == env.action_space), True)
)
else:
raise RuntimeError(
"Received unknown command `{0}`. Must "
"be one of {`reset`, `step`, `seed`, `close`, "
"`_check_observation_space`}.".format(command)
"`_check_spaces`}.".format(command)
)
except (KeyboardInterrupt, Exception):
error_queue.put((index,) + sys.exc_info()[:2])
24 changes: 15 additions & 9 deletions gym/vector/sync_vector_env.py
Original file line number Diff line number Diff line change
@@ -64,7 +64,7 @@ def __init__(self, env_fns, observation_space=None, action_space=None, copy=True
action_space=action_space,
)

self._check_observation_spaces()
self._check_spaces()
self.observations = create_empty_array(
self.single_observation_space, n=self.num_envs, fn=np.zeros
)
@@ -121,15 +121,21 @@ def close_extras(self, **kwargs):
"""Close the environments."""
[env.close() for env in self.envs]

def _check_observation_spaces(self):
def _check_spaces(self):
for env in self.envs:
if not (env.observation_space == self.single_observation_space):
break
raise RuntimeError(
"Some environments have an observation space different from "
f"`{self.single_observation_space}`. In order to batch observations, "
"the observation spaces from all environments must be equal."
)

if not (env.action_space == self.single_action_space):
raise RuntimeError(
"Some environments have an action space different from "
f"`{self.single_action_space}`. In order to batch actions, the "
"action spaces from all environments must be equal."
)

else:
return True
raise RuntimeError(
"Some environments have an observation space "
"different from `{}`. In order to batch observations, the "
"observation spaces from all environments must be "
"equal.".format(self.single_observation_space)
)
6 changes: 3 additions & 3 deletions tests/vector/test_async_vector_env.py
Original file line number Diff line number Diff line change
@@ -193,10 +193,10 @@ def test_already_closed_async_vector_env(shared_memory):


@pytest.mark.parametrize("shared_memory", [True, False])
def test_check_observations_async_vector_env(shared_memory):
# CubeCrash-v0 - observation_space: Box(40, 32, 3)
def test_check_spaces_async_vector_env(shared_memory):
# CubeCrash-v0 - observation_space: Box(40, 32, 3), action_space: Discrete(3)
env_fns = [make_env("CubeCrash-v0", i) for i in range(8)]
# MemorizeDigits-v0 - observation_space: Box(24, 32, 3)
# MemorizeDigits-v0 - observation_space: Box(24, 32, 3), action_space: Discrete(10)
env_fns[1] = make_env("MemorizeDigits-v0", 1)
with pytest.raises(RuntimeError):
env = AsyncVectorEnv(env_fns, shared_memory=shared_memory)
6 changes: 3 additions & 3 deletions tests/vector/test_sync_vector_env.py
Original file line number Diff line number Diff line change
@@ -67,10 +67,10 @@ def test_step_sync_vector_env(use_single_action_space):
assert dones.size == 8


def test_check_observations_sync_vector_env():
# CubeCrash-v0 - observation_space: Box(40, 32, 3)
def test_check_spaces_sync_vector_env():
# CubeCrash-v0 - observation_space: Box(40, 32, 3), action_space: Discrete(3)
env_fns = [make_env("CubeCrash-v0", i) for i in range(8)]
# MemorizeDigits-v0 - observation_space: Box(24, 32, 3)
# MemorizeDigits-v0 - observation_space: Box(24, 32, 3), action_space: Discrete(10)
env_fns[1] = make_env("MemorizeDigits-v0", 1)
with pytest.raises(RuntimeError):
env = SyncVectorEnv(env_fns)