Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update the type hinting for core.py #39

Merged
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 71 additions & 46 deletions gymnasium/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,14 +76,16 @@ class Env(Generic[ObsType, ActType]):
def np_random(self) -> np.random.Generator:
"""Returns the environment's internal :attr:`_np_random` that if not set will initialize with a random seed."""
if self._np_random is None:
self._np_random, seed = seeding.np_random()
self._np_random, _ = seeding.np_random()
return self._np_random

@np_random.setter
def np_random(self, value: np.random.Generator):
self._np_random = value

def step(self, action: ActType) -> Tuple[ObsType, float, bool, bool, dict]:
def step(
self, action: ActType
) -> Tuple[ObsType, float, bool, bool, Dict[str, Any]]:
"""Run one timestep of the environment's dynamics.

When end of episode is reached, you are responsible for calling :meth:`reset` to reset this environment's state.
Expand Down Expand Up @@ -113,8 +115,8 @@ def reset(
self,
*,
seed: Optional[int] = None,
options: Optional[dict] = None,
pseudo-rnd-thoughts marked this conversation as resolved.
Show resolved Hide resolved
) -> Tuple[ObsType, dict]:
options: Optional[Dict[str, Any]] = None,
) -> Tuple[ObsType, Dict[str, Any]]:
"""Resets the environment to an initial state and returns the initial observation.

