Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support only new step API (while retaining compatibility functions) #3019

Merged
merged 17 commits into from
Aug 30, 2022
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 8 additions & 37 deletions gym/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,11 +131,7 @@ def np_random(self) -> RandomNumberGenerator:
def np_random(self, value: RandomNumberGenerator):
self._np_random = value

def step(
self, action: ActType
) -> Union[
Tuple[ObsType, float, bool, bool, dict], Tuple[ObsType, float, bool, dict]
]:
def step(self, action: ActType) -> Tuple[ObsType, float, bool, bool, dict]:
arjun-kg marked this conversation as resolved.
Show resolved Hide resolved
"""Run one timestep of the environment's dynamics.

When end of episode is reached, you are responsible for calling :meth:`reset` to reset this environment's state.
Expand Down Expand Up @@ -311,25 +307,18 @@ class Wrapper(Env[ObsType, ActType]):
Don't forget to call ``super().__init__(env)`` if the subclass overrides :meth:`__init__`.
"""

def __init__(self, env: Env, new_step_api: bool = False):
def __init__(self, env: Env):
"""Wraps an environment to allow a modular transformation of the :meth:`step` and :meth:`reset` methods.

Args:
env: The environment to wrap
new_step_api: Whether the wrapper's step method will output in new or old step API
"""
self.env = env

self._action_space: Optional[spaces.Space] = None
self._observation_space: Optional[spaces.Space] = None
self._reward_range: Optional[Tuple[SupportsFloat, SupportsFloat]] = None
self._metadata: Optional[dict] = None
self.new_step_api = new_step_api

if not self.new_step_api:
deprecation(
"Initializing wrapper in old step API which returns one bool instead of two. It is recommended to set `new_step_api=True` to use new step API. This will be the default behaviour in future."
)

def __getattr__(self, name):
"""Returns an attribute with ``name``, unless ``name`` starts with an underscore."""
Expand Down Expand Up @@ -411,17 +400,9 @@ def _np_random(self):
"Can't access `_np_random` of a wrapper, use `.unwrapped._np_random` or `.np_random`."
)

def step(
self, action: ActType
) -> Union[
Tuple[ObsType, float, bool, bool, dict], Tuple[ObsType, float, bool, dict]
]:
def step(self, action: ActType) -> Tuple[ObsType, float, bool, bool, dict]:
"""Steps through the environment with action."""
from gym.utils.step_api_compatibility import ( # avoid circular import
step_api_compatibility,
)

return step_api_compatibility(self.env.step(action), self.new_step_api)
return self.env.step(action)

def reset(self, **kwargs) -> Union[ObsType, Tuple[ObsType, dict]]:
"""Resets the environment with kwargs."""
Expand Down Expand Up @@ -493,13 +474,8 @@ def reset(self, **kwargs):

def step(self, action):
"""Returns a modified observation using :meth:`self.observation` after calling :meth:`env.step`."""
step_returns = self.env.step(action)
if len(step_returns) == 5:
observation, reward, terminated, truncated, info = step_returns
return self.observation(observation), reward, terminated, truncated, info
else:
observation, reward, done, info = step_returns
return self.observation(observation), reward, done, info
observation, reward, terminated, truncated, info = self.env.step(action)
return self.observation(observation), reward, terminated, truncated, info

def observation(self, observation):
"""Returns a modified observation."""
Expand Down Expand Up @@ -532,13 +508,8 @@ def reward(self, reward):

def step(self, action):
"""Modifies the reward using :meth:`self.reward` after the environment :meth:`env.step`."""
step_returns = self.env.step(action)
if len(step_returns) == 5:
observation, reward, terminated, truncated, info = step_returns
return observation, self.reward(reward), terminated, truncated, info
else:
observation, reward, done, info = step_returns
return observation, self.reward(reward), done, info
observation, reward, terminated, truncated, info = self.env.step(action)
return observation, self.reward(reward), terminated, truncated, info

