diff --git a/mypy/errors.py b/mypy/errors.py index f9d952af2297..edfb3bd1607a 100644 --- a/mypy/errors.py +++ b/mypy/errors.py @@ -229,11 +229,42 @@ def filtered_errors(self) -> list[ErrorInfo]: return self._filtered +class NonOverlapErrorInfo: + line: int + column: int + end_line: int | None + end_column: int | None + kind: str + + def __init__( + self, *, line: int, column: int, end_line: int | None, end_column: int | None, kind: str + ) -> None: + self.line = line + self.column = column + self.end_line = end_line + self.end_column = end_column + self.kind = kind + + def __eq__(self, other: object) -> bool: + if isinstance(other, NonOverlapErrorInfo): + return ( + self.line == other.line + and self.column == other.column + and self.end_line == other.end_line + and self.end_column == other.end_column + and self.kind == other.kind + ) + return False + + def __hash__(self) -> int: + return hash((self.line, self.column, self.end_line, self.end_column, self.kind)) + + class IterationDependentErrors: """An `IterationDependentErrors` instance serves to collect the `unreachable`, - `redundant-expr`, and `redundant-casts` errors, as well as the revealed types, - handled by the individual `IterationErrorWatcher` instances sequentially applied to - the same code section.""" + `redundant-expr`, and `redundant-casts` errors, as well as the revealed types and + non-overlapping types, handled by the individual `IterationErrorWatcher` instances + sequentially applied to the same code section.""" # One set of `unreachable`, `redundant-expr`, and `redundant-casts` errors per # iteration step. Meaning of the tuple items: ErrorCode, message, line, column, @@ -249,9 +280,13 @@ class IterationDependentErrors: # end_line, end_column: revealed_types: dict[tuple[int, int, int | None, int | None], list[Type]] + # One dictionary of non-overlapping types per iteration step: + nonoverlapping_types: list[dict[NonOverlapErrorInfo, tuple[Type, Type]]] + def __init__(self) -> None: self.uselessness_errors = [] self.unreachable_lines = [] + self.nonoverlapping_types = [] self.revealed_types = defaultdict(list) def yield_uselessness_error_infos(self) -> Iterator[tuple[str, Context, ErrorCode]]: @@ -271,6 +306,36 @@ def yield_uselessness_error_infos(self) -> Iterator[tuple[str, Context, ErrorCod context.end_column = error_info[5] yield error_info[1], context, error_info[0] + def yield_nonoverlapping_types( + self, + ) -> Iterator[tuple[tuple[list[Type], list[Type]], str, Context]]: + """Report expressions where non-overlapping types were detected for all iterations + were the expression was reachable.""" + + selected = set() + for candidate in set(chain.from_iterable(self.nonoverlapping_types)): + if all( + (candidate in nonoverlap) or (candidate.line in lines) + for nonoverlap, lines in zip(self.nonoverlapping_types, self.unreachable_lines) + ): + selected.add(candidate) + + persistent_nonoverlaps: dict[NonOverlapErrorInfo, tuple[list[Type], list[Type]]] = ( + defaultdict(lambda: ([], [])) + ) + for nonoverlaps in self.nonoverlapping_types: + for candidate, (left, right) in nonoverlaps.items(): + if candidate in selected: + types = persistent_nonoverlaps[candidate] + types[0].append(left) + types[1].append(right) + + for error_info, types in persistent_nonoverlaps.items(): + context = Context(line=error_info.line, column=error_info.column) + context.end_line = error_info.end_line + context.end_column = error_info.end_column + yield (types[0], types[1]), error_info.kind, context + def yield_revealed_type_infos(self) -> Iterator[tuple[list[Type], Context]]: """Yield all types revealed in at least one iteration step.""" @@ -283,8 +348,9 @@ def yield_revealed_type_infos(self) -> Iterator[tuple[list[Type], Context]]: class IterationErrorWatcher(ErrorWatcher): """Error watcher that filters and separately collects `unreachable` errors, - `redundant-expr` and `redundant-casts` errors, and revealed types when analysing - code sections iteratively to help avoid making too-hasty reports.""" + `redundant-expr` and `redundant-casts` errors, and revealed types and + non-overlapping types when analysing code sections iteratively to help avoid + making too-hasty reports.""" iteration_dependent_errors: IterationDependentErrors @@ -305,6 +371,7 @@ def __init__( ) self.iteration_dependent_errors = iteration_dependent_errors iteration_dependent_errors.uselessness_errors.append(set()) + iteration_dependent_errors.nonoverlapping_types.append({}) iteration_dependent_errors.unreachable_lines.append(set()) def on_error(self, file: str, info: ErrorInfo) -> bool: diff --git a/mypy/messages.py b/mypy/messages.py index 65bcbd4049e2..bbcc93ebfb25 100644 --- a/mypy/messages.py +++ b/mypy/messages.py @@ -29,6 +29,7 @@ ErrorWatcher, IterationDependentErrors, IterationErrorWatcher, + NonOverlapErrorInfo, ) from mypy.nodes import ( ARG_NAMED, @@ -1624,6 +1625,26 @@ def incompatible_typevar_value( ) def dangerous_comparison(self, left: Type, right: Type, kind: str, ctx: Context) -> None: + # In loops (and similar cases), the same expression might be analysed multiple + # times and thereby confronted with different types. We only want to raise a + # `comparison-overlap` error if it occurs in all cases and therefore collect the + # respective types of the current iteration here so that we can report the error + # later if it is persistent over all iteration steps: + for watcher in self.errors.get_watchers(): + if watcher._filter: + break + if isinstance(watcher, IterationErrorWatcher): + watcher.iteration_dependent_errors.nonoverlapping_types[-1][ + NonOverlapErrorInfo( + line=ctx.line, + column=ctx.column, + end_line=ctx.end_line, + end_column=ctx.end_column, + kind=kind, + ) + ] = (left, right) + return + left_str = "element" if kind == "container" else "left operand" right_str = "container item" if kind == "container" else "right operand" message = "Non-overlapping {} check ({} type: {}, {} type: {})" @@ -2510,6 +2531,13 @@ def match_statement_inexhaustive_match(self, typ: Type, context: Context) -> Non def iteration_dependent_errors(self, iter_errors: IterationDependentErrors) -> None: for error_info in iter_errors.yield_uselessness_error_infos(): self.fail(*error_info[:2], code=error_info[2]) + for nonoverlaps, kind, context in iter_errors.yield_nonoverlapping_types(): + self.dangerous_comparison( + mypy.typeops.make_simplified_union(nonoverlaps[0]), + mypy.typeops.make_simplified_union(nonoverlaps[1]), + kind, + context, + ) for types, context in iter_errors.yield_revealed_type_infos(): self.reveal_type(mypy.typeops.make_simplified_union(types), context) diff --git a/test-data/unit/check-narrowing.test b/test-data/unit/check-narrowing.test index 8dd8ba01eca8..03586e4109f6 100644 --- a/test-data/unit/check-narrowing.test +++ b/test-data/unit/check-narrowing.test @@ -2446,6 +2446,41 @@ while x is not None and b(): x = f() [builtins fixtures/primitives.pyi] +[case testAvoidFalseNonOverlappingEqualityCheckInLoop1] +# flags: --allow-redefinition-new --local-partial-types --strict-equality + +x = 1 +while True: + if x == str(): + break + x = str() + if x == int(): # E: Non-overlapping equality check (left operand type: "str", right operand type: "int") + break +[builtins fixtures/primitives.pyi] + +[case testAvoidFalseNonOverlappingEqualityCheckInLoop2] +# flags: --allow-redefinition-new --local-partial-types --strict-equality + +class A: ... +class B: ... +class C: ... + +x = A() +while True: + if x == C(): # E: Non-overlapping equality check (left operand type: "A | B", right operand type: "C") + break + x = B() +[builtins fixtures/primitives.pyi] + +[case testAvoidFalseNonOverlappingEqualityCheckInLoop3] +# flags: --strict-equality + +for y in [1.0]: + if y is not None or y != "None": + ... + +[builtins fixtures/primitives.pyi] + [case testNarrowPromotionsInsideUnions1] from typing import Union