Skip to content

Commit

Permalink
Internal change.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 428105412
  • Loading branch information
ChexDev authored and ChexDev committed Oct 25, 2022
1 parent ce3970e commit f544219
Showing 1 changed file with 15 additions and 2 deletions.
17 changes: 15 additions & 2 deletions chex/_src/asserts.py
Original file line number Diff line number Diff line change
Expand Up @@ -514,7 +514,8 @@ def _shape_matches(actual_shape: Sequence[int],
def assert_shape(
inputs: Union[Scalar, Union[Array, Sequence[Array]]],
expected_shapes: Union[_ai.TShapeMatcher,
Sequence[_ai.TShapeMatcher]]) -> None:
Sequence[_ai.TShapeMatcher]],
explanation: Optional[Union[str, Callable[[], str]]] = None) -> None:
"""Checks that the shape of all inputs matches specified ``expected_shapes``.
Valid usages include:
Expand All @@ -535,6 +536,8 @@ def assert_shape(
where the expected shape is a sequence of integer and `None` dimensions;
if all inputs have same shape, a single shape may be passed as
``expected_shapes``.
explanation: Additional message to give context when this assertion fails
(or a function/closure that returns such a message).
Raises:
AssertionError: If the lengths of ``inputs`` and ``expected_shapes`` do not
Expand Down Expand Up @@ -564,9 +567,19 @@ def assert_shape(
errors.append((idx, shape, _ai.format_shape_matcher(expected)))

if errors:
if callable(explanation):
try:
explanation: str = explanation()
except Exception as e: # pylint: disable=broad-except
explanation = ("[[`explanation` callback failed: " +
"\n".join(traceback.format_exception(
e.__class__, e, e.__traceback__, limit=4)) + "]]")
if not explanation:
explanation = ""
msg = "; ".join(
f"input {e[0]} has shape {e[1]} but expected {e[2]}" for e in errors)
raise AssertionError(f"Error in shape compatibility check: {msg}.")
raise AssertionError(
f"Error in shape compatibility check: {msg}. {explanation}")


@_static_assertion
Expand Down

0 comments on commit f544219

Please sign in to comment.