def reward(self, reward):
"""Returns a modified ``reward``."""
Expand Down
15 changes: 9 additions & 6 deletions gym/envs/registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -547,7 +547,7 @@ def make(
id: Union[str, EnvSpec],
max_episode_steps: Optional[int] = None,
autoreset: bool = False,
new_step_api: bool = False,
new_step_api: bool = True,
arjun-kg marked this conversation as resolved.
Show resolved Hide resolved
disable_env_checker: Optional[bool] = None,
**kwargs,
) -> Env:
Expand All @@ -557,7 +557,7 @@ def make(
id: Name of the environment. Optionally, a module to import can be included, eg. 'module:Env-v0'
max_episode_steps: Maximum length of an episode (TimeLimit wrapper).
autoreset: Whether to automatically reset the environment after each episode (AutoResetWrapper).
new_step_api: Whether to use old or new step API (StepAPICompatibility wrapper). Will be removed at v1.0
new_step_api: Whether to use old or new step API (StepAPICompatibility wrapper)
disable_env_checker: If to run the env checker, None will default to the environment specification `disable_env_checker`
(which is by default False, running the environment checker),
otherwise will run according to this parameter (`True` = not run, `False` = run)
Expand Down Expand Up @@ -684,26 +684,29 @@ def make(
):
env = PassiveEnvChecker(env)

env = StepAPICompatibility(env, new_step_api)

# Add the order enforcing wrapper
if spec_.order_enforce:
env = OrderEnforcing(env)

# Add the time limit wrapper
if max_episode_steps is not None:
env = TimeLimit(env, max_episode_steps, new_step_api)
env = TimeLimit(env, max_episode_steps)
elif spec_.max_episode_steps is not None:
env = TimeLimit(env, spec_.max_episode_steps, new_step_api)
env = TimeLimit(env, spec_.max_episode_steps)

# Add the autoreset wrapper
if autoreset:
env = AutoResetWrapper(env, new_step_api)
env = AutoResetWrapper(env)

# Add human rendering wrapper
if apply_human_rendering:
env = HumanRendering(env)

# Add step API wrapper
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it work if the compatibility wrapper is at the end? As far as I understand, the use case here is if someone has a legacy environment, then it would convert it to a new-style environment. But wouldn't one of the wrappers before this crash out if the compatibility is not handled in advance?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(checked now, it doesn't work, at least assuming my understanding is correct)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes I think you are correct, this should occur after the environment checker in order of wrapper

if not new_step_api:
env = StepAPICompatibility(env, new_step_api)

return env


