-
Notifications
You must be signed in to change notification settings - Fork 8.7k
Support only new step API (while retaining compatibility functions) #3019
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
d412e40
fd427bd
7c26262
515fe55
af9f0e9
9d4de3e
be7fa6f
0bd7157
fb61509
d3db77e
48dad41
21bb6ee
1d5c4df
a001631
0277ca4
f8add0f
7a8f6d4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,7 +16,7 @@ | |
import numpy as np | ||
|
||
from gym import spaces | ||
from gym.logger import deprecation, warn | ||
from gym.logger import warn | ||
from gym.utils import seeding | ||
|
||
if TYPE_CHECKING: | ||
|
@@ -83,16 +83,11 @@ def np_random(self) -> np.random.Generator: | |
def np_random(self, value: np.random.Generator): | ||
self._np_random = value | ||
|
||
def step( | ||
self, action: ActType | ||
) -> Union[ | ||
Tuple[ObsType, float, bool, bool, dict], Tuple[ObsType, float, bool, dict] | ||
]: | ||
def step(self, action: ActType) -> Tuple[ObsType, float, bool, bool, dict]: | ||
"""Run one timestep of the environment's dynamics. | ||
|
||
When end of episode is reached, you are responsible for calling :meth:`reset` to reset this environment's state. | ||
Accepts an action and returns either a tuple `(observation, reward, terminated, truncated, info)`, or a tuple | ||
(observation, reward, done, info). The latter is deprecated and will be removed in future versions. | ||
Accepts an action and returns either a tuple `(observation, reward, terminated, truncated, info)`. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Remove "either" |
||
|
||
Args: | ||
action (ActType): an action provided by the agent | ||
|
@@ -226,25 +221,18 @@ class Wrapper(Env[ObsType, ActType]): | |
Don't forget to call ``super().__init__(env)`` if the subclass overrides :meth:`__init__`. | ||
""" | ||
|
||
def __init__(self, env: Env, new_step_api: bool = False): | ||
def __init__(self, env: Env): | ||
"""Wraps an environment to allow a modular transformation of the :meth:`step` and :meth:`reset` methods. | ||
|
||
Args: | ||
env: The environment to wrap | ||
new_step_api: Whether the wrapper's step method will output in new or old step API | ||
""" | ||
self.env = env | ||
|
||
self._action_space: Optional[spaces.Space] = None | ||
self._observation_space: Optional[spaces.Space] = None | ||
self._reward_range: Optional[Tuple[SupportsFloat, SupportsFloat]] = None | ||
self._metadata: Optional[dict] = None | ||
self.new_step_api = new_step_api | ||
|
||
if not self.new_step_api: | ||
deprecation( | ||
"Initializing wrapper in old step API which returns one bool instead of two. It is recommended to set `new_step_api=True` to use new step API. This will be the default behaviour in future." | ||
) | ||
|
||
def __getattr__(self, name): | ||
"""Returns an attribute with ``name``, unless ``name`` starts with an underscore.""" | ||
|
@@ -326,17 +314,9 @@ def _np_random(self): | |
"Can't access `_np_random` of a wrapper, use `.unwrapped._np_random` or `.np_random`." | ||
) | ||
|
||
def step( | ||
self, action: ActType | ||
) -> Union[ | ||
Tuple[ObsType, float, bool, bool, dict], Tuple[ObsType, float, bool, dict] | ||
]: | ||
def step(self, action: ActType) -> Tuple[ObsType, float, bool, bool, dict]: | ||
"""Steps through the environment with action.""" | ||
from gym.utils.step_api_compatibility import ( # avoid circular import | ||
step_api_compatibility, | ||
) | ||
|
||
return step_api_compatibility(self.env.step(action), self.new_step_api) | ||
return self.env.step(action) | ||
|
||
def reset(self, **kwargs) -> Tuple[ObsType, dict]: | ||
"""Resets the environment with kwargs.""" | ||
|
@@ -401,13 +381,8 @@ def reset(self, **kwargs): | |
|
||
def step(self, action): | ||
"""Returns a modified observation using :meth:`self.observation` after calling :meth:`env.step`.""" | ||
step_returns = self.env.step(action) | ||
if len(step_returns) == 5: | ||
observation, reward, terminated, truncated, info = step_returns | ||
return self.observation(observation), reward, terminated, truncated, info | ||
else: | ||
observation, reward, done, info = step_returns | ||
return self.observation(observation), reward, done, info | ||
observation, reward, terminated, truncated, info = self.env.step(action) | ||
return self.observation(observation), reward, terminated, truncated, info | ||
|
||
def observation(self, observation): | ||
"""Returns a modified observation.""" | ||
|
@@ -440,13 +415,8 @@ def reward(self, reward): | |
|
||
def step(self, action): | ||
"""Modifies the reward using :meth:`self.reward` after the environment :meth:`env.step`.""" | ||
step_returns = self.env.step(action) | ||
if len(step_returns) == 5: | ||
observation, reward, terminated, truncated, info = step_returns | ||
return observation, self.reward(reward), terminated, truncated, info | ||
else: | ||
observation, reward, done, info = step_returns | ||
return observation, self.reward(reward), done, info | ||
observation, reward, terminated, truncated, info = self.env.step(action) | ||
return observation, self.reward(reward), terminated, truncated, info | ||
|
||
def reward(self, reward): | ||
"""Returns a modified ``reward``.""" | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -140,7 +140,7 @@ class EnvSpec: | |
order_enforce: bool = field(default=True) | ||
autoreset: bool = field(default=False) | ||
disable_env_checker: bool = field(default=False) | ||
new_step_api: bool = field(default=False) | ||
apply_step_compatibility: bool = field(default=False) | ||
|
||
# Environment arguments | ||
kwargs: dict = field(default_factory=dict) | ||
|
@@ -547,7 +547,7 @@ def make( | |
id: Union[str, EnvSpec], | ||
max_episode_steps: Optional[int] = None, | ||
autoreset: bool = False, | ||
new_step_api: bool = False, | ||
apply_step_compatibility: bool = False, | ||
disable_env_checker: Optional[bool] = None, | ||
**kwargs, | ||
) -> Env: | ||
|
@@ -557,7 +557,7 @@ def make( | |
id: Name of the environment. Optionally, a module to import can be included, eg. 'module:Env-v0' | ||
max_episode_steps: Maximum length of an episode (TimeLimit wrapper). | ||
autoreset: Whether to automatically reset the environment after each episode (AutoResetWrapper). | ||
new_step_api: Whether to use old or new step API (StepAPICompatibility wrapper). Will be removed at v1.0 | ||
apply_step_compatibility: Whether to use apply compatibility wrapper that converts step method to return two bools (StepAPICompatibility wrapper) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Aren't we removing this? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Or I guess it might be useful for automatically supporting legacy environments? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, I think we should keep a parameter in make to easily apply the compatibility wrapper |
||
disable_env_checker: If to run the env checker, None will default to the environment specification `disable_env_checker` | ||
(which is by default False, running the environment checker), | ||
otherwise will run according to this parameter (`True` = not run, `False` = run) | ||
|
@@ -684,26 +684,28 @@ def make( | |
): | ||
env = PassiveEnvChecker(env) | ||
|
||
env = StepAPICompatibility(env, new_step_api) | ||
|
||
# Add the order enforcing wrapper | ||
if spec_.order_enforce: | ||
env = OrderEnforcing(env) | ||
|
||
# Add the time limit wrapper | ||
if max_episode_steps is not None: | ||
env = TimeLimit(env, max_episode_steps, new_step_api) | ||
env = TimeLimit(env, max_episode_steps) | ||
elif spec_.max_episode_steps is not None: | ||
env = TimeLimit(env, spec_.max_episode_steps, new_step_api) | ||
env = TimeLimit(env, spec_.max_episode_steps) | ||
|
||
# Add the autoreset wrapper | ||
if autoreset: | ||
env = AutoResetWrapper(env, new_step_api) | ||
env = AutoResetWrapper(env) | ||
|
||
# Add human rendering wrapper | ||
if apply_human_rendering: | ||
env = HumanRendering(env) | ||
|
||
# Add step API wrapper | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does it work if the compatibility wrapper is at the end? As far as I understand, the use case here is if someone has a legacy environment, then it would convert it to a new-style environment. But wouldn't one of the wrappers before this crash out if the compatibility is not handled in advance? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. (checked now, it doesn't work, at least assuming my understanding is correct) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes I think you are correct, this should occur after the environment checker in order of wrapper |
||
if apply_step_compatibility: | ||
env = StepAPICompatibility(env, True) | ||
|
||
return env | ||
|
||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -170,7 +170,7 @@ def play( | |
:class:`gym.utils.play.PlayPlot`. Here's a sample code for plotting the reward | ||
for last 150 steps. | ||
|
||
>>> def callback(obs_t, obs_tp1, action, rew, done, info): | ||
>>> def callback(obs_t, obs_tp1, action, rew, terminated, truncated, info): | ||
... return [rew,] | ||
>>> plotter = PlayPlot(callback, 150, ["reward"]) | ||
>>> play(gym.make("ALE/AirRaid-v5"), callback=plotter.callback) | ||
|
@@ -187,7 +187,8 @@ def play( | |
obs_tp1: observation after performing action | ||
action: action that was executed | ||
rew: reward that was received | ||
done: whether the environment is done or not | ||
terminated: whether the environment is terminated or not | ||
truncated: whether the environment is truncated or not | ||
info: debug info | ||
keys_to_action: Mapping from keys pressed to action performed. | ||
Different formats are supported: Key combinations can either be expressed as a tuple of unicode code | ||
|
@@ -219,11 +220,6 @@ def play( | |
deprecation( | ||
"`play.py` currently supports only the old step API which returns one boolean, however this will soon be updated to support only the new step api that returns two bools." | ||
) | ||
if env.render_mode not in {"rgb_array", "single_rgb_array"}: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why was this removed? Seems irrelevant |
||
logger.error( | ||
"play method works only with rgb_array and single_rgb_array render modes, " | ||
f"but your environment render_mode = {env.render_mode}." | ||
) | ||
|
||
env.reset(seed=seed) | ||
|
||
|
@@ -261,9 +257,10 @@ def play( | |
else: | ||
action = key_code_to_action.get(tuple(sorted(game.pressed_keys)), noop) | ||
prev_obs = obs | ||
obs, rew, done, info = env.step(action) | ||
obs, rew, terminated, truncated, info = env.step(action) | ||
arjun-kg marked this conversation as resolved.
Show resolved
Hide resolved
|
||
done = terminated or truncated | ||
if callback is not None: | ||
callback(prev_obs, obs, action, rew, done, info) | ||
callback(prev_obs, obs, action, rew, terminated, truncated, info) | ||
if obs is not None: | ||
rendered = env.render() | ||
if isinstance(rendered, List): | ||
|
@@ -290,13 +287,14 @@ class PlayPlot: | |
- obs_tp1: observation after performing action | ||
- action: action that was executed | ||
- rew: reward that was received | ||
- done: whether the environment is done or not | ||
- terminated: whether the environment is terminated or not | ||
- truncated: whether the environment is truncated or not | ||
- info: debug info | ||
|
||
It should return a list of metrics that are computed from this data. | ||
For instance, the function may look like this:: | ||
|
||
>>> def compute_metrics(obs_t, obs_tp, action, reward, done, info): | ||
>>> def compute_metrics(obs_t, obs_tp, action, reward, terminated, truncated, info): | ||
... return [reward, info["cumulative_reward"], np.linalg.norm(action)] | ||
|
||
:class:`PlayPlot` provides the method :meth:`callback` which will pass its arguments along to that function | ||
|
@@ -353,7 +351,8 @@ def callback( | |
obs_tp1: ObsType, | ||
action: ActType, | ||
rew: float, | ||
done: bool, | ||
terminated: bool, | ||
truncated: bool, | ||
info: dict, | ||
): | ||
"""The callback that calls the provided data callback and adds the data to the plots. | ||
|
@@ -363,10 +362,13 @@ def callback( | |
obs_tp1: The observation at time step t+1 | ||
action: The action | ||
rew: The reward | ||
done: If the environment is done | ||
terminated: If the environment is terminated | ||
truncated: If the environment is truncated | ||
info: The information from the environment | ||
""" | ||
points = self.data_callback(obs_t, obs_tp1, action, rew, done, info) | ||
points = self.data_callback( | ||
obs_t, obs_tp1, action, rew, terminated, truncated, info | ||
) | ||
for point, data_series in zip(points, self.data): | ||
data_series.append(point) | ||
self.t += 1 | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,18 +1,18 @@ | ||
"""Contains methods for step compatibility, from old-to-new and new-to-old API, to be removed in 1.0.""" | ||
"""Contains methods for step compatibility, from old-to-new and new-to-old API.""" | ||
from typing import Tuple, Union | ||
|
||
import numpy as np | ||
|
||
from gym.core import ObsType | ||
|
||
OldStepType = Tuple[ | ||
DoneStepType = Tuple[ | ||
Union[ObsType, np.ndarray], | ||
Union[float, np.ndarray], | ||
Union[bool, np.ndarray], | ||
Union[dict, list], | ||
] | ||
|
||
NewStepType = Tuple[ | ||
TerminatedTruncatedStepType = Tuple[ | ||
Union[ObsType, np.ndarray], | ||
Union[float, np.ndarray], | ||
Union[bool, np.ndarray], | ||
|
@@ -21,9 +21,9 @@ | |
] | ||
|
||
|
||
def step_to_new_api( | ||
step_returns: Union[OldStepType, NewStepType], is_vector_env=False | ||
) -> NewStepType: | ||
def convert_to_terminated_truncated_step_api( | ||
step_returns: Union[DoneStepType, TerminatedTruncatedStepType], is_vector_env=False | ||
) -> TerminatedTruncatedStepType: | ||
"""Function to transform step returns to new step API irrespective of input API. | ||
|
||
Args: | ||
|
@@ -73,9 +73,10 @@ def step_to_new_api( | |
) | ||
|
||
|
||
def step_to_old_api( | ||
step_returns: Union[NewStepType, OldStepType], is_vector_env: bool = False | ||
) -> OldStepType: | ||
def convert_to_done_step_api( | ||
step_returns: Union[TerminatedTruncatedStepType, DoneStepType], | ||
is_vector_env: bool = False, | ||
) -> DoneStepType: | ||
"""Function to transform step returns to old step API irrespective of input API. | ||
|
||
Args: | ||
|
@@ -128,33 +129,33 @@ def step_to_old_api( | |
|
||
|
||
def step_api_compatibility( | ||
step_returns: Union[NewStepType, OldStepType], | ||
new_step_api: bool = False, | ||
step_returns: Union[TerminatedTruncatedStepType, DoneStepType], | ||
output_truncation_bool: bool = True, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't like this argument name, it's pretty unclear. I think there's a different name used earlier? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
is_vector_env: bool = False, | ||
) -> Union[NewStepType, OldStepType]: | ||
"""Function to transform step returns to the API specified by `new_step_api` bool. | ||
) -> Union[TerminatedTruncatedStepType, DoneStepType]: | ||
"""Function to transform step returns to the API specified by `output_truncation_bool` bool. | ||
|
||
Old step API refers to step() method returning (observation, reward, done, info) | ||
New step API refers to step() method returning (observation, reward, terminated, truncated, info) | ||
Done (old) step API refers to step() method returning (observation, reward, done, info) | ||
Terminated Truncated (new) step API refers to step() method returning (observation, reward, terminated, truncated, info) | ||
(Refer to docs for details on the API change) | ||
|
||
Args: | ||
step_returns (tuple): Items returned by step(). Can be (obs, rew, done, info) or (obs, rew, terminated, truncated, info) | ||
new_step_api (bool): Whether the output should be in new step API or old (False by default) | ||
output_truncation_bool (bool): Whether the output should return two booleans (new API) or one (old) (True by default) | ||
is_vector_env (bool): Whether the step_returns are from a vector environment | ||
|
||
Returns: | ||
step_returns (tuple): Depending on `new_step_api` bool, it can return (obs, rew, done, info) or (obs, rew, terminated, truncated, info) | ||
step_returns (tuple): Depending on `output_truncation_bool` bool, it can return (obs, rew, done, info) or (obs, rew, terminated, truncated, info) | ||
|
||
Examples: | ||
This function can be used to ensure compatibility in step interfaces with conflicting API. Eg. if env is written in old API, | ||
wrapper is written in new API, and the final step output is desired to be in old API. | ||
|
||
>>> obs, rew, done, info = step_api_compatibility(env.step(action)) | ||
>>> obs, rew, terminated, truncated, info = step_api_compatibility(env.step(action), new_step_api=True) | ||
>>> obs, rew, done, info = step_api_compatibility(env.step(action), output_truncation_bool=False) | ||
>>> obs, rew, terminated, truncated, info = step_api_compatibility(env.step(action), output_truncation_bool=True) | ||
>>> observations, rewards, dones, infos = step_api_compatibility(vec_env.step(action), is_vector_env=True) | ||
""" | ||
if new_step_api: | ||
return step_to_new_api(step_returns, is_vector_env) | ||
if output_truncation_bool: | ||
return convert_to_terminated_truncated_step_api(step_returns, is_vector_env) | ||
else: | ||
return step_to_old_api(step_returns, is_vector_env) | ||
return convert_to_done_step_api(step_returns, is_vector_env) |
Uh oh!
There was an error while loading. Please reload this page.