Skip to content

Commit

Permalink
Update chex.assert_type to check concrete types instead of just ass…
Browse files Browse the repository at this point in the history
…erting that the type is a floating/integer sub-type.

Previously, `assert_type` would only check that the input was of the same parent type. For example:
```
x = np.ones((1,), dtype=np.float32)
chex.assert_type(x, np.float64)  # Succeeds
chex.assert_type(x, np.int32)  # Fails.
```

Instead, if a concrete dtype is provided we check that the input has the same type. If `float` or `np.floating` is provided, we continue to only assert that the input is the same parent.

```
x = np.ones((1,), dtype=np.float32)
chex.assert_type(x, np.float64) # Fails
chex.assert_type(x, float) # Succeeds.
```
PiperOrigin-RevId: 609283182
  • Loading branch information
tomwardio authored and ChexDev committed Feb 22, 2024
1 parent d4d3467 commit fd62ce0
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 40 deletions.
42 changes: 25 additions & 17 deletions chex/_src/asserts.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import functools
import inspect
import traceback
from typing import Any, Callable, List, Optional, Sequence, Set, Type, Union, cast
from typing import Any, Callable, List, Optional, Sequence, Set, Union, cast
import unittest
from unittest import mock

Expand All @@ -33,8 +33,10 @@

Scalar = pytypes.Scalar
Array = pytypes.Array
ArrayDType = pytypes.ArrayDType # pylint:disable=invalid-name
ArrayTree = pytypes.ArrayTree


_value_assertion = _ai.chex_assertion
_static_assertion = functools.partial(
_ai.chex_assertion, jittable_assert_fn=None)
Expand Down Expand Up @@ -782,10 +784,14 @@ def assert_rank(
@_static_assertion
def assert_type(
inputs: Union[Scalar, Union[Array, Sequence[Array]]],
expected_types: Union[Type[Scalar], Sequence[Type[Scalar]]]) -> None:
expected_types: Union[ArrayDType, Sequence[ArrayDType]]) -> None:
"""Checks that the type of all inputs matches specified ``expected_types``.
Valid usages include:
If the expected type is a Python type or abstract dtype (e.g. `np.floating`),
assert that the input has the same sub-type. If the expected type is a
concrete dtype (e.g. np.float32), assert that the input's type is the same.
Example usage:
.. code-block:: python
Expand All @@ -796,8 +802,9 @@ def assert_type(
assert_type([7, 7.1], [int, float])
assert_type(np.array(7), int)
assert_type(np.array(7.1), float)
assert_type(jnp.array(7), int)
assert_type([jnp.array([7, 8]), np.array(7.1)], [int, float])
assert_type(jnp.array(1., dtype=jnp.bfloat16)), jnp.bfloat16)
assert_type(jnp.ones(1, dtype=np.int8), np.int8)
Args:
inputs: An array or a sequence of arrays or scalars.
Expand All @@ -817,21 +824,22 @@ def assert_type(

errors = []
if len(inputs) != len(expected_types):
raise AssertionError(f"Length of `inputs` and `expected_types` must match, "
f"got {len(inputs)} != {len(expected_types)}.")
raise AssertionError(
"Length of `inputs` and `expected_types` must match, "
f"got {len(inputs)} != {len(expected_types)}."
)
for idx, (x, expected) in enumerate(zip(inputs, expected_types)):
if jnp.issubdtype(expected, jnp.floating):
parent = jnp.floating
elif jnp.issubdtype(expected, jnp.integer):
parent = jnp.integer
elif jnp.issubdtype(expected, jnp.bool_):
parent = jnp.bool_
dtype = np.result_type(x)
if expected in {float, jnp.floating}:
if not jnp.issubdtype(dtype, jnp.floating):
errors.append((idx, dtype, expected))
elif expected in {int, jnp.integer}:
if not jnp.issubdtype(dtype, jnp.integer):
errors.append((idx, dtype, expected))
else:
raise AssertionError(
f"Error in type compatibility check, unsupported dtype '{expected}'.")

if not jnp.issubdtype(jnp.result_type(x), parent):
errors.append((idx, jnp.result_type(x), expected))
expected = np.dtype(expected)
if dtype != expected:
errors.append((idx, dtype, expected))

if errors:
msg = "; ".join(
Expand Down
41 changes: 18 additions & 23 deletions chex/_src/asserts_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@ def array_from_shape(*shape):
return np.ones(shape=shape)


def emplace(arrays):
return arrays
def emplace(arrays, dtype):
return jnp.array(arrays, dtype=dtype)


class AssertsSwitchTest(parameterized.TestCase):
Expand Down Expand Up @@ -664,11 +664,14 @@ def test_type_should_fail_scalar(self, scalars, wrong_type):

@variants.variants(with_device=True, without_device=True)
@parameterized.named_parameters(
('one_float_array', [1., 2.], int),
('one_int_array', [1, 2], float),
('one_float_array', [1., 2.], float, int),
('one_int_array', [1, 2], int, float),
('bfloat16_array', [1, 2], jnp.bfloat16, jnp.float32),
('int8_array', [1, 2], jnp.int8, jnp.int32),
('float32_array', [1, 2], jnp.float32, np.integer),
)
def test_type_should_fail_array(self, array, wrong_type):
array = self.variant(emplace)(array)
def test_type_should_fail_array(self, array, dtype, wrong_type):
array = self.variant(emplace)(array, dtype)
with self.assertRaisesRegex(
AssertionError, _get_err_regex('input .+ has type .+ but expected .+')):
asserts.assert_type(array, wrong_type)
Expand All @@ -680,17 +683,19 @@ def test_type_should_fail_array(self, array, wrong_type):
('many_floats', [1., 2., 3.], float),
('many_floats_verbose', [1., 2., 3.], [float, float, float]),
)
def test_type_should_pass_scalar(self, array, wrong_type):
asserts.assert_type(array, wrong_type)
def test_type_should_pass_scalar(self, array, expected_type):
asserts.assert_type(array, expected_type)

@variants.variants(with_device=True, without_device=True)
@parameterized.named_parameters(
('one_float_array', [1., 2.], float),
('one_int_array', [1, 2], int),
('one_float_array', [1., 2.], float, float),
('one_int_array', [1, 2], int, int),
('one_integer_array', [1, 2], int, np.integer),
('one_bool_array', [True], bool, bool),
)
def test_type_should_pass_array(self, array, wrong_type):
array = self.variant(emplace)(array)
asserts.assert_type(array, wrong_type)
def test_type_should_pass_array(self, array, dtype, expected_type):
array = self.variant(emplace)(array, dtype)
asserts.assert_type(array, expected_type)

def test_type_should_fail_mixed(self):
a_float = 1.
Expand Down Expand Up @@ -720,16 +725,6 @@ def test_type_should_fail_wrong_length(self, array, wrong_type):
_get_err_regex('Length of `inputs` and `expected_types` must match')):
asserts.assert_type(array, wrong_type)

def test_type_should_fail_unsupported_dtype(self):
a_float = 1.
an_int = 2
a_np_float = np.asarray([3., 4.])
a_jax_int = jnp.asarray([5, 6])
with self.assertRaisesRegex(AssertionError,
_get_err_regex('unsupported dtype')):
asserts.assert_type([a_float, an_int, a_np_float, a_jax_int],
[complex, complex, float, int])


class AxisDimensionAssertionsTest(parameterized.TestCase):

Expand Down

0 comments on commit fd62ce0

Please sign in to comment.