Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Batch action_space in VectorEnv #2280

Merged
merged 9 commits into from
Dec 9, 2021
46 changes: 31 additions & 15 deletions gym/vector/async_vector_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
write_to_shared_memory,
read_from_shared_memory,
concatenate,
iterate,
CloudpickleWrapper,
clear_mpi_env_vars,
)
Expand Down Expand Up @@ -185,7 +186,7 @@ def __init__(
child_pipe.close()

self._state = AsyncState.DEFAULT
self._check_observation_spaces()
self._check_spaces()

def seed(self, seeds=None):
self._assert_is_running()
Expand Down Expand Up @@ -307,6 +308,7 @@ def step_async(self, actions):
self._state.value,
)

actions = iterate(self.action_space, actions)
for pipe, action in zip(self.parent_pipes, actions):
pipe.send(("step", action))
self._state = AsyncState.WAITING_STEP
Expand Down Expand Up @@ -439,18 +441,25 @@ def _poll(self, timeout=None):
return False
return True

def _check_observation_spaces(self):
def _check_spaces(self):
self._assert_is_running()
spaces = (self.single_observation_space, self.single_action_space)
for pipe in self.parent_pipes:
pipe.send(("_check_observation_space", self.single_observation_space))
same_spaces, successes = zip(*[pipe.recv() for pipe in self.parent_pipes])
pipe.send(("_check_spaces", spaces))
results, successes = zip(*[pipe.recv() for pipe in self.parent_pipes])
self._raise_if_errors(successes)
if not all(same_spaces):
same_observation_spaces, same_action_spaces = zip(*results)
if not all(same_observation_spaces):
raise RuntimeError(
"Some environments have an observation space different from "
f"`{self.single_observation_space}`. In order to batch observations, "
"the observation spaces from all environments must be equal."
)
if not all(same_action_spaces):
raise RuntimeError(
"Some environments have an observation space "
"different from `{}`. In order to batch observations, the "
"observation spaces from all environments must be "
"equal.".format(self.single_observation_space)
"Some environments have an action space different from "
f"`{self.single_action_space}`. In order to batch actions, the "
"action spaces from all environments must be equal."
)

def _assert_is_running(self):
Expand Down Expand Up @@ -500,13 +509,18 @@ def _worker(index, env_fn, pipe, parent_pipe, shared_memory, error_queue):
elif command == "close":
pipe.send((None, True))
break
elif command == "_check_observation_space":
pipe.send((data == env.observation_space, True))
elif command == "_check_spaces":
pipe.send(
(
(data[0] == env.observation_space, data[1] == env.action_space),
True,
)
)
else:
raise RuntimeError(
"Received unknown command `{0}`. Must "
"be one of {`reset`, `step`, `seed`, `close`, "
"`_check_observation_space`}.".format(command)
"`_check_spaces`}.".format(command)
)
except (KeyboardInterrupt, Exception):
error_queue.put((index,) + sys.exc_info()[:2])
Expand Down Expand Up @@ -544,13 +558,15 @@ def _worker_shared_memory(index, env_fn, pipe, parent_pipe, shared_memory, error
elif command == "close":
pipe.send((None, True))
break
elif command == "_check_observation_space":
pipe.send((data == observation_space, True))
elif command == "_check_spaces":
pipe.send(
((data[0] == observation_space, data[1] == env.action_space), True)
)
else:
raise RuntimeError(
"Received unknown command `{0}`. Must "
"be one of {`reset`, `step`, `seed`, `close`, "
"`_check_observation_space`}.".format(command)
"`_check_spaces`}.".format(command)
)
except (KeyboardInterrupt, Exception):
error_queue.put((index,) + sys.exc_info()[:2])
Expand Down
28 changes: 17 additions & 11 deletions gym/vector/sync_vector_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from gym import logger
from gym.vector.vector_env import VectorEnv
from gym.vector.utils import concatenate, create_empty_array
from gym.vector.utils import concatenate, iterate, create_empty_array

__all__ = ["SyncVectorEnv"]

Expand Down Expand Up @@ -64,7 +64,7 @@ def __init__(self, env_fns, observation_space=None, action_space=None, copy=True
action_space=action_space,
)

self._check_observation_spaces()
self._check_spaces()
self.observations = create_empty_array(
self.single_observation_space, n=self.num_envs, fn=np.zeros
)
Expand Down Expand Up @@ -95,7 +95,7 @@ def reset_wait(self):
return deepcopy(self.observations) if self.copy else self.observations

def step_async(self, actions):
self._actions = actions
self._actions = iterate(self.action_space, actions)

def step_wait(self):
observations, infos = [], []
Expand All @@ -121,15 +121,21 @@ def close_extras(self, **kwargs):
"""Close the environments."""
[env.close() for env in self.envs]

def _check_observation_spaces(self):
def _check_spaces(self):
for env in self.envs:
if not (env.observation_space == self.single_observation_space):
break
raise RuntimeError(
"Some environments have an observation space different from "
f"`{self.single_observation_space}`. In order to batch observations, "
"the observation spaces from all environments must be equal."
)

if not (env.action_space == self.single_action_space):
raise RuntimeError(
"Some environments have an action space different from "
f"`{self.single_action_space}`. In order to batch actions, the "
"action spaces from all environments must be equal."
)

else:
return True
raise RuntimeError(
"Some environments have an observation space "
"different from `{}`. In order to batch observations, the "
"observation spaces from all environments must be "
"equal.".format(self.single_observation_space)
)
3 changes: 2 additions & 1 deletion gym/vector/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
read_from_shared_memory,
write_to_shared_memory,
)
from gym.vector.utils.spaces import _BaseGymSpaces, batch_space
from gym.vector.utils.spaces import _BaseGymSpaces, batch_space, iterate

