Skip to content

Commit

Permalink
Add in_axes test
Browse files Browse the repository at this point in the history
  • Loading branch information
michalk8 committed Oct 15, 2024
1 parent f9a41bd commit 78003d9
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 7 deletions.
5 changes: 1 addition & 4 deletions src/ott/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,10 +245,7 @@ def _prepare_axes(
) -> Any:
axes = jax.api_util.flatten_axes(name, treedef, axes, kws=False)
assert len(leaves) == len(axes), (len(leaves), len(axes))
axes = [
axis if axis is None else _canonicalize_axis(axis, jnp.ndim(leaf))
for axis, leaf in zip(axes, leaves)
]
# TODO(michalk8): enable negative axes
return axes if return_flat else treedef.unflatten(axes)


Expand Down
20 changes: 17 additions & 3 deletions tests/utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,22 @@ def f(x: Any) -> jnp.ndarray:

np.testing.assert_array_equal(gt_fn(x), fn(x))

def test_in_axes(self):
pass
@pytest.mark.parametrize("in_axes", [0, 1, [0, None]])
def test_in_axes(self, rng: jax.Array, in_axes: Any):

def f(x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray:
x = jnp.atleast_2d(x)
y = jnp.atleast_2d(y)
return jnp.dot(x, y.T)

rng1, rng2 = jax.random.split(rng, 2)
x = jax.random.normal(rng1, (15, 3)) + 10.0
y = jax.random.normal(rng2, (15, 3))

gt_fn = jax.jit(jax.vmap(f, in_axes=in_axes))
fn = jax.jit(utils.batched_vmap(f, batch_size=5, in_axes=in_axes))

np.testing.assert_array_equal(gt_fn(x, y), fn(x, y))

@pytest.mark.parametrize("out_axes", [0, 1, 2])
def test_out_axes(self, rng: jax.Array, out_axes: int):
Expand All @@ -81,7 +95,7 @@ def f(x: jnp.ndarray, y: jnp.ndarray) -> Any:
}
}, (1,))]
)
def test_multiple_out_axes(self, rng: jax.Array, out_axes: Any):
def test_out_axes_pytree(self, rng: jax.Array, out_axes: Any):

def f(x: jnp.ndarray) -> Any:
z = jnp.arange(9).reshape(3, 3)
Expand Down

0 comments on commit 78003d9

Please sign in to comment.