This method can reset the environment's random number generator(s) if ``seed`` is an integer or
Expand All @@ -134,7 +136,6 @@ def reset(
options (optional dict): Additional information to specify how the environment is reset (optional,
depending on the specific environment)


Returns:
observation (object): Observation of the initial state. This will be an element of :attr:`observation_space`
(typically a numpy array) and is analogous to the observation returned by :meth:`step`.
Expand Down Expand Up @@ -175,7 +176,7 @@ def close(self):
pass

@property
def unwrapped(self) -> "Env":
def unwrapped(self) -> "Env[ObsType, ActType]":
"""Returns the base non-wrapped environment.

Returns:
Expand All @@ -194,14 +195,18 @@ def __enter__(self):
"""Support with-statement for the environment."""
return self

def __exit__(self, *args):
def __exit__(self, *args: List[Any]):
pseudo-rnd-thoughts marked this conversation as resolved.
Show resolved Hide resolved
"""Support with-statement for the environment."""
self.close()
# propagate exception
return False


class Wrapper(Env[ObsType, ActType]):
WrapperObsType = TypeVar("WrapperObsType")
WrapperActType = TypeVar("WrapperActType")


class Wrapper(Env[WrapperObsType, WrapperActType]):
"""Wraps an environment to allow a modular transformation of the :meth:`step` and :meth:`reset` methods.

This class is the base class for all wrappers. The subclass could override
Expand All @@ -212,55 +217,63 @@ class Wrapper(Env[ObsType, ActType]):
Don't forget to call ``super().__init__(env)`` if the subclass overrides :meth:`__init__`.
"""

def __init__(self, env: Env):
def __init__(self, env: Env[ObsType, ActType]):
"""Wraps an environment to allow a modular transformation of the :meth:`step` and :meth:`reset` methods.

Args:
env: The environment to wrap
"""
self.env = env

self._action_space: Optional[spaces.Space] = None
self._observation_space: Optional[spaces.Space] = None
self._action_space: Optional[spaces.Space[WrapperActType]] = None
self._observation_space: Optional[spaces.Space[WrapperObsType]] = None
self._reward_range: Optional[Tuple[SupportsFloat, SupportsFloat]] = None
self._metadata: Optional[dict] = None
self._metadata: Optional[Dict[str, Any]] = None

def __getattr__(self, name):
def __getattr__(self, name: str):
"""Returns an attribute with ``name``, unless ``name`` starts with an underscore."""
if name.startswith("_"):
if name == "_np_random":
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not entirely sure about this, it's out of scope, and seems like a guardrail that's not necessarily needed.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was a guardrail that already existed, but it was bugged so didn't work. I had to modify this function to re-add it such that the correct error occurred.

We could replace it with a warning? I added it to gym a couple of months ago as I thought I had found a bug, but what was accurately happening was that I was using the wrapper _np_random not the environment _np_random which took me about an hour + to solve. Therefore, I proposed this solution. I guess the issue is that wrappers could modify the reproducibility of an environment using the environment rng not the wrapper rng however in this case, the wrapper could use its own rng with a different variable name. Thoughts?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's fine for now, we can reconsider this guardrail in general in the future

raise AttributeError(
"Can't access `_np_random` of a wrapper, use `self.unwrapped._np_random` or `self.np_random`."
)
elif name.startswith("_"):
raise AttributeError(f"accessing private attribute '{name}' is prohibited")
return getattr(self.env, name)

@property
def spec(self):
def spec(self) -> "EnvSpec":
"""Returns the environment specification."""
return self.env.spec

@classmethod
def class_name(cls):
def class_name(cls) -> str:
"""Returns the class name of the wrapper."""
return cls.__name__

@property
def action_space(self) -> spaces.Space[ActType]:
def action_space(
self,
) -> Union[spaces.Space[ActType], spaces.Space[WrapperActType]]:
"""Returns the action space of the environment."""
if self._action_space is None:
return self.env.action_space
return self._action_space

@action_space.setter
def action_space(self, space: spaces.Space):
def action_space(self, space: spaces.Space[WrapperActType]):
self._action_space = space

@property
def observation_space(self) -> spaces.Space:
def observation_space(
self,
) -> Union[spaces.Space[ObsType], spaces.Space[WrapperObsType]]:
"""Returns the observation space of the environment."""
if self._observation_space is None:
return self.env.observation_space
return self._observation_space

@observation_space.setter
def observation_space(self, space: spaces.Space):
def observation_space(self, space: spaces.Space[WrapperObsType]):
self._observation_space = space

@property
Expand All @@ -275,14 +288,14 @@ def reward_range(self, value: Tuple[SupportsFloat, SupportsFloat]):
self._reward_range = value

@property
def metadata(self) -> dict:
def metadata(self) -> Dict[str, Any]:
"""Returns the environment metadata."""
if self._metadata is None:
return self.env.metadata
return self._metadata

@metadata.setter
def metadata(self, value):
def metadata(self, value: Dict[str, Any]):
self._metadata = value

@property
Expand All @@ -296,11 +309,15 @@ def np_random(self) -> np.random.Generator:
return self.env.np_random

@np_random.setter
def np_random(self, value):
def np_random(self, value: np.random.Generator):
self.env.np_random = value

@property
def _np_random(self):
"""This code will never be run due to __getattr__ being called prior this.

It seems that @property overwrites the variable (`_np_random`) meaning that __getattr__ gets called with the missing variable.
"""
raise AttributeError(
"Can't access `_np_random` of a wrapper, use `.unwrapped._np_random` or `.np_random`."
)
Expand All @@ -309,15 +326,15 @@ def step(self, action: ActType) -> Tuple[ObsType, float, bool, bool, dict]:
"""Steps through the environment with action."""
return self.env.step(action)

def reset(self, **kwargs) -> Tuple[ObsType, dict]:
"""Resets the environment with kwargs."""
return self.env.reset(**kwargs)
def reset(
self, *, seed: Optional[int] = None, options: Optional[Dict[str, Any]] = None
) -> Tuple[WrapperObsType, Dict[str, Any]]:
"""Resets the environment with a seed and options."""
return self.env.reset(seed=seed, options=options)

def render(
self, *args, **kwargs
) -> Optional[Union[RenderFrame, List[RenderFrame]]]:
def render(self) -> Optional[Union[RenderFrame, List[RenderFrame]]]:
"""Renders the environment."""
return self.env.render(*args, **kwargs)
return self.env.render()

def close(self):
"""Closes the environment."""
Expand All @@ -332,12 +349,12 @@ def __repr__(self):
return str(self)

@property
def unwrapped(self) -> Env:
def unwrapped(self) -> Env[ObsType, ActType]:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not necessarily true - consider the wrappers that discretize actions, or add pixel observations

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, removed

"""Returns the base environment of the wrapper."""
return self.env.unwrapped


class ObservationWrapper(Wrapper):
class ObservationWrapper(Wrapper[WrapperObsType, ActType]):
"""Superclass of wrappers that can modify observations using :meth:`observation` for :meth:`reset` and :meth:`step`.

If you would like to apply a function to the observation that is returned by the base environment before
Expand Down Expand Up @@ -365,22 +382,26 @@ def observation(self, obs):
index of the timestep to the observation.
"""

def reset(self, **kwargs):
def reset(
self, *, seed: Optional[int] = None, options: Optional[Dict[str, Any]] = None
) -> Tuple[WrapperObsType, Dict[str, Any]]:
"""Resets the environment, returning a modified observation using :meth:`self.observation`."""
obs, info = self.env.reset(**kwargs)
obs, info = self.env.reset(seed=seed, options=options)
return self.observation(obs), info

def step(self, action):
def step(
self, action: ActType
) -> Tuple[WrapperObsType, float, bool, bool, Dict[str, Any]]:
"""Returns a modified observation using :meth:`self.observation` after calling :meth:`env.step`."""
observation, reward, terminated, truncated, info = self.env.step(action)
return self.observation(observation), reward, terminated, truncated, info

def observation(self, observation):
def observation(self, observation: ObsType) -> WrapperObsType:
"""Returns a modified observation."""
raise NotImplementedError


class RewardWrapper(Wrapper):
class RewardWrapper(Wrapper[ObsType, ActType]):
"""Superclass of wrappers that can modify the returning reward from a step.

If you would like to apply a function to the reward that is returned by the base environment before
Expand All @@ -393,28 +414,30 @@ class RewardWrapper(Wrapper):
because it is intrinsic), we want to clip the reward to a range to gain some numerical stability.
To do that, we could, for instance, implement the following wrapper::

