Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
Hylke Cornelis Donker committed May 30, 2023
1 parent 42a39e2 commit e658b67
Showing 1 changed file with 112 additions and 33 deletions.
145 changes: 112 additions & 33 deletions chex/_src/asserts_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,12 +423,18 @@ def test_equal_shape_at_dims_should_fail(self, shapes, dims):
class ShapeAssertTest(parameterized.TestCase):

@parameterized.named_parameters(
('wrong_rank', [1], (1,)),
('wrong_shape', [1, 2], (1, 3)),
('some_wrong_shape', [[1, 2], [3, 4]], [(1, 2), (1, 3)]),
('wrong_common_shape', [[1, 2], [3, 4, 3]], (2,)),
('wrong_common_shape_2', [[1, 2, 3], [1, 2]], (2,)),
('some_wrong_shape_set', [[1, 2], [3, 4]], [(1, 2), (1, {3, 4})]),
('wrong_rank_tuple', [1], (1,)),
('wrong_rank_list', [1], [1]),
('wrong_shape_tuple', [1, 2], (1, 3)),
('wrong_shape_list', [1, 2], [1, 3]),
('some_wrong_shape_tuple', [[1, 2], [3, 4]], [(1, 2), (1, 3)]),
('some_wrong_shape_list', [[1, 2], [3, 4]], [[1, 2], [1, 3]]),
('wrong_common_shape_tuple', [[1, 2], [3, 4, 3]], (2,)),
('wrong_common_shape_list', [[1, 2], [3, 4, 3]], [2]),
('wrong_common_shape_tuple_2', [[1, 2, 3], [1, 2]], (2,)),
('wrong_common_shape_list_2', [[1, 2, 3], [1, 2]], [2]),
('some_wrong_shape_tuple_set', [[1, 2], [3, 4]], [(1, 2), (1, {3, 4})]),
('some_wrong_shape_list_set', [[1, 2], [3, 4]], [[1, 2], [1, {3, 4}]]),
)
def test_shape_should_fail(self, arrays, shapes):
arrays = as_arrays(arrays)
Expand All @@ -438,8 +444,10 @@ def test_shape_should_fail(self, arrays, shapes):
asserts.assert_shape(arrays, shapes)

@parameterized.named_parameters(
('too_many_shapes', [[1]], [(1,), (2,)]),
('not_enough_shapes', [[1, 2], [3, 4]], [(3,)]),
('too_many_shapes_tuple', [[1]], [(1,), (2,)]),
('too_many_shapes_list', [[1]], [[1], [2]]),
('not_enough_shapes_tuple', [[1, 2], [3, 4]], [(3,)]),
('not_enough_shapes_list', [[1, 2], [3, 4]], [[3]]),
)
def test_shape_should_fail_wrong_length(self, arrays, shapes):
arrays = as_arrays(arrays)
Expand All @@ -449,13 +457,20 @@ def test_shape_should_fail_wrong_length(self, arrays, shapes):
asserts.assert_shape(arrays, shapes)

