Skip to content

Commit

Permalink
Black formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
tristandeleu committed Jul 31, 2021
1 parent 924e96f commit 7c64777
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 10 deletions.
2 changes: 1 addition & 1 deletion gym/vector/tests/test_spaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def test_batch_space_custom_space(space, expected_batch_space_4):


@pytest.mark.parametrize(
"space,batch_space",
"space,batch_space",
list(zip(spaces, expected_batch_spaces_4)),
ids=[space.__class__.__name__ for space in spaces],
)
Expand Down
2 changes: 1 addition & 1 deletion gym/vector/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,5 +17,5 @@
"write_to_shared_memory",
"_BaseGymSpaces",
"batch_space",
"iterate"
"iterate",
]
23 changes: 15 additions & 8 deletions gym/vector/utils/spaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,25 +148,32 @@ def iterate_base(items, space):


def iterate_tuple(items, space):
# If this is a tuple of custome subspaces only, then simply iterate over items
if all(not isinstance(subspace, (_BaseGymSpaces, Tuple, Dict))
for subspace in space.spaces):
# If this is a tuple of custom subspaces only, then simply iterate over items
if all(
not isinstance(subspace, (_BaseGymSpaces, Tuple, Dict))
for subspace in space.spaces
):
return iter(items)

return zip(*[iterate(items[i], subspace)
for i, subspace in enumerate(space.spaces)])
return zip(
*[iterate(items[i], subspace) for i, subspace in enumerate(space.spaces)]
)


def iterate_dict(items, space):
keys, values = zip(*[(key, iterate(items[key], subspace))
for key, subspace in space.spaces.items()])
keys, values = zip(
*[
(key, iterate(items[key], subspace))
for key, subspace in space.spaces.items()
]
)
for item in zip(*values):
yield OrderedDict([(key, value) for (key, value) in zip(keys, item)])


def iterate_custom(items, space):
raise CustomSpaceError(
f"Unable to iterate over {items}, since {space} "
"is a custome `gym.Space` instance (i.e. not one of "
"is a custom `gym.Space` instance (i.e. not one of "
"`Box`, `Dict`, etc...)."
)

0 comments on commit 7c64777

Please sign in to comment.