Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update the type hinting for core.py #39

Merged
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
161 changes: 89 additions & 72 deletions gymnasium/core.py

Large diffs are not rendered by default.

10 changes: 5 additions & 5 deletions gymnasium/wrappers/monitoring/video_recorder.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,10 @@ def __init__(
Error: You can pass at most one of `path` or `base_path`
Error: Invalid path given that must have a particular file extension
"""
self._async = env.metadata.get("semantics.async")
pseudo-rnd-thoughts marked this conversation as resolved.
Show resolved Hide resolved
self.enabled = enabled
self._closed = False

try:
# check that moviepy is now installed
import moviepy # noqa: F401
Expand All @@ -46,16 +50,12 @@ def __init__(
"MoviePy is not installed, run `pip install moviepy`"
)

self._async = env.metadata.get("semantics.async")
self.enabled = enabled
self._closed = False

self.render_history = []
self.env = env

self.render_mode = env.render_mode

if "rgb_array_list" != self.render_mode and "rgb_array" != self.render_mode:
if "rgb_array" != self.render_mode and "rgb_array_list" != self.render_mode:
pseudo-rnd-thoughts marked this conversation as resolved.
Show resolved Hide resolved
logger.warn(
f"Disabling video recorder because environment {env} was not initialized with any compatible video "
"mode between `rgb_array` and `rgb_array_list`"
Expand Down
2 changes: 1 addition & 1 deletion tests/envs/test_make.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from tests.envs.test_envs import PASSIVE_CHECK_IGNORE_WARNING
from tests.envs.utils import all_testing_env_specs
from tests.envs.utils_envs import ArgumentEnv, RegisterDuringMakeEnv
from tests.testing_env import GenericTestEnv, old_step_fn
from tests.generic_test_env import GenericTestEnv, old_step_fn
from tests.wrappers.utils import has_wrapper

gym.register(
Expand Down
2 changes: 2 additions & 0 deletions tests/envs/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ def try_make_env(env_spec: EnvSpec) -> Optional[gym.Env]:
return env_spec.make(disable_env_checker=True).unwrapped
except (ImportError, gym.error.DependencyNotInstalled) as e:
logger.warn(f"Not testing {env_spec.id} due to error: {e}")
except Exception as e:
logger.warn(f"Unexpected exception occurred: {e}")
return None


Expand Down
43 changes: 27 additions & 16 deletions tests/envs/utils_envs.py
Original file line number Diff line number Diff line change
@@ -1,43 +1,54 @@
from typing import Any, Tuple, Dict, Optional

from gymnasium.core import ActType, ObsType

import gymnasium as gym
from tests.generic_test_env import GenericTestEnv


class RegisterDuringMakeEnv(gym.Env):
class RegisterDuringMakeEnv(GenericTestEnv):
"""Used in `test_registration.py` to check if `env.make` can import and register an env"""

def __init__(self):
self.action_space = gym.spaces.Discrete(1)
self.observation_space = gym.spaces.Discrete(1)


class ArgumentEnv(gym.Env):
observation_space = gym.spaces.Box(low=-1, high=1, shape=(1,))
action_space = gym.spaces.Box(low=-1, high=1, shape=(1,))
class ArgumentEnv(GenericTestEnv):
def __init__(self, arg1: Any, arg2: Any, arg3: Any):
super().__init__()

def __init__(self, arg1, arg2, arg3):
self.arg1 = arg1
self.arg2 = arg2
self.arg3 = arg3
self.arg1, self.arg2, self.arg3 = arg1, arg2, arg3


# Environments to test render_mode
class NoHuman(gym.Env):
class NoHuman(GenericTestEnv):
"""Environment that does not have human-rendering."""

metadata = {"render_modes": ["rgb_array_list"], "render_fps": 4}

def __init__(self, render_mode=None):
def __init__(self, render_mode: str = None):
assert render_mode in self.metadata["render_modes"]
self.render_mode = render_mode
super().__init__(render_mode=render_mode)


class NoHumanOldAPI(gym.Env):
"""Environment that does not have human-rendering."""

metadata = {"render_modes": ["rgb_array_list"], "render_fps": 4}
metadata = {"render_modes": ["rgb_array"], "render_fps": 4}

def __init__(self):
pass

def step(
self, action: ActType
) -> Tuple[ObsType, float, bool, bool, Dict[str, Any]]:
return self.observation_space.sample(), 0, False, False, {}

def reset(
self,
*,
seed: Optional[int] = None,
options: Optional[dict] = None,
) -> Tuple[ObsType, dict]:
return self.observation_space.sample(), {}


class NoHumanNoRGB(gym.Env):
"""Environment that has neither human- nor rgb-rendering"""
Expand Down
7 changes: 5 additions & 2 deletions tests/testing_env.py → tests/generic_test_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,12 @@ def __init__(
render_fn: callable = basic_render_fn,
metadata: Optional[Dict[str, Any]] = None,
render_mode: Optional[str] = None,
spec: EnvSpec = EnvSpec("TestingEnv-v0", "testing-env-no-entry-point"),
spec: Optional[EnvSpec] = EnvSpec(
"TestingEnv-v0", "testing-env-no-entry-point"
),
):
self.metadata = {} if metadata is None else metadata
if metadata is not None:
self.metadata = metadata
self.render_mode = render_mode
self.spec = spec

Expand Down
195 changes: 77 additions & 118 deletions tests/test_core.py
Original file line number Diff line number Diff line change
@@ -1,139 +1,98 @@
from typing import Optional
from typing import Any, Dict, Optional, SupportsFloat, Tuple

import numpy as np
import pytest

from gymnasium import core, spaces
from gymnasium.wrappers import OrderEnforcing, TimeLimit
from gymnasium import Env, ObservationWrapper, RewardWrapper, Wrapper
from gymnasium.core import ActionWrapper, ActType, ObsType
from gymnasium.spaces import Box
from tests.generic_test_env import GenericTestEnv


class ArgumentEnv(core.Env):
pseudo-rnd-thoughts marked this conversation as resolved.
Show resolved Hide resolved
observation_space = spaces.Box(low=0, high=1, shape=(1,))
action_space = spaces.Box(low=0, high=1, shape=(1,))
calls = 0
class ExampleEnv(Env):
def step(
self, action: ActType
) -> Tuple[ObsType, float, bool, bool, Dict[str, Any]]:
return 0, 0, False, False, {}

def __init__(self, arg):
self.calls += 1
self.arg = arg
def reset(
self,
*,
seed: Optional[int] = None,
options: Optional[dict] = None,
) -> Tuple[ObsType, dict]:
return 0, {}


class UnittestEnv(core.Env):
observation_space = spaces.Box(low=0, high=255, shape=(64, 64, 3), dtype=np.uint8)
action_space = spaces.Discrete(3)
def test_gymnasium_env():
env = ExampleEnv()

def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
super().reset(seed=seed)
return self.observation_space.sample(), {"info": "dummy"}
assert env.metadata == {"render_modes": []}
assert env.render_mode is None
assert env.reward_range == (-float("inf"), float("inf"))
assert env.spec is None
assert env._np_random is None

def step(self, action):
observation = self.observation_space.sample() # Dummy observation
return (observation, 0.0, False, {})

class ExampleWrapper(Wrapper):
pass

class UnknownSpacesEnv(core.Env):
"""This environment defines its observation & action spaces only
after the first call to reset. Although this pattern is sometimes
necessary when implementing a new environment (e.g. if it depends
on external resources), it is not encouraged.
"""

def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
super().reset(seed=seed)
self.observation_space = spaces.Box(
low=0, high=255, shape=(64, 64, 3), dtype=np.uint8
)
self.action_space = spaces.Discrete(3)
return self.observation_space.sample(), {} # Dummy observation with info
def test_gymnasium_wrapper():
env = ExampleEnv()
wrapper_env = ExampleWrapper(env)

def step(self, action):
observation = self.observation_space.sample() # Dummy observation
return (observation, 0.0, False, {})
assert env.metadata == wrapper_env.metadata
wrapper_env.metadata = {"render_modes": ["rgb_array"]}
assert env.metadata != wrapper_env.metadata

assert env.render_mode == wrapper_env.render_mode

class OldStyleEnv(core.Env):
"""This environment doesn't accept any arguments in reset, ideally we want to support this too (for now)"""
assert env.reward_range == wrapper_env.reward_range
wrapper_env.reward_range = (-1.0, 1.0)
assert env.reward_range != wrapper_env.reward_range

def __init__(self):
pass
assert env.spec == wrapper_env.spec

def reset(self):
super().reset()
return 0
env.observation_space = Box(0, 1)
env.action_space = Box(0, 1)
assert env.observation_space == wrapper_env.observation_space
assert env.action_space == wrapper_env.action_space
wrapper_env.observation_space = Box(1, 2)
wrapper_env.action_space = Box(1, 2)
assert env.observation_space != wrapper_env.observation_space
assert env.action_space != wrapper_env.action_space

def step(self, action):
return 0, 0, False, {}

class ExampleRewardWrapper(RewardWrapper):
def reward(self, reward: SupportsFloat) -> SupportsFloat:
return 1

class NewPropertyWrapper(core.Wrapper):
def __init__(
self,
env,
observation_space=None,
action_space=None,
reward_range=None,
metadata=None,
):
super().__init__(env)
if observation_space is not None:
# Only set the observation space if not None to test property forwarding
self.observation_space = observation_space
if action_space is not None:
self.action_space = action_space
if reward_range is not None:
self.reward_range = reward_range
if metadata is not None:
self.metadata = metadata


def test_env_instantiation():
# This looks like a pretty trivial, but given our usage of
# __new__, it's worth having.
env = ArgumentEnv("arg")
assert env.arg == "arg"
assert env.calls == 1


properties = [
{
"observation_space": spaces.Box(
low=0.0, high=1.0, shape=(64, 64, 3), dtype=np.float32
)
},
{"action_space": spaces.Discrete(2)},
{"reward_range": (-1.0, 1.0)},
{"metadata": {"render_modes": ["human", "rgb_array_list"]}},
{
"observation_space": spaces.Box(
low=0.0, high=1.0, shape=(64, 64, 3), dtype=np.float32
),
"action_space": spaces.Discrete(2),
},
]


@pytest.mark.parametrize("class_", [UnittestEnv, UnknownSpacesEnv])
@pytest.mark.parametrize("props", properties)
def test_wrapper_property_forwarding(class_, props):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm still a bit hesitant about removing the existing tests.

Are these tests redundant now? Do they fail with the new changes? (if so, that's a potential problem) Are they 100% replaced with the new tests?

I'd lean towards keeping them in, and just adding new tests instead. And if some tests turn out to be unnecessary, we can remove them separately with no other changes.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I will keep both then and we can decide what tests to keep later then

env = class_()
env = NewPropertyWrapper(env, **props)

# If UnknownSpacesEnv, then call reset to define the spaces
if isinstance(env.unwrapped, UnknownSpacesEnv):
_ = env.reset()

# Test the properties set by the wrapper
for key, value in props.items():
assert getattr(env, key) == value

# Otherwise, test if the properties are forwarded
all_properties = {"observation_space", "action_space", "reward_range", "metadata"}
for key in all_properties - props.keys():
assert getattr(env, key) == getattr(env.unwrapped, key)


def test_compatibility_with_old_style_env():
env = OldStyleEnv()
env = OrderEnforcing(env)
env = TimeLimit(env)
obs = env.reset()
assert obs == 0

class ExampleObservationWrapper(ObservationWrapper):
def observation(self, observation: ObsType) -> ObsType:
return np.array([1])


class ExampleActionWrapper(ActionWrapper):
def action(self, action: ActType) -> ActType:
return np.array([1])


def test_wrapper_types():
env = GenericTestEnv()

reward_env = ExampleRewardWrapper(env)
reward_env.reset()
_, reward, _, _, _ = reward_env.step(0)
assert reward == 1

observation_env = ExampleObservationWrapper(env)
obs, _ = observation_env.reset()
assert obs == np.array([1])
obs, _, _, _, _ = observation_env.step(0)
assert obs == np.array([1])

env = GenericTestEnv(step_fn=lambda self, action: (action, 0, False, False, {}))
action_env = ExampleActionWrapper(env)
obs, _, _, _, _ = action_env.step(0)
assert obs == np.array([1])
2 changes: 1 addition & 1 deletion tests/utils/test_env_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
check_reset_seed,
check_seed_deprecation,
)
from tests.testing_env import GenericTestEnv
from tests.generic_test_env import GenericTestEnv


@pytest.mark.parametrize(
Expand Down
2 changes: 1 addition & 1 deletion tests/utils/test_passive_env_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
env_reset_passive_checker,
env_step_passive_checker,
)
from tests.testing_env import GenericTestEnv
from tests.generic_test_env import GenericTestEnv


def _modify_space(space: spaces.Space, attribute: str, value):
Expand Down
2 changes: 1 addition & 1 deletion tests/utils/test_play.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

import gymnasium as gym
from gymnasium.utils.play import MissingKeysToAction, PlayableGame, play
from tests.testing_env import GenericTestEnv
from tests.generic_test_env import GenericTestEnv

RELEVANT_KEY_1 = ord("a") # 97
RELEVANT_KEY_2 = ord("d") # 100
Expand Down
2 changes: 1 addition & 1 deletion tests/wrappers/test_atari_preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from gymnasium.spaces import Box, Discrete
from gymnasium.wrappers import AtariPreprocessing, StepAPICompatibility
from tests.testing_env import GenericTestEnv, old_step_fn
from tests.generic_test_env import GenericTestEnv, old_step_fn


class AleTesting:
Expand Down
Loading