From 046c18c958de6ee0235482fb5dae68cbdb487ca8 Mon Sep 17 00:00:00 2001 From: StringTheory Date: Thu, 6 Oct 2022 16:40:38 +0100 Subject: [PATCH 01/12] Initial commit --- gymnasium/core.py | 98 ++++++++++--------- .../test_record_episode_statistics.py | 1 + 2 files changed, 55 insertions(+), 44 deletions(-) diff --git a/gymnasium/core.py b/gymnasium/core.py index 82e160e0b..c9e1c0d9c 100644 --- a/gymnasium/core.py +++ b/gymnasium/core.py @@ -1,5 +1,5 @@ """Core API for Environment, Wrapper, ActionWrapper, RewardWrapper and ObservationWrapper.""" -import sys +from abc import ABC, abstractmethod from typing import ( TYPE_CHECKING, Any, @@ -16,23 +16,18 @@ import numpy as np from gymnasium import spaces -from gymnasium.logger import warn from gymnasium.utils import seeding if TYPE_CHECKING: from gymnasium.envs.registration import EnvSpec -if sys.version_info[0:2] == (3, 6): - warn( - "Gymnasium minimally supports python 3.6 as the python foundation not longer supports the version, please update your version to 3.7+" - ) ObsType = TypeVar("ObsType") ActType = TypeVar("ActType") RenderFrame = TypeVar("RenderFrame") -class Env(Generic[ObsType, ActType]): +class Env(Generic[ObsType, ActType], ABC): r"""The main Gymnasium class. It encapsulates an environment with arbitrary behind-the-scenes dynamics. @@ -66,8 +61,8 @@ class Env(Generic[ObsType, ActType]): spec: "EnvSpec" = None # Set these in ALL subclasses - action_space: spaces.Space[ActType] - observation_space: spaces.Space[ObsType] + action_space: spaces.Space + observation_space: spaces.Space # Created _np_random: Optional[np.random.Generator] = None @@ -76,14 +71,17 @@ class Env(Generic[ObsType, ActType]): def np_random(self) -> np.random.Generator: """Returns the environment's internal :attr:`_np_random` that if not set will initialise with a random seed.""" if self._np_random is None: - self._np_random, seed = seeding.np_random() + self._np_random, _ = seeding.np_random() return self._np_random @np_random.setter def np_random(self, value: np.random.Generator): self._np_random = value - def step(self, action: ActType) -> Tuple[ObsType, float, bool, bool, dict]: + @abstractmethod + def step( + self, action: ActType + ) -> Tuple[ObsType, float, bool, bool, Dict[str, Any]]: """Run one timestep of the environment's dynamics. When end of episode is reached, you are responsible for calling :meth:`reset` to reset this environment's state. @@ -99,7 +97,7 @@ def step(self, action: ActType) -> Tuple[ObsType, float, bool, bool, dict]: terminated (bool): whether a `terminal state` (as defined under the MDP of the task) is reached. In this case further step() calls could return undefined results. truncated (bool): whether a truncation condition outside the scope of the MDP is satisfied. - Typically a timelimit, but could also be used to indicate agent physically going out of bounds. + Typically, a timelimit, but could also be used to indicate agent physically going out of bounds. Can be used to end the episode prematurely before a `terminal state` is reached. info (dictionary): `info` contains auxiliary diagnostic information (helpful for debugging, learning, and logging). This might, for instance, contain: metrics that describe the agent's performance state, variables that are @@ -113,6 +111,7 @@ def step(self, action: ActType) -> Tuple[ObsType, float, bool, bool, dict]: """ raise NotImplementedError + @abstractmethod def reset( self, *, @@ -138,7 +137,6 @@ def reset( options (optional dict): Additional information to specify how the environment is reset (optional, depending on the specific environment) - Returns: observation (object): Observation of the initial state. This will be an element of :attr:`observation_space` (typically a numpy array) and is analogous to the observation returned by :meth:`step`. @@ -183,7 +181,7 @@ def close(self): pass @property - def unwrapped(self) -> "Env": + def unwrapped(self) -> "Env[ObsType, ActType]": """Returns the base non-wrapped environment. Returns: @@ -202,14 +200,14 @@ def __enter__(self): """Support with-statement for the environment.""" return self - def __exit__(self, *args): + def __exit__(self, *args: List[Any]): """Support with-statement for the environment.""" self.close() # propagate exception return False -class Wrapper(Env[ObsType, ActType]): +class Wrapper(Env[ObsType, ActType], ABC): """Wraps an environment to allow a modular transformation of the :meth:`step` and :meth:`reset` methods. This class is the base class for all wrappers. The subclass could override @@ -220,7 +218,7 @@ class Wrapper(Env[ObsType, ActType]): Don't forget to call ``super().__init__(env)`` if the subclass overrides :meth:`__init__`. """ - def __init__(self, env: Env): + def __init__(self, env: Env[ObsType, ActType]): """Wraps an environment to allow a modular transformation of the :meth:`step` and :meth:`reset` methods. Args: @@ -231,16 +229,16 @@ def __init__(self, env: Env): self._action_space: Optional[spaces.Space] = None self._observation_space: Optional[spaces.Space] = None self._reward_range: Optional[Tuple[SupportsFloat, SupportsFloat]] = None - self._metadata: Optional[dict] = None + self._metadata: Optional[Dict[str, Any]] = None - def __getattr__(self, name): + def __getattr__(self, name: str): """Returns an attribute with ``name``, unless ``name`` starts with an underscore.""" if name.startswith("_"): raise AttributeError(f"accessing private attribute '{name}' is prohibited") return getattr(self.env, name) @property - def spec(self): + def spec(self) -> Optional["EnvSpec"]: """Returns the environment specification.""" return self.env.spec @@ -250,7 +248,7 @@ def class_name(cls): return cls.__name__ @property - def action_space(self) -> spaces.Space[ActType]: + def action_space(self) -> spaces.Space: """Returns the action space of the environment.""" if self._action_space is None: return self.env.action_space @@ -283,14 +281,14 @@ def reward_range(self, value: Tuple[SupportsFloat, SupportsFloat]): self._reward_range = value @property - def metadata(self) -> dict: + def metadata(self) -> Dict[str, Any]: """Returns the environment metadata.""" if self._metadata is None: return self.env.metadata return self._metadata @metadata.setter - def metadata(self, value): + def metadata(self, value: Dict[str, Any]): self._metadata = value @property @@ -304,7 +302,7 @@ def np_random(self) -> np.random.Generator: return self.env.np_random @np_random.setter - def np_random(self, value): + def np_random(self, value: np.random.Generator): self.env.np_random = value @property @@ -321,11 +319,9 @@ def reset(self, **kwargs) -> Tuple[ObsType, dict]: """Resets the environment with kwargs.""" return self.env.reset(**kwargs) - def render( - self, *args, **kwargs - ) -> Optional[Union[RenderFrame, List[RenderFrame]]]: + def render(self) -> Optional[Union[RenderFrame, List[RenderFrame]]]: """Renders the environment.""" - return self.env.render(*args, **kwargs) + return self.env.render() def close(self): """Closes the environment.""" @@ -340,18 +336,18 @@ def __repr__(self): return str(self) @property - def unwrapped(self) -> Env: + def unwrapped(self) -> Env[ObsType, ActType]: """Returns the base environment of the wrapper.""" return self.env.unwrapped -class ObservationWrapper(Wrapper): +class ObservationWrapper(Wrapper[ObsType, ActType], ABC): """Superclass of wrappers that can modify observations using :meth:`observation` for :meth:`reset` and :meth:`step`. If you would like to apply a function to the observation that is returned by the base environment before passing it to learning code, you can simply inherit from :class:`ObservationWrapper` and overwrite the method :meth:`observation` to implement that transformation. The transformation defined in that method must be - defined on the base environment’s observation space. However, it may take values in a different space. + defined on the base environment's observation space. However, it may take values in a different space. In that case, you need to specify the new observation space of the wrapper by setting :attr:`self.observation_space` in the :meth:`__init__` method of your wrapper. @@ -373,22 +369,30 @@ def observation(self, obs): index of the timestep to the observation. """ - def reset(self, **kwargs): + def reset( + self, + *, + seed: Optional[int] = None, + options: Optional[Dict[str, Any]] = None, + ) -> Tuple[ObsType, Dict[str, Any]]: """Resets the environment, returning a modified observation using :meth:`self.observation`.""" - obs, info = self.env.reset(**kwargs) + obs, info = self.env.reset(seed=seed, options=options) return self.observation(obs), info - def step(self, action): + def step( + self, action: ActType + ) -> Tuple[ObsType, SupportsFloat, bool, Dict[str, Any]]: """Returns a modified observation using :meth:`self.observation` after calling :meth:`env.step`.""" observation, reward, terminated, truncated, info = self.env.step(action) return self.observation(observation), reward, terminated, truncated, info - def observation(self, observation): + @abstractmethod + def observation(self, observation: ObsType) -> ObsType: """Returns a modified observation.""" raise NotImplementedError -class RewardWrapper(Wrapper): +class RewardWrapper(Wrapper[ObsType, ActType], ABC): """Superclass of wrappers that can modify the returning reward from a step. If you would like to apply a function to the reward that is returned by the base environment before @@ -401,7 +405,7 @@ class RewardWrapper(Wrapper): because it is intrinsic), we want to clip the reward to a range to gain some numerical stability. To do that, we could, for instance, implement the following wrapper:: - class ClipReward(gymnasium.RewardWrapper): + class ClipReward(gym.RewardWrapper): def __init__(self, env, min_reward, max_reward): super().__init__(env) self.min_reward = min_reward @@ -412,17 +416,20 @@ def reward(self, reward): return np.clip(reward, self.min_reward, self.max_reward) """ - def step(self, action): + def step( + self, action: ActType + ) -> Tuple[ObsType, SupportsFloat, bool, bool, Dict[str, Any]]: """Modifies the reward using :meth:`self.reward` after the environment :meth:`env.step`.""" observation, reward, terminated, truncated, info = self.env.step(action) return observation, self.reward(reward), terminated, truncated, info - def reward(self, reward): + @abstractmethod + def reward(self, reward: SupportsFloat) -> SupportsFloat: """Returns a modified ``reward``.""" raise NotImplementedError -class ActionWrapper(Wrapper): +class ActionWrapper(Wrapper[ObsType, ActType], ABC): """Superclass of wrappers that can modify the action before :meth:`env.step`. If you would like to apply a function to the action before passing it to the base environment, @@ -432,7 +439,7 @@ class ActionWrapper(Wrapper): In that case, you need to specify the new action space of the wrapper by setting :attr:`self.action_space` in the :meth:`__init__` method of your wrapper. - Let’s say you have an environment with action space of type :class:`gymnasium.spaces.Box`, but you would only like + Let's say you have an environment with action space of type :class:`gymnasium.spaces.Box`, but you would only like to use a finite subset of actions. Then, you might want to implement the following wrapper:: class DiscreteActions(gym.ActionWrapper): @@ -454,14 +461,17 @@ def action(self, act): Among others, Gymnasium provides the action wrappers :class:`ClipAction` and :class:`RescaleAction`. """ - def step(self, action): + def step( + self, action: ActType + ) -> Tuple[ObsType, SupportsFloat, bool, bool, Dict[str, Any]]: """Runs the environment :meth:`env.step` using the modified ``action`` from :meth:`self.action`.""" return self.env.step(self.action(action)) - def action(self, action): + @abstractmethod + def action(self, action: ActType) -> ActType: """Returns a modified action before :meth:`env.step` is called.""" raise NotImplementedError - def reverse_action(self, action): + def reverse_action(self, action: ActType) -> ActType: """Returns a reversed ``action``.""" raise NotImplementedError diff --git a/tests/wrappers/test_record_episode_statistics.py b/tests/wrappers/test_record_episode_statistics.py index 72d33c4e9..c9ed0d6d0 100644 --- a/tests/wrappers/test_record_episode_statistics.py +++ b/tests/wrappers/test_record_episode_statistics.py @@ -11,6 +11,7 @@ def test_record_episode_statistics(env_id, deque_size): env = gym.make(env_id, disable_env_checker=True) env = RecordEpisodeStatistics(env, deque_size) + assert env.spec is not None for n in range(5): env.reset() From 1cd0a4c0770595e5db8c113eefa4fcc48355e8e0 Mon Sep 17 00:00:00 2001 From: StringTheory Date: Thu, 6 Oct 2022 21:05:56 +0100 Subject: [PATCH 02/12] Fix the tests through using GenericTestEnv --- gymnasium/core.py | 67 +++--- .../wrappers/monitoring/video_recorder.py | 10 +- tests/envs/test_make.py | 2 +- tests/envs/utils.py | 2 + tests/envs/utils_envs.py | 33 ++- tests/{testing_env.py => generic_test_env.py} | 7 +- tests/test_core.py | 195 +++++++----------- tests/utils/test_env_checker.py | 2 +- tests/utils/test_passive_env_checker.py | 2 +- tests/utils/test_play.py | 2 +- tests/wrappers/test_atari_preprocessing.py | 2 +- tests/wrappers/test_flatten.py | 115 +++-------- tests/wrappers/test_passive_env_checker.py | 2 +- tests/wrappers/test_step_compatibility.py | 32 +-- tests/wrappers/test_video_recorder.py | 18 +- 15 files changed, 188 insertions(+), 303 deletions(-) rename tests/{testing_env.py => generic_test_env.py} (93%) diff --git a/gymnasium/core.py b/gymnasium/core.py index c9e1c0d9c..d32d7f326 100644 --- a/gymnasium/core.py +++ b/gymnasium/core.py @@ -357,13 +357,15 @@ class ObservationWrapper(Wrapper[ObsType, ActType], ABC): ``observation["target_position"] - observation["agent_position"]``. For this, you could implement an observation wrapper like this:: - class RelativePosition(gym.ObservationWrapper): - def __init__(self, env): - super().__init__(env) - self.observation_space = Box(shape=(2,), low=-np.inf, high=np.inf) - - def observation(self, obs): - return obs["target"] - obs["agent"] + >>> import gymnasium as gym + >>> from gymnasium.spaces import Box + >>> class RelativePosition(gym.ObservationWrapper): + ... def __init__(self, env): + ... super().__init__(env) + ... self.observation_space = Box(shape=(2,), low=-np.inf, high=np.inf) + ... + ... def observation(self, obs): + ... return obs["target"] - obs["agent"] Among others, Gymnasium provides the observation wrapper :class:`TimeAwareObservation`, which adds information about the index of the timestep to the observation. @@ -381,7 +383,7 @@ def reset( def step( self, action: ActType - ) -> Tuple[ObsType, SupportsFloat, bool, Dict[str, Any]]: + ) -> Tuple[ObsType, SupportsFloat, bool, bool, Dict[str, Any]]: """Returns a modified observation using :meth:`self.observation` after calling :meth:`env.step`.""" observation, reward, terminated, truncated, info = self.env.step(action) return self.observation(observation), reward, terminated, truncated, info @@ -405,15 +407,16 @@ class RewardWrapper(Wrapper[ObsType, ActType], ABC): because it is intrinsic), we want to clip the reward to a range to gain some numerical stability. To do that, we could, for instance, implement the following wrapper:: - class ClipReward(gym.RewardWrapper): - def __init__(self, env, min_reward, max_reward): - super().__init__(env) - self.min_reward = min_reward - self.max_reward = max_reward - self.reward_range = (min_reward, max_reward) - - def reward(self, reward): - return np.clip(reward, self.min_reward, self.max_reward) + >>> import gymnasium as gym + >>> class ClipReward(gym.RewardWrapper): + ... def __init__(self, env, min_reward, max_reward): + ... super().__init__(env) + ... self.min_reward = min_reward + ... self.max_reward = max_reward + ... self.reward_range = (min_reward, max_reward) + ... + ... def reward(self, r): + ... return np.clip(r, self.min_reward, self.max_reward) """ def step( @@ -442,20 +445,22 @@ class ActionWrapper(Wrapper[ObsType, ActType], ABC): Let's say you have an environment with action space of type :class:`gymnasium.spaces.Box`, but you would only like to use a finite subset of actions. Then, you might want to implement the following wrapper:: - class DiscreteActions(gym.ActionWrapper): - def __init__(self, env, disc_to_cont): - super().__init__(env) - self.disc_to_cont = disc_to_cont - self.action_space = Discrete(len(disc_to_cont)) - - def action(self, act): - return self.disc_to_cont[act] - - if __name__ == "__main__": - env = gym.make("LunarLanderContinuous-v2") - wrapped_env = DiscreteActions(env, [np.array([1,0]), np.array([-1,0]), - np.array([0,1]), np.array([0,-1])]) - print(wrapped_env.action_space) #Discrete(4) + >>> import gymnasium as gym + >>> from gymnasium.spaces import Discrete + >>> class DiscreteActions(gym.ActionWrapper): + ... def __init__(self, env, disc_to_cont): + ... super().__init__(env) + ... self.disc_to_cont = disc_to_cont + ... self.action_space = Discrete(len(disc_to_cont)) + ... + ... def action(self, act): + ... return self.disc_to_cont[act] + + >>> if __name__ == "__main__": + >>> env = gym.make("LunarLanderContinuous-v2") + >>> wrapped_env = DiscreteActions(env, [np.array([1,0]), np.array([-1,0]), + ... np.array([0,1]), np.array([0,-1])]) + >>> print(wrapped_env.action_space) #Discrete(4) Among others, Gymnasium provides the action wrappers :class:`ClipAction` and :class:`RescaleAction`. diff --git a/gymnasium/wrappers/monitoring/video_recorder.py b/gymnasium/wrappers/monitoring/video_recorder.py index 899a71f9d..72a5b9e42 100644 --- a/gymnasium/wrappers/monitoring/video_recorder.py +++ b/gymnasium/wrappers/monitoring/video_recorder.py @@ -38,6 +38,10 @@ def __init__( Error: You can pass at most one of `path` or `base_path` Error: Invalid path given that must have a particular file extension """ + self._async = env.metadata.get("semantics.async") + self.enabled = enabled + self._closed = False + try: # check that moviepy is now installed import moviepy # noqa: F401 @@ -46,16 +50,12 @@ def __init__( "MoviePy is not installed, run `pip install moviepy`" ) - self._async = env.metadata.get("semantics.async") - self.enabled = enabled - self._closed = False - self.render_history = [] self.env = env self.render_mode = env.render_mode - if "rgb_array_list" != self.render_mode and "rgb_array" != self.render_mode: + if "rgb_array" != self.render_mode and "rgb_array_list" != self.render_mode: logger.warn( f"Disabling video recorder because environment {env} was not initialized with any compatible video " "mode between `rgb_array` and `rgb_array_list`" diff --git a/tests/envs/test_make.py b/tests/envs/test_make.py index 4844b8a28..ebe919143 100644 --- a/tests/envs/test_make.py +++ b/tests/envs/test_make.py @@ -19,7 +19,7 @@ from tests.envs.test_envs import PASSIVE_CHECK_IGNORE_WARNING from tests.envs.utils import all_testing_env_specs from tests.envs.utils_envs import ArgumentEnv, RegisterDuringMakeEnv -from tests.testing_env import GenericTestEnv, old_step_fn +from tests.generic_test_env import GenericTestEnv, old_step_fn from tests.wrappers.utils import has_wrapper gym.register( diff --git a/tests/envs/utils.py b/tests/envs/utils.py index 36f60a1a1..d9b1d1db5 100644 --- a/tests/envs/utils.py +++ b/tests/envs/utils.py @@ -19,6 +19,8 @@ def try_make_env(env_spec: EnvSpec) -> Optional[gym.Env]: return env_spec.make(disable_env_checker=True).unwrapped except (ImportError, gym.error.DependencyNotInstalled) as e: logger.warn(f"Not testing {env_spec.id} due to error: {e}") + except Exception as e: + logger.warn(f"Unexpected exception occurred: {e}") return None diff --git a/tests/envs/utils_envs.py b/tests/envs/utils_envs.py index 8713f4b20..34dfca0cc 100644 --- a/tests/envs/utils_envs.py +++ b/tests/envs/utils_envs.py @@ -1,42 +1,35 @@ +from typing import Any + import gymnasium as gym +from tests.generic_test_env import GenericTestEnv -class RegisterDuringMakeEnv(gym.Env): +class RegisterDuringMakeEnv(GenericTestEnv): """Used in `test_registration.py` to check if `env.make` can import and register an env""" - def __init__(self): - self.action_space = gym.spaces.Discrete(1) - self.observation_space = gym.spaces.Discrete(1) - -class ArgumentEnv(gym.Env): - observation_space = gym.spaces.Box(low=-1, high=1, shape=(1,)) - action_space = gym.spaces.Box(low=-1, high=1, shape=(1,)) +class ArgumentEnv(GenericTestEnv): + def __init__(self, arg1: Any, arg2: Any, arg3: Any): + super().__init__() - def __init__(self, arg1, arg2, arg3): - self.arg1 = arg1 - self.arg2 = arg2 - self.arg3 = arg3 + self.arg1, self.arg2, self.arg3 = arg1, arg2, arg3 # Environments to test render_mode -class NoHuman(gym.Env): +class NoHuman(GenericTestEnv): """Environment that does not have human-rendering.""" metadata = {"render_modes": ["rgb_array_list"], "render_fps": 4} - def __init__(self, render_mode=None): + def __init__(self, render_mode: str = None): assert render_mode in self.metadata["render_modes"] - self.render_mode = render_mode + super().__init__(render_mode=render_mode) -class NoHumanOldAPI(gym.Env): +class NoHumanOldAPI(GenericTestEnv): """Environment that does not have human-rendering.""" - metadata = {"render_modes": ["rgb_array_list"], "render_fps": 4} - - def __init__(self): - pass + metadata = {"render_modes": ["rgb_array"], "render_fps": 4} class NoHumanNoRGB(gym.Env): diff --git a/tests/testing_env.py b/tests/generic_test_env.py similarity index 93% rename from tests/testing_env.py rename to tests/generic_test_env.py index c3e957aea..65e899fd7 100644 --- a/tests/testing_env.py +++ b/tests/generic_test_env.py @@ -48,9 +48,12 @@ def __init__( render_fn: callable = basic_render_fn, metadata: Optional[Dict[str, Any]] = None, render_mode: Optional[str] = None, - spec: EnvSpec = EnvSpec("TestingEnv-v0", "testing-env-no-entry-point"), + spec: Optional[EnvSpec] = EnvSpec( + "TestingEnv-v0", "testing-env-no-entry-point" + ), ): - self.metadata = {} if metadata is None else metadata + if not hasattr(self, "metadata"): + self.metadata = {} if metadata is None else metadata self.render_mode = render_mode self.spec = spec diff --git a/tests/test_core.py b/tests/test_core.py index cc56049bb..f34ee62ba 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -1,139 +1,98 @@ -from typing import Optional +from typing import Any, Dict, Optional, SupportsFloat, Tuple import numpy as np -import pytest -from gymnasium import core, spaces -from gymnasium.wrappers import OrderEnforcing, TimeLimit +from gymnasium import Env, ObservationWrapper, RewardWrapper, Wrapper +from gymnasium.core import ActionWrapper, ActType, ObsType +from gymnasium.spaces import Box +from tests.generic_test_env import GenericTestEnv -class ArgumentEnv(core.Env): - observation_space = spaces.Box(low=0, high=1, shape=(1,)) - action_space = spaces.Box(low=0, high=1, shape=(1,)) - calls = 0 +class ExampleEnv(Env): + def step( + self, action: ActType + ) -> Tuple[ObsType, float, bool, bool, Dict[str, Any]]: + return 0, 0, False, False, {} - def __init__(self, arg): - self.calls += 1 - self.arg = arg + def reset( + self, + *, + seed: Optional[int] = None, + options: Optional[dict] = None, + ) -> Tuple[ObsType, dict]: + return 0, {} -class UnittestEnv(core.Env): - observation_space = spaces.Box(low=0, high=255, shape=(64, 64, 3), dtype=np.uint8) - action_space = spaces.Discrete(3) +def test_gymnasium_env(): + env = ExampleEnv() - def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None): - super().reset(seed=seed) - return self.observation_space.sample(), {"info": "dummy"} + assert env.metadata == {"render_modes": []} + assert env.render_mode is None + assert env.reward_range == (-float("inf"), float("inf")) + assert env.spec is None + assert env._np_random is None - def step(self, action): - observation = self.observation_space.sample() # Dummy observation - return (observation, 0.0, False, {}) +class ExampleWrapper(Wrapper): + pass -class UnknownSpacesEnv(core.Env): - """This environment defines its observation & action spaces only - after the first call to reset. Although this pattern is sometimes - necessary when implementing a new environment (e.g. if it depends - on external resources), it is not encouraged. - """ - def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None): - super().reset(seed=seed) - self.observation_space = spaces.Box( - low=0, high=255, shape=(64, 64, 3), dtype=np.uint8 - ) - self.action_space = spaces.Discrete(3) - return self.observation_space.sample(), {} # Dummy observation with info +def test_gymnasium_wrapper(): + env = ExampleEnv() + wrapper_env = ExampleWrapper(env) - def step(self, action): - observation = self.observation_space.sample() # Dummy observation - return (observation, 0.0, False, {}) + assert env.metadata == wrapper_env.metadata + wrapper_env.metadata = {"render_modes": ["rgb_array"]} + assert env.metadata != wrapper_env.metadata + assert env.render_mode == wrapper_env.render_mode -class OldStyleEnv(core.Env): - """This environment doesn't accept any arguments in reset, ideally we want to support this too (for now)""" + assert env.reward_range == wrapper_env.reward_range + wrapper_env.reward_range = (-1.0, 1.0) + assert env.reward_range != wrapper_env.reward_range - def __init__(self): - pass + assert env.spec == wrapper_env.spec - def reset(self): - super().reset() - return 0 + env.observation_space = Box(0, 1) + env.action_space = Box(0, 1) + assert env.observation_space == wrapper_env.observation_space + assert env.action_space == wrapper_env.action_space + wrapper_env.observation_space = Box(1, 2) + wrapper_env.action_space = Box(1, 2) + assert env.observation_space != wrapper_env.observation_space + assert env.action_space != wrapper_env.action_space - def step(self, action): - return 0, 0, False, {} +class ExampleRewardWrapper(RewardWrapper): + def reward(self, reward: SupportsFloat) -> SupportsFloat: + return 1 -class NewPropertyWrapper(core.Wrapper): - def __init__( - self, - env, - observation_space=None, - action_space=None, - reward_range=None, - metadata=None, - ): - super().__init__(env) - if observation_space is not None: - # Only set the observation space if not None to test property forwarding - self.observation_space = observation_space - if action_space is not None: - self.action_space = action_space - if reward_range is not None: - self.reward_range = reward_range - if metadata is not None: - self.metadata = metadata - - -def test_env_instantiation(): - # This looks like a pretty trivial, but given our usage of - # __new__, it's worth having. - env = ArgumentEnv("arg") - assert env.arg == "arg" - assert env.calls == 1 - - -properties = [ - { - "observation_space": spaces.Box( - low=0.0, high=1.0, shape=(64, 64, 3), dtype=np.float32 - ) - }, - {"action_space": spaces.Discrete(2)}, - {"reward_range": (-1.0, 1.0)}, - {"metadata": {"render_modes": ["human", "rgb_array_list"]}}, - { - "observation_space": spaces.Box( - low=0.0, high=1.0, shape=(64, 64, 3), dtype=np.float32 - ), - "action_space": spaces.Discrete(2), - }, -] - - -@pytest.mark.parametrize("class_", [UnittestEnv, UnknownSpacesEnv]) -@pytest.mark.parametrize("props", properties) -def test_wrapper_property_forwarding(class_, props): - env = class_() - env = NewPropertyWrapper(env, **props) - - # If UnknownSpacesEnv, then call reset to define the spaces - if isinstance(env.unwrapped, UnknownSpacesEnv): - _ = env.reset() - - # Test the properties set by the wrapper - for key, value in props.items(): - assert getattr(env, key) == value - - # Otherwise, test if the properties are forwarded - all_properties = {"observation_space", "action_space", "reward_range", "metadata"} - for key in all_properties - props.keys(): - assert getattr(env, key) == getattr(env.unwrapped, key) - - -def test_compatibility_with_old_style_env(): - env = OldStyleEnv() - env = OrderEnforcing(env) - env = TimeLimit(env) - obs = env.reset() - assert obs == 0 + +class ExampleObservationWrapper(ObservationWrapper): + def observation(self, observation: ObsType) -> ObsType: + return np.array([1]) + + +class ExampleActionWrapper(ActionWrapper): + def action(self, action: ActType) -> ActType: + return np.array([1]) + + +def test_wrapper_types(): + env = GenericTestEnv() + + reward_env = ExampleRewardWrapper(env) + reward_env.reset() + _, reward, _, _, _ = reward_env.step(0) + assert reward == 1 + + observation_env = ExampleObservationWrapper(env) + obs, _ = observation_env.reset() + assert obs == np.array([1]) + obs, _, _, _, _ = observation_env.step(0) + assert obs == np.array([1]) + + env = GenericTestEnv(step_fn=lambda self, action: (action, 0, False, False, {})) + action_env = ExampleActionWrapper(env) + obs, _, _, _, _ = action_env.step(0) + assert obs == np.array([1]) diff --git a/tests/utils/test_env_checker.py b/tests/utils/test_env_checker.py index eaed92273..cf518a11f 100644 --- a/tests/utils/test_env_checker.py +++ b/tests/utils/test_env_checker.py @@ -17,7 +17,7 @@ check_reset_seed, check_seed_deprecation, ) -from tests.testing_env import GenericTestEnv +from tests.generic_test_env import GenericTestEnv @pytest.mark.parametrize( diff --git a/tests/utils/test_passive_env_checker.py b/tests/utils/test_passive_env_checker.py index 719fa2d85..cba60201d 100644 --- a/tests/utils/test_passive_env_checker.py +++ b/tests/utils/test_passive_env_checker.py @@ -15,7 +15,7 @@ env_reset_passive_checker, env_step_passive_checker, ) -from tests.testing_env import GenericTestEnv +from tests.generic_test_env import GenericTestEnv def _modify_space(space: spaces.Space, attribute: str, value): diff --git a/tests/utils/test_play.py b/tests/utils/test_play.py index 054acf137..b4b321278 100644 --- a/tests/utils/test_play.py +++ b/tests/utils/test_play.py @@ -10,7 +10,7 @@ import gymnasium as gym from gymnasium.utils.play import MissingKeysToAction, PlayableGame, play -from tests.testing_env import GenericTestEnv +from tests.generic_test_env import GenericTestEnv RELEVANT_KEY_1 = ord("a") # 97 RELEVANT_KEY_2 = ord("d") # 100 diff --git a/tests/wrappers/test_atari_preprocessing.py b/tests/wrappers/test_atari_preprocessing.py index b451f528f..0263f4002 100644 --- a/tests/wrappers/test_atari_preprocessing.py +++ b/tests/wrappers/test_atari_preprocessing.py @@ -3,7 +3,7 @@ from gymnasium.spaces import Box, Discrete from gymnasium.wrappers import AtariPreprocessing, StepAPICompatibility -from tests.testing_env import GenericTestEnv, old_step_fn +from tests.generic_test_env import GenericTestEnv, old_step_fn class AleTesting: diff --git a/tests/wrappers/test_flatten.py b/tests/wrappers/test_flatten.py index 9c6f08022..bb0398f86 100644 --- a/tests/wrappers/test_flatten.py +++ b/tests/wrappers/test_flatten.py @@ -1,98 +1,51 @@ """Tests for the flatten observation wrapper.""" from collections import OrderedDict -from typing import Optional -import numpy as np import pytest -import gymnasium as gym -from gymnasium.spaces import Box, Dict, flatten, unflatten +from gymnasium.spaces import Box, Dict, unflatten from gymnasium.wrappers import FlattenObservation - - -class FakeEnvironment(gym.Env): - def __init__(self, observation_space): - self.observation_space = observation_space - - def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None): - super().reset(seed=seed) - self.observation = self.observation_space.sample() - return self.observation, {} - +from tests.generic_test_env import GenericTestEnv OBSERVATION_SPACES = ( - ( - Dict( - OrderedDict( - [ - ("key1", Box(shape=(2, 3), low=0, high=0, dtype=np.float32)), - ("key2", Box(shape=(), low=1, high=1, dtype=np.float32)), - ("key3", Box(shape=(2,), low=2, high=2, dtype=np.float32)), - ] - ) - ), - True, + Dict( + OrderedDict( + [ + ("key1", Box(shape=(2, 3), low=0, high=0)), + ("key2", Box(shape=(1,), low=1, high=1)), + ("key3", Box(shape=(2,), low=2, high=2)), + ] + ) ), - ( - Dict( - OrderedDict( - [ - ("key2", Box(shape=(), low=0, high=0, dtype=np.float32)), - ("key3", Box(shape=(2,), low=1, high=1, dtype=np.float32)), - ("key1", Box(shape=(2, 3), low=2, high=2, dtype=np.float32)), - ] - ) - ), - True, + Dict( + OrderedDict( + [ + ("key2", Box(shape=(1,), low=0, high=0)), + ("key3", Box(shape=(2,), low=1, high=1)), + ("key1", Box(shape=(2, 3), low=2, high=2)), + ] + ) ), - ( - Dict( - { - "key1": Box(shape=(2, 3), low=-1, high=1, dtype=np.float32), - "key2": Box(shape=(), low=-1, high=1, dtype=np.float32), - "key3": Box(shape=(2,), low=-1, high=1, dtype=np.float32), - } - ), - False, + Dict( + { + "key1": Box(shape=(2, 3), low=-1, high=1), + "key2": Box(shape=(1,), low=-1, high=1), + "key3": Box(shape=(2,), low=-1, high=1), + } ), ) -class TestFlattenEnvironment: - @pytest.mark.parametrize("observation_space, ordered_values", OBSERVATION_SPACES) - def test_flattened_environment(self, observation_space, ordered_values): - """ - make sure that flattened observations occur in the order expected - """ - env = FakeEnvironment(observation_space=observation_space) - wrapped_env = FlattenObservation(env) - flattened, info = wrapped_env.reset() - - unflattened = unflatten(env.observation_space, flattened) - original = env.observation - - self._check_observations(original, flattened, unflattened, ordered_values) - - @pytest.mark.parametrize("observation_space, ordered_values", OBSERVATION_SPACES) - def test_flatten_unflatten(self, observation_space, ordered_values): - """ - test flatten and unflatten functions directly - """ - original = observation_space.sample() - - flattened = flatten(observation_space, original) - unflattened = unflatten(observation_space, flattened) - - self._check_observations(original, flattened, unflattened, ordered_values) +@pytest.mark.parametrize("observation_space", OBSERVATION_SPACES) +def test_flattened_environment(observation_space): + env = GenericTestEnv(observation_space=observation_space) + flattened_env = FlattenObservation(env) + flattened_obs, info = flattened_env.reset() - def _check_observations(self, original, flattened, unflattened, ordered_values): - # make sure that unflatten(flatten(original)) == original - assert set(unflattened.keys()) == set(original.keys()) - for k, v in original.items(): - np.testing.assert_allclose(unflattened[k], v) + assert flattened_obs in flattened_env.observation_space + assert flattened_obs not in env.observation_space - if ordered_values: - # make sure that the values were flattened in the order they appeared in the - # OrderedDict - np.testing.assert_allclose(sorted(flattened), flattened) + unflattened_obs = unflatten(env.observation_space, flattened_obs) + assert unflattened_obs in env.observation_space + assert unflattened_obs not in flattened_env.observation_space diff --git a/tests/wrappers/test_passive_env_checker.py b/tests/wrappers/test_passive_env_checker.py index e49d901de..63d441233 100644 --- a/tests/wrappers/test_passive_env_checker.py +++ b/tests/wrappers/test_passive_env_checker.py @@ -8,7 +8,7 @@ from gymnasium.wrappers.env_checker import PassiveEnvChecker from tests.envs.test_envs import PASSIVE_CHECK_IGNORE_WARNING from tests.envs.utils import all_testing_initialised_envs -from tests.testing_env import GenericTestEnv +from tests.generic_test_env import GenericTestEnv @pytest.mark.parametrize( diff --git a/tests/wrappers/test_step_compatibility.py b/tests/wrappers/test_step_compatibility.py index f4c7f465c..b7557d5ed 100644 --- a/tests/wrappers/test_step_compatibility.py +++ b/tests/wrappers/test_step_compatibility.py @@ -1,35 +1,13 @@ +from functools import partial + import pytest import gymnasium as gym -from gymnasium.spaces import Discrete from gymnasium.wrappers import StepAPICompatibility +from tests.generic_test_env import GenericTestEnv, old_step_fn - -class OldStepEnv(gym.Env): - def __init__(self): - self.action_space = Discrete(2) - self.observation_space = Discrete(2) - - def step(self, action): - obs = self.observation_space.sample() - rew = 0 - done = False - info = {} - return obs, rew, done, info - - -class NewStepEnv(gym.Env): - def __init__(self): - self.action_space = Discrete(2) - self.observation_space = Discrete(2) - - def step(self, action): - obs = self.observation_space.sample() - rew = 0 - terminated = False - truncated = False - info = {} - return obs, rew, terminated, truncated, info +OldStepEnv = partial(GenericTestEnv, step_fn=old_step_fn) +NewStepEnv = GenericTestEnv @pytest.mark.parametrize("env", [OldStepEnv, NewStepEnv]) diff --git a/tests/wrappers/test_video_recorder.py b/tests/wrappers/test_video_recorder.py index a0e38adcf..84d20c64c 100644 --- a/tests/wrappers/test_video_recorder.py +++ b/tests/wrappers/test_video_recorder.py @@ -7,27 +7,19 @@ import gymnasium as gym from gymnasium.wrappers.monitoring.video_recorder import VideoRecorder +from tests.generic_test_env import GenericTestEnv -class BrokenRecordableEnv(gym.Env): +class BrokenRecordableEnv(GenericTestEnv): metadata = {"render_modes": ["rgb_array_list"]} def __init__(self, render_mode="rgb_array_list"): - self.render_mode = render_mode + super().__init__(render_mode=render_mode) - def render(self): - pass - -class UnrecordableEnv(gym.Env): +class UnrecordableEnv(GenericTestEnv): metadata = {"render_modes": [None]} - def __init__(self, render_mode=None): - self.render_mode = render_mode - - def render(self): - pass - def test_record_simple(): env = gym.make( @@ -82,7 +74,7 @@ def test_record_unrecordable_method(): with pytest.warns( UserWarning, match=re.escape( - "\x1b[33mWARN: Disabling video recorder because environment was not initialized with any compatible video mode between `rgb_array` and `rgb_array_list`\x1b[0m" + "\x1b[33mWARN: Disabling video recorder because environment > was not initialized with any compatible video mode between `rgb_array` and `rgb_array_list`\x1b[0m" ), ): env = UnrecordableEnv() From 8eaa5bd5b5ef952826cf0abb9b9fb1aa3af08e90 Mon Sep 17 00:00:00 2001 From: StringTheory Date: Thu, 6 Oct 2022 21:23:28 +0100 Subject: [PATCH 03/12] Fix the tests through using GenericTestEnv --- tests/envs/utils_envs.py | 22 ++++++++++++++++++++-- tests/generic_test_env.py | 4 ++-- 2 files changed, 22 insertions(+), 4 deletions(-) diff --git a/tests/envs/utils_envs.py b/tests/envs/utils_envs.py index 34dfca0cc..1af99fc26 100644 --- a/tests/envs/utils_envs.py +++ b/tests/envs/utils_envs.py @@ -1,4 +1,6 @@ -from typing import Any +from typing import Any, Tuple, Dict, Optional + +from gymnasium.core import ActType, ObsType import gymnasium as gym from tests.generic_test_env import GenericTestEnv @@ -26,11 +28,27 @@ def __init__(self, render_mode: str = None): super().__init__(render_mode=render_mode) -class NoHumanOldAPI(GenericTestEnv): +class NoHumanOldAPI(gym.Env): """Environment that does not have human-rendering.""" metadata = {"render_modes": ["rgb_array"], "render_fps": 4} + def __init__(self): + pass + + def step( + self, action: ActType + ) -> Tuple[ObsType, float, bool, bool, Dict[str, Any]]: + return self.observation_space.sample(), 0, False, False, {} + + def reset( + self, + *, + seed: Optional[int] = None, + options: Optional[dict] = None, + ) -> Tuple[ObsType, dict]: + return self.observation_space.sample(), {} + class NoHumanNoRGB(gym.Env): """Environment that has neither human- nor rgb-rendering""" diff --git a/tests/generic_test_env.py b/tests/generic_test_env.py index 65e899fd7..37ce3962e 100644 --- a/tests/generic_test_env.py +++ b/tests/generic_test_env.py @@ -52,8 +52,8 @@ def __init__( "TestingEnv-v0", "testing-env-no-entry-point" ), ): - if not hasattr(self, "metadata"): - self.metadata = {} if metadata is None else metadata + if metadata is not None: + self.metadata = metadata self.render_mode = render_mode self.spec = spec From c43e48c6df581c2398f56131f3d61d23d6048d94 Mon Sep 17 00:00:00 2001 From: StringTheory Date: Fri, 7 Oct 2022 12:05:25 +0100 Subject: [PATCH 04/12] Change the type hint of wrapper.reset to the true env.reset parameters --- gymnasium/core.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/gymnasium/core.py b/gymnasium/core.py index d32d7f326..c6ff84c50 100644 --- a/gymnasium/core.py +++ b/gymnasium/core.py @@ -61,8 +61,8 @@ class Env(Generic[ObsType, ActType], ABC): spec: "EnvSpec" = None # Set these in ALL subclasses - action_space: spaces.Space - observation_space: spaces.Space + action_space: spaces.Space[ActType] + observation_space: spaces.Space[ObsType] # Created _np_random: Optional[np.random.Generator] = None @@ -315,9 +315,11 @@ def step(self, action: ActType) -> Tuple[ObsType, float, bool, bool, dict]: """Steps through the environment with action.""" return self.env.step(action) - def reset(self, **kwargs) -> Tuple[ObsType, dict]: + def reset( + self, *, seed: Optional[int] = None, options: Optional[dict] = None + ) -> Tuple[ObsType, dict]: """Resets the environment with kwargs.""" - return self.env.reset(**kwargs) + return self.env.reset(seed=seed, options=options) def render(self) -> Optional[Union[RenderFrame, List[RenderFrame]]]: """Renders the environment.""" From 8cbf48486f35b10dd775fd164c21dc6fc3f8f3d8 Mon Sep 17 00:00:00 2001 From: StringTheory Date: Mon, 10 Oct 2022 11:10:33 +0100 Subject: [PATCH 05/12] Reverted all changes that are not to core.py and test_core.py --- gymnasium/core.py | 155 +++++++++--------- .../wrappers/monitoring/video_recorder.py | 10 +- tests/envs/test_make.py | 2 +- tests/envs/utils.py | 2 - tests/envs/utils_envs.py | 43 ++--- tests/test_core.py | 37 ++++- tests/{generic_test_env.py => testing_env.py} | 7 +- tests/utils/test_env_checker.py | 2 +- tests/utils/test_passive_env_checker.py | 2 +- tests/utils/test_play.py | 2 +- tests/wrappers/test_atari_preprocessing.py | 2 +- tests/wrappers/test_flatten.py | 115 +++++++++---- tests/wrappers/test_passive_env_checker.py | 2 +- .../test_record_episode_statistics.py | 1 - tests/wrappers/test_step_compatibility.py | 32 +++- tests/wrappers/test_video_recorder.py | 18 +- 16 files changed, 265 insertions(+), 167 deletions(-) rename tests/{generic_test_env.py => testing_env.py} (94%) diff --git a/gymnasium/core.py b/gymnasium/core.py index c6ff84c50..4f0e74282 100644 --- a/gymnasium/core.py +++ b/gymnasium/core.py @@ -1,5 +1,5 @@ """Core API for Environment, Wrapper, ActionWrapper, RewardWrapper and ObservationWrapper.""" -from abc import ABC, abstractmethod +import sys from typing import ( TYPE_CHECKING, Any, @@ -16,18 +16,23 @@ import numpy as np from gymnasium import spaces +from gymnasium.logger import warn from gymnasium.utils import seeding if TYPE_CHECKING: from gymnasium.envs.registration import EnvSpec +if sys.version_info[0:2] == (3, 6): + warn( + "Gymnasium minimally supports python 3.6 as the python foundation not longer supports the version, please update your version to 3.7+" + ) ObsType = TypeVar("ObsType") ActType = TypeVar("ActType") RenderFrame = TypeVar("RenderFrame") -class Env(Generic[ObsType, ActType], ABC): +class Env(Generic[ObsType, ActType]): r"""The main Gymnasium class. It encapsulates an environment with arbitrary behind-the-scenes dynamics. @@ -78,7 +83,6 @@ def np_random(self) -> np.random.Generator: def np_random(self, value: np.random.Generator): self._np_random = value - @abstractmethod def step( self, action: ActType ) -> Tuple[ObsType, float, bool, bool, Dict[str, Any]]: @@ -111,13 +115,12 @@ def step( """ raise NotImplementedError - @abstractmethod def reset( self, *, seed: Optional[int] = None, - options: Optional[dict] = None, - ) -> Tuple[ObsType, dict]: + options: Optional[Dict[str, Any]] = None, + ) -> Tuple[ObsType, Dict[str, Any]]: """Resets the environment to an initial state and returns the initial observation. This method can reset the environment's random number generator(s) if ``seed`` is an integer or @@ -207,7 +210,11 @@ def __exit__(self, *args: List[Any]): return False -class Wrapper(Env[ObsType, ActType], ABC): +WrapperObsType = TypeVar("WrapperObsType") +WrapperActType = TypeVar("WrapperActType") + + +class Wrapper(Env[WrapperObsType, WrapperActType]): """Wraps an environment to allow a modular transformation of the :meth:`step` and :meth:`reset` methods. This class is the base class for all wrappers. The subclass could override @@ -226,47 +233,51 @@ def __init__(self, env: Env[ObsType, ActType]): """ self.env = env - self._action_space: Optional[spaces.Space] = None - self._observation_space: Optional[spaces.Space] = None + self._action_space: Optional[spaces.Space[WrapperActType]] = None + self._observation_space: Optional[spaces.Space[WrapperObsType]] = None self._reward_range: Optional[Tuple[SupportsFloat, SupportsFloat]] = None self._metadata: Optional[Dict[str, Any]] = None def __getattr__(self, name: str): """Returns an attribute with ``name``, unless ``name`` starts with an underscore.""" - if name.startswith("_"): + if name == "_np_random": + raise AttributeError( + "Can't access `_np_random` of a wrapper, use `self.unwrapped._np_random` or `self.np_random`." + ) + elif name.startswith("_"): raise AttributeError(f"accessing private attribute '{name}' is prohibited") return getattr(self.env, name) @property - def spec(self) -> Optional["EnvSpec"]: + def spec(self) -> "EnvSpec": """Returns the environment specification.""" return self.env.spec @classmethod - def class_name(cls): + def class_name(cls) -> str: """Returns the class name of the wrapper.""" return cls.__name__ @property - def action_space(self) -> spaces.Space: + def action_space(self) -> Union[spaces.Space[ActType], spaces.Space[WrapperActType]]: """Returns the action space of the environment.""" if self._action_space is None: return self.env.action_space return self._action_space @action_space.setter - def action_space(self, space: spaces.Space): + def action_space(self, space: spaces.Space[WrapperActType]): self._action_space = space @property - def observation_space(self) -> spaces.Space: + def observation_space(self) -> Union[spaces.Space[ObsType], spaces.Space[WrapperObsType]]: """Returns the observation space of the environment.""" if self._observation_space is None: return self.env.observation_space return self._observation_space @observation_space.setter - def observation_space(self, space: spaces.Space): + def observation_space(self, space: spaces.Space[WrapperObsType]): self._observation_space = space @property @@ -307,6 +318,9 @@ def np_random(self, value: np.random.Generator): @property def _np_random(self): + """This code will never be run due to __getattr__ being called prior this. + + It seems that @property overwrites the variable (`_np_random`) meaning that __getattr__ gets called with the missing variable.""" raise AttributeError( "Can't access `_np_random` of a wrapper, use `.unwrapped._np_random` or `.np_random`." ) @@ -316,9 +330,9 @@ def step(self, action: ActType) -> Tuple[ObsType, float, bool, bool, dict]: return self.env.step(action) def reset( - self, *, seed: Optional[int] = None, options: Optional[dict] = None - ) -> Tuple[ObsType, dict]: - """Resets the environment with kwargs.""" + self, *, seed: Optional[int] = None, options: Optional[Dict[str, Any]] = None + ) -> Tuple[WrapperObsType, Dict[str, Any]]: + """Resets the environment with a seed and options.""" return self.env.reset(seed=seed, options=options) def render(self) -> Optional[Union[RenderFrame, List[RenderFrame]]]: @@ -343,13 +357,13 @@ def unwrapped(self) -> Env[ObsType, ActType]: return self.env.unwrapped -class ObservationWrapper(Wrapper[ObsType, ActType], ABC): +class ObservationWrapper(Wrapper[WrapperObsType, ActType]): """Superclass of wrappers that can modify observations using :meth:`observation` for :meth:`reset` and :meth:`step`. If you would like to apply a function to the observation that is returned by the base environment before passing it to learning code, you can simply inherit from :class:`ObservationWrapper` and overwrite the method :meth:`observation` to implement that transformation. The transformation defined in that method must be - defined on the base environment's observation space. However, it may take values in a different space. + defined on the base environment’s observation space. However, it may take values in a different space. In that case, you need to specify the new observation space of the wrapper by setting :attr:`self.observation_space` in the :meth:`__init__` method of your wrapper. @@ -359,44 +373,40 @@ class ObservationWrapper(Wrapper[ObsType, ActType], ABC): ``observation["target_position"] - observation["agent_position"]``. For this, you could implement an observation wrapper like this:: - >>> import gymnasium as gym - >>> from gymnasium.spaces import Box - >>> class RelativePosition(gym.ObservationWrapper): - ... def __init__(self, env): - ... super().__init__(env) - ... self.observation_space = Box(shape=(2,), low=-np.inf, high=np.inf) - ... - ... def observation(self, obs): - ... return obs["target"] - obs["agent"] + import gymnasium as gym + from gymnasium.spaces import Box + class RelativePosition(gym.ObservationWrapper): + def __init__(self, env): + super().__init__(env) + self.observation_space = Box(shape=(2,), low=-np.inf, high=np.inf) + + def observation(self, obs): + return obs["target"] - obs["agent"] Among others, Gymnasium provides the observation wrapper :class:`TimeAwareObservation`, which adds information about the index of the timestep to the observation. """ def reset( - self, - *, - seed: Optional[int] = None, - options: Optional[Dict[str, Any]] = None, - ) -> Tuple[ObsType, Dict[str, Any]]: + self, *, seed: Optional[int] = None, options: Optional[Dict[str, Any]] = None + ) -> Tuple[WrapperObsType, Dict[str, Any]]: """Resets the environment, returning a modified observation using :meth:`self.observation`.""" obs, info = self.env.reset(seed=seed, options=options) return self.observation(obs), info def step( self, action: ActType - ) -> Tuple[ObsType, SupportsFloat, bool, bool, Dict[str, Any]]: + ) -> Tuple[WrapperObsType, float, bool, bool, Dict[str, Any]]: """Returns a modified observation using :meth:`self.observation` after calling :meth:`env.step`.""" observation, reward, terminated, truncated, info = self.env.step(action) return self.observation(observation), reward, terminated, truncated, info - @abstractmethod - def observation(self, observation: ObsType) -> ObsType: + def observation(self, observation: ObsType) -> WrapperObsType: """Returns a modified observation.""" raise NotImplementedError -class RewardWrapper(Wrapper[ObsType, ActType], ABC): +class RewardWrapper(Wrapper[ObsType, ActType]): """Superclass of wrappers that can modify the returning reward from a step. If you would like to apply a function to the reward that is returned by the base environment before @@ -409,32 +419,31 @@ class RewardWrapper(Wrapper[ObsType, ActType], ABC): because it is intrinsic), we want to clip the reward to a range to gain some numerical stability. To do that, we could, for instance, implement the following wrapper:: - >>> import gymnasium as gym - >>> class ClipReward(gym.RewardWrapper): - ... def __init__(self, env, min_reward, max_reward): - ... super().__init__(env) - ... self.min_reward = min_reward - ... self.max_reward = max_reward - ... self.reward_range = (min_reward, max_reward) - ... - ... def reward(self, r): - ... return np.clip(r, self.min_reward, self.max_reward) + import gymnasium as gym + class ClipReward(gym.RewardWrapper): + def __init__(self, env, min_reward, max_reward): + super().__init__(env) + self.min_reward = min_reward + self.max_reward = max_reward + self.reward_range = (min_reward, max_reward) + + def reward(self, r: float) -> float: + return np.clip(r, self.min_reward, self.max_reward) """ def step( self, action: ActType - ) -> Tuple[ObsType, SupportsFloat, bool, bool, Dict[str, Any]]: + ) -> Tuple[ObsType, float, bool, bool, Dict[str, Any]]: """Modifies the reward using :meth:`self.reward` after the environment :meth:`env.step`.""" observation, reward, terminated, truncated, info = self.env.step(action) return observation, self.reward(reward), terminated, truncated, info - @abstractmethod - def reward(self, reward: SupportsFloat) -> SupportsFloat: + def reward(self, reward: SupportsFloat) -> float: """Returns a modified ``reward``.""" raise NotImplementedError -class ActionWrapper(Wrapper[ObsType, ActType], ABC): +class ActionWrapper(Wrapper[ObsType, WrapperActType]): """Superclass of wrappers that can modify the action before :meth:`env.step`. If you would like to apply a function to the action before passing it to the base environment, @@ -444,41 +453,39 @@ class ActionWrapper(Wrapper[ObsType, ActType], ABC): In that case, you need to specify the new action space of the wrapper by setting :attr:`self.action_space` in the :meth:`__init__` method of your wrapper. - Let's say you have an environment with action space of type :class:`gymnasium.spaces.Box`, but you would only like + Let’s say you have an environment with action space of type :class:`gymnasium.spaces.Box`, but you would only like to use a finite subset of actions. Then, you might want to implement the following wrapper:: - >>> import gymnasium as gym - >>> from gymnasium.spaces import Discrete - >>> class DiscreteActions(gym.ActionWrapper): - ... def __init__(self, env, disc_to_cont): - ... super().__init__(env) - ... self.disc_to_cont = disc_to_cont - ... self.action_space = Discrete(len(disc_to_cont)) - ... - ... def action(self, act): - ... return self.disc_to_cont[act] + import gymnasium as gym + from gymnasium.spaces import Discrete + class DiscreteActions(gym.ActionWrapper): + def __init__(self, env, disc_to_cont): + super().__init__(env) + self.disc_to_cont = disc_to_cont + self.action_space = Discrete(len(disc_to_cont)) - >>> if __name__ == "__main__": - >>> env = gym.make("LunarLanderContinuous-v2") - >>> wrapped_env = DiscreteActions(env, [np.array([1,0]), np.array([-1,0]), - ... np.array([0,1]), np.array([0,-1])]) - >>> print(wrapped_env.action_space) #Discrete(4) + def action(self, act): + return self.disc_to_cont[ac + if __name__ == "__main__": + env = gym.make("LunarLanderContinuous-v2") + wrapped_env = DiscreteActions(env, [np.array([1,0]), np.array([-1,0]), + np.array([0,1]), np.array([0,-1])]) + print(wrapped_env.action_space) #Discrete(4) Among others, Gymnasium provides the action wrappers :class:`ClipAction` and :class:`RescaleAction`. """ def step( - self, action: ActType - ) -> Tuple[ObsType, SupportsFloat, bool, bool, Dict[str, Any]]: + self, action: WrapperActType + ) -> Tuple[ObsType, float, bool, bool, Dict[str, Any]]: """Runs the environment :meth:`env.step` using the modified ``action`` from :meth:`self.action`.""" return self.env.step(self.action(action)) - @abstractmethod - def action(self, action: ActType) -> ActType: + def action(self, action: WrapperActType) -> ActType: """Returns a modified action before :meth:`env.step` is called.""" raise NotImplementedError - def reverse_action(self, action: ActType) -> ActType: + def reverse_action(self, action: ActType) -> WrapperActType: """Returns a reversed ``action``.""" raise NotImplementedError diff --git a/gymnasium/wrappers/monitoring/video_recorder.py b/gymnasium/wrappers/monitoring/video_recorder.py index 72a5b9e42..899a71f9d 100644 --- a/gymnasium/wrappers/monitoring/video_recorder.py +++ b/gymnasium/wrappers/monitoring/video_recorder.py @@ -38,10 +38,6 @@ def __init__( Error: You can pass at most one of `path` or `base_path` Error: Invalid path given that must have a particular file extension """ - self._async = env.metadata.get("semantics.async") - self.enabled = enabled - self._closed = False - try: # check that moviepy is now installed import moviepy # noqa: F401 @@ -50,12 +46,16 @@ def __init__( "MoviePy is not installed, run `pip install moviepy`" ) + self._async = env.metadata.get("semantics.async") + self.enabled = enabled + self._closed = False + self.render_history = [] self.env = env self.render_mode = env.render_mode - if "rgb_array" != self.render_mode and "rgb_array_list" != self.render_mode: + if "rgb_array_list" != self.render_mode and "rgb_array" != self.render_mode: logger.warn( f"Disabling video recorder because environment {env} was not initialized with any compatible video " "mode between `rgb_array` and `rgb_array_list`" diff --git a/tests/envs/test_make.py b/tests/envs/test_make.py index ebe919143..4844b8a28 100644 --- a/tests/envs/test_make.py +++ b/tests/envs/test_make.py @@ -19,7 +19,7 @@ from tests.envs.test_envs import PASSIVE_CHECK_IGNORE_WARNING from tests.envs.utils import all_testing_env_specs from tests.envs.utils_envs import ArgumentEnv, RegisterDuringMakeEnv -from tests.generic_test_env import GenericTestEnv, old_step_fn +from tests.testing_env import GenericTestEnv, old_step_fn from tests.wrappers.utils import has_wrapper gym.register( diff --git a/tests/envs/utils.py b/tests/envs/utils.py index d9b1d1db5..36f60a1a1 100644 --- a/tests/envs/utils.py +++ b/tests/envs/utils.py @@ -19,8 +19,6 @@ def try_make_env(env_spec: EnvSpec) -> Optional[gym.Env]: return env_spec.make(disable_env_checker=True).unwrapped except (ImportError, gym.error.DependencyNotInstalled) as e: logger.warn(f"Not testing {env_spec.id} due to error: {e}") - except Exception as e: - logger.warn(f"Unexpected exception occurred: {e}") return None diff --git a/tests/envs/utils_envs.py b/tests/envs/utils_envs.py index 1af99fc26..8713f4b20 100644 --- a/tests/envs/utils_envs.py +++ b/tests/envs/utils_envs.py @@ -1,54 +1,43 @@ -from typing import Any, Tuple, Dict, Optional - -from gymnasium.core import ActType, ObsType - import gymnasium as gym -from tests.generic_test_env import GenericTestEnv -class RegisterDuringMakeEnv(GenericTestEnv): +class RegisterDuringMakeEnv(gym.Env): """Used in `test_registration.py` to check if `env.make` can import and register an env""" + def __init__(self): + self.action_space = gym.spaces.Discrete(1) + self.observation_space = gym.spaces.Discrete(1) + -class ArgumentEnv(GenericTestEnv): - def __init__(self, arg1: Any, arg2: Any, arg3: Any): - super().__init__() +class ArgumentEnv(gym.Env): + observation_space = gym.spaces.Box(low=-1, high=1, shape=(1,)) + action_space = gym.spaces.Box(low=-1, high=1, shape=(1,)) - self.arg1, self.arg2, self.arg3 = arg1, arg2, arg3 + def __init__(self, arg1, arg2, arg3): + self.arg1 = arg1 + self.arg2 = arg2 + self.arg3 = arg3 # Environments to test render_mode -class NoHuman(GenericTestEnv): +class NoHuman(gym.Env): """Environment that does not have human-rendering.""" metadata = {"render_modes": ["rgb_array_list"], "render_fps": 4} - def __init__(self, render_mode: str = None): + def __init__(self, render_mode=None): assert render_mode in self.metadata["render_modes"] - super().__init__(render_mode=render_mode) + self.render_mode = render_mode class NoHumanOldAPI(gym.Env): """Environment that does not have human-rendering.""" - metadata = {"render_modes": ["rgb_array"], "render_fps": 4} + metadata = {"render_modes": ["rgb_array_list"], "render_fps": 4} def __init__(self): pass - def step( - self, action: ActType - ) -> Tuple[ObsType, float, bool, bool, Dict[str, Any]]: - return self.observation_space.sample(), 0, False, False, {} - - def reset( - self, - *, - seed: Optional[int] = None, - options: Optional[dict] = None, - ) -> Tuple[ObsType, dict]: - return self.observation_space.sample(), {} - class NoHumanNoRGB(gym.Env): """Environment that has neither human- nor rgb-rendering""" diff --git a/tests/test_core.py b/tests/test_core.py index f34ee62ba..c533ae7ee 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -1,14 +1,23 @@ +"""Checks that the core Gymnasium API is implemented as expected.""" +import re from typing import Any, Dict, Optional, SupportsFloat, Tuple import numpy as np +import pytest from gymnasium import Env, ObservationWrapper, RewardWrapper, Wrapper -from gymnasium.core import ActionWrapper, ActType, ObsType +from gymnasium.core import ActionWrapper, ActType, ObsType, WrapperObsType, WrapperActType from gymnasium.spaces import Box -from tests.generic_test_env import GenericTestEnv +from gymnasium.utils import seeding +from tests.testing_env import GenericTestEnv class ExampleEnv(Env): + + def __init__(self): + self.observation_space = Box(0, 1) + self.action_space = Box(0, 1) + def step( self, action: ActType ) -> Tuple[ObsType, float, bool, bool, Dict[str, Any]]: @@ -34,7 +43,23 @@ def test_gymnasium_env(): class ExampleWrapper(Wrapper): - pass + + def __init__(self, env: Env[ObsType, ActType]): + super().__init__(env) + + self.new_reward = 3 + + def reset(self, *, seed: Optional[int] = None, options: Optional[Dict[str, Any]] = None) -> Tuple[ + WrapperObsType, Dict[str, Any]]: + return super().reset(seed=seed, options=options) + + def step(self, action: WrapperActType) -> Tuple[WrapperObsType, float, bool, bool, Dict[str, Any]]: + obs, reward, termination, truncation, info = self.env.step(action) + return obs, self.new_reward, termination, truncation, info + + def access_hidden_np_random(self): + """This should raise an error when called as wrappers should not access their own `_np_random` instances and should use the unwrapped environments.""" + return self._np_random def test_gymnasium_wrapper(): @@ -62,6 +87,12 @@ def test_gymnasium_wrapper(): assert env.observation_space != wrapper_env.observation_space assert env.action_space != wrapper_env.action_space + wrapper_env.np_random, _ = seeding.np_random() + assert env._np_random is env.np_random is wrapper_env.np_random + assert 0 <= wrapper_env.np_random.uniform() <= 1 + with pytest.raises(AttributeError, match=re.escape("Can't access `_np_random` of a wrapper, use `self.unwrapped._np_random` or `self.np_random`.")): + print(wrapper_env.access_hidden_np_random()) + class ExampleRewardWrapper(RewardWrapper): def reward(self, reward: SupportsFloat) -> SupportsFloat: diff --git a/tests/generic_test_env.py b/tests/testing_env.py similarity index 94% rename from tests/generic_test_env.py rename to tests/testing_env.py index 37ce3962e..c3e957aea 100644 --- a/tests/generic_test_env.py +++ b/tests/testing_env.py @@ -48,12 +48,9 @@ def __init__( render_fn: callable = basic_render_fn, metadata: Optional[Dict[str, Any]] = None, render_mode: Optional[str] = None, - spec: Optional[EnvSpec] = EnvSpec( - "TestingEnv-v0", "testing-env-no-entry-point" - ), + spec: EnvSpec = EnvSpec("TestingEnv-v0", "testing-env-no-entry-point"), ): - if metadata is not None: - self.metadata = metadata + self.metadata = {} if metadata is None else metadata self.render_mode = render_mode self.spec = spec diff --git a/tests/utils/test_env_checker.py b/tests/utils/test_env_checker.py index cf518a11f..eaed92273 100644 --- a/tests/utils/test_env_checker.py +++ b/tests/utils/test_env_checker.py @@ -17,7 +17,7 @@ check_reset_seed, check_seed_deprecation, ) -from tests.generic_test_env import GenericTestEnv +from tests.testing_env import GenericTestEnv @pytest.mark.parametrize( diff --git a/tests/utils/test_passive_env_checker.py b/tests/utils/test_passive_env_checker.py index cba60201d..719fa2d85 100644 --- a/tests/utils/test_passive_env_checker.py +++ b/tests/utils/test_passive_env_checker.py @@ -15,7 +15,7 @@ env_reset_passive_checker, env_step_passive_checker, ) -from tests.generic_test_env import GenericTestEnv +from tests.testing_env import GenericTestEnv def _modify_space(space: spaces.Space, attribute: str, value): diff --git a/tests/utils/test_play.py b/tests/utils/test_play.py index b4b321278..054acf137 100644 --- a/tests/utils/test_play.py +++ b/tests/utils/test_play.py @@ -10,7 +10,7 @@ import gymnasium as gym from gymnasium.utils.play import MissingKeysToAction, PlayableGame, play -from tests.generic_test_env import GenericTestEnv +from tests.testing_env import GenericTestEnv RELEVANT_KEY_1 = ord("a") # 97 RELEVANT_KEY_2 = ord("d") # 100 diff --git a/tests/wrappers/test_atari_preprocessing.py b/tests/wrappers/test_atari_preprocessing.py index 0263f4002..b451f528f 100644 --- a/tests/wrappers/test_atari_preprocessing.py +++ b/tests/wrappers/test_atari_preprocessing.py @@ -3,7 +3,7 @@ from gymnasium.spaces import Box, Discrete from gymnasium.wrappers import AtariPreprocessing, StepAPICompatibility -from tests.generic_test_env import GenericTestEnv, old_step_fn +from tests.testing_env import GenericTestEnv, old_step_fn class AleTesting: diff --git a/tests/wrappers/test_flatten.py b/tests/wrappers/test_flatten.py index bb0398f86..9c6f08022 100644 --- a/tests/wrappers/test_flatten.py +++ b/tests/wrappers/test_flatten.py @@ -1,51 +1,98 @@ """Tests for the flatten observation wrapper.""" from collections import OrderedDict +from typing import Optional +import numpy as np import pytest -from gymnasium.spaces import Box, Dict, unflatten +import gymnasium as gym +from gymnasium.spaces import Box, Dict, flatten, unflatten from gymnasium.wrappers import FlattenObservation -from tests.generic_test_env import GenericTestEnv + + +class FakeEnvironment(gym.Env): + def __init__(self, observation_space): + self.observation_space = observation_space + + def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None): + super().reset(seed=seed) + self.observation = self.observation_space.sample() + return self.observation, {} + OBSERVATION_SPACES = ( - Dict( - OrderedDict( - [ - ("key1", Box(shape=(2, 3), low=0, high=0)), - ("key2", Box(shape=(1,), low=1, high=1)), - ("key3", Box(shape=(2,), low=2, high=2)), - ] - ) + ( + Dict( + OrderedDict( + [ + ("key1", Box(shape=(2, 3), low=0, high=0, dtype=np.float32)), + ("key2", Box(shape=(), low=1, high=1, dtype=np.float32)), + ("key3", Box(shape=(2,), low=2, high=2, dtype=np.float32)), + ] + ) + ), + True, ), - Dict( - OrderedDict( - [ - ("key2", Box(shape=(1,), low=0, high=0)), - ("key3", Box(shape=(2,), low=1, high=1)), - ("key1", Box(shape=(2, 3), low=2, high=2)), - ] - ) + ( + Dict( + OrderedDict( + [ + ("key2", Box(shape=(), low=0, high=0, dtype=np.float32)), + ("key3", Box(shape=(2,), low=1, high=1, dtype=np.float32)), + ("key1", Box(shape=(2, 3), low=2, high=2, dtype=np.float32)), + ] + ) + ), + True, ), - Dict( - { - "key1": Box(shape=(2, 3), low=-1, high=1), - "key2": Box(shape=(1,), low=-1, high=1), - "key3": Box(shape=(2,), low=-1, high=1), - } + ( + Dict( + { + "key1": Box(shape=(2, 3), low=-1, high=1, dtype=np.float32), + "key2": Box(shape=(), low=-1, high=1, dtype=np.float32), + "key3": Box(shape=(2,), low=-1, high=1, dtype=np.float32), + } + ), + False, ), ) -@pytest.mark.parametrize("observation_space", OBSERVATION_SPACES) -def test_flattened_environment(observation_space): - env = GenericTestEnv(observation_space=observation_space) - flattened_env = FlattenObservation(env) - flattened_obs, info = flattened_env.reset() +class TestFlattenEnvironment: + @pytest.mark.parametrize("observation_space, ordered_values", OBSERVATION_SPACES) + def test_flattened_environment(self, observation_space, ordered_values): + """ + make sure that flattened observations occur in the order expected + """ + env = FakeEnvironment(observation_space=observation_space) + wrapped_env = FlattenObservation(env) + flattened, info = wrapped_env.reset() + + unflattened = unflatten(env.observation_space, flattened) + original = env.observation + + self._check_observations(original, flattened, unflattened, ordered_values) + + @pytest.mark.parametrize("observation_space, ordered_values", OBSERVATION_SPACES) + def test_flatten_unflatten(self, observation_space, ordered_values): + """ + test flatten and unflatten functions directly + """ + original = observation_space.sample() + + flattened = flatten(observation_space, original) + unflattened = unflatten(observation_space, flattened) + + self._check_observations(original, flattened, unflattened, ordered_values) - assert flattened_obs in flattened_env.observation_space - assert flattened_obs not in env.observation_space + def _check_observations(self, original, flattened, unflattened, ordered_values): + # make sure that unflatten(flatten(original)) == original + assert set(unflattened.keys()) == set(original.keys()) + for k, v in original.items(): + np.testing.assert_allclose(unflattened[k], v) - unflattened_obs = unflatten(env.observation_space, flattened_obs) - assert unflattened_obs in env.observation_space - assert unflattened_obs not in flattened_env.observation_space + if ordered_values: + # make sure that the values were flattened in the order they appeared in the + # OrderedDict + np.testing.assert_allclose(sorted(flattened), flattened) diff --git a/tests/wrappers/test_passive_env_checker.py b/tests/wrappers/test_passive_env_checker.py index 63d441233..e49d901de 100644 --- a/tests/wrappers/test_passive_env_checker.py +++ b/tests/wrappers/test_passive_env_checker.py @@ -8,7 +8,7 @@ from gymnasium.wrappers.env_checker import PassiveEnvChecker from tests.envs.test_envs import PASSIVE_CHECK_IGNORE_WARNING from tests.envs.utils import all_testing_initialised_envs -from tests.generic_test_env import GenericTestEnv +from tests.testing_env import GenericTestEnv @pytest.mark.parametrize( diff --git a/tests/wrappers/test_record_episode_statistics.py b/tests/wrappers/test_record_episode_statistics.py index c9ed0d6d0..72d33c4e9 100644 --- a/tests/wrappers/test_record_episode_statistics.py +++ b/tests/wrappers/test_record_episode_statistics.py @@ -11,7 +11,6 @@ def test_record_episode_statistics(env_id, deque_size): env = gym.make(env_id, disable_env_checker=True) env = RecordEpisodeStatistics(env, deque_size) - assert env.spec is not None for n in range(5): env.reset() diff --git a/tests/wrappers/test_step_compatibility.py b/tests/wrappers/test_step_compatibility.py index b7557d5ed..f4c7f465c 100644 --- a/tests/wrappers/test_step_compatibility.py +++ b/tests/wrappers/test_step_compatibility.py @@ -1,13 +1,35 @@ -from functools import partial - import pytest import gymnasium as gym +from gymnasium.spaces import Discrete from gymnasium.wrappers import StepAPICompatibility -from tests.generic_test_env import GenericTestEnv, old_step_fn -OldStepEnv = partial(GenericTestEnv, step_fn=old_step_fn) -NewStepEnv = GenericTestEnv + +class OldStepEnv(gym.Env): + def __init__(self): + self.action_space = Discrete(2) + self.observation_space = Discrete(2) + + def step(self, action): + obs = self.observation_space.sample() + rew = 0 + done = False + info = {} + return obs, rew, done, info + + +class NewStepEnv(gym.Env): + def __init__(self): + self.action_space = Discrete(2) + self.observation_space = Discrete(2) + + def step(self, action): + obs = self.observation_space.sample() + rew = 0 + terminated = False + truncated = False + info = {} + return obs, rew, terminated, truncated, info @pytest.mark.parametrize("env", [OldStepEnv, NewStepEnv]) diff --git a/tests/wrappers/test_video_recorder.py b/tests/wrappers/test_video_recorder.py index 84d20c64c..a0e38adcf 100644 --- a/tests/wrappers/test_video_recorder.py +++ b/tests/wrappers/test_video_recorder.py @@ -7,19 +7,27 @@ import gymnasium as gym from gymnasium.wrappers.monitoring.video_recorder import VideoRecorder -from tests.generic_test_env import GenericTestEnv -class BrokenRecordableEnv(GenericTestEnv): +class BrokenRecordableEnv(gym.Env): metadata = {"render_modes": ["rgb_array_list"]} def __init__(self, render_mode="rgb_array_list"): - super().__init__(render_mode=render_mode) + self.render_mode = render_mode + def render(self): + pass -class UnrecordableEnv(GenericTestEnv): + +class UnrecordableEnv(gym.Env): metadata = {"render_modes": [None]} + def __init__(self, render_mode=None): + self.render_mode = render_mode + + def render(self): + pass + def test_record_simple(): env = gym.make( @@ -74,7 +82,7 @@ def test_record_unrecordable_method(): with pytest.warns( UserWarning, match=re.escape( - "\x1b[33mWARN: Disabling video recorder because environment > was not initialized with any compatible video mode between `rgb_array` and `rgb_array_list`\x1b[0m" + "\x1b[33mWARN: Disabling video recorder because environment was not initialized with any compatible video mode between `rgb_array` and `rgb_array_list`\x1b[0m" ), ): env = UnrecordableEnv() From 8559e12b856f19fc83aa7981726555414d1e92a5 Mon Sep 17 00:00:00 2001 From: StringTheory Date: Mon, 10 Oct 2022 11:12:56 +0100 Subject: [PATCH 06/12] pre-commit --- gymnasium/core.py | 11 ++++++++--- tests/test_core.py | 30 ++++++++++++++++++++++-------- 2 files changed, 30 insertions(+), 11 deletions(-) diff --git a/gymnasium/core.py b/gymnasium/core.py index 4f0e74282..1065c4dab 100644 --- a/gymnasium/core.py +++ b/gymnasium/core.py @@ -259,7 +259,9 @@ def class_name(cls) -> str: return cls.__name__ @property - def action_space(self) -> Union[spaces.Space[ActType], spaces.Space[WrapperActType]]: + def action_space( + self, + ) -> Union[spaces.Space[ActType], spaces.Space[WrapperActType]]: """Returns the action space of the environment.""" if self._action_space is None: return self.env.action_space @@ -270,7 +272,9 @@ def action_space(self, space: spaces.Space[WrapperActType]): self._action_space = space @property - def observation_space(self) -> Union[spaces.Space[ObsType], spaces.Space[WrapperObsType]]: + def observation_space( + self, + ) -> Union[spaces.Space[ObsType], spaces.Space[WrapperObsType]]: """Returns the observation space of the environment.""" if self._observation_space is None: return self.env.observation_space @@ -320,7 +324,8 @@ def np_random(self, value: np.random.Generator): def _np_random(self): """This code will never be run due to __getattr__ being called prior this. - It seems that @property overwrites the variable (`_np_random`) meaning that __getattr__ gets called with the missing variable.""" + It seems that @property overwrites the variable (`_np_random`) meaning that __getattr__ gets called with the missing variable. + """ raise AttributeError( "Can't access `_np_random` of a wrapper, use `.unwrapped._np_random` or `.np_random`." ) diff --git a/tests/test_core.py b/tests/test_core.py index c533ae7ee..3a966f5d3 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -6,14 +6,19 @@ import pytest from gymnasium import Env, ObservationWrapper, RewardWrapper, Wrapper -from gymnasium.core import ActionWrapper, ActType, ObsType, WrapperObsType, WrapperActType +from gymnasium.core import ( + ActionWrapper, + ActType, + ObsType, + WrapperActType, + WrapperObsType, +) from gymnasium.spaces import Box from gymnasium.utils import seeding from tests.testing_env import GenericTestEnv class ExampleEnv(Env): - def __init__(self): self.observation_space = Box(0, 1) self.action_space = Box(0, 1) @@ -43,17 +48,19 @@ def test_gymnasium_env(): class ExampleWrapper(Wrapper): - def __init__(self, env: Env[ObsType, ActType]): super().__init__(env) self.new_reward = 3 - def reset(self, *, seed: Optional[int] = None, options: Optional[Dict[str, Any]] = None) -> Tuple[ - WrapperObsType, Dict[str, Any]]: + def reset( + self, *, seed: Optional[int] = None, options: Optional[Dict[str, Any]] = None + ) -> Tuple[WrapperObsType, Dict[str, Any]]: return super().reset(seed=seed, options=options) - def step(self, action: WrapperActType) -> Tuple[WrapperObsType, float, bool, bool, Dict[str, Any]]: + def step( + self, action: WrapperActType + ) -> Tuple[WrapperObsType, float, bool, bool, Dict[str, Any]]: obs, reward, termination, truncation, info = self.env.step(action) return obs, self.new_reward, termination, truncation, info @@ -88,9 +95,16 @@ def test_gymnasium_wrapper(): assert env.action_space != wrapper_env.action_space wrapper_env.np_random, _ = seeding.np_random() - assert env._np_random is env.np_random is wrapper_env.np_random + assert ( + env._np_random is env.np_random is wrapper_env.np_random + ) # ignore: reportPrivateUsage assert 0 <= wrapper_env.np_random.uniform() <= 1 - with pytest.raises(AttributeError, match=re.escape("Can't access `_np_random` of a wrapper, use `self.unwrapped._np_random` or `self.np_random`.")): + with pytest.raises( + AttributeError, + match=re.escape( + "Can't access `_np_random` of a wrapper, use `self.unwrapped._np_random` or `self.np_random`." + ), + ): print(wrapper_env.access_hidden_np_random()) From cb6c92c1149c34db6d9f4d662bf87afd246fd0a6 Mon Sep 17 00:00:00 2001 From: StringTheory Date: Wed, 26 Oct 2022 16:49:25 +0100 Subject: [PATCH 07/12] Code review by Ariel --- gymnasium/core.py | 55 +++++++++++++++++++++++------------------------ 1 file changed, 27 insertions(+), 28 deletions(-) diff --git a/gymnasium/core.py b/gymnasium/core.py index df4431f92..eaeae88da 100644 --- a/gymnasium/core.py +++ b/gymnasium/core.py @@ -1,14 +1,13 @@ """Core API for Environment, Wrapper, ActionWrapper, RewardWrapper and ObservationWrapper.""" +from __future__ import annotations + import sys from typing import ( TYPE_CHECKING, Any, - Dict, Generic, - List, Optional, SupportsFloat, - Tuple, TypeVar, Union, ) @@ -59,7 +58,7 @@ class Env(Generic[ObsType, ActType]): """ # Set this in SOME subclasses - metadata: Dict[str, Any] = {"render_modes": []} + metadata: dict[str, Any] = {"render_modes": []} # define render_mode if your environment supports rendering render_mode: Optional[str] = None reward_range = (-float("inf"), float("inf")) @@ -85,7 +84,7 @@ def np_random(self, value: np.random.Generator): def step( self, action: ActType - ) -> Tuple[ObsType, float, bool, bool, Dict[str, Any]]: + ) -> tuple[ObsType, float, bool, bool, dict[str, Any]]: """Run one timestep of the environment's dynamics. When end of episode is reached, you are responsible for calling :meth:`reset` to reset this environment's state. @@ -115,8 +114,8 @@ def reset( self, *, seed: Optional[int] = None, - options: Optional[Dict[str, Any]] = None, - ) -> Tuple[ObsType, Dict[str, Any]]: + options: Optional[dict[str, Any]] = None, + ) -> tuple[ObsType, dict[str, Any]]: """Resets the environment to an initial state and returns the initial observation. This method can reset the environment's random number generator(s) if ``seed`` is an integer or @@ -146,7 +145,7 @@ def reset( if seed is not None: self._np_random, seed = seeding.np_random(seed) - def render(self) -> Optional[Union[RenderFrame, List[RenderFrame]]]: + def render(self) -> Optional[Union[RenderFrame, list[RenderFrame]]]: """Compute the render frames as specified by render_mode attribute during initialization of the environment. The set of supported modes varies per environment. (And some @@ -195,7 +194,7 @@ def __enter__(self): """Support with-statement for the environment.""" return self - def __exit__(self, *args: List[Any]): + def __exit__(self, *args: Any): """Support with-statement for the environment.""" self.close() # propagate exception @@ -227,8 +226,8 @@ def __init__(self, env: Env[ObsType, ActType]): self._action_space: Optional[spaces.Space[WrapperActType]] = None self._observation_space: Optional[spaces.Space[WrapperObsType]] = None - self._reward_range: Optional[Tuple[SupportsFloat, SupportsFloat]] = None - self._metadata: Optional[Dict[str, Any]] = None + self._reward_range: Optional[tuple[SupportsFloat, SupportsFloat]] = None + self._metadata: Optional[dict[str, Any]] = None def __getattr__(self, name: str): """Returns an attribute with ``name``, unless ``name`` starts with an underscore.""" @@ -277,25 +276,25 @@ def observation_space(self, space: spaces.Space[WrapperObsType]): self._observation_space = space @property - def reward_range(self) -> Tuple[SupportsFloat, SupportsFloat]: + def reward_range(self) -> tuple[SupportsFloat, SupportsFloat]: """Return the reward range of the environment.""" if self._reward_range is None: return self.env.reward_range return self._reward_range @reward_range.setter - def reward_range(self, value: Tuple[SupportsFloat, SupportsFloat]): + def reward_range(self, value: tuple[SupportsFloat, SupportsFloat]): self._reward_range = value @property - def metadata(self) -> Dict[str, Any]: + def metadata(self) -> dict[str, Any]: """Returns the environment metadata.""" if self._metadata is None: return self.env.metadata return self._metadata @metadata.setter - def metadata(self, value: Dict[str, Any]): + def metadata(self, value: dict[str, Any]): self._metadata = value @property @@ -322,17 +321,17 @@ def _np_random(self): "Can't access `_np_random` of a wrapper, use `.unwrapped._np_random` or `.np_random`." ) - def step(self, action: ActType) -> Tuple[ObsType, float, bool, bool, dict]: + def step(self, action: ActType) -> tuple[ObsType, float, bool, bool, dict]: """Steps through the environment with action.""" return self.env.step(action) def reset( - self, *, seed: Optional[int] = None, options: Optional[Dict[str, Any]] = None - ) -> Tuple[WrapperObsType, Dict[str, Any]]: + self, *, seed: Optional[int] = None, options: Optional[dict[str, Any]] = None + ) -> tuple[WrapperObsType, dict[str, Any]]: """Resets the environment with a seed and options.""" return self.env.reset(seed=seed, options=options) - def render(self) -> Optional[Union[RenderFrame, List[RenderFrame]]]: + def render(self) -> Optional[Union[RenderFrame, list[RenderFrame]]]: """Renders the environment.""" return self.env.render() @@ -370,7 +369,7 @@ class ObservationWrapper(Wrapper[WrapperObsType, ActType]): ``observation["target_position"] - observation["agent_position"]``. For this, you could implement an observation wrapper like this:: - class RelativePosition(gymnasium.ObservationWrapper): + class RelativePosition(gym.ObservationWrapper): def __init__(self, env): super().__init__(env) self.observation_space = Box(shape=(2,), low=-np.inf, high=np.inf) @@ -383,15 +382,15 @@ def observation(self, obs): """ def reset( - self, *, seed: Optional[int] = None, options: Optional[Dict[str, Any]] = None - ) -> Tuple[WrapperObsType, Dict[str, Any]]: + self, *, seed: Optional[int] = None, options: Optional[dict[str, Any]] = None + ) -> tuple[WrapperObsType, dict[str, Any]]: """Resets the environment, returning a modified observation using :meth:`self.observation`.""" obs, info = self.env.reset(seed=seed, options=options) return self.observation(obs), info def step( self, action: ActType - ) -> Tuple[WrapperObsType, float, bool, bool, Dict[str, Any]]: + ) -> tuple[WrapperObsType, float, bool, bool, dict[str, Any]]: """Returns a modified observation using :meth:`self.observation` after calling :meth:`env.step`.""" observation, reward, terminated, truncated, info = self.env.step(action) return self.observation(observation), reward, terminated, truncated, info @@ -414,7 +413,7 @@ class RewardWrapper(Wrapper[ObsType, ActType]): because it is intrinsic), we want to clip the reward to a range to gain some numerical stability. To do that, we could, for instance, implement the following wrapper:: - class ClipReward(gymnasium.RewardWrapper): + class ClipReward(gym.RewardWrapper): def __init__(self, env, min_reward, max_reward): super().__init__(env) self.min_reward = min_reward @@ -427,7 +426,7 @@ def reward(self, r: float) -> float: def step( self, action: ActType - ) -> Tuple[ObsType, float, bool, bool, Dict[str, Any]]: + ) -> tuple[ObsType, float, bool, bool, dict[str, Any]]: """Modifies the reward using :meth:`self.reward` after the environment :meth:`env.step`.""" observation, reward, terminated, truncated, info = self.env.step(action) return observation, self.reward(reward), terminated, truncated, info @@ -450,7 +449,7 @@ class ActionWrapper(Wrapper[ObsType, WrapperActType]): Let’s say you have an environment with action space of type :class:`gymnasium.spaces.Box`, but you would only like to use a finite subset of actions. Then, you might want to implement the following wrapper:: - class DiscreteActions(gymnasium.ActionWrapper): + class DiscreteActions(gym.ActionWrapper): def __init__(self, env, disc_to_cont): super().__init__(env) self.disc_to_cont = disc_to_cont @@ -460,7 +459,7 @@ def action(self, act): return self.disc_to_cont[act] if __name__ == "__main__": - env = gymnasium.make("LunarLanderContinuous-v2") + env = gym.make("LunarLanderContinuous-v2") wrapped_env = DiscreteActions(env, [np.array([1,0]), np.array([-1,0]), np.array([0,1]), np.array([0,-1])]) print(wrapped_env.action_space) #Discrete(4) @@ -471,7 +470,7 @@ def action(self, act): def step( self, action: WrapperActType - ) -> Tuple[ObsType, float, bool, bool, Dict[str, Any]]: + ) -> tuple[ObsType, float, bool, bool, dict[str, Any]]: """Runs the environment :meth:`env.step` using the modified ``action`` from :meth:`self.action`.""" return self.env.step(self.action(action)) From 19224f5c1019f694a916c110e1269127794d05ac Mon Sep 17 00:00:00 2001 From: StringTheory Date: Wed, 26 Oct 2022 16:51:12 +0100 Subject: [PATCH 08/12] pre-commit --- gymnasium/core.py | 16 ++++------------ 1 file changed, 4 insertions(+), 12 deletions(-) diff --git a/gymnasium/core.py b/gymnasium/core.py index eaeae88da..b00637d3b 100644 --- a/gymnasium/core.py +++ b/gymnasium/core.py @@ -2,15 +2,7 @@ from __future__ import annotations import sys -from typing import ( - TYPE_CHECKING, - Any, - Generic, - Optional, - SupportsFloat, - TypeVar, - Union, -) +from typing import TYPE_CHECKING, Any, Generic, Optional, SupportsFloat, TypeVar, Union import numpy as np @@ -62,7 +54,7 @@ class Env(Generic[ObsType, ActType]): # define render_mode if your environment supports rendering render_mode: Optional[str] = None reward_range = (-float("inf"), float("inf")) - spec: "EnvSpec" = None + spec: EnvSpec = None # Set these in ALL subclasses action_space: spaces.Space[ActType] @@ -175,7 +167,7 @@ def close(self): pass @property - def unwrapped(self) -> "Env[ObsType, ActType]": + def unwrapped(self) -> Env[ObsType, ActType]: """Returns the base non-wrapped environment. Returns: @@ -240,7 +232,7 @@ def __getattr__(self, name: str): return getattr(self.env, name) @property - def spec(self) -> "EnvSpec": + def spec(self) -> EnvSpec: """Returns the environment specification.""" return self.env.spec From 78cec07f42fb97fd934a91f0b4c00d1373ff4483 Mon Sep 17 00:00:00 2001 From: StringTheory Date: Wed, 26 Oct 2022 18:14:56 +0100 Subject: [PATCH 09/12] pre-commit --- tests/test_core.py | 146 ++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 143 insertions(+), 3 deletions(-) diff --git a/tests/test_core.py b/tests/test_core.py index 3a966f5d3..00d278860 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -5,7 +5,7 @@ import numpy as np import pytest -from gymnasium import Env, ObservationWrapper, RewardWrapper, Wrapper +from gymnasium import Env, ObservationWrapper, RewardWrapper, Wrapper, spaces from gymnasium.core import ( ActionWrapper, ActType, @@ -15,8 +15,146 @@ ) from gymnasium.spaces import Box from gymnasium.utils import seeding +from gymnasium.wrappers import OrderEnforcing, TimeLimit from tests.testing_env import GenericTestEnv +# ==== Old testing code + + +class ArgumentEnv(Env): + observation_space = spaces.Box(low=0, high=1, shape=(1,)) + action_space = spaces.Box(low=0, high=1, shape=(1,)) + calls = 0 + + def __init__(self, arg): + self.calls += 1 + self.arg = arg + + +class UnittestEnv(Env): + observation_space = spaces.Box(low=0, high=255, shape=(64, 64, 3), dtype=np.uint8) + action_space = spaces.Discrete(3) + + def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None): + super().reset(seed=seed) + return self.observation_space.sample(), {"info": "dummy"} + + def step(self, action): + observation = self.observation_space.sample() # Dummy observation + return (observation, 0.0, False, {}) + + +class UnknownSpacesEnv(Env): + """This environment defines its observation & action spaces only + after the first call to reset. Although this pattern is sometimes + necessary when implementing a new environment (e.g. if it depends + on external resources), it is not encouraged. + """ + + def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None): + super().reset(seed=seed) + self.observation_space = spaces.Box( + low=0, high=255, shape=(64, 64, 3), dtype=np.uint8 + ) + self.action_space = spaces.Discrete(3) + return self.observation_space.sample(), {} # Dummy observation with info + + def step(self, action): + observation = self.observation_space.sample() # Dummy observation + return (observation, 0.0, False, {}) + + +class OldStyleEnv(Env): + """This environment doesn't accept any arguments in reset, ideally we want to support this too (for now)""" + + def __init__(self): + pass + + def reset(self): + super().reset() + return 0 + + def step(self, action): + return 0, 0, False, {} + + +class NewPropertyWrapper(Wrapper): + def __init__( + self, + env, + observation_space=None, + action_space=None, + reward_range=None, + metadata=None, + ): + super().__init__(env) + if observation_space is not None: + # Only set the observation space if not None to test property forwarding + self.observation_space = observation_space + if action_space is not None: + self.action_space = action_space + if reward_range is not None: + self.reward_range = reward_range + if metadata is not None: + self.metadata = metadata + + +def test_env_instantiation(): + # This looks like a pretty trivial, but given our usage of + # __new__, it's worth having. + env = ArgumentEnv("arg") + assert env.arg == "arg" + assert env.calls == 1 + + +properties = [ + { + "observation_space": spaces.Box( + low=0.0, high=1.0, shape=(64, 64, 3), dtype=np.float32 + ) + }, + {"action_space": spaces.Discrete(2)}, + {"reward_range": (-1.0, 1.0)}, + {"metadata": {"render_modes": ["human", "rgb_array_list"]}}, + { + "observation_space": spaces.Box( + low=0.0, high=1.0, shape=(64, 64, 3), dtype=np.float32 + ), + "action_space": spaces.Discrete(2), + }, +] + + +@pytest.mark.parametrize("class_", [UnittestEnv, UnknownSpacesEnv]) +@pytest.mark.parametrize("props", properties) +def test_wrapper_property_forwarding(class_, props): + env = class_() + env = NewPropertyWrapper(env, **props) + + # If UnknownSpacesEnv, then call reset to define the spaces + if isinstance(env.unwrapped, UnknownSpacesEnv): + _ = env.reset() + + # Test the properties set by the wrapper + for key, value in props.items(): + assert getattr(env, key) == value + + # Otherwise, test if the properties are forwarded + all_properties = {"observation_space", "action_space", "reward_range", "metadata"} + for key in all_properties - props.keys(): + assert getattr(env, key) == getattr(env.unwrapped, key) + + +def test_compatibility_with_old_style_env(): + env = OldStyleEnv() + env = OrderEnforcing(env) + env = TimeLimit(env) + obs = env.reset() + assert obs == 0 + + +# ==== New testing code + class ExampleEnv(Env): def __init__(self): @@ -96,8 +234,10 @@ def test_gymnasium_wrapper(): wrapper_env.np_random, _ = seeding.np_random() assert ( - env._np_random is env.np_random is wrapper_env.np_random - ) # ignore: reportPrivateUsage + env._np_random # pyright: ignore [reportPrivateUsage] + is env.np_random + is wrapper_env.np_random + ) assert 0 <= wrapper_env.np_random.uniform() <= 1 with pytest.raises( AttributeError, From ec71d7d14df9911e1994451ddd94d61b78ab2e5f Mon Sep 17 00:00:00 2001 From: StringTheory Date: Wed, 26 Oct 2022 18:26:17 +0100 Subject: [PATCH 10/12] pre-commit --- gymnasium/core.py | 32 ++++++++++++++++---------------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/gymnasium/core.py b/gymnasium/core.py index d9fd30cbf..c85f19c5c 100644 --- a/gymnasium/core.py +++ b/gymnasium/core.py @@ -1,7 +1,7 @@ """Core API for Environment, Wrapper, ActionWrapper, RewardWrapper and ObservationWrapper.""" from __future__ import annotations -from typing import TYPE_CHECKING, Any, Generic, Optional, SupportsFloat, TypeVar, Union +from typing import TYPE_CHECKING, Any, Generic, SupportsFloat, TypeVar import numpy as np @@ -48,7 +48,7 @@ class Env(Generic[ObsType, ActType]): # Set this in SOME subclasses metadata: dict[str, Any] = {"render_modes": []} # define render_mode if your environment supports rendering - render_mode: Optional[str] = None + render_mode: str | None = None reward_range = (-float("inf"), float("inf")) spec: EnvSpec = None @@ -57,7 +57,7 @@ class Env(Generic[ObsType, ActType]): observation_space: spaces.Space[ObsType] # Created - _np_random: Optional[np.random.Generator] = None + _np_random: np.random.Generator | None = None def step( self, action: ActType @@ -102,8 +102,8 @@ def step( def reset( self, *, - seed: Optional[int] = None, - options: Optional[dict[str, Any]] = None, + seed: int | None = None, + options: dict[str, Any] | None = None, ) -> tuple[ObsType, dict[str, Any]]: """Resets the environment to an initial internal state, returning an initial observation and info. @@ -142,7 +142,7 @@ def reset( if seed is not None: self._np_random, seed = seeding.np_random(seed) - def render(self) -> Optional[Union[RenderFrame, list[RenderFrame]]]: + def render(self) -> RenderFrame | list[RenderFrame] | None: """Compute the render frames as specified by :attr:`render_mode` during the initialization of the environment. The environment's :attr:`metadata` render modes (`env.metadata["render_modes"]`) should contain the possible @@ -298,10 +298,10 @@ def __init__(self, env: Env[ObsType, ActType]): """ self.env = env - self._action_space: Optional[spaces.Space[WrapperActType]] = None - self._observation_space: Optional[spaces.Space[WrapperObsType]] = None - self._reward_range: Optional[tuple[SupportsFloat, SupportsFloat]] = None - self._metadata: Optional[dict[str, Any]] = None + self._action_space: spaces.Space[WrapperActType] | None = None + self._observation_space: spaces.Space[WrapperObsType] | None = None + self._reward_range: tuple[SupportsFloat, SupportsFloat] | None = None + self._metadata: dict[str, Any] | None = None def __getattr__(self, name: str): """Returns an attribute with ``name``, unless ``name`` starts with an underscore.""" @@ -326,7 +326,7 @@ def class_name(cls) -> str: @property def action_space( self, - ) -> Union[spaces.Space[ActType], spaces.Space[WrapperActType]]: + ) -> spaces.Space[ActType] | spaces.Space[WrapperActType]: """Return the :attr:`Env` :attr:`action_space` unless overwritten then the wrapper :attr:`action_space` is used.""" if self._action_space is None: return self.env.action_space @@ -339,7 +339,7 @@ def action_space(self, space: spaces.Space[WrapperActType]): @property def observation_space( self, - ) -> Union[spaces.Space[ObsType], spaces.Space[WrapperObsType]]: + ) -> spaces.Space[ObsType] | spaces.Space[WrapperObsType]: """Return the :attr:`Env` :attr:`observation_space` unless overwritten then the wrapper :attr:`observation_space` is used.""" if self._observation_space is None: return self.env.observation_space @@ -372,7 +372,7 @@ def metadata(self, value: dict[str, Any]): self._metadata = value @property - def render_mode(self) -> Optional[str]: + def render_mode(self) -> str | None: """Returns the :attr:`Env` :attr:`render_mode`.""" return self.env.render_mode @@ -400,12 +400,12 @@ def step(self, action: ActType) -> tuple[ObsType, float, bool, bool, dict]: return self.env.step(action) def reset( - self, *, seed: Optional[int] = None, options: Optional[dict[str, Any]] = None + self, *, seed: int | None = None, options: dict[str, Any] | None = None ) -> tuple[WrapperObsType, dict[str, Any]]: """Uses the :meth:`reset` of the :attr:`env` that can be overwritten to change the returned data.""" return self.env.reset(seed=seed, options=options) - def render(self) -> Optional[Union[RenderFrame, list[RenderFrame]]]: + def render(self) -> RenderFrame | list[RenderFrame] | None: """Uses the :meth:`render` of the :attr:`env` that can be overwritten to change the returned data.""" return self.env.render() @@ -455,7 +455,7 @@ def observation(self, obs): """ def reset( - self, *, seed: Optional[int] = None, options: Optional[dict[str, Any]] = None + self, *, seed: int | None = None, options: dict[str, Any] | None = None ) -> tuple[WrapperObsType, dict[str, Any]]: """Modifies the :attr:`env` after calling :meth:`reset`, returning a modified observation using :meth:`self.observation`.""" obs, info = self.env.reset(seed=seed, options=options) From c49e12e9d1907189e9caef4e3d9673d147530226 Mon Sep 17 00:00:00 2001 From: StringTheory Date: Thu, 27 Oct 2022 09:12:55 +0100 Subject: [PATCH 11/12] Fix pyright with env.spec: Optional[EnvSpec] --- gymnasium/core.py | 38 +++++++++++++------ gymnasium/envs/box2d/bipedal_walker.py | 1 + gymnasium/envs/box2d/car_racing.py | 1 + gymnasium/envs/box2d/lunar_lander.py | 1 + gymnasium/envs/classic_control/acrobot.py | 1 + gymnasium/envs/classic_control/cartpole.py | 1 + .../continuous_mountain_car.py | 1 + .../envs/classic_control/mountain_car.py | 1 + gymnasium/envs/classic_control/pendulum.py | 1 + gymnasium/envs/toy_text/blackjack.py | 1 + gymnasium/envs/toy_text/cliffwalking.py | 1 + gymnasium/envs/toy_text/frozen_lake.py | 1 + gymnasium/envs/toy_text/taxi.py | 1 + gymnasium/utils/play.py | 2 + gymnasium/wrappers/atari_preprocessing.py | 2 +- gymnasium/wrappers/time_limit.py | 1 + tests/envs/test_action_dim_check.py | 5 ++- tests/envs/test_envs.py | 2 +- tests/envs/test_make.py | 4 ++ tests/envs/test_register.py | 1 + tests/envs/test_spec.py | 1 + tests/test_core.py | 2 +- tests/vector/test_vector_make.py | 1 + tests/wrappers/test_passive_env_checker.py | 2 +- .../test_record_episode_statistics.py | 1 + 25 files changed, 56 insertions(+), 18 deletions(-) diff --git a/gymnasium/core.py b/gymnasium/core.py index c85f19c5c..68c7f7a30 100644 --- a/gymnasium/core.py +++ b/gymnasium/core.py @@ -1,7 +1,7 @@ """Core API for Environment, Wrapper, ActionWrapper, RewardWrapper and ObservationWrapper.""" from __future__ import annotations -from typing import TYPE_CHECKING, Any, Generic, SupportsFloat, TypeVar +from typing import TYPE_CHECKING, Any, Generic, Optional, SupportsFloat, TypeVar import numpy as np @@ -50,7 +50,7 @@ class Env(Generic[ObsType, ActType]): # define render_mode if your environment supports rendering render_mode: str | None = None reward_range = (-float("inf"), float("inf")) - spec: EnvSpec = None + spec: EnvSpec | None = None # Set these in ALL subclasses action_space: spaces.Space[ActType] @@ -61,7 +61,7 @@ class Env(Generic[ObsType, ActType]): def step( self, action: ActType - ) -> tuple[ObsType, float, bool, bool, dict[str, Any]]: + ) -> tuple[ObsType, SupportsFloat, bool, bool, dict[str, Any]]: """Run one timestep of the environment's dynamics using the agent actions. When the end of an episode is reached (``terminated or truncated``), it is necessary to call :meth:`reset` to @@ -79,7 +79,7 @@ def step( Returns: observation (ObsType): An element of the environment's :attr:`observation_space` as the next observation due to the agent actions. An example is a numpy array containing the positions and velocities of the pole in CartPole. - reward (float): The reward as a result of taking the action. + reward (SupportsFloat): The reward as a result of taking the action. terminated (bool): Whether the agent reaches the terminal state (as defined under the MDP of the task) which can be positive or negative. An example is reaching the goal state or moving into the lava from the Sutton and Barton, Gridworld. If true, the user needs to call :meth:`reset`. @@ -104,7 +104,7 @@ def reset( *, seed: int | None = None, options: dict[str, Any] | None = None, - ) -> tuple[ObsType, dict[str, Any]]: + ) -> tuple[ObsType, dict[str, Any]]: # type: ignore """Resets the environment to an initial internal state, returning an initial observation and info. This method generates a new starting state often with some randomness to ensure that the agent explores the @@ -314,7 +314,7 @@ def __getattr__(self, name: str): return getattr(self.env, name) @property - def spec(self) -> EnvSpec: + def spec(self) -> Optional[EnvSpec]: """Returns the :attr:`Env` :attr:`spec` attribute.""" return self.env.spec @@ -395,7 +395,9 @@ def _np_random(self): "Can't access `_np_random` of a wrapper, use `.unwrapped._np_random` or `.np_random`." ) - def step(self, action: ActType) -> tuple[ObsType, float, bool, bool, dict]: + def step( + self, action: WrapperActType + ) -> tuple[WrapperObsType, SupportsFloat, bool, bool, dict]: """Uses the :meth:`step` of the :attr:`env` that can be overwritten to change the returned data.""" return self.env.step(action) @@ -454,6 +456,10 @@ def observation(self, obs): index of the timestep to the observation. """ + def __init__(self, env: Env[ObsType, ActType]): + """Constructor for the observation wrapper.""" + super().__init__(env) + def reset( self, *, seed: int | None = None, options: dict[str, Any] | None = None ) -> tuple[WrapperObsType, dict[str, Any]]: @@ -463,7 +469,7 @@ def reset( def step( self, action: ActType - ) -> tuple[WrapperObsType, float, bool, bool, dict[str, Any]]: + ) -> tuple[WrapperObsType, SupportsFloat, bool, bool, dict[str, Any]]: """Modifies the :attr:`env` after calling :meth:`step` using :meth:`self.observation` on the returned observations.""" observation, reward, terminated, truncated, info = self.env.step(action) return self.observation(observation), reward, terminated, truncated, info @@ -500,18 +506,22 @@ def __init__(self, env, min_reward, max_reward): self.max_reward = max_reward self.reward_range = (min_reward, max_reward) - def reward(self, r: float) -> float: + def reward(self, r: SupportsFloat) -> SupportsFloat: return np.clip(r, self.min_reward, self.max_reward) """ + def __init__(self, env: Env[ObsType, ActType]): + """Constructor for the Reward wrapper.""" + super().__init__(env) + def step( self, action: ActType - ) -> tuple[ObsType, float, bool, bool, dict[str, Any]]: + ) -> tuple[ObsType, SupportsFloat, bool, bool, dict[str, Any]]: """Modifies the :attr:`env` :meth:`step` reward using :meth:`self.reward`.""" observation, reward, terminated, truncated, info = self.env.step(action) return observation, self.reward(reward), terminated, truncated, info - def reward(self, reward: SupportsFloat) -> float: + def reward(self, reward: SupportsFloat) -> SupportsFloat: """Returns a modified environment ``reward``. Args: @@ -554,9 +564,13 @@ def action(self, act): Among others, Gymnasium provides the action wrappers :class:`ClipAction` and :class:`RescaleAction` for clipping and rescaling actions. """ + def __init__(self, env: Env[ObsType, ActType]): + """Constructor for the action wrapper.""" + super().__init__(env) + def step( self, action: WrapperActType - ) -> tuple[ObsType, float, bool, bool, dict[str, Any]]: + ) -> tuple[ObsType, SupportsFloat, bool, bool, dict[str, Any]]: """Runs the :attr:`env` :meth:`env.step` using the modified ``action`` from :meth:`self.action`.""" return self.env.step(self.action(action)) diff --git a/gymnasium/envs/box2d/bipedal_walker.py b/gymnasium/envs/box2d/bipedal_walker.py index 5f078f58f..81eeff0e5 100644 --- a/gymnasium/envs/box2d/bipedal_walker.py +++ b/gymnasium/envs/box2d/bipedal_walker.py @@ -609,6 +609,7 @@ def step(self, action: np.ndarray): def render(self): if self.render_mode is None: + assert self.spec is not None gym.logger.warn( "You are calling render method without specifying any render mode. " "You can specify the render_mode at initialization, " diff --git a/gymnasium/envs/box2d/car_racing.py b/gymnasium/envs/box2d/car_racing.py index db6822a86..38d017552 100644 --- a/gymnasium/envs/box2d/car_racing.py +++ b/gymnasium/envs/box2d/car_racing.py @@ -570,6 +570,7 @@ def step(self, action: Union[np.ndarray, int]): def render(self): if self.render_mode is None: + assert self.spec is not None gym.logger.warn( "You are calling render method without specifying any render mode. " "You can specify the render_mode at initialization, " diff --git a/gymnasium/envs/box2d/lunar_lander.py b/gymnasium/envs/box2d/lunar_lander.py index 29bc67f32..5945dfcc4 100644 --- a/gymnasium/envs/box2d/lunar_lander.py +++ b/gymnasium/envs/box2d/lunar_lander.py @@ -603,6 +603,7 @@ def step(self, action): def render(self): if self.render_mode is None: + assert self.spec is not None gym.logger.warn( "You are calling render method without specifying any render mode. " "You can specify the render_mode at initialization, " diff --git a/gymnasium/envs/classic_control/acrobot.py b/gymnasium/envs/classic_control/acrobot.py index fc335d068..34e08d7ae 100644 --- a/gymnasium/envs/classic_control/acrobot.py +++ b/gymnasium/envs/classic_control/acrobot.py @@ -282,6 +282,7 @@ def _dsdt(self, s_augmented): def render(self): if self.render_mode is None: + assert self.spec is not None gym.logger.warn( "You are calling render method without specifying any render mode. " "You can specify the render_mode at initialization, " diff --git a/gymnasium/envs/classic_control/cartpole.py b/gymnasium/envs/classic_control/cartpole.py index c173de7a7..4d36a36ff 100644 --- a/gymnasium/envs/classic_control/cartpole.py +++ b/gymnasium/envs/classic_control/cartpole.py @@ -210,6 +210,7 @@ def reset( def render(self): if self.render_mode is None: + assert self.spec is not None gym.logger.warn( "You are calling render method without specifying any render mode. " "You can specify the render_mode at initialization, " diff --git a/gymnasium/envs/classic_control/continuous_mountain_car.py b/gymnasium/envs/classic_control/continuous_mountain_car.py index 4d9460020..a12da6258 100644 --- a/gymnasium/envs/classic_control/continuous_mountain_car.py +++ b/gymnasium/envs/classic_control/continuous_mountain_car.py @@ -194,6 +194,7 @@ def _height(self, xs): def render(self): if self.render_mode is None: + assert self.spec is not None gym.logger.warn( "You are calling render method without specifying any render mode. " "You can specify the render_mode at initialization, " diff --git a/gymnasium/envs/classic_control/mountain_car.py b/gymnasium/envs/classic_control/mountain_car.py index 3dabf5698..25ab6603d 100644 --- a/gymnasium/envs/classic_control/mountain_car.py +++ b/gymnasium/envs/classic_control/mountain_car.py @@ -171,6 +171,7 @@ def _height(self, xs): def render(self): if self.render_mode is None: + assert self.spec is not None gym.logger.warn( "You are calling render method without specifying any render mode. " "You can specify the render_mode at initialization, " diff --git a/gymnasium/envs/classic_control/pendulum.py b/gymnasium/envs/classic_control/pendulum.py index 908ce2ec8..7f4dae966 100644 --- a/gymnasium/envs/classic_control/pendulum.py +++ b/gymnasium/envs/classic_control/pendulum.py @@ -168,6 +168,7 @@ def _get_obs(self): def render(self): if self.render_mode is None: + assert self.spec is not None gym.logger.warn( "You are calling render method without specifying any render mode. " "You can specify the render_mode at initialization, " diff --git a/gymnasium/envs/toy_text/blackjack.py b/gymnasium/envs/toy_text/blackjack.py index 4d15b038e..7649c1386 100644 --- a/gymnasium/envs/toy_text/blackjack.py +++ b/gymnasium/envs/toy_text/blackjack.py @@ -192,6 +192,7 @@ def reset( def render(self): if self.render_mode is None: + assert self.spec is not None gym.logger.warn( "You are calling render method without specifying any render mode. " "You can specify the render_mode at initialization, " diff --git a/gymnasium/envs/toy_text/cliffwalking.py b/gymnasium/envs/toy_text/cliffwalking.py index 0c2686812..7677977c1 100644 --- a/gymnasium/envs/toy_text/cliffwalking.py +++ b/gymnasium/envs/toy_text/cliffwalking.py @@ -165,6 +165,7 @@ def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None): def render(self): if self.render_mode is None: + assert self.spec is not None gym.logger.warn( "You are calling render method without specifying any render mode. " "You can specify the render_mode at initialization, " diff --git a/gymnasium/envs/toy_text/frozen_lake.py b/gymnasium/envs/toy_text/frozen_lake.py index 9993330e4..98895eee5 100644 --- a/gymnasium/envs/toy_text/frozen_lake.py +++ b/gymnasium/envs/toy_text/frozen_lake.py @@ -270,6 +270,7 @@ def reset( def render(self): if self.render_mode is None: + assert self.spec is not None gym.logger.warn( "You are calling render method without specifying any render mode. " "You can specify the render_mode at initialization, " diff --git a/gymnasium/envs/toy_text/taxi.py b/gymnasium/envs/toy_text/taxi.py index a8146aafd..fdb79def5 100644 --- a/gymnasium/envs/toy_text/taxi.py +++ b/gymnasium/envs/toy_text/taxi.py @@ -281,6 +281,7 @@ def reset( def render(self): if self.render_mode is None: + assert self.spec is not None gym.logger.warn( "You are calling render method without specifying any render mode. " "You can specify the render_mode at initialization, " diff --git a/gymnasium/utils/play.py b/gymnasium/utils/play.py index c171323db..247777c69 100644 --- a/gymnasium/utils/play.py +++ b/gymnasium/utils/play.py @@ -72,6 +72,7 @@ def _get_relevant_keys( elif hasattr(self.env.unwrapped, "get_keys_to_action"): keys_to_action = self.env.unwrapped.get_keys_to_action() else: + assert self.env.spec is not None raise MissingKeysToAction( f"{self.env.spec.id} does not have explicit key to action mapping, " "please specify one manually" @@ -230,6 +231,7 @@ def play( elif hasattr(env.unwrapped, "get_keys_to_action"): keys_to_action = env.unwrapped.get_keys_to_action() else: + assert env.spec is not None raise MissingKeysToAction( f"{env.spec.id} does not have explicit key to action mapping, " "please specify one manually" diff --git a/gymnasium/wrappers/atari_preprocessing.py b/gymnasium/wrappers/atari_preprocessing.py index d4768c126..3e8c678cb 100644 --- a/gymnasium/wrappers/atari_preprocessing.py +++ b/gymnasium/wrappers/atari_preprocessing.py @@ -68,7 +68,7 @@ def __init__( assert screen_size > 0 assert noop_max >= 0 if frame_skip > 1: - if ( + if (env.spec is not None and "NoFrameskip" not in env.spec.id and getattr(env.unwrapped, "_frameskip", None) != 1 ): diff --git a/gymnasium/wrappers/time_limit.py b/gymnasium/wrappers/time_limit.py index bcd9821f6..47e75a874 100644 --- a/gymnasium/wrappers/time_limit.py +++ b/gymnasium/wrappers/time_limit.py @@ -30,6 +30,7 @@ def __init__( """ super().__init__(env) if max_episode_steps is None and self.env.spec is not None: + assert env.spec is not None max_episode_steps = env.spec.max_episode_steps if self.env.spec is not None: self.env.spec.max_episode_steps = max_episode_steps diff --git a/tests/envs/test_action_dim_check.py b/tests/envs/test_action_dim_check.py index 5dc6786b9..68f8fc34c 100644 --- a/tests/envs/test_action_dim_check.py +++ b/tests/envs/test_action_dim_check.py @@ -57,7 +57,7 @@ def test_mujoco_action_dimensions(env_spec: EnvSpec): @pytest.mark.parametrize( - "env", DISCRETE_ENVS, ids=[env.spec.id for env in DISCRETE_ENVS] + "env", DISCRETE_ENVS, ids=[env.spec.id for env in DISCRETE_ENVS if env.spec is not None] ) def test_discrete_actions_out_of_bound(env: gym.Env): """Test out of bound actions in Discrete action_space. @@ -87,7 +87,7 @@ def test_discrete_actions_out_of_bound(env: gym.Env): OOB_VALUE = 100 -@pytest.mark.parametrize("env", BOX_ENVS, ids=[env.spec.id for env in BOX_ENVS]) +@pytest.mark.parametrize("env", BOX_ENVS, ids=[env.spec.id for env in BOX_ENVS if env.spec is not None]) def test_box_actions_out_of_bound(env: gym.Env): """Test out of bound actions in Box action_space. @@ -100,6 +100,7 @@ def test_box_actions_out_of_bound(env: gym.Env): """ env.reset(seed=42) + assert env.spec is not None oob_env = gym.make(env.spec.id, disable_env_checker=True) oob_env.reset(seed=42) diff --git a/tests/envs/test_envs.py b/tests/envs/test_envs.py index 2f9880a37..fec015adc 100644 --- a/tests/envs/test_envs.py +++ b/tests/envs/test_envs.py @@ -187,7 +187,7 @@ def test_render_modes(spec): @pytest.mark.parametrize( "env", all_testing_initialised_envs, - ids=[env.spec.id for env in all_testing_initialised_envs], + ids=[env.spec.id for env in all_testing_initialised_envs if env.spec is not None], ) def test_pickle_env(env: gym.Env): pickled_env = pickle.loads(pickle.dumps(env)) diff --git a/tests/envs/test_make.py b/tests/envs/test_make.py index 4844b8a28..850ec515c 100644 --- a/tests/envs/test_make.py +++ b/tests/envs/test_make.py @@ -53,6 +53,7 @@ def test_make(): env = gym.make("CartPole-v1", disable_env_checker=True) + assert env.spec is not None assert env.spec.id == "CartPole-v1" assert isinstance(env.unwrapped, cartpole.CartPoleEnv) env.close() @@ -73,6 +74,7 @@ def test_make_max_episode_steps(): # Default, uses the spec's env = gym.make("CartPole-v1", disable_env_checker=True) assert has_wrapper(env, TimeLimit) + assert env.spec is not None assert ( env.spec.max_episode_steps == gym.envs.registry["CartPole-v1"].max_episode_steps ) @@ -81,6 +83,7 @@ def test_make_max_episode_steps(): # Custom max episode steps env = gym.make("CartPole-v1", max_episode_steps=100, disable_env_checker=True) assert has_wrapper(env, TimeLimit) + assert env.spec is not None assert env.spec.max_episode_steps == 100 env.close() @@ -297,6 +300,7 @@ def test_make_kwargs(): arg3="override_arg3", disable_env_checker=True, ) + assert env.spec is not None assert env.spec.id == "test.ArgumentEnv-v0" assert isinstance(env.unwrapped, ArgumentEnv) assert env.arg1 == "arg1" diff --git a/tests/envs/test_register.py b/tests/envs/test_register.py index f7b0ca7ba..079902135 100644 --- a/tests/envs/test_register.py +++ b/tests/envs/test_register.py @@ -183,6 +183,7 @@ def test_make_latest_versioned_env(register_testing_envs): env = gym.make( "MyAwesomeNamespace/MyAwesomeVersionedEnv", disable_env_checker=True ) + assert env.spec is not None assert env.spec.id == "MyAwesomeNamespace/MyAwesomeVersionedEnv-v5" diff --git a/tests/envs/test_spec.py b/tests/envs/test_spec.py index 94b78ad75..283f9e0b7 100644 --- a/tests/envs/test_spec.py +++ b/tests/envs/test_spec.py @@ -16,6 +16,7 @@ def test_spec(): def test_spec_kwargs(): map_name_value = "8x8" env = gym.make("FrozenLake-v1", map_name=map_name_value) + assert env.spec is not None assert env.spec.kwargs["map_name"] == map_name_value diff --git a/tests/test_core.py b/tests/test_core.py index 00d278860..97e2d07c2 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -182,7 +182,7 @@ def test_gymnasium_env(): assert env.render_mode is None assert env.reward_range == (-float("inf"), float("inf")) assert env.spec is None - assert env._np_random is None + assert env._np_random is None # pyright: ignore [reportPrivateUsage] class ExampleWrapper(Wrapper): diff --git a/tests/vector/test_vector_make.py b/tests/vector/test_vector_make.py index 58bc75a88..97d74640c 100644 --- a/tests/vector/test_vector_make.py +++ b/tests/vector/test_vector_make.py @@ -38,6 +38,7 @@ def test_vector_make_wrappers(): sub_env = env.envs[0] assert isinstance(sub_env, gym.Env) + assert sub_env.spec is not None if sub_env.spec.order_enforce: assert has_wrapper(sub_env, OrderEnforcing) if sub_env.spec.max_episode_steps is not None: diff --git a/tests/wrappers/test_passive_env_checker.py b/tests/wrappers/test_passive_env_checker.py index e49d901de..fbc5c1cec 100644 --- a/tests/wrappers/test_passive_env_checker.py +++ b/tests/wrappers/test_passive_env_checker.py @@ -14,7 +14,7 @@ @pytest.mark.parametrize( "env", all_testing_initialised_envs, - ids=[env.spec.id for env in all_testing_initialised_envs], + ids=[env.spec.id for env in all_testing_initialised_envs if env.spec is not None], ) def test_passive_checker_wrapper_warnings(env): with warnings.catch_warnings(record=True) as caught_warnings: diff --git a/tests/wrappers/test_record_episode_statistics.py b/tests/wrappers/test_record_episode_statistics.py index c5e3114f2..f277ea937 100644 --- a/tests/wrappers/test_record_episode_statistics.py +++ b/tests/wrappers/test_record_episode_statistics.py @@ -16,6 +16,7 @@ def test_record_episode_statistics(env_id, deque_size): assert env.episode_returns is not None and env.episode_lengths is not None assert env.episode_returns[0] == 0.0 assert env.episode_lengths[0] == 0 + assert env.spec is not None for t in range(env.spec.max_episode_steps): _, _, terminated, truncated, info = env.step(env.action_space.sample()) if terminated or truncated: From 53b9ac649e002fd637f5a68a5e360136f30c4988 Mon Sep 17 00:00:00 2001 From: StringTheory Date: Sun, 30 Oct 2022 12:13:27 +0000 Subject: [PATCH 12/12] pre-commit --- gymnasium/core.py | 4 ++-- gymnasium/wrappers/atari_preprocessing.py | 5 +++-- tests/envs/test_action_dim_check.py | 8 ++++++-- 3 files changed, 11 insertions(+), 6 deletions(-) diff --git a/gymnasium/core.py b/gymnasium/core.py index 68c7f7a30..de7005dec 100644 --- a/gymnasium/core.py +++ b/gymnasium/core.py @@ -1,7 +1,7 @@ """Core API for Environment, Wrapper, ActionWrapper, RewardWrapper and ObservationWrapper.""" from __future__ import annotations -from typing import TYPE_CHECKING, Any, Generic, Optional, SupportsFloat, TypeVar +from typing import TYPE_CHECKING, Any, Generic, SupportsFloat, TypeVar import numpy as np @@ -314,7 +314,7 @@ def __getattr__(self, name: str): return getattr(self.env, name) @property - def spec(self) -> Optional[EnvSpec]: + def spec(self) -> EnvSpec | None: """Returns the :attr:`Env` :attr:`spec` attribute.""" return self.env.spec diff --git a/gymnasium/wrappers/atari_preprocessing.py b/gymnasium/wrappers/atari_preprocessing.py index 3e8c678cb..c18655117 100644 --- a/gymnasium/wrappers/atari_preprocessing.py +++ b/gymnasium/wrappers/atari_preprocessing.py @@ -68,8 +68,9 @@ def __init__( assert screen_size > 0 assert noop_max >= 0 if frame_skip > 1: - if (env.spec is not None and - "NoFrameskip" not in env.spec.id + if ( + env.spec is not None + and "NoFrameskip" not in env.spec.id and getattr(env.unwrapped, "_frameskip", None) != 1 ): raise ValueError( diff --git a/tests/envs/test_action_dim_check.py b/tests/envs/test_action_dim_check.py index 68f8fc34c..9610a861b 100644 --- a/tests/envs/test_action_dim_check.py +++ b/tests/envs/test_action_dim_check.py @@ -57,7 +57,9 @@ def test_mujoco_action_dimensions(env_spec: EnvSpec): @pytest.mark.parametrize( - "env", DISCRETE_ENVS, ids=[env.spec.id for env in DISCRETE_ENVS if env.spec is not None] + "env", + DISCRETE_ENVS, + ids=[env.spec.id for env in DISCRETE_ENVS if env.spec is not None], ) def test_discrete_actions_out_of_bound(env: gym.Env): """Test out of bound actions in Discrete action_space. @@ -87,7 +89,9 @@ def test_discrete_actions_out_of_bound(env: gym.Env): OOB_VALUE = 100 -@pytest.mark.parametrize("env", BOX_ENVS, ids=[env.spec.id for env in BOX_ENVS if env.spec is not None]) +@pytest.mark.parametrize( + "env", BOX_ENVS, ids=[env.spec.id for env in BOX_ENVS if env.spec is not None] +) def test_box_actions_out_of_bound(env: gym.Env): """Test out of bound actions in Box action_space.