__all__ = [
"CloudpickleWrapper",
Expand All @@ -17,4 +17,5 @@
"write_to_shared_memory",
"_BaseGymSpaces",
"batch_space",
"iterate",
]
95 changes: 94 additions & 1 deletion gym/vector/utils/spaces.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import numpy as np
from collections import OrderedDict
from functools import singledispatch

from gym.spaces import Space, Box, Discrete, MultiDiscrete, MultiBinary, Tuple, Dict
from gym.error import CustomSpaceError

_BaseGymSpaces = (Box, Discrete, MultiDiscrete, MultiBinary)
__all__ = ["_BaseGymSpaces", "batch_space"]
__all__ = ["_BaseGymSpaces", "batch_space", "iterate"]


def batch_space(space, n=1):
Expand Down Expand Up @@ -86,3 +88,94 @@ def batch_space_dict(space, n=1):

def batch_space_custom(space, n=1):
return Tuple(tuple(space for _ in range(n)))


@singledispatch
def iterate(space, items):
"""Iterate over the elements of a (batched) space.

Parameters
----------
space : `gym.spaces.Space` instance
Space to which `items` belong to.

items : samples of `space`
Items to be iterated over.

Returns
-------
iterator : `Iterable` instance
Iterator over the elements in `items`.

Example
-------
>>> from gym.spaces import Box, Dict
>>> space = Dict({
... 'position': Box(low=0, high=1, shape=(2, 3), dtype=np.float32),
... 'velocity': Box(low=0, high=1, shape=(2, 2), dtype=np.float32)})
>>> items = space.sample()
>>> it = iterate(space, items)
>>> next(it)
{'position': array([-0.99644893, -0.08304597, -0.7238421 ], dtype=float32),
'velocity': array([0.35848552, 0.1533453 ], dtype=float32)}
>>> next(it)
{'position': array([-0.67958736, -0.49076623, 0.38661423], dtype=float32),
'velocity': array([0.7975036 , 0.93317133], dtype=float32)}
>>> next(it)
StopIteration
"""
raise ValueError(
"Space of type `{0}` is not a valid `gym.Space` "
"instance.".format(type(space))
)


@iterate.register(Discrete)
def iterate_discrete(space, items):
raise TypeError("Unable to iterate over a space of type `Discrete`.")


@iterate.register(Box)
@iterate.register(MultiDiscrete)
@iterate.register(MultiBinary)
def iterate_base(space, items):
try:
return iter(items)
except TypeError:
raise TypeError(f"Unable to iterate over the following elements: {items}")


@iterate.register(Tuple)
def iterate_tuple(space, items):
# If this is a tuple of custom subspaces only, then simply iterate over items
if all(
isinstance(subspace, Space)
and (not isinstance(subspace, _BaseGymSpaces + (Tuple, Dict)))
for subspace in space.spaces
):
return iter(items)

return zip(
*[iterate(subspace, items[i]) for i, subspace in enumerate(space.spaces)]
)


@iterate.register(Dict)
def iterate_dict(space, items):
keys, values = zip(
*[
(key, iterate(subspace, items[key]))
for key, subspace in space.spaces.items()
]
)
for item in zip(*values):
yield OrderedDict([(key, value) for (key, value) in zip(keys, item)])