@parameterized.named_parameters(
('scalars', [1, 2], ()),
('vectors', [[1, 2], [3, 4, 5]], [(2,), (3,)]),
('matrices', [[[1, 2], [3, 4]]], (2, 2)),
('matrices_variable_shape', [[[1, 2], [3, 4]]], (None, 2)),
('vectors_common_shape', [[1, 2], [3, 4]], (2,)),
('variable_common_shape', [[[1, 2], [3, 4]], [[1], [3]]], (2, None)),
('common_shape_set', [[[1, 2], [3, 4]], [[1], [3]]], (2, {1, 2})),
('scalars_tuple_shape', [1, 2], ()),
('scalars_list_shape', [1, 2], []),
('vectors_tuple_shape', [[1, 2], [3, 4, 5]], [(2,), (3,)]),
('vectors_list_shape', [[1, 2], [3, 4, 5]], [[2], [3]]),
('matrices_tuple_shape', [[[1, 2], [3, 4]]], (2, 2)),
('matrices_list_shape', [[[1, 2], [3, 4]]], [2, 2]),
('matrices_variable_tuple_shape', [[[1, 2], [3, 4]]], (None, 2)),
('matrices_variable_list_shape', [[[1, 2], [3, 4]]], [None, 2]),
('vectors_common_tuple_shape', [[1, 2], [3, 4]], (2,)),
('vectors_common_list_shape', [[1, 2], [3, 4]], [2]),
('variable_common_tuple_shape', [[[1, 2], [3, 4]], [[1], [3]]], (2, None)),
('variable_common_list_shape', [[[1, 2], [3, 4]], [[1], [3]]], [2, None]),
('common_shape_tuple_set', [[[1, 2], [3, 4]], [[1], [3]]], (2, {1, 2})),
('common_shape_list_set', [[[1, 2], [3, 4]], [[1], [3]]], [2, {1, 2}]),
)
def test_shape_should_pass(self, arrays, shapes):
arrays = as_arrays(arrays)
Expand All @@ -465,31 +480,51 @@ def test_pytypes_pass(self):
arrays = as_arrays([[[1, 2], [3, 4]], [[1], [3]]])
asserts.assert_shape(arrays, (2, None))
asserts.assert_shape(arrays, (2, {1, 2}))
asserts.assert_shape(arrays, (2, ...))
asserts.assert_shape(arrays, [2, ...])
asserts.assert_shape(arrays, [2, None])
asserts.assert_shape(arrays, [2, {1, 2}])
asserts.assert_shape(arrays, [2, ...])

@parameterized.named_parameters(
('prefix_2', array_from_shape(2, 3, 4, 5, 6), (..., 4, 5, 6)),
('prefix_1', array_from_shape(3, 4, 5, 6), (..., 4, 5, 6)),
('prefix_0', array_from_shape(4, 5, 6), (..., 4, 5, 6)),
('inner_2', array_from_shape(2, 3, 4, 5, 6), (2, 3, ..., 6)),
('inner_1', array_from_shape(2, 3, 4, 6), (2, 3, ..., 6)),
('inner_0', array_from_shape(2, 3, 6), (2, 3, ..., 6)),
('suffix_2', array_from_shape(2, 3, 4, 5, 6), (2, 3, 4, ...)),
('suffix_1', array_from_shape(2, 3, 4, 5), (2, 3, 4, ...)),
('suffix_0', array_from_shape(2, 3, 4), (2, 3, 4, ...)),
('prefix_2_tuple', array_from_shape(2, 3, 4, 5, 6), (..., 4, 5, 6)),
('prefix_2_list', array_from_shape(2, 3, 4, 5, 6), [..., 4, 5, 6]),
('prefix_1_tuple', array_from_shape(3, 4, 5, 6), (..., 4, 5, 6)),
('prefix_1_list', array_from_shape(3, 4, 5, 6), [..., 4, 5, 6]),
('prefix_0_tuple', array_from_shape(4, 5, 6), (..., 4, 5, 6)),
('prefix_0_list', array_from_shape(4, 5, 6), [..., 4, 5, 6]),
('inner_2_tuple', array_from_shape(2, 3, 4, 5, 6), (2, 3, ..., 6)),
('inner_2_list', array_from_shape(2, 3, 4, 5, 6), [2, 3, ..., 6]),
('inner_1_tuple', array_from_shape(2, 3, 4, 6), (2, 3, ..., 6)),
('inner_1_list', array_from_shape(2, 3, 4, 6), [2, 3, ..., 6]),
('inner_0_tuple', array_from_shape(2, 3, 6), (2, 3, ..., 6)),
('inner_0_list', array_from_shape(2, 3, 6), [2, 3, ..., 6]),
('suffix_2_tuple', array_from_shape(2, 3, 4, 5, 6), (2, 3, 4, ...)),
('suffix_2_list', array_from_shape(2, 3, 4, 5, 6), [2, 3, 4, ...]),
('suffix_1_tuple', array_from_shape(2, 3, 4, 5), (2, 3, 4, ...)),
('suffix_1_list', array_from_shape(2, 3, 4, 5), [2, 3, 4, ...]),
('suffix_0_tuple', array_from_shape(2, 3, 4), (2, 3, 4, ...)),
('suffix_0_list', array_from_shape(2, 3, 4), [2, 3, 4, ...]),
)
def test_ellipsis_should_pass(self, array, expected_shape):
asserts.assert_shape(array, expected_shape)

