diff --git a/conftest.py b/conftest.py index b81f4a8..817f384 100644 --- a/conftest.py +++ b/conftest.py @@ -21,14 +21,14 @@ def __init__(self): def reset( self, key: jax.random.KeyArray - ) -> tuple[jit_env.State, jit_env.TimeStep]: + ) -> tuple[DummyState, jit_env.TimeStep]: return DummyState(key=key), jit_env.restart(jax.numpy.zeros(())) def step( self, - state: jit_env.State, + state: DummyState, action: jit_env.Action - ) -> tuple[jit_env.State, jit_env.TimeStep]: + ) -> tuple[DummyState, jit_env.TimeStep]: if action is None: return state, jit_env.termination(*jax.numpy.ones((2,))) return state, jit_env.transition(*jax.numpy.ones((3,))) diff --git a/jit_env/__init__.py b/jit_env/__init__.py index fd8be16..2ff129d 100644 --- a/jit_env/__init__.py +++ b/jit_env/__init__.py @@ -9,7 +9,9 @@ from jit_env._core import ( Action as Action, State as State, - Observation as Observation + Observation as Observation, + RewardT as Reward, + DiscountT as Discount ) from jit_env import specs diff --git a/jit_env/_compat_test.py b/jit_env/_compat_test.py index 5ffabed..ca7a828 100644 --- a/jit_env/_compat_test.py +++ b/jit_env/_compat_test.py @@ -97,7 +97,7 @@ def test_spec_conversion( ): out_spec = to_dm_spec(in_spec) - _ = jax.tree_map(lambda a, b: type(a) == type(b), out_spec, dm_spec) + _ = jax.tree_map(lambda a, b: type(a) is type(b), out_spec, dm_spec) _ = jax.tree_map(lambda a, b: a.name == b.name, out_spec, dm_spec) _ = jax.tree_map(lambda a, b: a.shape == b.shape, out_spec, dm_spec) _ = jax.tree_map(lambda a, b: a.dtype == b.dtype, out_spec, dm_spec) @@ -105,9 +105,9 @@ def test_spec_conversion( samples = jax.tree_map(lambda s: s.generate_value(), out_spec) dm_samples = jax.tree_map(lambda s: s.generate_value(), dm_spec) - chex.assert_trees_all_equal(samples, dm_samples, ignore_nones=True) + chex.assert_trees_all_equal(samples, dm_samples) chex.assert_trees_all_equal_shapes_and_dtypes( - samples, dm_samples, ignore_nones=True + samples, dm_samples ) @pytest.mark.usefixtures('dummy_env') @@ -246,7 +246,7 @@ def test_spec_conversion( """ out_space = to_gym_space(in_spec) - _ = jax.tree_map(lambda a, b: type(a) == type(b), out_space, gym_space) + _ = jax.tree_map(lambda a, b: type(a) is type(b), out_space, gym_space) _ = jax.tree_map(lambda a, b: a.shape == b.shape, out_space, gym_space) _ = jax.tree_map(lambda a, b: a.dtype == b.dtype, out_space, gym_space) @@ -256,16 +256,16 @@ def test_spec_conversion( # Numpy uses 64-bits for default precision, jax uses 32-bits. chex.assert_trees_all_equal_shapes( # Differs in Type Accuracy - samples, gym_converted, ignore_nones=True + samples, gym_converted ) chex.assert_trees_all_equal_shapes( # Differs in Type Accuracy - samples, gym_samples, ignore_nones=True + samples, gym_samples ) # Note: The following line is important to test, but the current API # of gymnasium does not allow setting explicit dtypes for `Discrete`. # TODO: Comment out following line when dtype API compatible. # chex.assert_trees_all_equal_shapes_and_dtypes( - # gym_samples, gym_converted, ignore_nones=True + # gym_samples, gym_converted # ) # Check if dtype can be accurately promoted/ demoted/ converted. @@ -273,8 +273,7 @@ def test_spec_conversion( lambda a, b: np.can_cast(a.dtype, b.dtype, casting='same_kind'), lambda a, b: f'DType conversion of {a.dtype} does ' f'not return a subdtype of {b.dtype}', - samples, gym_samples, - ignore_nones=True + samples, gym_samples ) @pytest.mark.usefixtures('dummy_env') diff --git a/jit_env/_core.py b/jit_env/_core.py index 08511ca..52e7650 100644 --- a/jit_env/_core.py +++ b/jit_env/_core.py @@ -44,8 +44,13 @@ class StateProtocol(Protocol): # The following should all be valid Jax types Action = TypeVar("Action") + Observation = TypeVar("Observation") -State = TypeVar("State", bound="StateProtocol") +StepT = TypeVar("StepT", bound=Int8[Array, '']) +RewardT = TypeVar("RewardT", bound=PyTree[ArrayLike]) +DiscountT = TypeVar("DiscountT", bound=PyTree[ArrayLike]) + +State = TypeVar("State", bound=StateProtocol) class StepType: @@ -56,7 +61,7 @@ class StepType: @dataclass(init=True, repr=True, eq=True, frozen=True) -class TimeStep(Generic[Observation]): +class TimeStep(Generic[Observation, RewardT, DiscountT, StepT]): """Defines the datastructure that is communicated to the Agent. While dm_env utilizes a NamedTuple, we opted for a mappable dataclass @@ -86,9 +91,9 @@ class TimeStep(Generic[Observation]): training or for generally monitorring Agent behaviour. This field is excluded from object comparisons. """ - step_type: Int8[Array, ''] - reward: PyTree[ArrayLike] - discount: PyTree[ArrayLike] + step_type: StepT + reward: RewardT + discount: DiscountT observation: Observation extras: dict[str, Any] | None = field(default=None, compare=False) @@ -107,7 +112,10 @@ def last(self) -> Bool[Array, '']: # Define Environment and Wrapper -class Environment(Generic[State, Action, Observation], metaclass=abc.ABCMeta): +class Environment( + Generic[State, Action, Observation, RewardT, DiscountT], + metaclass=abc.ABCMeta +): """Interface for defining Environment logic for RL-Agents. """ def __init__(self, *, renderer: Callable[[State], Any] | None = None): @@ -145,7 +153,9 @@ def unwrapped(self) -> Environment: return self @abc.abstractmethod - def reset(self, key: jax.random.KeyArray) -> tuple[State, TimeStep]: + def reset(self, key: jax.random.KeyArray) -> tuple[ + State, TimeStep[Observation, RewardT, DiscountT, Int8[Array, '']] + ]: """Starts a new episode as a functionally pure transformation. Args: @@ -161,7 +171,9 @@ def reset(self, key: jax.random.KeyArray) -> tuple[State, TimeStep]: """ @abc.abstractmethod - def step(self, state: State, action: Action) -> tuple[State, TimeStep]: + def step(self, state: State, action: Action) -> tuple[ + State, TimeStep[Observation, RewardT, DiscountT, Int8[Array, '']] + ]: """Updates the environment according to the given state and action. If the environment already returned a `TimeStep` with `StepType.LAST` @@ -277,8 +289,8 @@ def __exit__(self, exc_type, exc_value, traceback): class Wrapper( - Environment[State, Action, Observation], - Generic[State, Action, Observation], + Environment[State, Action, Observation, RewardT, DiscountT], + Generic[State, Action, Observation, RewardT, DiscountT], metaclass=abc.ABCMeta ): """Interface for Composing Environment logic for RL-Agents. """ @@ -320,10 +332,14 @@ def unwrapped(self) -> Environment: """Helper function to unpack Composite Environments to the base.""" return self.env.unwrapped - def reset(self, key: jax.random.KeyArray) -> tuple[State, TimeStep]: + def reset(self, key: jax.random.KeyArray) -> tuple[ + State, TimeStep[Observation, RewardT, DiscountT, Int8[Array, '']] + ]: return self.env.reset(key) - def step(self, state: State, action: Action) -> tuple[State, TimeStep]: + def step(self, state: State, action: Action) -> tuple[ + State, TimeStep[Observation, RewardT, DiscountT, Int8[Array, '']] + ]: return self.env.step(state, action) def reward_spec(self) -> specs.Spec: @@ -350,7 +366,12 @@ def restart( extras: dict | None = None, shape: int | Sequence[int] = (), dtype: Any = float -) -> TimeStep: +) -> TimeStep[ + Observation, + PyTree[Num[Array, '...']], + PyTree[Num[Array, '...']], + Int8[Array, ''] +]: """Returns a `TimeStep` with `step_type` set to `StepType.FIRST`. Unlike dm_env the reward and discount are not `None` to prevent array @@ -371,7 +392,12 @@ def transition( discount: PyTree[Num[Array, '...']] | None = None, extras: dict | None = None, shape: int | Sequence[int] = () -) -> TimeStep: +) -> TimeStep[ + Observation, + PyTree[Num[Array, '...']], + PyTree[Num[Array, '...']], + Int8[Array, ''] +]: """Returns a `TimeStep` with `step_type` set to `StepType.MID`. """ return TimeStep( step_type=StepType.MID, @@ -387,7 +413,12 @@ def termination( observation: Observation, extras: dict | None = None, shape: int | Sequence[int] = () -) -> TimeStep: +) -> TimeStep[ + Observation, + PyTree[Num[Array, '...']], + PyTree[Num[Array, '...']], + Int8[Array, ''] +]: """Returns a `TimeStep` with `step_type` set to `StepType.LAST`. """ return TimeStep( step_type=StepType.LAST, @@ -404,7 +435,12 @@ def truncation( discount: PyTree[Num[Array, '...']] | None = None, extras: dict | None = None, shape: int | Sequence[int] = () -) -> TimeStep: +) -> TimeStep[ + Observation, + PyTree[Num[Array, '...']], + PyTree[Num[Array, '...']], + Int8[Array, ''] +]: """Alternative to `termination` that does not set `discount` to zero. """ return TimeStep( step_type=StepType.LAST, diff --git a/jit_env/_core_test.py b/jit_env/_core_test.py index 0a85e16..8cccbd4 100644 --- a/jit_env/_core_test.py +++ b/jit_env/_core_test.py @@ -55,25 +55,25 @@ def test_empty_wrapper(dummy_env: jit_env.Environment): _ = jax.tree_map(lambda a, b: a.validate(b), spec, wrap_samples) _ = jax.tree_map(lambda a, b: a.validate(b), wrapped_spec, spec_samples) - chex.assert_trees_all_equal(spec_samples, wrap_samples, ignore_nones=True) + chex.assert_trees_all_equal(spec_samples, wrap_samples) chex.assert_trees_all_equal_shapes_and_dtypes( - spec_samples, wrap_samples, ignore_nones=True + spec_samples, wrap_samples ) out = dummy_env.reset(jax.random.PRNGKey(0)) wrap_out = wrapped.reset(jax.random.PRNGKey(0)) - chex.assert_trees_all_equal(out, wrap_out, ignore_nones=True) + chex.assert_trees_all_equal(out, wrap_out) chex.assert_trees_all_equal_shapes_and_dtypes( - out, wrap_out, ignore_nones=True + out, wrap_out ) out = dummy_env.step(out[0], spec.actions.generate_value()) wrap_out = wrapped.step(wrap_out[0], wrapped_spec.actions.generate_value()) - chex.assert_trees_all_equal(out, wrap_out, ignore_nones=True) + chex.assert_trees_all_equal(out, wrap_out) chex.assert_trees_all_equal_shapes_and_dtypes( - out, wrap_out, ignore_nones=True + out, wrap_out ) diff --git a/jit_env/_specs_test.py b/jit_env/_specs_test.py index d197559..9492c0c 100644 --- a/jit_env/_specs_test.py +++ b/jit_env/_specs_test.py @@ -176,10 +176,10 @@ def test_unpack_spec(in_spec: specs.Spec, expected_tree: PyTree[specs.Spec]): in_spec.validate(sample_spec) chex.assert_trees_all_equal_shapes_and_dtypes( - sample_normal, sample_tree, ignore_nones=True + sample_normal, sample_tree ) chex.assert_trees_all_equal_shapes_and_dtypes( - sample_normal, sample_spec, ignore_nones=True + sample_normal, sample_spec ) diff --git a/jit_env/_wrappers_test.py b/jit_env/_wrappers_test.py index 90eb8aa..b6702a0 100644 --- a/jit_env/_wrappers_test.py +++ b/jit_env/_wrappers_test.py @@ -68,7 +68,8 @@ def test_unwrap(dummy_env: jit_env.Environment): @pytest.mark.usefixtures('dummy_env') def test_jit( dummy_env: jit_env.Environment[ - jit_env.State, jax.Array, jit_env.Observation + jit_env.State, jax.Array, jit_env.Observation, + jit_env.Reward, jit_env.Discount ] ): jitted = wrappers.Jit(dummy_env) @@ -80,26 +81,26 @@ def test_jit( state, step = dummy_env.reset(jax.random.PRNGKey(0)) jit_state, jit_step = jitted.reset(jax.random.PRNGKey(0)) - chex.assert_trees_all_equal(state, jit_state, ignore_nones=True) - chex.assert_trees_all_equal(step, jit_step, ignore_nones=True) + chex.assert_trees_all_equal(state, jit_state) + chex.assert_trees_all_equal(step, jit_step) chex.assert_trees_all_equal_shapes_and_dtypes( - state, jit_state, ignore_nones=True + state, jit_state ) chex.assert_trees_all_equal_shapes_and_dtypes( - (step,), (jit_step,), ignore_nones=True + (step,), (jit_step,) ) # Step logic state, step = dummy_env.step(state, jnp.zeros(())) jit_state, jit_step = jitted.step(jit_state, jnp.zeros(())) - chex.assert_trees_all_equal(state, jit_state, ignore_nones=True) - chex.assert_trees_all_equal(step, jit_step, ignore_nones=True) + chex.assert_trees_all_equal(state, jit_state) + chex.assert_trees_all_equal(step, jit_step) chex.assert_trees_all_equal_shapes_and_dtypes( - state, jit_state, ignore_nones=True + state, jit_state ) chex.assert_trees_all_equal_shapes_and_dtypes( - (step,), (jit_step,), ignore_nones=True + (step,), (jit_step,) ) @@ -156,9 +157,9 @@ def test_autoreset(dummy_env: jit_env.Environment): assert step.first() assert ref_step.first() - chex.assert_trees_all_equal(step, ref_step, ignore_nones=True) + chex.assert_trees_all_equal(step, ref_step) chex.assert_trees_all_equal_shapes_and_dtypes( - (step,), (ref_step,), ignore_nones=True + (step,), (ref_step,) ) for _ in range(5): @@ -168,9 +169,9 @@ def test_autoreset(dummy_env: jit_env.Environment): assert step.mid() assert ref_step.mid() - chex.assert_trees_all_equal(step, ref_step, ignore_nones=True) + chex.assert_trees_all_equal(step, ref_step) chex.assert_trees_all_equal_shapes_and_dtypes( - (step,), (ref_step,), ignore_nones=True + (step,), (ref_step,) ) state, step = env.step(state, None) @@ -196,7 +197,8 @@ class TestVmap: def test_env( self, dummy_env: jit_env.Environment[ - jit_env.State, jax.Array, jit_env.Observation + jit_env.State, jax.Array, jit_env.Observation, + jit_env.Reward, jit_env.Discount ], batch_size: int = 5 ): @@ -211,12 +213,12 @@ def test_env( states, steps = batched.reset(jax.random.split(key, num=batch_size)) chex.assert_tree_shape_prefix( - (states, steps), (batch_size,), ignore_nones=True + (states, steps), (batch_size,) ) sliced = jax.tree_map(lambda x: x.at[0].get(), (states, steps)) chex.assert_trees_all_equal_shapes_and_dtypes( - sliced, (state, step), ignore_nones=True + sliced, (state, step) ) # Step logic @@ -224,19 +226,20 @@ def test_env( states, steps = batched.step(states, jnp.zeros((batch_size,))) chex.assert_tree_shape_prefix( - (states, steps), (batch_size,), ignore_nones=True + (states, steps), (batch_size,) ) sliced = jax.tree_map(lambda x: x.at[0].get(), (states, steps)) chex.assert_trees_all_equal_shapes_and_dtypes( - sliced, (state, step), ignore_nones=True + sliced, (state, step) ) @pytest.mark.usefixtures('dummy_env') def test_render( self, dummy_env: jit_env.Environment[ - jit_env.State, jit_env.Action, jit_env.Observation + jit_env.State, jit_env.Action, jit_env.Observation, + jit_env.Reward, jit_env.Discount ], batch_size: int = 5 ): @@ -253,17 +256,18 @@ def test_render( batch_render = batched.render(states) chex.assert_trees_all_equal( - single_render, batch_render, ignore_nones=True + single_render, batch_render ) chex.assert_trees_all_equal_shapes_and_dtypes( - single_render, batch_render, ignore_nones=True + single_render, batch_render ) @pytest.mark.usefixtures('dummy_env') def test_wrongly_wrapped_autoreset( self, dummy_env: jit_env.Environment[ - jit_env.State, jit_env.Action, jit_env.Observation + jit_env.State, jit_env.Action, jit_env.Observation, + jit_env.Reward, jit_env.Discount ] ): vmap_first = wrappers.AutoReset(wrappers.Vmap(dummy_env)) @@ -284,7 +288,8 @@ def test_wrongly_wrapped_autoreset( def test_autoreset( self, dummy_env: jit_env.Environment[ - jit_env.State, jit_env.Action, jit_env.Observation + jit_env.State, jit_env.Action, jit_env.Observation, + jit_env.Reward, jit_env.Discount ], num: int = 2 ): @@ -303,9 +308,9 @@ def test_autoreset( doubly_out = doubly_wrapped.reset(keys) singly_out = singly_wrapped.reset(keys) - chex.assert_trees_all_equal(doubly_out, singly_out, ignore_nones=True) + chex.assert_trees_all_equal(doubly_out, singly_out) chex.assert_trees_all_equal_shapes_and_dtypes( - singly_out, doubly_out, ignore_nones=True + singly_out, doubly_out ) a = jnp.zeros(()) # Action is held constant across batch @@ -314,20 +319,20 @@ def test_autoreset( singly_out = singly_wrapped.step(singly_out[0], a) chex.assert_trees_all_equal( - doubly_out, singly_out, ignore_nones=True + doubly_out, singly_out ) chex.assert_trees_all_equal_shapes_and_dtypes( - singly_out, doubly_out, ignore_nones=True + singly_out, doubly_out ) doubly_out = doubly_wrapped.step(doubly_out[0], None) singly_out = singly_wrapped.step(singly_out[0], None) chex.assert_trees_all_equal( - doubly_out, singly_out, ignore_nones=True + doubly_out, singly_out ) chex.assert_trees_all_equal_shapes_and_dtypes( - singly_out, doubly_out, ignore_nones=True + singly_out, doubly_out ) @@ -348,12 +353,12 @@ def test_spec(self, dummy_env: jit_env.Environment, num: int = 2): samples = jax.tree_map(lambda s: s.generate_value(), spec) batch = jax.tree_map(lambda s: s.generate_value(), batch_spec) - chex.assert_tree_shape_prefix(batch, (num,), ignore_nones=True) + chex.assert_tree_shape_prefix(batch, (num,)) for i in range(num): sliced = jax.tree_map(lambda x: x.at[i].get(), batch) chex.assert_trees_all_equal_shapes_and_dtypes( - sliced, samples, ignore_nones=True + sliced, samples ) @pytest.mark.usefixtures('dummy_env') @@ -384,14 +389,15 @@ def test_tile(self, dummy_env: jit_env.Environment, num: int = 2): states, steps = tiled_env.reset(jax.random.PRNGKey(0)) # type: ignore chex.assert_tree_shape_prefix( - (states, steps), (num,), ignore_nones=True + (states, steps), (num,) ) @pytest.mark.usefixtures('dummy_env') def test_autoreset( self, dummy_env: jit_env.Environment[ - jit_env.State, jit_env.Action, jit_env.Observation + jit_env.State, jit_env.Action, jit_env.Observation, + jit_env.Reward, jit_env.Discount ], num: int = 2 ): @@ -408,9 +414,9 @@ def test_autoreset( vmap_out = vmapped.reset(keys) tile_out = tiled.reset(jax.random.PRNGKey(0)) - chex.assert_trees_all_equal(vmap_out, tile_out, ignore_nones=True) + chex.assert_trees_all_equal(vmap_out, tile_out) chex.assert_trees_all_equal_shapes_and_dtypes( - tile_out, vmap_out, ignore_nones=True + tile_out, vmap_out ) a = jnp.zeros(()) # Action is held constant across batch @@ -419,18 +425,18 @@ def test_autoreset( tile_out = tiled.step(tile_out[0], a) chex.assert_trees_all_equal( - vmap_out, tile_out, ignore_nones=True + vmap_out, tile_out ) chex.assert_trees_all_equal_shapes_and_dtypes( - tile_out, vmap_out, ignore_nones=True + tile_out, vmap_out ) vmap_out = vmapped.step(vmap_out[0], None) tile_out = tiled.step(tile_out[0], None) chex.assert_trees_all_equal( - vmap_out, tile_out, ignore_nones=True + vmap_out, tile_out ) chex.assert_trees_all_equal_shapes_and_dtypes( - tile_out, vmap_out, ignore_nones=True + tile_out, vmap_out )