Skip to content

Commit

Permalink
Add more out_axes tests
Browse files Browse the repository at this point in the history
  • Loading branch information
michalk8 committed Oct 15, 2024
1 parent 721eca9 commit f9a41bd
Showing 1 changed file with 20 additions and 0 deletions.
20 changes: 20 additions & 0 deletions tests/utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,26 @@ def f(x: jnp.ndarray, y: jnp.ndarray) -> Any:

chex.assert_trees_all_equal(gt_fn(x, y), fn(x, y))

@pytest.mark.parametrize(
"out_axes", [0, (0, 0, 1), (0, {
"x": {
"y": 1
}
}, (1,))]
)
def test_multiple_out_axes(self, rng: jax.Array, out_axes: Any):

def f(x: jnp.ndarray) -> Any:
z = jnp.arange(9).reshape(3, 3)
return x.mean(), {"x": {"y": jnp.ones(13)}}, (z,)

x = jax.random.normal(rng, (13, 5))

fn = utils.batched_vmap(f, batch_size=12, out_axes=out_axes)
gt_fn = jax.vmap(f, out_axes=out_axes)

chex.assert_trees_all_equal(gt_fn(x), fn(x))

@pytest.mark.parametrize("n", [16, 7])
@pytest.mark.parametrize("batch_size", [1, 4, 5, 7, 16])
def test_max_traces(self, rng: jax.Array, batch_size: int, n: int):
Expand Down

0 comments on commit f9a41bd

Please sign in to comment.