@parameterized.named_parameters(
('prefix', array_from_shape(3, 1, 5), (..., 4, 5, 6)),
('inner_bad_prefix', array_from_shape(2, 1, 4, 6), (2, 3, ..., 6)),
('inner_bad_suffix', array_from_shape(2, 3, 1, 5), (2, 3, ..., 6)),
('inner_both_bad', array_from_shape(2, 1, 4, 5), (2, 3, ..., 6)),
('suffix', array_from_shape(2, 3, 1, 5), (2, 3, 4, ...)),
('short_rank_prefix', array_from_shape(2, 3), (..., 4, 5, 6)),
('short_rank_inner', array_from_shape(2, 3), (2, 3, ..., 6)),
('short_rank_suffix', array_from_shape(2, 3), (2, 3, 4, ...)),
('prefix_tuple', array_from_shape(3, 1, 5), (..., 4, 5, 6)),
('prefix_list', array_from_shape(3, 1, 5), [..., 4, 5, 6]),
('inner_bad_prefix_tuple', array_from_shape(2, 1, 4, 6), (2, 3, ..., 6)),
('inner_bad_prefix_list', array_from_shape(2, 1, 4, 6), [2, 3, ..., 6]),
('inner_bad_suffix_tuple', array_from_shape(2, 3, 1, 5), (2, 3, ..., 6)),
('inner_bad_suffix_list', array_from_shape(2, 3, 1, 5), [2, 3, ..., 6]),
('inner_both_bad_tuple', array_from_shape(2, 1, 4, 5), (2, 3, ..., 6)),
('inner_both_bad_list', array_from_shape(2, 1, 4, 5), [2, 3, ..., 6]),
('suffix_tuple', array_from_shape(2, 3, 1, 5), (2, 3, 4, ...)),
('suffix_list', array_from_shape(2, 3, 1, 5), [2, 3, 4, ...]),
('short_rank_prefix_tuple', array_from_shape(2, 3), (..., 4, 5, 6)),
('short_rank_prefix_list', array_from_shape(2, 3), [..., 4, 5, 6]),
('short_rank_inner_tuple', array_from_shape(2, 3), (2, 3, ..., 6)),
('short_rank_inner_list', array_from_shape(2, 3), [2, 3, ..., 6]),
('short_rank_suffix_tuple', array_from_shape(2, 3), (2, 3, 4, ...)),
('short_rank_suffix_list', array_from_shape(2, 3), [2, 3, 4, ...]),
)
def test_ellipsis_should_fail(self, array, expected_shape):
with self.assertRaisesRegex(
Expand Down Expand Up @@ -1449,20 +1484,40 @@ def test_assert_tree_shape_prefix(self):
_get_err_regex(r'leaf \'x/y\' has a shape of length 2')):
asserts.assert_tree_shape_prefix(tree, (3, 2, 1))

# Also test a shape_prefix that's a list instead of tuple.
asserts.assert_tree_shape_prefix(tree, [])
asserts.assert_tree_shape_prefix(tree, [3])
asserts.assert_tree_shape_prefix(tree, [3, 2])

with self.assertRaisesRegex(
AssertionError,
_get_err_regex(r'leaf \'x/y\' has a shape of length 2')):
asserts.assert_tree_shape_prefix(tree, [3, 2, 1])

def test_assert_tree_shape_prefix_none(self):
tree = {'x': np.zeros([3]), 'n': None}
asserts.assert_tree_shape_prefix(tree, (3,), ignore_nones=True)
asserts.assert_tree_shape_prefix(tree, [3], ignore_nones=True)

