diff --git a/pyproject.toml b/pyproject.toml index ac805c89..7dd76da9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -43,6 +43,7 @@ dev = [ "ruff==0.6.2", "pre-commit==3.3.2", "pytest", + "hypothesis[numpy]", ] examples = [ "torch>=1.13.1", diff --git a/tests/test_transforms_ops.py b/tests/test_transforms_ops.py index 583db9c0..019626d3 100644 --- a/tests/test_transforms_ops.py +++ b/tests/test_transforms_ops.py @@ -4,7 +4,8 @@ import numpy as onp import numpy.typing as onpt -from jax import numpy as jnp +import viser.transforms as vtf + from utils import ( assert_arrays_close, assert_transforms_close, @@ -12,8 +13,6 @@ sample_transform, ) -import viser.transforms as vtf - @general_group_test def test_sample_uniform_valid( @@ -119,7 +118,7 @@ def test_multiply( T_b_a = sample_transform(Group, batch_axes, dtype) assert_arrays_close( onp.einsum( - "...ij,...jk->...ik", T_w_b.as_matrix(), jnp.linalg.inv(T_w_b.as_matrix()) + "...ij,...jk->...ik", T_w_b.as_matrix(), onp.linalg.inv(T_w_b.as_matrix()) ), onp.broadcast_to( onp.eye(Group.matrix_dim, dtype=dtype),