Skip to content

Commit

Permalink
Fixes #274
Browse files Browse the repository at this point in the history
  • Loading branch information
Hylke Cornelis Donker committed Jun 2, 2023
1 parent 42a39e2 commit 968d7f3
Show file tree
Hide file tree
Showing 2 changed files with 113 additions and 35 deletions.
7 changes: 7 additions & 0 deletions chex/_src/asserts.py
Original file line number Diff line number Diff line change
Expand Up @@ -1280,6 +1280,9 @@ def assert_tree_shape_prefix(tree: ArrayTree,
if not ignore_nones:
assert_tree_no_nones(tree)

# To compare with the leaf's `shape`, convert int sequence to tuple.
shape_prefix = tuple(shape_prefix)

if not shape_prefix:
return # No prefix, this is trivially true.

Expand Down Expand Up @@ -1328,6 +1331,10 @@ def assert_tree_shape_suffix(tree: ArrayTree,
"""
if not ignore_nones:
assert_tree_no_nones(tree)

# To compare with the leaf's `shape`, convert int sequence to tuple.
shape_suffix = tuple(shape_suffix)

if not shape_suffix:
return # No suffix, this is trivially true.

Expand Down
141 changes: 106 additions & 35 deletions chex/_src/asserts_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,10 +432,16 @@ class ShapeAssertTest(parameterized.TestCase):
)
def test_shape_should_fail(self, arrays, shapes):
arrays = as_arrays(arrays)
with self.assertRaisesRegex(
AssertionError,
_get_err_regex('input .+ has shape .+ but expected .+')):
asserts.assert_shape(arrays, shapes)
with self.subTest('list'):
with self.assertRaisesRegex(
AssertionError,
_get_err_regex('input .+ has shape .+ but expected .+')):
asserts.assert_shape(arrays, list(shapes))
with self.subTest('tuple'):
with self.assertRaisesRegex(
AssertionError,
_get_err_regex('input .+ has shape .+ but expected .+')):
asserts.assert_shape(arrays, tuple(shapes))

@parameterized.named_parameters(
('too_many_shapes', [[1]], [(1,), (2,)]),
Expand All @@ -446,7 +452,11 @@ def test_shape_should_fail_wrong_length(self, arrays, shapes):
with self.assertRaisesRegex(
AssertionError,
_get_err_regex('Length of `inputs` and `expected_shapes` must match')):
asserts.assert_shape(arrays, shapes)
asserts.assert_shape(arrays, tuple(shapes))
with self.assertRaisesRegex(
AssertionError,
_get_err_regex('Length of `inputs` and `expected_shapes` must match')):
asserts.assert_shape(arrays, list(shapes))

@parameterized.named_parameters(
('scalars', [1, 2], ()),
Expand All @@ -459,13 +469,22 @@ def test_shape_should_fail_wrong_length(self, arrays, shapes):
)
def test_shape_should_pass(self, arrays, shapes):
arrays = as_arrays(arrays)
asserts.assert_shape(arrays, shapes)
with self.subTest('tuple'):
asserts.assert_shape(arrays, tuple(shapes))
with self.subTest('list'):
asserts.assert_shape(arrays, list(shapes))

def test_pytypes_pass(self):
@parameterized.named_parameters(
('variable_shape', (2, None)),
('shape_set', (2, {1, 2})),
('suffix', (2, ...)),
)
def test_pytypes_pass(self, shape):
arrays = as_arrays([[[1, 2], [3, 4]], [[1], [3]]])
asserts.assert_shape(arrays, (2, None))
asserts.assert_shape(arrays, (2, {1, 2}))
asserts.assert_shape(arrays, (2, ...))
with self.subTest('tuple'):
asserts.assert_shape(arrays, tuple(shape))
with self.subTest('list'):
asserts.assert_shape(arrays, list(shape))

@parameterized.named_parameters(
('prefix_2', array_from_shape(2, 3, 4, 5, 6), (..., 4, 5, 6)),
Expand All @@ -479,7 +498,10 @@ def test_pytypes_pass(self):
('suffix_0', array_from_shape(2, 3, 4), (2, 3, 4, ...)),
)
def test_ellipsis_should_pass(self, array, expected_shape):
asserts.assert_shape(array, expected_shape)
with self.subTest('list'):
asserts.assert_shape(array, list(expected_shape))
with self.subTest('tuple'):
asserts.assert_shape(array, tuple(expected_shape))

@parameterized.named_parameters(
('prefix', array_from_shape(3, 1, 5), (..., 4, 5, 6)),
Expand All @@ -492,10 +514,16 @@ def test_ellipsis_should_pass(self, array, expected_shape):
('short_rank_suffix', array_from_shape(2, 3), (2, 3, 4, ...)),
)
def test_ellipsis_should_fail(self, array, expected_shape):
with self.assertRaisesRegex(
AssertionError,
_get_err_regex('input .+ has shape .+ but expected .+')):
asserts.assert_shape(array, expected_shape)
with self.subTest('tuple'):
with self.assertRaisesRegex(
AssertionError,
_get_err_regex('input .+ has shape .+ but expected .+')):
asserts.assert_shape(array, tuple(expected_shape))
with self.subTest('list'):
with self.assertRaisesRegex(
AssertionError,
_get_err_regex('input .+ has shape .+ but expected .+')):
asserts.assert_shape(array, list(expected_shape))

@parameterized.named_parameters(
('prefix_and_suffix', array_from_shape(2, 3), (..., 2, 3, ...)),)
Expand Down Expand Up @@ -1438,63 +1466,106 @@ def test_assert_trees_all_equal_structs(self):
asserts.assert_trees_all_equal_structs(tree3, tree3)
self._assert_tree_structs_validation(asserts.assert_trees_all_equal_structs)

def test_assert_tree_shape_prefix(self):
@parameterized.named_parameters(
('scalars', ()),
('vectors', (3,)),
('matrices', (3, 2)),
)
def test_assert_tree_shape_prefix(self, shape):
tree = {'x': {'y': np.zeros([3, 2])}, 'z': np.zeros([3, 2, 1])}
asserts.assert_tree_shape_prefix(tree, ())
asserts.assert_tree_shape_prefix(tree, (3,))
asserts.assert_tree_shape_prefix(tree, (3, 2))
with self.subTest('tuple'):
asserts.assert_tree_shape_prefix(tree, tuple(shape))
with self.subTest('list'):
asserts.assert_tree_shape_prefix(tree, list(shape))

def test_leaf_shape_should_fail_wrong_length(self):
tree = {'x': {'y': np.zeros([3, 2])}, 'z': np.zeros([3, 2, 1])}
with self.assertRaisesRegex(
AssertionError,
_get_err_regex(r'leaf \'x/y\' has a shape of length 2')):
asserts.assert_tree_shape_prefix(tree, (3, 2, 1))
with self.assertRaisesRegex(
AssertionError,
_get_err_regex(r'leaf \'x/y\' has a shape of length 2')):
asserts.assert_tree_shape_prefix(tree, [3, 2, 1])

def test_assert_tree_shape_prefix_none(self):
tree = {'x': np.zeros([3]), 'n': None}
asserts.assert_tree_shape_prefix(tree, (3,), ignore_nones=True)
asserts.assert_tree_shape_prefix(tree, [3], ignore_nones=True)

with self.assertRaisesRegex(AssertionError,
_get_err_regex('`None` detected')):
asserts.assert_tree_shape_prefix(tree, (3,), ignore_nones=False)

def test_assert_tree_shape_suffix_matching(self):
with self.assertRaisesRegex(AssertionError,
_get_err_regex('`None` detected')):
asserts.assert_tree_shape_prefix(tree, [3], ignore_nones=False)

@parameterized.named_parameters(
('scalars', ()),
('vectors', (1,)),
('matrices', (2, 1)),
)
def test_assert_tree_shape_suffix_matching(self, shape):
tree = {'x': {'y': np.zeros([4, 2, 1])}, 'z': np.zeros([2, 1])}
asserts.assert_tree_shape_suffix(tree, ())
asserts.assert_tree_shape_suffix(tree, (1,))
asserts.assert_tree_shape_suffix(tree, (2, 1))
with self.subTest('tuple'):
asserts.assert_tree_shape_suffix(tree, tuple(shape))
with self.subTest('list'):
asserts.assert_tree_shape_suffix(tree, list(shape))

def test_assert_tree_shape_suffix_mismatch(self):
@parameterized.named_parameters(
('bad_suffix_leaf_1', 'z', (1, 1), (2, 1)),
('bad_suffix_leaf_2', 'x/y', (2, 1), (1, 1)),
)
def test_assert_tree_shape_suffix_mismatch(self, leaf, shape_true, shape):
tree = {'x': {'y': np.zeros([4, 2, 1])}, 'z': np.zeros([1, 1])}

with self.assertRaisesRegex(
AssertionError,
_get_err_regex(
r'Tree leaf \'z\'.*different from expected: \(1, 1\) != \(2, 1\)')):
asserts.assert_tree_shape_suffix(tree, (2, 1))
error_msg = (
r'Tree leaf \'' + str(leaf) + '\'.*different from expected: '
+ re.escape(str(shape_true)) + ' != ' + re.escape(str(shape))
)
with self.subTest('tuple'):
with self.assertRaisesRegex(
AssertionError,
_get_err_regex(
error_msg)):
asserts.assert_tree_shape_suffix(tree, tuple(shape))

with self.assertRaisesRegex(
AssertionError,
_get_err_regex(
r'Tree leaf \'x/y\'.*different from expected: \(2, 1\) != \(1, 1\)')
):
asserts.assert_tree_shape_suffix(tree, (1, 1))
with self.subTest('list'):
with self.assertRaisesRegex(
AssertionError,
_get_err_regex(
error_msg)):
asserts.assert_tree_shape_suffix(tree, list(shape))

def test_assert_tree_shape_suffix_long_suffix(self):
tree = {'x': {'y': np.zeros([4, 2, 1])}, 'z': np.zeros([4, 2, 1])}
asserts.assert_tree_shape_suffix(tree, (4, 2, 1))
asserts.assert_tree_shape_suffix(tree, [4, 2, 1])

with self.assertRaisesRegex(
AssertionError, _get_err_regex('which is smaller than the expected')):
asserts.assert_tree_shape_suffix(tree, (3, 4, 2, 1))

with self.assertRaisesRegex(
AssertionError, _get_err_regex('which is smaller than the expected')):
asserts.assert_tree_shape_suffix(tree, [3, 4, 2, 1])

def test_assert_tree_shape_suffix_none(self):
tree = {'x': np.zeros([3]), 'n': None}
asserts.assert_tree_shape_suffix(tree, (3,), ignore_nones=True)
asserts.assert_tree_shape_suffix(tree, [3], ignore_nones=True)

with self.assertRaisesRegex(AssertionError,
_get_err_regex('`None` detected')):
asserts.assert_tree_shape_suffix(tree, (3,), ignore_nones=False)

with self.assertRaisesRegex(AssertionError,
_get_err_regex('`None` detected')):
asserts.assert_tree_shape_suffix(tree, [3], ignore_nones=False)


def test_assert_trees_all_equal_dtypes(self):
t_0 = {'x': np.zeros(3, dtype=np.int16), 'y': np.ones(2, dtype=np.float32)}
t_1 = {'x': np.zeros(5, dtype=np.uint16), 'y': np.ones(4, dtype=np.float32)}
Expand Down

0 comments on commit 968d7f3

Please sign in to comment.