Skip to content

Commit

Permalink
Merge pull request #27 from joeryjoery/dev
Browse files Browse the repository at this point in the history
Dev
  • Loading branch information
joeryjoery authored Nov 10, 2023
2 parents 7add117 + aa855bb commit 7f81dd1
Show file tree
Hide file tree
Showing 7 changed files with 119 additions and 76 deletions.
6 changes: 3 additions & 3 deletions conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,14 @@ def __init__(self):
def reset(
self,
key: jax.random.KeyArray
) -> tuple[jit_env.State, jit_env.TimeStep]:
) -> tuple[DummyState, jit_env.TimeStep]:
return DummyState(key=key), jit_env.restart(jax.numpy.zeros(()))

def step(
self,
state: jit_env.State,
state: DummyState,
action: jit_env.Action
) -> tuple[jit_env.State, jit_env.TimeStep]:
) -> tuple[DummyState, jit_env.TimeStep]:
if action is None:
return state, jit_env.termination(*jax.numpy.ones((2,)))
return state, jit_env.transition(*jax.numpy.ones((3,)))
Expand Down
4 changes: 3 additions & 1 deletion jit_env/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@
from jit_env._core import (
Action as Action,
State as State,
Observation as Observation
Observation as Observation,
RewardT as Reward,
DiscountT as Discount
)

from jit_env import specs
Expand Down
17 changes: 8 additions & 9 deletions jit_env/_compat_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,17 +97,17 @@ def test_spec_conversion(
):
out_spec = to_dm_spec(in_spec)

_ = jax.tree_map(lambda a, b: type(a) == type(b), out_spec, dm_spec)
_ = jax.tree_map(lambda a, b: type(a) is type(b), out_spec, dm_spec)
_ = jax.tree_map(lambda a, b: a.name == b.name, out_spec, dm_spec)
_ = jax.tree_map(lambda a, b: a.shape == b.shape, out_spec, dm_spec)
_ = jax.tree_map(lambda a, b: a.dtype == b.dtype, out_spec, dm_spec)

samples = jax.tree_map(lambda s: s.generate_value(), out_spec)
dm_samples = jax.tree_map(lambda s: s.generate_value(), dm_spec)

chex.assert_trees_all_equal(samples, dm_samples, ignore_nones=True)
chex.assert_trees_all_equal(samples, dm_samples)
chex.assert_trees_all_equal_shapes_and_dtypes(
samples, dm_samples, ignore_nones=True
samples, dm_samples
)