@iterate.register(Space)
def iterate_custom(space, items):
raise CustomSpaceError(
f"Unable to iterate over {items}, since {space} "
"is a custom `gym.Space` instance (i.e. not one of "
"`Box`, `Dict`, etc...)."
)
2 changes: 1 addition & 1 deletion gym/vector/vector_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def __init__(self, num_envs, observation_space, action_space):
self.num_envs = num_envs
self.is_vector_env = True
self.observation_space = batch_space(observation_space, n=num_envs)
self.action_space = Tuple((action_space,) * num_envs)
self.action_space = batch_space(action_space, n=num_envs)

self.closed = False
self.viewer = None
Expand Down
16 changes: 12 additions & 4 deletions tests/vector/test_async_vector_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import numpy as np

from multiprocessing import TimeoutError
from gym.spaces import Box, Tuple
from gym.spaces import Box, Tuple, Discrete, MultiDiscrete
from gym.error import AlreadyPendingCallError, NoAsyncCallError, ClosedEnvironmentError
from tests.vector.utils import (
CustomSpace,
Expand Down Expand Up @@ -48,6 +48,10 @@ def test_step_async_vector_env(shared_memory, use_single_action_space):
try:
env = AsyncVectorEnv(env_fns, shared_memory=shared_memory)
observations = env.reset()

assert isinstance(env.single_action_space, Discrete)
assert isinstance(env.action_space, MultiDiscrete)

if use_single_action_space:
actions = [env.single_action_space.sample() for _ in range(8)]
else:
Expand Down Expand Up @@ -189,10 +193,10 @@ def test_already_closed_async_vector_env(shared_memory):


@pytest.mark.parametrize("shared_memory", [True, False])
def test_check_observations_async_vector_env(shared_memory):
# CubeCrash-v0 - observation_space: Box(40, 32, 3)
def test_check_spaces_async_vector_env(shared_memory):
# CubeCrash-v0 - observation_space: Box(40, 32, 3), action_space: Discrete(3)
env_fns = [make_env("CubeCrash-v0", i) for i in range(8)]
# MemorizeDigits-v0 - observation_space: Box(24, 32, 3)
# MemorizeDigits-v0 - observation_space: Box(24, 32, 3), action_space: Discrete(10)
env_fns[1] = make_env("MemorizeDigits-v0", 1)
with pytest.raises(RuntimeError):
env = AsyncVectorEnv(env_fns, shared_memory=shared_memory)
Expand All @@ -204,6 +208,10 @@ def test_custom_space_async_vector_env():
try:
env = AsyncVectorEnv(env_fns, shared_memory=False)
reset_observations = env.reset()

assert isinstance(env.single_action_space, CustomSpace)
assert isinstance(env.action_space, Tuple)

actions = ("action-2", "action-3", "action-5", "action-7")
step_observations, rewards, dones, _ = env.step(actions)
finally:
Expand Down
28 changes: 27 additions & 1 deletion tests/vector/test_spaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from gym.spaces import Box, MultiDiscrete, Tuple, Dict
from tests.vector.utils import spaces, custom_spaces, CustomSpace

from gym.vector.utils.spaces import batch_space
from gym.vector.utils.spaces import batch_space, iterate

expected_batch_spaces_4 = [
Box(low=-1.0, high=1.0, shape=(4,), dtype=np.float64),
Expand Down Expand Up @@ -103,3 +103,29 @@ def test_batch_space(space, expected_batch_space_4):
def test_batch_space_custom_space(space, expected_batch_space_4):
batch_space_4 = batch_space(space, n=4)
assert batch_space_4 == expected_batch_space_4


@pytest.mark.parametrize(
"space,batch_space",
list(zip(spaces, expected_batch_spaces_4)),
ids=[space.__class__.__name__ for space in spaces],
)
def test_iterate(space, batch_space):
items = batch_space.sample()
iterator = iterate(batch_space, items)
for i, item in enumerate(iterator):
assert item in space
assert i == 3


@pytest.mark.parametrize(
"space,batch_space",
list(zip(custom_spaces, expected_custom_batch_spaces_4)),
ids=[space.__class__.__name__ for space in custom_spaces],
)
def test_iterate_custom_space(space, batch_space):
items = batch_space.sample()
iterator = iterate(batch_space, items)
for i, item in enumerate(iterator):
assert item in space
assert i == 3
Loading