with self.assertRaisesRegex(AssertionError,
_get_err_regex('`None` detected')):
asserts.assert_tree_shape_prefix(tree, (3,), ignore_nones=False)

with self.assertRaisesRegex(AssertionError,
_get_err_regex('`None` detected')):
asserts.assert_tree_shape_prefix(tree, [3], ignore_nones=False)

def test_assert_tree_shape_suffix_matching(self):
tree = {'x': {'y': np.zeros([4, 2, 1])}, 'z': np.zeros([2, 1])}
asserts.assert_tree_shape_suffix(tree, ())
asserts.assert_tree_shape_suffix(tree, (1,))
asserts.assert_tree_shape_suffix(tree, (2, 1))

# Also test shape_suffix that's a list instead of a tuple.
asserts.assert_tree_shape_suffix(tree, [])
asserts.assert_tree_shape_suffix(tree, [1])
asserts.assert_tree_shape_suffix(tree, [2, 1])

def test_assert_tree_shape_suffix_mismatch(self):
tree = {'x': {'y': np.zeros([4, 2, 1])}, 'z': np.zeros([1, 1])}

Expand All @@ -1472,29 +1527,53 @@ def test_assert_tree_shape_suffix_mismatch(self):
r'Tree leaf \'z\'.*different from expected: \(1, 1\) != \(2, 1\)')):
asserts.assert_tree_shape_suffix(tree, (2, 1))

with self.assertRaisesRegex(
AssertionError,
_get_err_regex(
r'Tree leaf \'z\'.*different from expected: \(1, 1\) != \(2, 1\)')):
asserts.assert_tree_shape_suffix(tree, [2, 1])

with self.assertRaisesRegex(
AssertionError,
_get_err_regex(
r'Tree leaf \'x/y\'.*different from expected: \(2, 1\) != \(1, 1\)')
):
asserts.assert_tree_shape_suffix(tree, (1, 1))

with self.assertRaisesRegex(
AssertionError,
_get_err_regex(
r'Tree leaf \'x/y\'.*different from expected: \(2, 1\) != \(1, 1\)')
):
asserts.assert_tree_shape_suffix(tree, [1, 1])

def test_assert_tree_shape_suffix_long_suffix(self):
tree = {'x': {'y': np.zeros([4, 2, 1])}, 'z': np.zeros([4, 2, 1])}
asserts.assert_tree_shape_suffix(tree, (4, 2, 1))
asserts.assert_tree_shape_suffix(tree, [4, 2, 1])

with self.assertRaisesRegex(
AssertionError, _get_err_regex('which is smaller than the expected')):
asserts.assert_tree_shape_suffix(tree, (3, 4, 2, 1))

with self.assertRaisesRegex(
AssertionError, _get_err_regex('which is smaller than the expected')):
asserts.assert_tree_shape_suffix(tree, [3, 4, 2, 1])

def test_assert_tree_shape_suffix_none(self):
tree = {'x': np.zeros([3]), 'n': None}
asserts.assert_tree_shape_suffix(tree, (3,), ignore_nones=True)
asserts.assert_tree_shape_suffix(tree, [3], ignore_nones=True)

with self.assertRaisesRegex(AssertionError,
_get_err_regex('`None` detected')):
asserts.assert_tree_shape_suffix(tree, (3,), ignore_nones=False)

with self.assertRaisesRegex(AssertionError,
_get_err_regex('`None` detected')):
asserts.assert_tree_shape_suffix(tree, [3], ignore_nones=False)


def test_assert_trees_all_equal_dtypes(self):
t_0 = {'x': np.zeros(3, dtype=np.int16), 'y': np.ones(2, dtype=np.float32)}
t_1 = {'x': np.zeros(5, dtype=np.uint16), 'y': np.ones(4, dtype=np.float32)}
Expand Down

0 comments on commit e658b67

Please sign in to comment.