class ClipReward(gym.RewardWrapper):
class ClipReward(gymnasium.RewardWrapper):
pseudo-rnd-thoughts marked this conversation as resolved.
Show resolved Hide resolved
def __init__(self, env, min_reward, max_reward):
super().__init__(env)
self.min_reward = min_reward
self.max_reward = max_reward
self.reward_range = (min_reward, max_reward)

def reward(self, reward):
return np.clip(reward, self.min_reward, self.max_reward)
def reward(self, r: float) -> float:
return np.clip(r, self.min_reward, self.max_reward)
"""

def step(self, action):
def step(
self, action: ActType
) -> Tuple[ObsType, float, bool, bool, Dict[str, Any]]:
"""Modifies the reward using :meth:`self.reward` after the environment :meth:`env.step`."""
observation, reward, terminated, truncated, info = self.env.step(action)
return observation, self.reward(reward), terminated, truncated, info

def reward(self, reward):
def reward(self, reward: SupportsFloat) -> float:
"""Returns a modified ``reward``."""
raise NotImplementedError


class ActionWrapper(Wrapper):
class ActionWrapper(Wrapper[ObsType, WrapperActType]):
"""Superclass of wrappers that can modify the action before :meth:`env.step`.

If you would like to apply a function to the action before passing it to the base environment,
Expand Down Expand Up @@ -446,14 +469,16 @@ def action(self, act):
Among others, Gymnasium provides the action wrappers :class:`ClipAction` and :class:`RescaleAction`.
"""

def step(self, action):
def step(
self, action: WrapperActType
) -> Tuple[ObsType, float, bool, bool, Dict[str, Any]]:
"""Runs the environment :meth:`env.step` using the modified ``action`` from :meth:`self.action`."""
return self.env.step(self.action(action))

def action(self, action):
def action(self, action: WrapperActType) -> ActType:
"""Returns a modified action before :meth:`env.step` is called."""
raise NotImplementedError

def reverse_action(self, action):
def reverse_action(self, action: ActType) -> WrapperActType:
"""Returns a reversed ``action``."""
raise NotImplementedError
Loading