Skip to content

Commit d06d3d9

Browse files
tyrallapre-commit-ci[bot]sterliakovhauntsaninja
authored
Fix --strict-equality for iteratively visited code (#19635)
Fixes #19328 Fixes #20294 The logic is very similar to what we did to report different revealed types that were discovered in multiple iteration steps in one line. I think this fix is the last one needed before I can implement #19256. --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Stanislav Terliakov <50529348+sterliakov@users.noreply.github.com> Co-authored-by: Shantanu <12621235+hauntsaninja@users.noreply.github.com>
1 parent eb144ce commit d06d3d9

File tree

3 files changed

+135
-5
lines changed

3 files changed

+135
-5
lines changed

mypy/errors.py

Lines changed: 72 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -229,11 +229,42 @@ def filtered_errors(self) -> list[ErrorInfo]:
229229
return self._filtered
230230

231231

232+
class NonOverlapErrorInfo:
233+
line: int
234+
column: int
235+
end_line: int | None
236+
end_column: int | None
237+
kind: str
238+
239+
def __init__(
240+
self, *, line: int, column: int, end_line: int | None, end_column: int | None, kind: str
241+
) -> None:
242+
self.line = line
243+
self.column = column
244+
self.end_line = end_line
245+
self.end_column = end_column
246+
self.kind = kind
247+
248+
def __eq__(self, other: object) -> bool:
249+
if isinstance(other, NonOverlapErrorInfo):
250+
return (
251+
self.line == other.line
252+
and self.column == other.column
253+
and self.end_line == other.end_line
254+
and self.end_column == other.end_column
255+
and self.kind == other.kind
256+
)
257+
return False
258+
259+
def __hash__(self) -> int:
260+
return hash((self.line, self.column, self.end_line, self.end_column, self.kind))
261+
262+
232263
class IterationDependentErrors:
233264
"""An `IterationDependentErrors` instance serves to collect the `unreachable`,
234-
`redundant-expr`, and `redundant-casts` errors, as well as the revealed types,
235-
handled by the individual `IterationErrorWatcher` instances sequentially applied to
236-
the same code section."""
265+
`redundant-expr`, and `redundant-casts` errors, as well as the revealed types and
266+
non-overlapping types, handled by the individual `IterationErrorWatcher` instances
267+
sequentially applied to the same code section."""
237268

238269
# One set of `unreachable`, `redundant-expr`, and `redundant-casts` errors per
239270
# iteration step. Meaning of the tuple items: ErrorCode, message, line, column,
@@ -249,9 +280,13 @@ class IterationDependentErrors:
249280
# end_line, end_column:
250281
revealed_types: dict[tuple[int, int, int | None, int | None], list[Type]]
251282

283+
# One dictionary of non-overlapping types per iteration step:
284+
nonoverlapping_types: list[dict[NonOverlapErrorInfo, tuple[Type, Type]]]
285+
252286
def __init__(self) -> None:
253287
self.uselessness_errors = []
254288
self.unreachable_lines = []
289+
self.nonoverlapping_types = []
255290
self.revealed_types = defaultdict(list)
256291

257292
def yield_uselessness_error_infos(self) -> Iterator[tuple[str, Context, ErrorCode]]:
@@ -271,6 +306,36 @@ def yield_uselessness_error_infos(self) -> Iterator[tuple[str, Context, ErrorCod
271306
context.end_column = error_info[5]
272307
yield error_info[1], context, error_info[0]
273308

309+
def yield_nonoverlapping_types(
310+
self,
311+
) -> Iterator[tuple[tuple[list[Type], list[Type]], str, Context]]:
312+
"""Report expressions where non-overlapping types were detected for all iterations
313+
were the expression was reachable."""
314+
315+
selected = set()
316+
for candidate in set(chain.from_iterable(self.nonoverlapping_types)):
317+
if all(
318+
(candidate in nonoverlap) or (candidate.line in lines)
319+
for nonoverlap, lines in zip(self.nonoverlapping_types, self.unreachable_lines)
320+
):
321+
selected.add(candidate)
322+
323+
persistent_nonoverlaps: dict[NonOverlapErrorInfo, tuple[list[Type], list[Type]]] = (
324+
defaultdict(lambda: ([], []))
325+
)
326+
for nonoverlaps in self.nonoverlapping_types:
327+
for candidate, (left, right) in nonoverlaps.items():
328+
if candidate in selected:
329+
types = persistent_nonoverlaps[candidate]
330+
types[0].append(left)
331+
types[1].append(right)
332+
333+
for error_info, types in persistent_nonoverlaps.items():
334+
context = Context(line=error_info.line, column=error_info.column)
335+
context.end_line = error_info.end_line
336+
context.end_column = error_info.end_column
337+
yield (types[0], types[1]), error_info.kind, context
338+
274339
def yield_revealed_type_infos(self) -> Iterator[tuple[list[Type], Context]]:
275340
"""Yield all types revealed in at least one iteration step."""
276341

@@ -283,8 +348,9 @@ def yield_revealed_type_infos(self) -> Iterator[tuple[list[Type], Context]]:
283348

284349
class IterationErrorWatcher(ErrorWatcher):
285350
"""Error watcher that filters and separately collects `unreachable` errors,
286-
`redundant-expr` and `redundant-casts` errors, and revealed types when analysing
287-
code sections iteratively to help avoid making too-hasty reports."""
351+
`redundant-expr` and `redundant-casts` errors, and revealed types and
352+
non-overlapping types when analysing code sections iteratively to help avoid
353+
making too-hasty reports."""
288354

289355
iteration_dependent_errors: IterationDependentErrors
290356

@@ -305,6 +371,7 @@ def __init__(
305371
)
306372
self.iteration_dependent_errors = iteration_dependent_errors
307373
iteration_dependent_errors.uselessness_errors.append(set())
374+
iteration_dependent_errors.nonoverlapping_types.append({})
308375
iteration_dependent_errors.unreachable_lines.append(set())
309376

310377
def on_error(self, file: str, info: ErrorInfo) -> bool:

mypy/messages.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
ErrorWatcher,
3030
IterationDependentErrors,
3131
IterationErrorWatcher,
32+
NonOverlapErrorInfo,
3233
)
3334
from mypy.nodes import (
3435
ARG_NAMED,
@@ -1624,6 +1625,26 @@ def incompatible_typevar_value(
16241625
)
16251626

16261627
def dangerous_comparison(self, left: Type, right: Type, kind: str, ctx: Context) -> None:
1628+
# In loops (and similar cases), the same expression might be analysed multiple
1629+
# times and thereby confronted with different types. We only want to raise a
1630+
# `comparison-overlap` error if it occurs in all cases and therefore collect the
1631+
# respective types of the current iteration here so that we can report the error
1632+
# later if it is persistent over all iteration steps:
1633+
for watcher in self.errors.get_watchers():
1634+
if watcher._filter:
1635+
break
1636+
if isinstance(watcher, IterationErrorWatcher):
1637+
watcher.iteration_dependent_errors.nonoverlapping_types[-1][
1638+
NonOverlapErrorInfo(
1639+
line=ctx.line,
1640+
column=ctx.column,
1641+
end_line=ctx.end_line,
1642+
end_column=ctx.end_column,
1643+
kind=kind,
1644+
)
1645+
] = (left, right)
1646+
return
1647+
16271648
left_str = "element" if kind == "container" else "left operand"
16281649
right_str = "container item" if kind == "container" else "right operand"
16291650
message = "Non-overlapping {} check ({} type: {}, {} type: {})"
@@ -2510,6 +2531,13 @@ def match_statement_inexhaustive_match(self, typ: Type, context: Context) -> Non
25102531
def iteration_dependent_errors(self, iter_errors: IterationDependentErrors) -> None:
25112532
for error_info in iter_errors.yield_uselessness_error_infos():
25122533
self.fail(*error_info[:2], code=error_info[2])
2534+
for nonoverlaps, kind, context in iter_errors.yield_nonoverlapping_types():
2535+
self.dangerous_comparison(
2536+
mypy.typeops.make_simplified_union(nonoverlaps[0]),
2537+
mypy.typeops.make_simplified_union(nonoverlaps[1]),
2538+
kind,
2539+
context,
2540+
)
25132541
for types, context in iter_errors.yield_revealed_type_infos():
25142542
self.reveal_type(mypy.typeops.make_simplified_union(types), context)
25152543

test-data/unit/check-narrowing.test

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2446,6 +2446,41 @@ while x is not None and b():
24462446
x = f()
24472447
[builtins fixtures/primitives.pyi]
24482448

2449+
[case testAvoidFalseNonOverlappingEqualityCheckInLoop1]
2450+
# flags: --allow-redefinition-new --local-partial-types --strict-equality
2451+
2452+
x = 1
2453+
while True:
2454+
if x == str():
2455+
break
2456+
x = str()
2457+
if x == int(): # E: Non-overlapping equality check (left operand type: "str", right operand type: "int")
2458+
break
2459+
[builtins fixtures/primitives.pyi]
2460+
2461+
[case testAvoidFalseNonOverlappingEqualityCheckInLoop2]
2462+
# flags: --allow-redefinition-new --local-partial-types --strict-equality
2463+
2464+
class A: ...
2465+
class B: ...
2466+
class C: ...
2467+
2468+
x = A()
2469+
while True:
2470+
if x == C(): # E: Non-overlapping equality check (left operand type: "A | B", right operand type: "C")
2471+
break
2472+
x = B()
2473+
[builtins fixtures/primitives.pyi]
2474+
2475+
[case testAvoidFalseNonOverlappingEqualityCheckInLoop3]
2476+
# flags: --strict-equality
2477+
2478+
for y in [1.0]:
2479+
if y is not None or y != "None":
2480+
...
2481+
2482+
[builtins fixtures/primitives.pyi]
2483+
24492484
[case testNarrowPromotionsInsideUnions1]
24502485

24512486
from typing import Union

0 commit comments

Comments
 (0)