Skip to content

Commit

Permalink
Removing frozen=True from dataclass decorator shuts mypy up about Sta…
Browse files Browse the repository at this point in the history
…te subtyping...
  • Loading branch information
joeryjoery committed Nov 11, 2023
1 parent 5abc46b commit d098bbb
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 7 deletions.
2 changes: 1 addition & 1 deletion examples/counting_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from chex import dataclass


@dataclass(frozen=True)
@dataclass
class MyState:
key: PRNGKeyArray
count: Int32[jax.Array, '']
Expand Down
18 changes: 12 additions & 6 deletions jit_env/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,11 @@ def reset(
"""

@abc.abstractmethod
def step(self, state: State, action: Action) -> tuple[
def step(
self,
state: State,
action: Action
) -> tuple[
State, TimeStep[Observation, RewardT, DiscountT, Int8[Array, '']]
]:
"""Updates the environment according to the given state and action.
Expand Down Expand Up @@ -283,7 +287,9 @@ def render(self, state: State) -> Any:
)
return self._renderer(state)

def __enter__(self) -> Environment:
def __enter__(self) -> Environment[
State, Action, Observation, RewardT, DiscountT
]:
"""Allows the environment to be used in a with-statement context."""
return self

Expand Down Expand Up @@ -374,7 +380,7 @@ def render(self, state: State) -> Any:

def restart(
observation: Observation,
extras: dict | None = None,
extras: dict[str, Any] | None = None,
shape: int | Sequence[int] = (),
dtype: Any = float
) -> TimeStep[
Expand All @@ -401,7 +407,7 @@ def transition(
reward: PyTree[Num[Array, '...']],
observation: Observation,
discount: PyTree[Num[Array, '...']] | None = None,
extras: dict | None = None,
extras: dict[str, Any] | None = None,
shape: int | Sequence[int] = ()
) -> TimeStep[
Observation,
Expand All @@ -422,7 +428,7 @@ def transition(
def termination(
reward: PyTree[Num[Array, '...']],
observation: Observation,
extras: dict | None = None,
extras: dict[str, Any] | None = None,
shape: int | Sequence[int] = ()
) -> TimeStep[
Observation,
Expand All @@ -444,7 +450,7 @@ def truncation(
reward: PyTree[Num[Array, '...']],
observation: Observation,
discount: PyTree[Num[Array, '...']] | None = None,
extras: dict | None = None,
extras: dict[str, Any] | None = None,
shape: int | Sequence[int] = ()
) -> TimeStep[
Observation,
Expand Down

0 comments on commit d098bbb

Please sign in to comment.