Expand Down
8 changes: 4 additions & 4 deletions gym/utils/step_api_compatibility.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Contains methods for step compatibility, from old-to-new and new-to-old API, to be removed in 1.0."""
"""Contains methods for step compatibility, from old-to-new and new-to-old API"""
from typing import Tuple, Union

import numpy as np
Expand Down Expand Up @@ -149,7 +149,7 @@ def step_to_old_api(

def step_api_compatibility(
step_returns: Union[NewStepType, OldStepType],
new_step_api: bool = False,
new_step_api: bool = True,
is_vector_env: bool = False,
) -> Union[NewStepType, OldStepType]:
"""Function to transform step returns to the API specified by `new_step_api` bool.
Expand All @@ -160,7 +160,7 @@ def step_api_compatibility(

Args:
step_returns (tuple): Items returned by step(). Can be (obs, rew, done, info) or (obs, rew, terminated, truncated, info)
new_step_api (bool): Whether the output should be in new step API or old (False by default)
new_step_api (bool): Whether the output should be in new step API or old (True by default)
is_vector_env (bool): Whether the step_returns are from a vector environment

Returns:
Expand All @@ -170,7 +170,7 @@ def step_api_compatibility(
This function can be used to ensure compatibility in step interfaces with conflicting API. Eg. if env is written in old API,
wrapper is written in new API, and the final step output is desired to be in old API.

>>> obs, rew, done, info = step_api_compatibility(env.step(action))
>>> obs, rew, done, info = step_api_compatibility(env.step(action), new_step_api=False)
>>> obs, rew, terminated, truncated, info = step_api_compatibility(env.step(action), new_step_api=True)
>>> observations, rewards, dones, infos = step_api_compatibility(vec_env.step(action), is_vector_env=True)
"""
Expand Down
9 changes: 1 addition & 8 deletions gym/vector/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ def make(
asynchronous: bool = True,
wrappers: Optional[Union[callable, List[callable]]] = None,
disable_env_checker: Optional[bool] = None,
new_step_api: bool = False,
arjun-kg marked this conversation as resolved.
Show resolved Hide resolved
**kwargs,
) -> VectorEnv:
"""Create a vectorized environment from multiple copies of an environment, from its id.
Expand All @@ -37,7 +36,6 @@ def make(
wrappers: If not ``None``, then apply the wrappers to each internal environment during creation.
disable_env_checker: If to run the env checker for the first environment only. None will default to the environment spec `disable_env_checker` parameter
(that is by default False), otherwise will run according to this argument (True = not run, False = run)
new_step_api: If True, the vector environment's step method outputs two booleans `terminated`, `truncated` instead of one `done`.
**kwargs: Keywords arguments applied during `gym.make`

Returns:
Expand All @@ -53,7 +51,6 @@ def _make_env():
env = gym.envs.registration.make(
id,
disable_env_checker=_disable_env_checker,
new_step_api=True,
**kwargs,
)
if wrappers is not None:
Expand All @@ -73,8 +70,4 @@ def _make_env():
env_fns = [
create_env(disable_env_checker or env_num > 0) for env_num in range(num_envs)
]
return (
AsyncVectorEnv(env_fns, new_step_api=new_step_api)
if asynchronous
else SyncVectorEnv(env_fns, new_step_api=new_step_api)
)
return AsyncVectorEnv(env_fns) if asynchronous else SyncVectorEnv(env_fns)
30 changes: 11 additions & 19 deletions gym/vector/async_vector_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
CustomSpaceError,
NoAsyncCallError,
)
from gym.utils.step_api_compatibility import step_api_compatibility
from gym.vector.utils import (
CloudpickleWrapper,
clear_mpi_env_vars,
Expand Down Expand Up @@ -67,7 +66,6 @@ def __init__(
context: Optional[str] = None,
daemon: bool = True,
worker: Optional[callable] = None,
new_step_api: bool = False,
):
"""Vectorized environment that runs multiple environments in parallel.

Expand All @@ -87,7 +85,6 @@ def __init__(
so for some environments you may want to have it set to ``False``.
worker: If set, then use that worker in a subprocess instead of a default one.
Can be useful to override some inner vector env logic, for instance, how resets on termination or truncation are handled.
new_step_api: If True, step method returns 2 bools - terminated, truncated, instead of 1 bool - done

Warnings: worker is an advanced mode option. It provides a high degree of flexibility and a high chance
to shoot yourself in the foot; thus, if you are writing your own worker, it is recommended to start
Expand Down Expand Up @@ -115,7 +112,6 @@ def __init__(
num_envs=len(env_fns),
observation_space=observation_space,
action_space=action_space,
new_step_api=new_step_api,
)

if self.shared_memory:
Expand Down Expand Up @@ -335,14 +331,14 @@ def step_async(self, actions: np.ndarray):

def step_wait(
self, timeout: Optional[Union[int, float]] = None
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, List[dict]]:
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, List[dict]]:
"""Wait for the calls to :obj:`step` in each sub-environment to finish.

Args:
timeout: Number of seconds before the call to :meth:`step_wait` times out. If ``None``, the call to :meth:`step_wait` never times out.

Returns:
The batched environment step information, (obs, reward, terminated, truncated, info) or (obs, reward, done, info) depending on new_step_api
The batched environment step information, (obs, reward, terminated, truncated, info)

Raises:
ClosedEnvironmentError: If the environment was closed (if :meth:`close` was previously called).
Expand All @@ -366,7 +362,7 @@ def step_wait(
successes = []
for i, pipe in enumerate(self.parent_pipes):
result, success = pipe.recv()
obs, rew, terminated, truncated, info = step_api_compatibility(result, True)
obs, rew, terminated, truncated, info = result

successes.append(success)
observations_list.append(obs)
Expand All @@ -385,16 +381,12 @@ def step_wait(
self.observations,
)

return step_api_compatibility(
(
deepcopy(self.observations) if self.copy else self.observations,
np.array(rewards),
np.array(terminateds, dtype=np.bool_),
np.array(truncateds, dtype=np.bool_),
infos,
),
self.new_step_api,
True,
return (
deepcopy(self.observations) if self.copy else self.observations,
np.array(rewards),
np.array(terminateds, dtype=np.bool_),
np.array(truncateds, dtype=np.bool_),
infos,
)

def call_async(self, name: str, *args, **kwargs):
Expand Down Expand Up @@ -620,7 +612,7 @@ def _worker(index, env_fn, pipe, parent_pipe, shared_memory, error_queue):
terminated,
truncated,
info,
) = step_api_compatibility(env.step(data), True)
) = env.step(data)
if terminated or truncated:
info["final_observation"] = observation
observation = env.reset()
Expand Down Expand Up @@ -695,7 +687,7 @@ def _worker_shared_memory(index, env_fn, pipe, parent_pipe, shared_memory, error
terminated,
truncated,
info,
) = step_api_compatibility(env.step(data), True)
) = env.step(data)
if terminated or truncated:
info["final_observation"] = observation
observation = env.reset()
Expand Down
23 changes: 9 additions & 14 deletions gym/vector/sync_vector_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

from gym import Env
from gym.spaces import Space
from gym.utils.step_api_compatibility import step_api_compatibility
from gym.vector.utils import concatenate, create_empty_array, iterate
from gym.vector.vector_env import VectorEnv

Expand Down Expand Up @@ -34,7 +33,6 @@ def __init__(
observation_space: Space = None,
action_space: Space = None,
copy: bool = True,
new_step_api: bool = False,
):
"""Vectorized environment that serially runs multiple environments.

Expand Down Expand Up @@ -62,7 +60,6 @@ def __init__(
num_envs=len(self.envs),
observation_space=observation_space,
action_space=action_space,
new_step_api=new_step_api,
)

self._check_spaces()
Expand Down Expand Up @@ -156,13 +153,15 @@ def step_wait(self):
"""
observations, infos = [], {}
for i, (env, action) in enumerate(zip(self.envs, self._actions)):

(
observation,
self._rewards[i],
self._terminateds[i],
self._truncateds[i],
info,
) = step_api_compatibility(env.step(action), True)
) = env.step(action)

if self._terminateds[i] or self._truncateds[i]:
info["final_observation"] = observation
observation = env.reset()
Expand All @@ -172,16 +171,12 @@ def step_wait(self):
self.single_observation_space, observations, self.observations
)

return step_api_compatibility(
(
deepcopy(self.observations) if self.copy else self.observations,
np.copy(self._rewards),
np.copy(self._terminateds),
np.copy(self._truncateds),
infos,
),
new_step_api=self.new_step_api,
is_vector_env=True,
return (
deepcopy(self.observations) if self.copy else self.observations,
np.copy(self._rewards),
np.copy(self._terminateds),
np.copy(self._truncateds),
infos,
)

def call(self, name, *args, **kwargs) -> tuple:
Expand Down
Loading