@pytest.mark.usefixtures('dummy_env')
Expand Down Expand Up @@ -246,7 +246,7 @@ def test_spec_conversion(
"""
out_space = to_gym_space(in_spec)

_ = jax.tree_map(lambda a, b: type(a) == type(b), out_space, gym_space)
_ = jax.tree_map(lambda a, b: type(a) is type(b), out_space, gym_space)
_ = jax.tree_map(lambda a, b: a.shape == b.shape, out_space, gym_space)
_ = jax.tree_map(lambda a, b: a.dtype == b.dtype, out_space, gym_space)

Expand All @@ -256,25 +256,24 @@ def test_spec_conversion(

# Numpy uses 64-bits for default precision, jax uses 32-bits.
chex.assert_trees_all_equal_shapes( # Differs in Type Accuracy
samples, gym_converted, ignore_nones=True
samples, gym_converted
)
chex.assert_trees_all_equal_shapes( # Differs in Type Accuracy
samples, gym_samples, ignore_nones=True
samples, gym_samples
)
# Note: The following line is important to test, but the current API
# of gymnasium does not allow setting explicit dtypes for `Discrete`.
# TODO: Comment out following line when dtype API compatible.
# chex.assert_trees_all_equal_shapes_and_dtypes(
# gym_samples, gym_converted, ignore_nones=True
# gym_samples, gym_converted
# )

# Check if dtype can be accurately promoted/ demoted/ converted.
chex.assert_trees_all_equal_comparator(
lambda a, b: np.can_cast(a.dtype, b.dtype, casting='same_kind'),
lambda a, b: f'DType conversion of {a.dtype} does '
f'not return a subdtype of {b.dtype}',
samples, gym_samples,
ignore_nones=True
samples, gym_samples
)

@pytest.mark.usefixtures('dummy_env')
Expand Down
68 changes: 52 additions & 16 deletions jit_env/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,13 @@ class StateProtocol(Protocol):

# The following should all be valid Jax types
Action = TypeVar("Action")

Observation = TypeVar("Observation")
State = TypeVar("State", bound="StateProtocol")
StepT = TypeVar("StepT", bound=Int8[Array, ''])
RewardT = TypeVar("RewardT", bound=PyTree[ArrayLike])
DiscountT = TypeVar("DiscountT", bound=PyTree[ArrayLike])

State = TypeVar("State", bound=StateProtocol)


class StepType:
Expand All @@ -56,7 +61,7 @@ class StepType:


@dataclass(init=True, repr=True, eq=True, frozen=True)
class TimeStep(Generic[Observation]):
class TimeStep(Generic[Observation, RewardT, DiscountT, StepT]):
"""Defines the datastructure that is communicated to the Agent.
While dm_env utilizes a NamedTuple, we opted for a mappable dataclass
Expand Down Expand Up @@ -86,9 +91,9 @@ class TimeStep(Generic[Observation]):
training or for generally monitorring Agent behaviour.
This field is excluded from object comparisons.
"""
step_type: Int8[Array, '']
reward: PyTree[ArrayLike]
discount: PyTree[ArrayLike]
step_type: StepT
reward: RewardT
discount: DiscountT
observation: Observation
extras: dict[str, Any] | None = field(default=None, compare=False)

Expand All @@ -107,7 +112,10 @@ def last(self) -> Bool[Array, '']:

# Define Environment and Wrapper

class Environment(Generic[State, Action, Observation], metaclass=abc.ABCMeta):
class Environment(
Generic[State, Action, Observation, RewardT, DiscountT],
metaclass=abc.ABCMeta
):
"""Interface for defining Environment logic for RL-Agents. """

def __init__(self, *, renderer: Callable[[State], Any] | None = None):
Expand Down Expand Up @@ -145,7 +153,9 @@ def unwrapped(self) -> Environment:
return self

@abc.abstractmethod
def reset(self, key: jax.random.KeyArray) -> tuple[State, TimeStep]:
def reset(self, key: jax.random.KeyArray) -> tuple[
State, TimeStep[Observation, RewardT, DiscountT, Int8[Array, '']]
]:
"""Starts a new episode as a functionally pure transformation.
Args:
Expand All @@ -161,7 +171,9 @@ def reset(self, key: jax.random.KeyArray) -> tuple[State, TimeStep]:
"""

@abc.abstractmethod
def step(self, state: State, action: Action) -> tuple[State, TimeStep]:
def step(self, state: State, action: Action) -> tuple[
State, TimeStep[Observation, RewardT, DiscountT, Int8[Array, '']]
]:
"""Updates the environment according to the given state and action.
If the environment already returned a `TimeStep` with `StepType.LAST`
Expand Down Expand Up @@ -277,8 +289,8 @@ def __exit__(self, exc_type, exc_value, traceback):


class Wrapper(
Environment[State, Action, Observation],
Generic[State, Action, Observation],
Environment[State, Action, Observation, RewardT, DiscountT],
Generic[State, Action, Observation, RewardT, DiscountT],
metaclass=abc.ABCMeta
):
"""Interface for Composing Environment logic for RL-Agents. """
Expand Down Expand Up @@ -320,10 +332,14 @@ def unwrapped(self) -> Environment:
"""Helper function to unpack Composite Environments to the base."""
return self.env.unwrapped

def reset(self, key: jax.random.KeyArray) -> tuple[State, TimeStep]:
def reset(self, key: jax.random.KeyArray) -> tuple[
State, TimeStep[Observation, RewardT, DiscountT, Int8[Array, '']]
]:
return self.env.reset(key)

def step(self, state: State, action: Action) -> tuple[State, TimeStep]:
def step(self, state: State, action: Action) -> tuple[
State, TimeStep[Observation, RewardT, DiscountT, Int8[Array, '']]
]:
return self.env.step(state, action)

def reward_spec(self) -> specs.Spec:
Expand All @@ -350,7 +366,12 @@ def restart(
extras: dict | None = None,
shape: int | Sequence[int] = (),
dtype: Any = float
) -> TimeStep:
) -> TimeStep[
Observation,
PyTree[Num[Array, '...']],
PyTree[Num[Array, '...']],
Int8[Array, '']
]:
"""Returns a `TimeStep` with `step_type` set to `StepType.FIRST`.
Unlike dm_env the reward and discount are not `None` to prevent array
Expand All @@ -371,7 +392,12 @@ def transition(
discount: PyTree[Num[Array, '...']] | None = None,
extras: dict | None = None,
shape: int | Sequence[int] = ()
) -> TimeStep:
) -> TimeStep[
Observation,
PyTree[Num[Array, '...']],
PyTree[Num[Array, '...']],
Int8[Array, '']
]:
"""Returns a `TimeStep` with `step_type` set to `StepType.MID`. """
return TimeStep(
step_type=StepType.MID,
Expand All @@ -387,7 +413,12 @@ def termination(
observation: Observation,
extras: dict | None = None,
shape: int | Sequence[int] = ()
) -> TimeStep:
) -> TimeStep[
Observation,
PyTree[Num[Array, '...']],
PyTree[Num[Array, '...']],
Int8[Array, '']
]:
"""Returns a `TimeStep` with `step_type` set to `StepType.LAST`. """
return TimeStep(
step_type=StepType.LAST,
Expand All @@ -404,7 +435,12 @@ def truncation(
discount: PyTree[Num[Array, '...']] | None = None,
extras: dict | None = None,
shape: int | Sequence[int] = ()
) -> TimeStep:
) -> TimeStep[
Observation,
PyTree[Num[Array, '...']],
PyTree[Num[Array, '...']],
Int8[Array, '']
]:
"""Alternative to `termination` that does not set `discount` to zero. """
return TimeStep(
step_type=StepType.LAST,
Expand Down
12 changes: 6 additions & 6 deletions jit_env/_core_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,25 +55,25 @@ def test_empty_wrapper(dummy_env: jit_env.Environment):
_ = jax.tree_map(lambda a, b: a.validate(b), spec, wrap_samples)
_ = jax.tree_map(lambda a, b: a.validate(b), wrapped_spec, spec_samples)

chex.assert_trees_all_equal(spec_samples, wrap_samples, ignore_nones=True)
chex.assert_trees_all_equal(spec_samples, wrap_samples)
chex.assert_trees_all_equal_shapes_and_dtypes(
spec_samples, wrap_samples, ignore_nones=True
spec_samples, wrap_samples
)

out = dummy_env.reset(jax.random.PRNGKey(0))
wrap_out = wrapped.reset(jax.random.PRNGKey(0))

chex.assert_trees_all_equal(out, wrap_out, ignore_nones=True)
chex.assert_trees_all_equal(out, wrap_out)
chex.assert_trees_all_equal_shapes_and_dtypes(
out, wrap_out, ignore_nones=True
out, wrap_out
)

out = dummy_env.step(out[0], spec.actions.generate_value())
wrap_out = wrapped.step(wrap_out[0], wrapped_spec.actions.generate_value())

chex.assert_trees_all_equal(out, wrap_out, ignore_nones=True)
chex.assert_trees_all_equal(out, wrap_out)
chex.assert_trees_all_equal_shapes_and_dtypes(
out, wrap_out, ignore_nones=True
out, wrap_out
)


Expand Down
4 changes: 2 additions & 2 deletions jit_env/_specs_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,10 +176,10 @@ def test_unpack_spec(in_spec: specs.Spec, expected_tree: PyTree[specs.Spec]):
in_spec.validate(sample_spec)

chex.assert_trees_all_equal_shapes_and_dtypes(
sample_normal, sample_tree, ignore_nones=True
sample_normal, sample_tree
)
chex.assert_trees_all_equal_shapes_and_dtypes(
sample_normal, sample_spec, ignore_nones=True
sample_normal, sample_spec
)


Expand Down
Loading

0 comments on commit 7f81dd1

Please sign in to comment.