diff --git a/src/compyre/api.py b/src/compyre/api.py index b8c40f4..59f2852 100644 --- a/src/compyre/api.py +++ b/src/compyre/api.py @@ -59,9 +59,9 @@ class Pair: @dataclasses.dataclass class CompareError: - """Comparison exception with position information.""" + """Comparison exception with pair that caused it.""" - index: tuple[str | int, ...] + pair: Pair exception: Exception @@ -130,7 +130,7 @@ def compare( if unpack_result is not None: if isinstance(unpack_result, Exception): - errors.append(CompareError(index=pair.index, exception=unpack_result)) + errors.append(CompareError(pair=pair, exception=unpack_result)) else: for p in reversed(unpack_result): pairs.appendleft(p) @@ -153,7 +153,7 @@ def compare( ) if isinstance(equal_result, Exception): - errors.append(CompareError(index=pair.index, exception=equal_result)) + errors.append(CompareError(pair, exception=equal_result)) return errors @@ -366,7 +366,7 @@ def _extract_equal_errors(errors: list[CompareError]) -> list[CompareError]: def _format_compare_errors(errors: list[CompareError]) -> str: parts = [] for e in errors: - i = ".".join(map(str, e.index)) + i = ".".join(map(str, e.pair.index)) m = f"{type(e.exception).__name__}: {e.exception}" parts.append(f"{i}\n{indent(m, ' ' * 4)}") return "\n".join(parts) diff --git a/tests/test_api.py b/tests/test_api.py index db310eb..02e1f89 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -244,7 +244,7 @@ def unpack_fn(pair, /): assert len(errors) == 1 error = errors[0] - assert error.index == () + assert error.pair.index == () assert error.exception is exc def test_unpack_pairs_order(self): @@ -286,7 +286,7 @@ def equal_fn(pair, /): assert len(errors) == 1 error = errors[0] - assert error.index == () + assert error.pair.index == () assert isinstance(error.exception, api.CompyreError) assert all( @@ -313,7 +313,7 @@ def equal_fn(pair, /): assert len(errors) == 1 error = errors[0] - assert error.index == () + assert error.pair.index == () assert isinstance(error.exception, AssertionError) assert repr(actual) in str(error.exception)