Skip to content

Commit

Permalink
test: add test to ensure timestep shapes and dtypes remain consistent…
Browse files Browse the repository at this point in the history
… over reset and step
  • Loading branch information
RuanJohn committed Jan 16, 2024
1 parent 86b150c commit 7b0acb6
Showing 1 changed file with 15 additions and 5 deletions.
20 changes: 15 additions & 5 deletions matrax/env_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def test_matrix_game__reset(matrix_game_env: MatrixGame) -> None:

key1, key2 = random.PRNGKey(0), random.PRNGKey(1)
state1, timestep1 = reset_fn(key1)
state2, timestep2 = reset_fn(key2)
state2, _ = reset_fn(key2)

assert isinstance(timestep1, TimeStep)
assert isinstance(state1, State)
Expand Down Expand Up @@ -111,10 +111,10 @@ def test_matrix_game__step(matrix_game_env_with_state: MatrixGame) -> None:

# Check that rewards have the correct number of dimensions
assert jnp.ndim(timestep1.reward) == 1
assert jnp.ndim(timestep.reward) == 0
assert jnp.ndim(timestep.reward) == 1
# Check that discounts have the correct number of dimensions
assert jnp.ndim(timestep1.discount) == 0
assert jnp.ndim(timestep.discount) == 0
assert jnp.ndim(timestep1.discount) == 1
assert jnp.ndim(timestep.discount) == 1
# Check that the state is made of DeviceArrays, this is false for the non-jitted
# step function since unpacking random.split returns numpy arrays and not device arrays.
assert_is_jax_array_tree(new_state1)
Expand Down Expand Up @@ -157,7 +157,6 @@ def test_matrix_game__reward(matrix_game_env: MatrixGame) -> None:
state, timestep = matrix_game_env.reset(state_key)

state, timestep = step_fn(state, jnp.array([0, 0]))
jax.debug.print("rewards: {r}", r=timestep.reward)
assert jnp.array_equal(timestep.reward, jnp.array([11, 11]))

state, timestep = step_fn(state, jnp.array([1, 0]))
Expand All @@ -174,3 +173,14 @@ def test_matrix_game__reward(matrix_game_env: MatrixGame) -> None:

state, timestep = step_fn(state, jnp.array([2, 2]))
assert jnp.array_equal(timestep.reward, jnp.array([5, 5]))


def test_matrix_game__timesteps_equivalent(matrix_game_env: MatrixGame) -> None:
"""Validate that all timestep attributes have the same dtype and shape over reset and step."""
step_fn = jax.jit(matrix_game_env.step)
state_key = random.PRNGKey(10)
state, init_timestep = matrix_game_env.reset(state_key)

state, new_timestep = step_fn(state, jnp.array([0, 0]))

chex.assert_trees_all_equal_shapes_and_dtypes(init_timestep, new_timestep)

0 comments on commit 7b0acb6

Please sign in to comment.