From a28617f43a93b2d88e33fc4752c276f929b50ad5 Mon Sep 17 00:00:00 2001 From: Ivan Levkivskyi Date: Tue, 3 Oct 2023 00:00:37 +0100 Subject: [PATCH 01/18] Copy over the tests --- test-data/unit/check-narrowing.test | 147 ++++++++++++++++++++++++++++ test-data/unit/fixtures/len.pyi | 38 +++++++ test-data/unit/lib-stub/typing.pyi | 1 + 3 files changed, 186 insertions(+) create mode 100644 test-data/unit/fixtures/len.pyi diff --git a/test-data/unit/check-narrowing.test b/test-data/unit/check-narrowing.test index c86cffd453dfc..ee06954160994 100644 --- a/test-data/unit/check-narrowing.test +++ b/test-data/unit/check-narrowing.test @@ -1334,3 +1334,150 @@ if isinstance(some, raw): else: reveal_type(some) # N: Revealed type is "Union[builtins.int, __main__.Base]" [builtins fixtures/dict.pyi] + +[case testNarrowingLenItemAndLenCompare] +from typing import Tuple, Union, Any + +x: Any +if len(x) == x: + reveal_type(x) # N: Revealed type is "Any" +[builtins fixtures/len.pyi] + +[case testNarrowingLenTuple] +from typing import Tuple, Union + +VarTuple = Union[Tuple[int, int], Tuple[int, int, int]] + +def make_tuple() -> VarTuple: + return (1, 1) + +x = make_tuple() +a = b = c = 0 +if len(x) == 3: + a, b, c = x +else: + a, b = x + +if len(x) != 3: + a, b = x +else: + a, b, c = x +[builtins fixtures/len.pyi] + +[case testNarrowingLenVariantLengthTuple] +from typing import Tuple, Union + +def make_tuple() -> Tuple[int, ...]: + return (1, 1) + +x = make_tuple() + +if len(x) == 3: + reveal_type(x) # N: Revealed type is "Tuple[builtins.int, builtins.int, builtins.int]" +else: + reveal_type(x) # N: Revealed type is "builtins.tuple[builtins.int, ...]" + +if len(x) != 3: + reveal_type(x) # N: Revealed type is "builtins.tuple[builtins.int, ...]" +else: + reveal_type(x) # N: Revealed type is "Tuple[builtins.int, builtins.int, builtins.int]" +[builtins fixtures/len.pyi] + +[case testNarrowingLenTypeUnaffected] +from typing import Tuple, Union, List, Any + +def make() -> Union[str, List[int]]: + return "" + +x = make() + +if len(x) == 3: + reveal_type(x) # N: Revealed type is "Union[builtins.str, builtins.list[builtins.int]]" +else: + reveal_type(x) # N: Revealed type is "Union[builtins.str, builtins.list[builtins.int]]" +[builtins fixtures/len.pyi] + +[case testNarrowingLenAnyListElseNotAffected] +from typing import Any +def f(self, value: Any) -> Any: + if isinstance(value, list) and len(value) == 0: + reveal_type(value) # N: Revealed type is "builtins.list[Any]" + return value + reveal_type(value) # N: Revealed type is "Any" + return None +[builtins fixtures/len.pyi] + +[case testNarrowingLenMultiple] +from typing import Tuple, Union + +VarTuple = Union[Tuple[int, int], Tuple[int, int, int]] + +def make_tuple() -> VarTuple: + return (1, 1) + +x = make_tuple() +y = make_tuple() +if len(x) == len(y) == 3: + reveal_type(x) # N: Revealed type is "Tuple[builtins.int, builtins.int, builtins.int]" + reveal_type(y) # N: Revealed type is "Tuple[builtins.int, builtins.int, builtins.int]" +[builtins fixtures/len.pyi] + +[case testNarrowingLenFinal] +from typing import Tuple, Union +from typing_extensions import Final + +VarTuple = Union[Tuple[int, int], Tuple[int, int, int]] + +def make_tuple() -> VarTuple: + return (1, 1) + +x = make_tuple() +fin: Final = 3 +if len(x) == fin: + reveal_type(x) # N: Revealed type is "Tuple[builtins.int, builtins.int, builtins.int]" +[builtins fixtures/len.pyi] + +[case testNarrowingLenBiggerThan] +from typing import Tuple, Union + +VarTuple = Union[Tuple[int], Tuple[int, int], Tuple[int, int, int]] + +def make_tuple() -> VarTuple: + return (1, 1) + +x = make_tuple() +if len(x) > 1: + reveal_type(x) # N: Revealed type is "Union[Tuple[builtins.int, builtins.int], Tuple[builtins.int, builtins.int, builtins.int]]" +else: + reveal_type(x) # N: Revealed type is "Tuple[builtins.int]" + +if len(x) < 2: + reveal_type(x) # N: Revealed type is "Tuple[builtins.int]" +else: + reveal_type(x) # N: Revealed type is "Union[Tuple[builtins.int, builtins.int], Tuple[builtins.int, builtins.int, builtins.int]]" + +if len(x) >= 2: + reveal_type(x) # N: Revealed type is "Union[Tuple[builtins.int, builtins.int], Tuple[builtins.int, builtins.int, builtins.int]]" +else: + reveal_type(x) # N: Revealed type is "Tuple[builtins.int]" + +if len(x) <= 2: + reveal_type(x) # N: Revealed type is "Union[Tuple[builtins.int], Tuple[builtins.int, builtins.int]]" +else: + reveal_type(x) # N: Revealed type is "Tuple[builtins.int, builtins.int, builtins.int]" +[builtins fixtures/len.pyi] + +[case testNarrowingLenBiggerThanVariantTuple] +from typing import Tuple + +VarTuple = Tuple[int, ...] + +def make_tuple() -> VarTuple: + return (1, 1) + +x = make_tuple() +if len(x) < 3: + reveal_type(x) # N: Revealed type is "builtins.tuple[builtins.int, ...]" +else: + reveal_type(x) # N: Revealed type is "builtins.tuple[builtins.int, ...]" +[builtins fixtures/len.pyi] diff --git a/test-data/unit/fixtures/len.pyi b/test-data/unit/fixtures/len.pyi new file mode 100644 index 0000000000000..13fc8829651ac --- /dev/null +++ b/test-data/unit/fixtures/len.pyi @@ -0,0 +1,38 @@ +from typing import Tuple, TypeVar, Generic, Union, Type, Sequence, Mapping +from typing_extensions import Protocol + +T = TypeVar("T") +V = TypeVar("V") + +class object: + def __init__(self) -> None: pass + +class type: + def __init__(self, x) -> None: pass + +class tuple(Generic[T]): + def __len__(self) -> int: pass + +class list(Sequence[T]): pass +class dict(Mapping[T, V]): pass + +class function: pass + +class Sized(Protocol): + def __len__(self) -> int: pass + +def len(__obj: Sized) -> int: ... +def isinstance(x: object, t: Union[Type[object], Tuple[Type[object], ...]]) -> bool: pass + +class int: + def __add__(self, other: int) -> int: pass + def __eq__(self, other: int) -> bool: pass + def __ne__(self, other: int) -> bool: pass + def __lt__(self, n: int) -> bool: pass + def __gt__(self, n: int) -> bool: pass + def __le__(self, n: int) -> bool: pass + def __ge__(self, n: int) -> bool: pass +class float: pass +class bool(int): pass +class str(Sequence[str]): pass +class ellipsis: pass diff --git a/test-data/unit/lib-stub/typing.pyi b/test-data/unit/lib-stub/typing.pyi index b35b64a383c94..5f458ca687c0f 100644 --- a/test-data/unit/lib-stub/typing.pyi +++ b/test-data/unit/lib-stub/typing.pyi @@ -48,6 +48,7 @@ class Generator(Iterator[T], Generic[T, U, V]): class Sequence(Iterable[T_co]): def __getitem__(self, n: Any) -> T_co: pass + def __len__(self) -> int: pass # Mapping type is oversimplified intentionally. class Mapping(Iterable[T], Generic[T, T_co]): From 77925e1f683fc1830b1e27880eea533e32b7e0d8 Mon Sep 17 00:00:00 2001 From: Ivan Levkivskyi Date: Thu, 5 Oct 2023 23:50:38 +0100 Subject: [PATCH 02/18] Add type narrowing logic --- mypy/checker.py | 120 +++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 119 insertions(+), 1 deletion(-) diff --git a/mypy/checker.py b/mypy/checker.py index 1a7a7e25d5250..5a141bdafba56 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -134,7 +134,7 @@ YieldExpr, is_final_node, ) -from mypy.options import Options +from mypy.options import TYPE_VAR_TUPLE, Options from mypy.patterns import AsPattern, StarredPattern from mypy.plugin import CheckerPluginInterface, Plugin from mypy.plugins import dataclasses as dataclasses_plugin @@ -205,10 +205,13 @@ TypeType, TypeVarId, TypeVarLikeType, + TypeVarTupleType, TypeVarType, UnboundType, UninhabitedType, UnionType, + UnpackType, + find_unpack_in_list, flatten_nested_unions, get_proper_type, get_proper_types, @@ -6130,6 +6133,121 @@ def refine_away_none_in_comparison( return if_map, {} + def refine_tuple_type_with_len(self, typ: TupleType, op: str, size: int) -> Type: + unpack_index = find_unpack_in_list(typ.items) + if unpack_index is None: + if op in ("==", "is"): + if typ.length() == size: + return typ + return UninhabitedType() + elif op in ("!=", "is not"): + if typ.length() != size: + return typ + return UninhabitedType() + elif op == ">": + if typ.length() > size: + return typ + return UninhabitedType() + elif op == ">=": + if typ.length() >= size: + return typ + return UninhabitedType() + elif op == "<": + if typ.length() < size: + return typ + return UninhabitedType() + elif op == "<=": + if typ.length() <= size: + return typ + return UninhabitedType() + else: + assert False, "Unsupported op for tuple len comparison" + unpack = typ.items[unpack_index] + assert isinstance(unpack, UnpackType) + unpacked = get_proper_type(unpack.type) + min_len = typ.length() - 1 + if isinstance(unpacked, TypeVarTupleType): + if op in ("==", "is"): + if min_len <= size: + return typ + return UninhabitedType() + elif op in ("!=", "is not"): + return typ + elif op == ">": + return typ + elif op == ">=": + return typ + elif op == "<": + if min_len < size: + return typ + return UninhabitedType() + elif op == "<=": + if min_len <= size: + return typ + return UninhabitedType() + else: + assert False, "Unsupported op for tuple len comparison" + assert isinstance(unpacked, Instance) and unpacked.type.fullname == "builtins.tuple" + arg = unpacked.args[0] + prefix = typ.items[:unpack_index] + suffix = typ.items[unpack_index + 1 :] + if op in ("==", "is"): + if min_len <= size: + return typ.copy_modified(items=prefix + [arg] * (size - min_len) + suffix) + return UninhabitedType() + elif op in ("!=", "is not"): + return typ + elif op == ">": + return typ.copy_modified( + items=prefix + [arg] * (size - min_len + 1) + [unpack] + suffix + ) + elif op == ">=": + return typ.copy_modified(items=prefix + [arg] * (size - min_len) + [unpack] + suffix) + elif op == "<": + if min_len < size: + items = [] + for n in range(size - min_len): + items.append(typ.copy_modified(items=prefix + [arg] * n + suffix)) + return UnionType.make_union(items, typ.line, typ.column) + return UninhabitedType() + elif op == "<=": + if min_len <= size: + items = [] + for n in range(size - min_len + 1): + items.append(typ.copy_modified(items=prefix + [arg] * n + suffix)) + return UnionType.make_union(items, typ.line, typ.column) + return UninhabitedType() + else: + assert False, "Unsupported op for tuple len comparison" + + def refine_instance_type_with_len(self, typ: Instance, op: str, size: int) -> Type: + arg = typ.args[0] + if op in ("==", "is"): + return TupleType(items=[arg] * size, fallback=typ) + elif op in ("!=", "is not"): + # TODO: return fixed union + prefixed variadic tuple? + return typ + elif op == ">": + if TYPE_VAR_TUPLE in self.options.enable_incomplete_feature: + return TupleType(items=[arg] * (size + 1) + [UnpackType(typ)], fallback=typ) + return typ + elif op == ">=": + if TYPE_VAR_TUPLE in self.options.enable_incomplete_feature: + return TupleType(items=[arg] * size + [UnpackType(typ)], fallback=typ) + return typ + elif op == "<": + items = [] + for n in range(size): + items.append(TupleType([arg] * n, fallback=typ)) + return UnionType.make_union(items, typ.line, typ.column) + elif op == "<=": + items = [] + for n in range(size + 1): + items.append(TupleType([arg] * n, fallback=typ)) + return UnionType.make_union(items, typ.line, typ.column) + else: + assert False, "Unsupported op for tuple len comparison" + # # Helpers # From 81f10e1aaefb9f86ba36aa9583ee3bd238c1fa33 Mon Sep 17 00:00:00 2001 From: Ivan Levkivskyi Date: Fri, 6 Oct 2023 01:07:00 +0100 Subject: [PATCH 03/18] Complete type narrowing logic --- mypy/checker.py | 92 +++++++++++++++++++++++++++++++++++++++++-------- 1 file changed, 77 insertions(+), 15 deletions(-) diff --git a/mypy/checker.py b/mypy/checker.py index 5a141bdafba56..3b7df3e3f260c 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -6133,33 +6133,94 @@ def refine_away_none_in_comparison( return if_map, {} - def refine_tuple_type_with_len(self, typ: TupleType, op: str, size: int) -> Type: + def can_be_narrowed_with_len(self, typ: Type) -> bool: + p_typ = get_proper_type(typ) + if isinstance(p_typ, TupleType): + return True + if isinstance(p_typ, Instance): + # TODO: support tuple subclasses? + return p_typ.type.fullname == "builtins.tuple" + if isinstance(p_typ, UnionType): + # TODO: support mixed unions + return all(self.can_be_narrowed_with_len(t) for t in p_typ.items) + return False + + def narrow_with_len(self, typ: Type, op: str, size: int) -> tuple[Type | None, Type | None]: + if op in "==": + neg_op = "!=" + elif op in "is": + neg_op = "is not" + elif op == "!=": + neg_op = "==" + elif op == "is not": + neg_op = "is" + elif op == ">": + neg_op = "<=" + elif op == ">=": + neg_op = "<" + elif op == "<": + neg_op = ">=" + elif op == "<=": + neg_op = ">" + else: + assert False, "Unsupported op for tuple len comparison" + typ = get_proper_type(typ) + if isinstance(typ, TupleType): + yes_type = self.refine_tuple_type_with_len(typ, op, size) + no_type = self.refine_tuple_type_with_len(typ, neg_op, size) + return yes_type, no_type + elif isinstance(typ, Instance) and typ.type.fullname == "builtins.tuple": + yes_type = self.refine_instance_type_with_len(typ, op, size) + no_type = self.refine_instance_type_with_len(typ, neg_op, size) + return yes_type, no_type + elif isinstance(typ, UnionType): + yes_types = [] + no_types = [] + for t in typ.items: + yt, nt = self.narrow_with_len(t, op, size) + if yt is not None: + yes_types.append(yt) + if nt is not None: + no_types.append(nt) + if yes_types: + yes_type = make_simplified_union(yes_types) + else: + yes_type = None + if no_types: + no_type = make_simplified_union(no_types) + else: + no_type = None + return yes_type, no_type + else: + assert False, "Unsupported type for len narrowing" + + def refine_tuple_type_with_len(self, typ: TupleType, op: str, size: int) -> Type | None: unpack_index = find_unpack_in_list(typ.items) if unpack_index is None: if op in ("==", "is"): if typ.length() == size: return typ - return UninhabitedType() + return None elif op in ("!=", "is not"): if typ.length() != size: return typ - return UninhabitedType() + return None elif op == ">": if typ.length() > size: return typ - return UninhabitedType() + return None elif op == ">=": if typ.length() >= size: return typ - return UninhabitedType() + return None elif op == "<": if typ.length() < size: return typ - return UninhabitedType() + return None elif op == "<=": if typ.length() <= size: return typ - return UninhabitedType() + return None else: assert False, "Unsupported op for tuple len comparison" unpack = typ.items[unpack_index] @@ -6170,7 +6231,7 @@ def refine_tuple_type_with_len(self, typ: TupleType, op: str, size: int) -> Type if op in ("==", "is"): if min_len <= size: return typ - return UninhabitedType() + return None elif op in ("!=", "is not"): return typ elif op == ">": @@ -6180,11 +6241,11 @@ def refine_tuple_type_with_len(self, typ: TupleType, op: str, size: int) -> Type elif op == "<": if min_len < size: return typ - return UninhabitedType() + return None elif op == "<=": if min_len <= size: return typ - return UninhabitedType() + return None else: assert False, "Unsupported op for tuple len comparison" assert isinstance(unpacked, Instance) and unpacked.type.fullname == "builtins.tuple" @@ -6194,7 +6255,7 @@ def refine_tuple_type_with_len(self, typ: TupleType, op: str, size: int) -> Type if op in ("==", "is"): if min_len <= size: return typ.copy_modified(items=prefix + [arg] * (size - min_len) + suffix) - return UninhabitedType() + return None elif op in ("!=", "is not"): return typ elif op == ">": @@ -6209,19 +6270,20 @@ def refine_tuple_type_with_len(self, typ: TupleType, op: str, size: int) -> Type for n in range(size - min_len): items.append(typ.copy_modified(items=prefix + [arg] * n + suffix)) return UnionType.make_union(items, typ.line, typ.column) - return UninhabitedType() + return None elif op == "<=": if min_len <= size: items = [] for n in range(size - min_len + 1): items.append(typ.copy_modified(items=prefix + [arg] * n + suffix)) return UnionType.make_union(items, typ.line, typ.column) - return UninhabitedType() + return None else: assert False, "Unsupported op for tuple len comparison" def refine_instance_type_with_len(self, typ: Instance, op: str, size: int) -> Type: arg = typ.args[0] + unpack = UnpackType(self.named_generic_type("builtins.tuple", [arg])) if op in ("==", "is"): return TupleType(items=[arg] * size, fallback=typ) elif op in ("!=", "is not"): @@ -6229,11 +6291,11 @@ def refine_instance_type_with_len(self, typ: Instance, op: str, size: int) -> Ty return typ elif op == ">": if TYPE_VAR_TUPLE in self.options.enable_incomplete_feature: - return TupleType(items=[arg] * (size + 1) + [UnpackType(typ)], fallback=typ) + return TupleType(items=[arg] * (size + 1) + [unpack], fallback=typ) return typ elif op == ">=": if TYPE_VAR_TUPLE in self.options.enable_incomplete_feature: - return TupleType(items=[arg] * size + [UnpackType(typ)], fallback=typ) + return TupleType(items=[arg] * size + [unpack], fallback=typ) return typ elif op == "<": items = [] From b854c4f4565684729b5ad5427afc5301eaf6c4d2 Mon Sep 17 00:00:00 2001 From: Ivan Levkivskyi Date: Sat, 7 Oct 2023 00:58:27 +0100 Subject: [PATCH 04/18] End-to-end implementation --- mypy/checker.py | 87 +++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 87 insertions(+) diff --git a/mypy/checker.py b/mypy/checker.py index 3b7df3e3f260c..1b92c8a6435fb 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -5741,6 +5741,7 @@ def has_no_custom_eq_checks(t: Type) -> bool: partial_type_maps.append((if_map, else_map)) + partial_type_maps.extend(self.find_tuple_len_narrowing(node)) return reduce_conditional_maps(partial_type_maps) elif isinstance(node, AssignmentExpr): if_map = {} @@ -5800,6 +5801,92 @@ def has_no_custom_eq_checks(t: Type) -> bool: else_map = {node: else_type} if not isinstance(else_type, UninhabitedType) else None return if_map, else_map + def is_len_of_tuple(self, expr: Expression) -> bool: + if not isinstance(expr, CallExpr): + return False + if not refers_to_fullname(expr.callee, "builtins.len"): + return False + if len(expr.args) != 1: + return False + expr = expr.args[0] + if literal(expr) != LITERAL_TYPE: + return False + if not self.has_type(expr): + return False + return self.can_be_narrowed_with_len(self.lookup_type(expr)) + + def literal_int_expr(self, expr: Expression) -> int | None: + if not self.has_type(expr): + return None + expr_type = self.lookup_type(expr) + expr_type = coerce_to_literal(expr_type) + proper_type = get_proper_type(expr_type) + if not isinstance(proper_type, LiteralType): + return None + if not isinstance(proper_type.value, int): + return None + return proper_type.value + + def find_tuple_len_narrowing(self, node: ComparisonExpr) -> list[tuple[TypeMap, TypeMap]]: + opposite = {"<": ">=", "<=": ">", ">": "<=", ">=": "<"} + type_maps = [] + chained = [] + last_group = set() + for op, left, right in node.pairwise(): + if isinstance(left, AssignmentExpr): + left = left.value + if isinstance(right, AssignmentExpr): + right = right.value + if op in ("is", "=="): + last_group.add(left) + last_group.add(right) + else: + chained.append(("==", list(last_group))) + last_group = set() + chained.append((op, [left, right])) + if last_group: + chained.append(("==", list(last_group))) + + for op, items in chained: + if not any(self.literal_int_expr(it) is not None for it in items): + continue + if not any(self.is_len_of_tuple(it) for it in items): + continue + if op in ("is", "=="): + literal_values = set() + tuples = [] + for it in items: + lit = self.literal_int_expr(it) + if lit is not None: + literal_values.add(lit) + continue + if self.is_len_of_tuple(it): + assert isinstance(it, CallExpr) + tuples.append(it.args[0]) + if len(literal_values) > 1: + return [(None, {})] + size = literal_values.pop() + for tpl in tuples: + yes_type, no_type = self.narrow_with_len(self.lookup_type(tpl), op, size) + yes_map = None if yes_type is None else {tpl: yes_type} + no_map = None if no_type is None else {tpl: no_type} + type_maps.append((yes_map, no_map)) + else: + left, right = items + if self.is_len_of_tuple(right): + left, right = right, left + op = opposite.get(op, op) + r_size = self.literal_int_expr(right) + assert r_size is not None + assert isinstance(left, CallExpr) + yes_type, no_type = self.narrow_with_len( + self.lookup_type(left.args[0]), op, r_size + ) + yes_map = None if yes_type is None else {left.args[0]: yes_type} + no_map = None if no_type is None else {left.args[0]: no_type} + type_maps.append((yes_map, no_map)) + return type_maps + def propagate_up_typemap_info(self, new_types: TypeMap) -> TypeMap: """Attempts refining parent expressions of any MemberExpr or IndexExprs in new_types. From 5683f9dcaa994d7b52da09894907a53ab05a9e16 Mon Sep 17 00:00:00 2001 From: Ivan Levkivskyi Date: Sat, 7 Oct 2023 14:26:59 +0100 Subject: [PATCH 05/18] Tweaks and fixes --- mypy/checker.py | 190 +++++++++++----------------- test-data/unit/check-narrowing.test | 25 +++- 2 files changed, 95 insertions(+), 120 deletions(-) diff --git a/mypy/checker.py b/mypy/checker.py index 1b92c8a6435fb..68c663f21b61c 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -228,6 +228,23 @@ DEFAULT_LAST_PASS: Final = 1 # Pass numbers start at 0 +# Maximum length of fixed tuple types inferred when narrowing from variadic tuples. +MAX_PRECISE_TUPLE_SIZE: Final = 15 + +int_op_to_method: Final = { + "==": int.__eq__, + "is": int.__eq__, + "<": int.__lt__, + "<=": int.__le__, + "!=": int.__ne__, + "is not": int.__ne__, + ">": int.__gt__, + ">=": int.__ge__, +} + +flip_ops: Final = {"<": ">=", "<=": ">", ">": "<=", ">=": "<"} +neg_ops: Final = {**flip_ops, "==": "!=", "!=": "==", "is": "is not", "is not": "is"} + DeferredNodeType: _TypeAlias = Union[FuncDef, LambdaExpr, OverloadedFuncDef, Decorator] FineGrainedDeferredNodeType: _TypeAlias = Union[FuncDef, MypyFile, OverloadedFuncDef] @@ -5741,8 +5758,11 @@ def has_no_custom_eq_checks(t: Type) -> bool: partial_type_maps.append((if_map, else_map)) - partial_type_maps.extend(self.find_tuple_len_narrowing(node)) - return reduce_conditional_maps(partial_type_maps) + if any(m != ({}, {}) for m in partial_type_maps): + return reduce_conditional_maps(partial_type_maps) + else: + # TODO: support regular and len() narrowing in the same chain. + return reduce_conditional_maps(self.find_tuple_len_narrowing(node)) elif isinstance(node, AssignmentExpr): if_map = {} else_map = {} @@ -5828,7 +5848,6 @@ def literal_int_expr(self, expr: Expression) -> int | None: return proper_type.value def find_tuple_len_narrowing(self, node: ComparisonExpr) -> list[tuple[TypeMap, TypeMap]]: - opposite = {"<": ">=", "<=": ">", ">": "<=", ">=": "<"} type_maps = [] chained = [] last_group = set() @@ -5843,7 +5862,8 @@ def find_tuple_len_narrowing(self, node: ComparisonExpr) -> list[tuple[TypeMap, else: chained.append(("==", list(last_group))) last_group = set() - chained.append((op, [left, right])) + if op in {"is not", "!=", "<", "<=", ">", ">="}: + chained.append((op, [left, right])) if last_group: chained.append(("==", list(last_group))) @@ -5866,6 +5886,8 @@ def find_tuple_len_narrowing(self, node: ComparisonExpr) -> list[tuple[TypeMap, if len(literal_values) > 1: return [(None, {})] size = literal_values.pop() + if size > MAX_PRECISE_TUPLE_SIZE: + continue for tpl in tuples: yes_type, no_type = self.narrow_with_len(self.lookup_type(tpl), op, size) yes_map = None if yes_type is None else {tpl: yes_type} @@ -5875,9 +5897,11 @@ def find_tuple_len_narrowing(self, node: ComparisonExpr) -> list[tuple[TypeMap, left, right = items if self.is_len_of_tuple(right): left, right = right, left - op = opposite.get(op, op) + op = flip_ops.get(op, op) r_size = self.literal_int_expr(right) assert r_size is not None + if r_size > MAX_PRECISE_TUPLE_SIZE: + continue assert isinstance(left, CallExpr) yes_type, no_type = self.narrow_with_len( self.lookup_type(left.args[0]), op, r_size @@ -6233,33 +6257,11 @@ def can_be_narrowed_with_len(self, typ: Type) -> bool: return False def narrow_with_len(self, typ: Type, op: str, size: int) -> tuple[Type | None, Type | None]: - if op in "==": - neg_op = "!=" - elif op in "is": - neg_op = "is not" - elif op == "!=": - neg_op = "==" - elif op == "is not": - neg_op = "is" - elif op == ">": - neg_op = "<=" - elif op == ">=": - neg_op = "<" - elif op == "<": - neg_op = ">=" - elif op == "<=": - neg_op = ">" - else: - assert False, "Unsupported op for tuple len comparison" typ = get_proper_type(typ) if isinstance(typ, TupleType): - yes_type = self.refine_tuple_type_with_len(typ, op, size) - no_type = self.refine_tuple_type_with_len(typ, neg_op, size) - return yes_type, no_type + return self.refine_tuple_type_with_len(typ, op, size) elif isinstance(typ, Instance) and typ.type.fullname == "builtins.tuple": - yes_type = self.refine_instance_type_with_len(typ, op, size) - no_type = self.refine_instance_type_with_len(typ, neg_op, size) - return yes_type, no_type + return self.refine_instance_type_with_len(typ, op, size) elif isinstance(typ, UnionType): yes_types = [] no_types = [] @@ -6281,35 +6283,15 @@ def narrow_with_len(self, typ: Type, op: str, size: int) -> tuple[Type | None, T else: assert False, "Unsupported type for len narrowing" - def refine_tuple_type_with_len(self, typ: TupleType, op: str, size: int) -> Type | None: + def refine_tuple_type_with_len( + self, typ: TupleType, op: str, size: int + ) -> tuple[Type | None, Type | None]: unpack_index = find_unpack_in_list(typ.items) if unpack_index is None: - if op in ("==", "is"): - if typ.length() == size: - return typ - return None - elif op in ("!=", "is not"): - if typ.length() != size: - return typ - return None - elif op == ">": - if typ.length() > size: - return typ - return None - elif op == ">=": - if typ.length() >= size: - return typ - return None - elif op == "<": - if typ.length() < size: - return typ - return None - elif op == "<=": - if typ.length() <= size: - return typ - return None - else: - assert False, "Unsupported op for tuple len comparison" + method = int_op_to_method[op] + if method(typ.length(), size): + return typ, None + return None, typ unpack = typ.items[unpack_index] assert isinstance(unpack, UnpackType) unpacked = get_proper_type(unpack.type) @@ -6317,85 +6299,63 @@ def refine_tuple_type_with_len(self, typ: TupleType, op: str, size: int) -> Type if isinstance(unpacked, TypeVarTupleType): if op in ("==", "is"): if min_len <= size: - return typ - return None - elif op in ("!=", "is not"): - return typ - elif op == ">": - return typ - elif op == ">=": - return typ - elif op == "<": + return typ, typ + return None, typ + elif op in ("<", "<="): + if op == "<=": + size += 1 if min_len < size: - return typ - return None - elif op == "<=": - if min_len <= size: - return typ - return None + return typ, typ + return None, typ else: - assert False, "Unsupported op for tuple len comparison" + yes_type, no_type = self.refine_tuple_type_with_len(typ, neg_ops[op], size) + return no_type, yes_type assert isinstance(unpacked, Instance) and unpacked.type.fullname == "builtins.tuple" arg = unpacked.args[0] prefix = typ.items[:unpack_index] suffix = typ.items[unpack_index + 1 :] if op in ("==", "is"): if min_len <= size: - return typ.copy_modified(items=prefix + [arg] * (size - min_len) + suffix) - return None - elif op in ("!=", "is not"): - return typ - elif op == ">": - return typ.copy_modified( - items=prefix + [arg] * (size - min_len + 1) + [unpack] + suffix - ) - elif op == ">=": - return typ.copy_modified(items=prefix + [arg] * (size - min_len) + [unpack] + suffix) - elif op == "<": + return typ.copy_modified(items=prefix + [arg] * (size - min_len) + suffix), typ + return None, typ + elif op in ("<", "<="): + if op == "<=": + size += 1 if min_len < size: - items = [] + no_type = typ.copy_modified( + items=prefix + [arg] * (size - min_len) + [unpack] + suffix + ) + yes_items = [] for n in range(size - min_len): - items.append(typ.copy_modified(items=prefix + [arg] * n + suffix)) - return UnionType.make_union(items, typ.line, typ.column) - return None - elif op == "<=": - if min_len <= size: - items = [] - for n in range(size - min_len + 1): - items.append(typ.copy_modified(items=prefix + [arg] * n + suffix)) - return UnionType.make_union(items, typ.line, typ.column) - return None + yes_items.append(typ.copy_modified(items=prefix + [arg] * n + suffix)) + return UnionType.make_union(yes_items, typ.line, typ.column), no_type + return None, typ else: - assert False, "Unsupported op for tuple len comparison" + yes_type, no_type = self.refine_tuple_type_with_len(typ, neg_ops[op], size) + return no_type, yes_type - def refine_instance_type_with_len(self, typ: Instance, op: str, size: int) -> Type: + def refine_instance_type_with_len( + self, typ: Instance, op: str, size: int + ) -> tuple[Type | None, Type | None]: arg = typ.args[0] unpack = UnpackType(self.named_generic_type("builtins.tuple", [arg])) if op in ("==", "is"): - return TupleType(items=[arg] * size, fallback=typ) - elif op in ("!=", "is not"): - # TODO: return fixed union + prefixed variadic tuple? - return typ - elif op == ">": + # TODO: return fixed union + prefixed variadic tuple for no_type? + return TupleType(items=[arg] * size, fallback=typ), typ + elif op in ("<", "<="): + if op == "<=": + size += 1 if TYPE_VAR_TUPLE in self.options.enable_incomplete_feature: - return TupleType(items=[arg] * (size + 1) + [unpack], fallback=typ) - return typ - elif op == ">=": - if TYPE_VAR_TUPLE in self.options.enable_incomplete_feature: - return TupleType(items=[arg] * size + [unpack], fallback=typ) - return typ - elif op == "<": + no_type: Type | None = TupleType(items=[arg] * size + [unpack], fallback=typ) + else: + no_type = typ items = [] for n in range(size): items.append(TupleType([arg] * n, fallback=typ)) - return UnionType.make_union(items, typ.line, typ.column) - elif op == "<=": - items = [] - for n in range(size + 1): - items.append(TupleType([arg] * n, fallback=typ)) - return UnionType.make_union(items, typ.line, typ.column) + return UnionType.make_union(items, typ.line, typ.column), no_type else: - assert False, "Unsupported op for tuple len comparison" + yes_type, no_type = self.refine_instance_type_with_len(typ, neg_ops[op], size) + return no_type, yes_type # # Helpers diff --git a/test-data/unit/check-narrowing.test b/test-data/unit/check-narrowing.test index ee06954160994..7d2a07fca1077 100644 --- a/test-data/unit/check-narrowing.test +++ b/test-data/unit/check-narrowing.test @@ -1336,7 +1336,7 @@ else: [builtins fixtures/dict.pyi] [case testNarrowingLenItemAndLenCompare] -from typing import Tuple, Union, Any +from typing import Any x: Any if len(x) == x: @@ -1364,8 +1364,8 @@ else: a, b, c = x [builtins fixtures/len.pyi] -[case testNarrowingLenVariantLengthTuple] -from typing import Tuple, Union +[case testNarrowingLenHomogeneousTuple] +from typing import Tuple def make_tuple() -> Tuple[int, ...]: return (1, 1) @@ -1384,7 +1384,7 @@ else: [builtins fixtures/len.pyi] [case testNarrowingLenTypeUnaffected] -from typing import Tuple, Union, List, Any +from typing import Union, List def make() -> Union[str, List[int]]: return "" @@ -1467,7 +1467,7 @@ else: reveal_type(x) # N: Revealed type is "Tuple[builtins.int, builtins.int, builtins.int]" [builtins fixtures/len.pyi] -[case testNarrowingLenBiggerThanVariantTuple] +[case testNarrowingLenBiggerThanHomogeneousTupleShort] from typing import Tuple VarTuple = Tuple[int, ...] @@ -1477,6 +1477,21 @@ def make_tuple() -> VarTuple: x = make_tuple() if len(x) < 3: + reveal_type(x) # N: Revealed type is "Union[Tuple[()], Tuple[builtins.int], Tuple[builtins.int, builtins.int]]" +else: + reveal_type(x) # N: Revealed type is "Tuple[builtins.int, builtins.int, builtins.int, Unpack[builtins.tuple[builtins.int, ...]]]" +[builtins fixtures/len.pyi] + +[case testNarrowingLenBiggerThanHomogeneousTupleLong] +from typing import Tuple + +VarTuple = Tuple[int, ...] + +def make_tuple() -> VarTuple: + return (1, 1) + +x = make_tuple() +if len(x) < 30: reveal_type(x) # N: Revealed type is "builtins.tuple[builtins.int, ...]" else: reveal_type(x) # N: Revealed type is "builtins.tuple[builtins.int, ...]" From e9cb9b7ffba44eda81d31a36c63d16ec2ef60352 Mon Sep 17 00:00:00 2001 From: Ivan Levkivskyi Date: Sat, 7 Oct 2023 15:14:34 +0100 Subject: [PATCH 06/18] Some more tests and fixes --- mypy/checker.py | 35 +++++++++--- test-data/unit/check-narrowing.test | 88 +++++++++++++++++------------ 2 files changed, 79 insertions(+), 44 deletions(-) diff --git a/mypy/checker.py b/mypy/checker.py index 68c663f21b61c..91b997cedd558 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -43,7 +43,7 @@ from mypy.expandtype import expand_self_type, expand_type, expand_type_by_instance from mypy.literals import Key, extract_var_from_literal_hash, literal, literal_hash from mypy.maptype import map_instance_to_supertype -from mypy.meet import is_overlapping_erased_types, is_overlapping_types +from mypy.meet import is_overlapping_erased_types, is_overlapping_types, meet_types from mypy.message_registry import ErrorMessage from mypy.messages import ( SUGGESTED_TEST_FIXTURES, @@ -242,8 +242,17 @@ ">=": int.__ge__, } -flip_ops: Final = {"<": ">=", "<=": ">", ">": "<=", ">=": "<"} -neg_ops: Final = {**flip_ops, "==": "!=", "!=": "==", "is": "is not", "is not": "is"} +flip_ops: Final = {"<": ">", "<=": ">=", ">": "<", ">=": "<="} +neg_ops: Final = { + "==": "!=", + "!=": "==", + "is": "is not", + "is not": "is", + "<": ">=", + "<=": ">", + ">": "<=", + ">=": "<", +} DeferredNodeType: _TypeAlias = Union[FuncDef, LambdaExpr, OverloadedFuncDef, Decorator] FineGrainedDeferredNodeType: _TypeAlias = Union[FuncDef, MypyFile, OverloadedFuncDef] @@ -5762,7 +5771,7 @@ def has_no_custom_eq_checks(t: Type) -> bool: return reduce_conditional_maps(partial_type_maps) else: # TODO: support regular and len() narrowing in the same chain. - return reduce_conditional_maps(self.find_tuple_len_narrowing(node)) + return reduce_conditional_maps(self.find_tuple_len_narrowing(node), use_meet=True) elif isinstance(node, AssignmentExpr): if_map = {} else_map = {} @@ -7304,7 +7313,7 @@ def builtin_item_type(tp: Type) -> Type | None: return None -def and_conditional_maps(m1: TypeMap, m2: TypeMap) -> TypeMap: +def and_conditional_maps(m1: TypeMap, m2: TypeMap, use_meet: bool = False) -> TypeMap: """Calculate what information we can learn from the truth of (e1 and e2) in terms of the information that we can learn from the truth of e1 and the truth of e2. @@ -7314,15 +7323,21 @@ def and_conditional_maps(m1: TypeMap, m2: TypeMap) -> TypeMap: # One of the conditions can never be true. return None # Both conditions can be true; combine the information. Anything - # we learn from either conditions's truth is valid. If the same + # we learn from either conditions' truth is valid. If the same # expression's type is refined by both conditions, we somewhat # arbitrarily give precedence to m2. (In the future, we could use - # an intersection type.) + # an intersection type or meet.) result = m2.copy() m2_keys = {literal_hash(n2) for n2 in m2} for n1 in m1: if literal_hash(n1) not in m2_keys: result[n1] = m1[n1] + if use_meet: + # For now, meet common keys only if specifically requested. + for n1 in m1: + for n2 in m2: + if literal_hash(n1) == literal_hash(n2): + result[n1] = meet_types(m1[n1], m2[n2]) return result @@ -7348,7 +7363,9 @@ def or_conditional_maps(m1: TypeMap, m2: TypeMap) -> TypeMap: return result -def reduce_conditional_maps(type_maps: list[tuple[TypeMap, TypeMap]]) -> tuple[TypeMap, TypeMap]: +def reduce_conditional_maps( + type_maps: list[tuple[TypeMap, TypeMap]], use_meet: bool = False +) -> tuple[TypeMap, TypeMap]: """Reduces a list containing pairs of if/else TypeMaps into a single pair. We "and" together all of the if TypeMaps and "or" together the else TypeMaps. So @@ -7379,7 +7396,7 @@ def reduce_conditional_maps(type_maps: list[tuple[TypeMap, TypeMap]]) -> tuple[T else: final_if_map, final_else_map = type_maps[0] for if_map, else_map in type_maps[1:]: - final_if_map = and_conditional_maps(final_if_map, if_map) + final_if_map = and_conditional_maps(final_if_map, if_map, use_meet=use_meet) final_else_map = or_conditional_maps(final_else_map, else_map) return final_if_map, final_else_map diff --git a/test-data/unit/check-narrowing.test b/test-data/unit/check-narrowing.test index 7d2a07fca1077..158c8078cc11e 100644 --- a/test-data/unit/check-narrowing.test +++ b/test-data/unit/check-narrowing.test @@ -1348,10 +1348,7 @@ from typing import Tuple, Union VarTuple = Union[Tuple[int, int], Tuple[int, int, int]] -def make_tuple() -> VarTuple: - return (1, 1) - -x = make_tuple() +x: VarTuple a = b = c = 0 if len(x) == 3: a, b, c = x @@ -1367,11 +1364,7 @@ else: [case testNarrowingLenHomogeneousTuple] from typing import Tuple -def make_tuple() -> Tuple[int, ...]: - return (1, 1) - -x = make_tuple() - +x: Tuple[int, ...] if len(x) == 3: reveal_type(x) # N: Revealed type is "Tuple[builtins.int, builtins.int, builtins.int]" else: @@ -1386,11 +1379,7 @@ else: [case testNarrowingLenTypeUnaffected] from typing import Union, List -def make() -> Union[str, List[int]]: - return "" - -x = make() - +x: Union[str, List[int]] if len(x) == 3: reveal_type(x) # N: Revealed type is "Union[builtins.str, builtins.list[builtins.int]]" else: @@ -1399,6 +1388,7 @@ else: [case testNarrowingLenAnyListElseNotAffected] from typing import Any + def f(self, value: Any) -> Any: if isinstance(value, list) and len(value) == 0: reveal_type(value) # N: Revealed type is "builtins.list[Any]" @@ -1412,11 +1402,8 @@ from typing import Tuple, Union VarTuple = Union[Tuple[int, int], Tuple[int, int, int]] -def make_tuple() -> VarTuple: - return (1, 1) - -x = make_tuple() -y = make_tuple() +x: VarTuple +y: VarTuple if len(x) == len(y) == 3: reveal_type(x) # N: Revealed type is "Tuple[builtins.int, builtins.int, builtins.int]" reveal_type(y) # N: Revealed type is "Tuple[builtins.int, builtins.int, builtins.int]" @@ -1428,10 +1415,7 @@ from typing_extensions import Final VarTuple = Union[Tuple[int, int], Tuple[int, int, int]] -def make_tuple() -> VarTuple: - return (1, 1) - -x = make_tuple() +x: VarTuple fin: Final = 3 if len(x) == fin: reveal_type(x) # N: Revealed type is "Tuple[builtins.int, builtins.int, builtins.int]" @@ -1442,10 +1426,7 @@ from typing import Tuple, Union VarTuple = Union[Tuple[int], Tuple[int, int], Tuple[int, int, int]] -def make_tuple() -> VarTuple: - return (1, 1) - -x = make_tuple() +x: VarTuple if len(x) > 1: reveal_type(x) # N: Revealed type is "Union[Tuple[builtins.int, builtins.int], Tuple[builtins.int, builtins.int, builtins.int]]" else: @@ -1467,15 +1448,29 @@ else: reveal_type(x) # N: Revealed type is "Tuple[builtins.int, builtins.int, builtins.int]" [builtins fixtures/len.pyi] +[case testNarrowingLenBothSidesUnionTuples] +from typing import Tuple, Union + +VarTuple = Union[ + Tuple[int], + Tuple[int, int], + Tuple[int, int, int], + Tuple[int, int, int, int], +] + +x: VarTuple +if 2 <= len(x) <= 3: + reveal_type(x) # N: Revealed type is "Union[Tuple[builtins.int, builtins.int], Tuple[builtins.int, builtins.int, builtins.int]]" +else: + reveal_type(x) # N: Revealed type is "Union[Tuple[builtins.int], Tuple[builtins.int, builtins.int, builtins.int, builtins.int]]" +[builtins fixtures/len.pyi] + [case testNarrowingLenBiggerThanHomogeneousTupleShort] from typing import Tuple VarTuple = Tuple[int, ...] -def make_tuple() -> VarTuple: - return (1, 1) - -x = make_tuple() +x: VarTuple if len(x) < 3: reveal_type(x) # N: Revealed type is "Union[Tuple[()], Tuple[builtins.int], Tuple[builtins.int, builtins.int]]" else: @@ -1487,12 +1482,35 @@ from typing import Tuple VarTuple = Tuple[int, ...] -def make_tuple() -> VarTuple: - return (1, 1) - -x = make_tuple() +x: VarTuple if len(x) < 30: reveal_type(x) # N: Revealed type is "builtins.tuple[builtins.int, ...]" else: reveal_type(x) # N: Revealed type is "builtins.tuple[builtins.int, ...]" [builtins fixtures/len.pyi] + +[case testNarrowingLenBothSidesHomogeneousTuple] +from typing import Tuple + +x: Tuple[int, ...] +if 1 < len(x) < 4: + reveal_type(x) # N: Revealed type is "Union[Tuple[builtins.int, builtins.int], Tuple[builtins.int, builtins.int, builtins.int]]" +else: + reveal_type(x) # N: Revealed type is "Union[Tuple[()], Tuple[builtins.int], Tuple[builtins.int, builtins.int, builtins.int, builtins.int, Unpack[builtins.tuple[builtins.int, ...]]]]" +[builtins fixtures/len.pyi] + +[case testNarrowingLenUnionTupleUnreachable] +# flags: --warn-unreachable +from typing import Tuple, Union + +x: Union[Tuple[int, int], Tuple[int, int, int]] +if len(x) >= 4: + reveal_type(x) # E: Statement is unreachable +else: + reveal_type(x) # N: Revealed type is "Union[Tuple[builtins.int, builtins.int], Tuple[builtins.int, builtins.int, builtins.int]]" + +if len(x) < 2: + reveal_type(x) # E: Statement is unreachable +else: + reveal_type(x) # N: Revealed type is "Union[Tuple[builtins.int, builtins.int], Tuple[builtins.int, builtins.int, builtins.int]]" +[builtins fixtures/len.pyi] From 9815637f9cc3ce2188e5359be9c69850cb1262a6 Mon Sep 17 00:00:00 2001 From: Ivan Levkivskyi Date: Sat, 7 Oct 2023 23:24:46 +0100 Subject: [PATCH 07/18] Add TypeVarTuple tests --- test-data/unit/check-narrowing.test | 67 ++++++++++++++++++++++++++++- 1 file changed, 65 insertions(+), 2 deletions(-) diff --git a/test-data/unit/check-narrowing.test b/test-data/unit/check-narrowing.test index 158c8078cc11e..803f77f8f10a7 100644 --- a/test-data/unit/check-narrowing.test +++ b/test-data/unit/check-narrowing.test @@ -1421,7 +1421,7 @@ if len(x) == fin: reveal_type(x) # N: Revealed type is "Tuple[builtins.int, builtins.int, builtins.int]" [builtins fixtures/len.pyi] -[case testNarrowingLenBiggerThan] +[case testNarrowingLenGreaterThan] from typing import Tuple, Union VarTuple = Union[Tuple[int], Tuple[int, int], Tuple[int, int, int]] @@ -1465,7 +1465,7 @@ else: reveal_type(x) # N: Revealed type is "Union[Tuple[builtins.int], Tuple[builtins.int, builtins.int, builtins.int, builtins.int]]" [builtins fixtures/len.pyi] -[case testNarrowingLenBiggerThanHomogeneousTupleShort] +[case testNarrowingLenGreaterThanHomogeneousTupleShort] from typing import Tuple VarTuple = Tuple[int, ...] @@ -1514,3 +1514,66 @@ if len(x) < 2: else: reveal_type(x) # N: Revealed type is "Union[Tuple[builtins.int, builtins.int], Tuple[builtins.int, builtins.int, builtins.int]]" [builtins fixtures/len.pyi] + +[case testNarrowingLenTypeVarTupleEquals] +from typing import Tuple +from typing_extensions import TypeVarTuple, Unpack + +Ts = TypeVarTuple("Ts") +def foo(x: Tuple[int, Unpack[Ts], str]) -> None: + if len(x) == 5: + reveal_type(x) # N: Revealed type is "Tuple[builtins.int, Unpack[Ts`-1], builtins.str]" + else: + reveal_type(x) # N: Revealed type is "Tuple[builtins.int, Unpack[Ts`-1], builtins.str]" + + if len(x) != 5: + reveal_type(x) # N: Revealed type is "Tuple[builtins.int, Unpack[Ts`-1], builtins.str]" + else: + reveal_type(x) # N: Revealed type is "Tuple[builtins.int, Unpack[Ts`-1], builtins.str]" +[builtins fixtures/len.pyi] + +[case testNarrowingLenTypeVarTupleGreaterThan] +from typing import Tuple +from typing_extensions import TypeVarTuple, Unpack + +Ts = TypeVarTuple("Ts") +def foo(x: Tuple[int, Unpack[Ts], str]) -> None: + if len(x) > 5: + reveal_type(x) # N: Revealed type is "Tuple[builtins.int, Unpack[Ts`-1], builtins.str]" + else: + reveal_type(x) # N: Revealed type is "Tuple[builtins.int, Unpack[Ts`-1], builtins.str]" + + if len(x) < 5: + reveal_type(x) # N: Revealed type is "Tuple[builtins.int, Unpack[Ts`-1], builtins.str]" + else: + reveal_type(x) # N: Revealed type is "Tuple[builtins.int, Unpack[Ts`-1], builtins.str]" +[builtins fixtures/len.pyi] + +[case testNarrowingLenTypeVarTupleUnreachable] +# flags: --warn-unreachable +from typing import Tuple +from typing_extensions import TypeVarTuple, Unpack + +Ts = TypeVarTuple("Ts") +def foo(x: Tuple[int, Unpack[Ts], str]) -> None: + if len(x) == 1: + reveal_type(x) # E: Statement is unreachable + else: + reveal_type(x) # N: Revealed type is "Tuple[builtins.int, Unpack[Ts`-1], builtins.str]" + + if len(x) != 1: + reveal_type(x) # N: Revealed type is "Tuple[builtins.int, Unpack[Ts`-1], builtins.str]" + else: + reveal_type(x) # E: Statement is unreachable + +def bar(x: Tuple[int, Unpack[Ts], str]) -> None: + if len(x) >= 2: + reveal_type(x) # N: Revealed type is "Tuple[builtins.int, Unpack[Ts`-1], builtins.str]" + else: + reveal_type(x) # E: Statement is unreachable + + if len(x) < 2: + reveal_type(x) # E: Statement is unreachable + else: + reveal_type(x) # N: Revealed type is "Tuple[builtins.int, Unpack[Ts`-1], builtins.str]" +[builtins fixtures/len.pyi] From 8eb8900fc4e61d47380bf67b634df24b2883a55f Mon Sep 17 00:00:00 2001 From: Ivan Levkivskyi Date: Sun, 8 Oct 2023 00:53:13 +0100 Subject: [PATCH 08/18] More variadic tests; binder special-casing --- mypy/binder.py | 41 ++++++++++++++++++ test-data/unit/check-narrowing.test | 64 +++++++++++++++++++++++++++++ 2 files changed, 105 insertions(+) diff --git a/mypy/binder.py b/mypy/binder.py index 8a68f24f661e5..c390fa5431a26 100644 --- a/mypy/binder.py +++ b/mypy/binder.py @@ -12,13 +12,18 @@ from mypy.subtypes import is_same_type, is_subtype from mypy.types import ( AnyType, + Instance, NoneType, PartialType, + TupleType, Type, TypeOfAny, TypeType, UnionType, + UnpackType, + find_unpack_in_list, get_proper_type, + get_proper_types, ) from mypy.typevars import fill_typevars_with_any @@ -213,6 +218,8 @@ def update_from_options(self, frames: list[Frame]) -> bool: for other in resulting_values[1:]: assert other is not None type = join_simple(self.declarations[key], type, other) + if isinstance(type, UnionType) and len(type.items) > 1: + type = collapse_variadic_union(type) if current_value is None or not is_same_type(type, current_value): self._put(key, type) changed = True @@ -453,3 +460,37 @@ def get_declaration(expr: BindableExpression) -> Type | None: elif isinstance(expr.node, TypeInfo): return TypeType(fill_typevars_with_any(expr.node)) return None + + +def collapse_variadic_union(typ: UnionType) -> UnionType | TupleType | Instance: + items = get_proper_types(typ.items) + if not all(isinstance(it, TupleType) for it in items): + return typ + tuple_items = cast("list[TupleType]", items) + tuple_items = sorted(tuple_items, key=lambda t: len(t.items)) + first = tuple_items[0] + last = tuple_items[-1] + unpack_index = find_unpack_in_list(last.items) + if unpack_index is None: + return typ + unpack = last.items[unpack_index] + assert isinstance(unpack, UnpackType) + unpacked = get_proper_type(unpack.type) + if not isinstance(unpacked, Instance): + return typ + assert unpacked.type.fullname == "builtins.tuple" + suffix = last.items[unpack_index + 1 :] + if len(first.items) < len(suffix): + return typ + if suffix and first.items[-len(suffix) :] != suffix: + return typ + prefix = first.items[: -len(suffix)] + arg = unpacked.args[0] + for i, it in enumerate(tuple_items[1:-1]): + if it.items != prefix + [arg] * (i + 1) + suffix: + return typ + if last.items != prefix + [arg] * (len(typ.items) - 1) + [unpack] + suffix: + return typ + if len(first.items) == 0: + return unpacked.copy_modified() + return TupleType(prefix + [unpack] + suffix, fallback=last.partial_fallback) diff --git a/test-data/unit/check-narrowing.test b/test-data/unit/check-narrowing.test index 803f77f8f10a7..7b2cdcf3e977b 100644 --- a/test-data/unit/check-narrowing.test +++ b/test-data/unit/check-narrowing.test @@ -1475,6 +1475,7 @@ if len(x) < 3: reveal_type(x) # N: Revealed type is "Union[Tuple[()], Tuple[builtins.int], Tuple[builtins.int, builtins.int]]" else: reveal_type(x) # N: Revealed type is "Tuple[builtins.int, builtins.int, builtins.int, Unpack[builtins.tuple[builtins.int, ...]]]" +reveal_type(x) # N: Revealed type is "builtins.tuple[builtins.int, ...]" [builtins fixtures/len.pyi] [case testNarrowingLenBiggerThanHomogeneousTupleLong] @@ -1497,6 +1498,7 @@ if 1 < len(x) < 4: reveal_type(x) # N: Revealed type is "Union[Tuple[builtins.int, builtins.int], Tuple[builtins.int, builtins.int, builtins.int]]" else: reveal_type(x) # N: Revealed type is "Union[Tuple[()], Tuple[builtins.int], Tuple[builtins.int, builtins.int, builtins.int, builtins.int, Unpack[builtins.tuple[builtins.int, ...]]]]" +reveal_type(x) # N: Revealed type is "builtins.tuple[builtins.int, ...]" [builtins fixtures/len.pyi] [case testNarrowingLenUnionTupleUnreachable] @@ -1577,3 +1579,65 @@ def bar(x: Tuple[int, Unpack[Ts], str]) -> None: else: reveal_type(x) # N: Revealed type is "Tuple[builtins.int, Unpack[Ts`-1], builtins.str]" [builtins fixtures/len.pyi] + +[case testNarrowingLenVariadicTupleEquals] +from typing import Tuple +from typing_extensions import Unpack + +def foo(x: Tuple[int, Unpack[Tuple[float, ...]], str]) -> None: + if len(x) == 4: + reveal_type(x) # N: Revealed type is "Tuple[builtins.int, builtins.float, builtins.float, builtins.str]" + else: + reveal_type(x) # N: Revealed type is "Tuple[builtins.int, Unpack[builtins.tuple[builtins.float, ...]], builtins.str]" + + if len(x) != 4: + reveal_type(x) # N: Revealed type is "Tuple[builtins.int, Unpack[builtins.tuple[builtins.float, ...]], builtins.str]" + else: + reveal_type(x) # N: Revealed type is "Tuple[builtins.int, builtins.float, builtins.float, builtins.str]" +[builtins fixtures/len.pyi] + +[case testNarrowingLenVariadicTupleGreaterThan] +from typing import Tuple +from typing_extensions import Unpack + +def foo(x: Tuple[int, Unpack[Tuple[float, ...]], str]) -> None: + if len(x) > 3: + reveal_type(x) # N: Revealed type is "Tuple[builtins.int, builtins.float, builtins.float, Unpack[builtins.tuple[builtins.float, ...]], builtins.str]" + else: + reveal_type(x) # N: Revealed type is "Union[Tuple[builtins.int, builtins.str], Tuple[builtins.int, builtins.float, builtins.str]]" + reveal_type(x) # N: Revealed type is "Tuple[builtins.int, Unpack[builtins.tuple[builtins.float, ...]], builtins.str]" + + if len(x) < 3: + reveal_type(x) # N: Revealed type is "Tuple[builtins.int, builtins.str]" + else: + reveal_type(x) # N: Revealed type is "Tuple[builtins.int, builtins.float, Unpack[builtins.tuple[builtins.float, ...]], builtins.str]" + reveal_type(x) # N: Revealed type is "Tuple[builtins.int, Unpack[builtins.tuple[builtins.float, ...]], builtins.str]" +[builtins fixtures/len.pyi] + +[case testNarrowingLenVariadicTupleUnreachable] +# flags: --warn-unreachable +from typing import Tuple +from typing_extensions import Unpack + +def foo(x: Tuple[int, Unpack[Tuple[float, ...]], str]) -> None: + if len(x) == 1: + reveal_type(x) # E: Statement is unreachable + else: + reveal_type(x) # N: Revealed type is "Tuple[builtins.int, Unpack[builtins.tuple[builtins.float, ...]], builtins.str]" + + if len(x) != 1: + reveal_type(x) # N: Revealed type is "Tuple[builtins.int, Unpack[builtins.tuple[builtins.float, ...]], builtins.str]" + else: + reveal_type(x) # E: Statement is unreachable + +def bar(x: Tuple[int, Unpack[Tuple[float, ...]], str]) -> None: + if len(x) >= 2: + reveal_type(x) # N: Revealed type is "Tuple[builtins.int, Unpack[builtins.tuple[builtins.float, ...]], builtins.str]" + else: + reveal_type(x) # E: Statement is unreachable + + if len(x) < 2: + reveal_type(x) # E: Statement is unreachable + else: + reveal_type(x) # N: Revealed type is "Tuple[builtins.int, Unpack[builtins.tuple[builtins.float, ...]], builtins.str]" +[builtins fixtures/len.pyi] From a03d3bfc04058ad86de88f46f38a88309791a937 Mon Sep 17 00:00:00 2001 From: Ivan Levkivskyi Date: Sun, 8 Oct 2023 13:08:56 +0100 Subject: [PATCH 09/18] Docstring/comments; reorg code --- mypy/binder.py | 27 +++++- mypy/checker.py | 245 +++++++++++++++++++++++++++++------------------- 2 files changed, 177 insertions(+), 95 deletions(-) diff --git a/mypy/binder.py b/mypy/binder.py index c390fa5431a26..ee686944ea794 100644 --- a/mypy/binder.py +++ b/mypy/binder.py @@ -218,6 +218,14 @@ def update_from_options(self, frames: list[Frame]) -> bool: for other in resulting_values[1:]: assert other is not None type = join_simple(self.declarations[key], type, other) + # Try simplifying resulting type for unions involving variadic tuples. + # Technically, everything is still valid without this step, but if we do + # not do this, this may create long unions after exiting an if check like: + # x: tuple[int, ...] + # if len(x) < 10: + # ... + # We want the type of x to be tuple[int, ...] after this block (if it is + # equivalent to such type). if isinstance(type, UnionType) and len(type.items) > 1: type = collapse_variadic_union(type) if current_value is None or not is_same_type(type, current_value): @@ -463,6 +471,14 @@ def get_declaration(expr: BindableExpression) -> Type | None: def collapse_variadic_union(typ: UnionType) -> UnionType | TupleType | Instance: + """Simplify a union involving variadic tuple if possible. + + This will collapse a type like e.g. + tuple[X, Z] | tuple[X, Y, Z] | tuple[X, Y, Y, *tuple[Y, ...], Z] + back to + tuple[X, *tuple[Y, ...], Z] + which is equivalent, but much simpler form of the same type. + """ items = get_proper_types(typ.items) if not all(isinstance(it, TupleType) for it in items): return typ @@ -480,15 +496,24 @@ def collapse_variadic_union(typ: UnionType) -> UnionType | TupleType | Instance: return typ assert unpacked.type.fullname == "builtins.tuple" suffix = last.items[unpack_index + 1 :] + + # Check that first item matches the expected pattern and infer prefix. if len(first.items) < len(suffix): return typ if suffix and first.items[-len(suffix) :] != suffix: return typ - prefix = first.items[: -len(suffix)] + if suffix: + prefix = first.items[: -len(suffix)] + else: + prefix = first.items + + # Check that all middle types match the expected pattern as well. arg = unpacked.args[0] for i, it in enumerate(tuple_items[1:-1]): if it.items != prefix + [arg] * (i + 1) + suffix: return typ + + # Check the last item (the one with unpack), and choose an appropriate simplified type. if last.items != prefix + [arg] * (len(typ.items) - 1) + [unpack] + suffix: return typ if len(first.items) == 0: diff --git a/mypy/checker.py b/mypy/checker.py index 91b997cedd558..9023e910b0882 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -5767,10 +5767,14 @@ def has_no_custom_eq_checks(t: Type) -> bool: partial_type_maps.append((if_map, else_map)) + # If we have found non-trivial restrictions from the regular comparisons, + # then return soon. Otherwise try to infer restrictions involving `len(x)`. + # TODO: support regular and len() narrowing in the same chain. if any(m != ({}, {}) for m in partial_type_maps): return reduce_conditional_maps(partial_type_maps) else: - # TODO: support regular and len() narrowing in the same chain. + # Use meet for `and` maps to get correct results for chained checks + # like `if 1 < len(x) < 4: ...` return reduce_conditional_maps(self.find_tuple_len_narrowing(node), use_meet=True) elif isinstance(node, AssignmentExpr): if_map = {} @@ -5830,96 +5834,6 @@ def has_no_custom_eq_checks(t: Type) -> bool: else_map = {node: else_type} if not isinstance(else_type, UninhabitedType) else None return if_map, else_map - def is_len_of_tuple(self, expr: Expression) -> bool: - if not isinstance(expr, CallExpr): - return False - if not refers_to_fullname(expr.callee, "builtins.len"): - return False - if len(expr.args) != 1: - return False - expr = expr.args[0] - if literal(expr) != LITERAL_TYPE: - return False - if not self.has_type(expr): - return False - return self.can_be_narrowed_with_len(self.lookup_type(expr)) - - def literal_int_expr(self, expr: Expression) -> int | None: - if not self.has_type(expr): - return None - expr_type = self.lookup_type(expr) - expr_type = coerce_to_literal(expr_type) - proper_type = get_proper_type(expr_type) - if not isinstance(proper_type, LiteralType): - return None - if not isinstance(proper_type.value, int): - return None - return proper_type.value - - def find_tuple_len_narrowing(self, node: ComparisonExpr) -> list[tuple[TypeMap, TypeMap]]: - type_maps = [] - chained = [] - last_group = set() - for op, left, right in node.pairwise(): - if isinstance(left, AssignmentExpr): - left = left.value - if isinstance(right, AssignmentExpr): - right = right.value - if op in ("is", "=="): - last_group.add(left) - last_group.add(right) - else: - chained.append(("==", list(last_group))) - last_group = set() - if op in {"is not", "!=", "<", "<=", ">", ">="}: - chained.append((op, [left, right])) - if last_group: - chained.append(("==", list(last_group))) - - for op, items in chained: - if not any(self.literal_int_expr(it) is not None for it in items): - continue - if not any(self.is_len_of_tuple(it) for it in items): - continue - if op in ("is", "=="): - literal_values = set() - tuples = [] - for it in items: - lit = self.literal_int_expr(it) - if lit is not None: - literal_values.add(lit) - continue - if self.is_len_of_tuple(it): - assert isinstance(it, CallExpr) - tuples.append(it.args[0]) - if len(literal_values) > 1: - return [(None, {})] - size = literal_values.pop() - if size > MAX_PRECISE_TUPLE_SIZE: - continue - for tpl in tuples: - yes_type, no_type = self.narrow_with_len(self.lookup_type(tpl), op, size) - yes_map = None if yes_type is None else {tpl: yes_type} - no_map = None if no_type is None else {tpl: no_type} - type_maps.append((yes_map, no_map)) - else: - left, right = items - if self.is_len_of_tuple(right): - left, right = right, left - op = flip_ops.get(op, op) - r_size = self.literal_int_expr(right) - assert r_size is not None - if r_size > MAX_PRECISE_TUPLE_SIZE: - continue - assert isinstance(left, CallExpr) - yes_type, no_type = self.narrow_with_len( - self.lookup_type(left.args[0]), op, r_size - ) - yes_map = None if yes_type is None else {left.args[0]: yes_type} - no_map = None if no_type is None else {left.args[0]: no_type} - type_maps.append((yes_map, no_map)) - return type_maps - def propagate_up_typemap_info(self, new_types: TypeMap) -> TypeMap: """Attempts refining parent expressions of any MemberExpr or IndexExprs in new_types. @@ -6253,19 +6167,147 @@ def refine_away_none_in_comparison( return if_map, {} + def is_len_of_tuple(self, expr: Expression) -> bool: + """Is this expression a `len(x)` call where x is a tuple or union of tuples?""" + if not isinstance(expr, CallExpr): + return False + if not refers_to_fullname(expr.callee, "builtins.len"): + return False + if len(expr.args) != 1: + return False + expr = expr.args[0] + if literal(expr) != LITERAL_TYPE: + return False + if not self.has_type(expr): + return False + return self.can_be_narrowed_with_len(self.lookup_type(expr)) + def can_be_narrowed_with_len(self, typ: Type) -> bool: + """Is this a type that can benefit from length check type restrictions? + + Currently supported types are TupleTypes, Instances of builtins.tuple, and + unions of such types. + """ p_typ = get_proper_type(typ) if isinstance(p_typ, TupleType): return True if isinstance(p_typ, Instance): - # TODO: support tuple subclasses? + # TODO: support tuple subclasses as well? return p_typ.type.fullname == "builtins.tuple" if isinstance(p_typ, UnionType): # TODO: support mixed unions return all(self.can_be_narrowed_with_len(t) for t in p_typ.items) return False + def literal_int_expr(self, expr: Expression) -> int | None: + """Is this expression an int literal, or a reference to an int constant? + + If yes, return the corresponding int value, otherwise return None. + """ + if not self.has_type(expr): + return None + expr_type = self.lookup_type(expr) + expr_type = coerce_to_literal(expr_type) + proper_type = get_proper_type(expr_type) + if not isinstance(proper_type, LiteralType): + return None + if not isinstance(proper_type.value, int): + return None + return proper_type.value + + def find_tuple_len_narrowing(self, node: ComparisonExpr) -> list[tuple[TypeMap, TypeMap]]: + """Top-level logic to find type restrictions from a length check on tuples. + + We try to detect `if` checks like the following: + x: tuple[int, int] | tuple[int, int, int] + y: tuple[int, int] | tuple[int, int, int] + if len(x) == len(y) == 2: + a, b = x # OK + c, d = y # OK + + z: tuple[int, ...] + if 1 < len(z) < 4: + x = z # OK + and report corresponding type restrictions to the binder. + """ + # First step: group consecutive `is` and `==` comparisons together. + # This is essentially a simplified version of group_comparison_operands(), + # tuned to the len()-like checks. Note that we don't propagate indirect + # restrictions like e.g. `len(x) > foo() > 1` yet, since it is tricky. + # TODO: propagate indirect len() comparison restrictions. + chained = [] + last_group = set() + for op, left, right in node.pairwise(): + if isinstance(left, AssignmentExpr): + left = left.value + if isinstance(right, AssignmentExpr): + right = right.value + if op in ("is", "=="): + last_group.add(left) + last_group.add(right) + else: + chained.append(("==", list(last_group))) + last_group = set() + if op in {"is not", "!=", "<", "<=", ">", ">="}: + chained.append((op, [left, right])) + if last_group: + chained.append(("==", list(last_group))) + + # Second step: infer type restrictions from each group found above. + type_maps = [] + for op, items in chained: + if not any(self.literal_int_expr(it) is not None for it in items): + continue + if not any(self.is_len_of_tuple(it) for it in items): + continue + + # At this step we know there is at least one len(x) and one literal in the group. + if op in ("is", "=="): + literal_values = set() + tuples = [] + for it in items: + lit = self.literal_int_expr(it) + if lit is not None: + literal_values.add(lit) + continue + if self.is_len_of_tuple(it): + assert isinstance(it, CallExpr) + tuples.append(it.args[0]) + if len(literal_values) > 1: + # More than one different literal value found, like 1 == len(x) == 2, + # so the corresponding branch is unreachable. + return [(None, {})] + size = literal_values.pop() + if size > MAX_PRECISE_TUPLE_SIZE: + # Avoid creating huge tuples from checks like if len(x) == 300. + continue + for tpl in tuples: + yes_type, no_type = self.narrow_with_len(self.lookup_type(tpl), op, size) + yes_map = None if yes_type is None else {tpl: yes_type} + no_map = None if no_type is None else {tpl: no_type} + type_maps.append((yes_map, no_map)) + else: + left, right = items + if self.is_len_of_tuple(right): + # Normalize `1 < len(x)` and similar as `len(x) > 1`. + left, right = right, left + op = flip_ops.get(op, op) + r_size = self.literal_int_expr(right) + assert r_size is not None + if r_size > MAX_PRECISE_TUPLE_SIZE: + # Avoid creating huge unions from checks like if len(x) > 300. + continue + assert isinstance(left, CallExpr) + yes_type, no_type = self.narrow_with_len( + self.lookup_type(left.args[0]), op, r_size + ) + yes_map = None if yes_type is None else {left.args[0]: yes_type} + no_map = None if no_type is None else {left.args[0]: no_type} + type_maps.append((yes_map, no_map)) + return type_maps + def narrow_with_len(self, typ: Type, op: str, size: int) -> tuple[Type | None, Type | None]: + """Dispatch tuple type narrowing logic depending on the kind of type we got.""" typ = get_proper_type(typ) if isinstance(typ, TupleType): return self.refine_tuple_type_with_len(typ, op, size) @@ -6295,8 +6337,11 @@ def narrow_with_len(self, typ: Type, op: str, size: int) -> tuple[Type | None, T def refine_tuple_type_with_len( self, typ: TupleType, op: str, size: int ) -> tuple[Type | None, Type | None]: + """Narrow a TupleType using length restrictions.""" unpack_index = find_unpack_in_list(typ.items) if unpack_index is None: + # For fixed length tuple situation is trivial, it is either reachable or not, + # depending on the current length, expected length, and the comparison op. method = int_op_to_method[op] if method(typ.length(), size): return typ, None @@ -6306,6 +6351,9 @@ def refine_tuple_type_with_len( unpacked = get_proper_type(unpack.type) min_len = typ.length() - 1 if isinstance(unpacked, TypeVarTupleType): + # For tuples involving TypeVarTuple unpack we can't do much except + # inferring reachability, since we cannot really split a TypeVarTuple. + # TODO: support some cases by adding a min_len attribute to TypeVarTupleType. if op in ("==", "is"): if min_len <= size: return typ, typ @@ -6319,18 +6367,24 @@ def refine_tuple_type_with_len( else: yes_type, no_type = self.refine_tuple_type_with_len(typ, neg_ops[op], size) return no_type, yes_type + # Homogeneous variadic item is the case where we are most flexible. Essentially, + # we adjust the variadic item by "eating away" from it to satisfy the restriction. assert isinstance(unpacked, Instance) and unpacked.type.fullname == "builtins.tuple" arg = unpacked.args[0] prefix = typ.items[:unpack_index] suffix = typ.items[unpack_index + 1 :] if op in ("==", "is"): if min_len <= size: + # TODO: return fixed union + prefixed variadic tuple for no_type? return typ.copy_modified(items=prefix + [arg] * (size - min_len) + suffix), typ return None, typ elif op in ("<", "<="): if op == "<=": size += 1 if min_len < size: + # Note: there is some ambiguity w.r.t. to where to put the additional + # items: before or after the unpack. However, such types are equivalent, + # so we always put them before for consistency. no_type = typ.copy_modified( items=prefix + [arg] * (size - min_len) + [unpack] + suffix ) @@ -6346,8 +6400,8 @@ def refine_tuple_type_with_len( def refine_instance_type_with_len( self, typ: Instance, op: str, size: int ) -> tuple[Type | None, Type | None]: + """Narrow a homogeneous tuple using length restrictions.""" arg = typ.args[0] - unpack = UnpackType(self.named_generic_type("builtins.tuple", [arg])) if op in ("==", "is"): # TODO: return fixed union + prefixed variadic tuple for no_type? return TupleType(items=[arg] * size, fallback=typ), typ @@ -6355,6 +6409,7 @@ def refine_instance_type_with_len( if op == "<=": size += 1 if TYPE_VAR_TUPLE in self.options.enable_incomplete_feature: + unpack = UnpackType(self.named_generic_type("builtins.tuple", [arg])) no_type: Type | None = TupleType(items=[arg] * size + [unpack], fallback=typ) else: no_type = typ @@ -7326,7 +7381,7 @@ def and_conditional_maps(m1: TypeMap, m2: TypeMap, use_meet: bool = False) -> Ty # we learn from either conditions' truth is valid. If the same # expression's type is refined by both conditions, we somewhat # arbitrarily give precedence to m2. (In the future, we could use - # an intersection type or meet.) + # an intersection type or meet_types().) result = m2.copy() m2_keys = {literal_hash(n2) for n2 in m2} for n1 in m1: @@ -7334,6 +7389,8 @@ def and_conditional_maps(m1: TypeMap, m2: TypeMap, use_meet: bool = False) -> Ty result[n1] = m1[n1] if use_meet: # For now, meet common keys only if specifically requested. + # This is currently used for tuple types narrowing, where having + # a precise result is important. for n1 in m1: for n2 in m2: if literal_hash(n1) == literal_hash(n2): From 4cc35ec7e0cac4abbb0412cf962f11befdc2d253 Mon Sep 17 00:00:00 2001 From: Ivan Levkivskyi Date: Sun, 8 Oct 2023 13:50:44 +0100 Subject: [PATCH 10/18] Fix one TODO (mixed unions) --- mypy/binder.py | 26 +++++++++++++++++--------- mypy/checker.py | 11 ++++++++--- test-data/unit/check-narrowing.test | 20 ++++++++++++++++++++ 3 files changed, 45 insertions(+), 12 deletions(-) diff --git a/mypy/binder.py b/mypy/binder.py index ee686944ea794..09ee4a92aa40e 100644 --- a/mypy/binder.py +++ b/mypy/binder.py @@ -23,7 +23,6 @@ UnpackType, find_unpack_in_list, get_proper_type, - get_proper_types, ) from mypy.typevars import fill_typevars_with_any @@ -225,8 +224,8 @@ def update_from_options(self, frames: list[Frame]) -> bool: # if len(x) < 10: # ... # We want the type of x to be tuple[int, ...] after this block (if it is - # equivalent to such type). - if isinstance(type, UnionType) and len(type.items) > 1: + # still equivalent to such type). + if isinstance(type, UnionType): type = collapse_variadic_union(type) if current_value is None or not is_same_type(type, current_value): self._put(key, type) @@ -470,7 +469,7 @@ def get_declaration(expr: BindableExpression) -> Type | None: return None -def collapse_variadic_union(typ: UnionType) -> UnionType | TupleType | Instance: +def collapse_variadic_union(typ: UnionType) -> Type: """Simplify a union involving variadic tuple if possible. This will collapse a type like e.g. @@ -479,10 +478,17 @@ def collapse_variadic_union(typ: UnionType) -> UnionType | TupleType | Instance: tuple[X, *tuple[Y, ...], Z] which is equivalent, but much simpler form of the same type. """ - items = get_proper_types(typ.items) - if not all(isinstance(it, TupleType) for it in items): + tuple_items = [] + other_items = [] + for t in typ.items: + p_t = get_proper_type(t) + if isinstance(p_t, TupleType): + tuple_items.append(p_t) + else: + other_items.append(t) + if len(tuple_items) <= 1: + # This type cannot be simplified further. return typ - tuple_items = cast("list[TupleType]", items) tuple_items = sorted(tuple_items, key=lambda t: len(t.items)) first = tuple_items[0] last = tuple_items[-1] @@ -517,5 +523,7 @@ def collapse_variadic_union(typ: UnionType) -> UnionType | TupleType | Instance: if last.items != prefix + [arg] * (len(typ.items) - 1) + [unpack] + suffix: return typ if len(first.items) == 0: - return unpacked.copy_modified() - return TupleType(prefix + [unpack] + suffix, fallback=last.partial_fallback) + simplified: Type = unpacked.copy_modified() + else: + simplified = TupleType(prefix + [unpack] + suffix, fallback=last.partial_fallback) + return UnionType.make_union([simplified] + other_items) diff --git a/mypy/checker.py b/mypy/checker.py index 9023e910b0882..a679cb09e10c6 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -6186,7 +6186,7 @@ def can_be_narrowed_with_len(self, typ: Type) -> bool: """Is this a type that can benefit from length check type restrictions? Currently supported types are TupleTypes, Instances of builtins.tuple, and - unions of such types. + unions involving such types. """ p_typ = get_proper_type(typ) if isinstance(p_typ, TupleType): @@ -6195,8 +6195,7 @@ def can_be_narrowed_with_len(self, typ: Type) -> bool: # TODO: support tuple subclasses as well? return p_typ.type.fullname == "builtins.tuple" if isinstance(p_typ, UnionType): - # TODO: support mixed unions - return all(self.can_be_narrowed_with_len(t) for t in p_typ.items) + return any(self.can_be_narrowed_with_len(t) for t in p_typ.items) return False def literal_int_expr(self, expr: Expression) -> int | None: @@ -6316,12 +6315,18 @@ def narrow_with_len(self, typ: Type, op: str, size: int) -> tuple[Type | None, T elif isinstance(typ, UnionType): yes_types = [] no_types = [] + other_types = [] for t in typ.items: + if not self.can_be_narrowed_with_len(t): + other_types.append(t) + continue yt, nt = self.narrow_with_len(t, op, size) if yt is not None: yes_types.append(yt) if nt is not None: no_types.append(nt) + yes_types += other_types + no_types += other_types if yes_types: yes_type = make_simplified_union(yes_types) else: diff --git a/test-data/unit/check-narrowing.test b/test-data/unit/check-narrowing.test index 7b2cdcf3e977b..8e6e6f86ac2a7 100644 --- a/test-data/unit/check-narrowing.test +++ b/test-data/unit/check-narrowing.test @@ -1517,6 +1517,26 @@ else: reveal_type(x) # N: Revealed type is "Union[Tuple[builtins.int, builtins.int], Tuple[builtins.int, builtins.int, builtins.int]]" [builtins fixtures/len.pyi] +[case testNarrowingLenMixedTypes] +from typing import Tuple, List, Union + +x: Union[Tuple[int, int], Tuple[int, int, int], List[int]] +a = b = c = 0 +if len(x) == 3: + reveal_type(x) # N: Revealed type is "Union[Tuple[builtins.int, builtins.int, builtins.int], builtins.list[builtins.int]]" + a, b, c = x +else: + reveal_type(x) # N: Revealed type is "Union[Tuple[builtins.int, builtins.int], builtins.list[builtins.int]]" + a, b = x + +if len(x) != 3: + reveal_type(x) # N: Revealed type is "Union[Tuple[builtins.int, builtins.int], builtins.list[builtins.int]]" + a, b = x +else: + reveal_type(x) # N: Revealed type is "Union[Tuple[builtins.int, builtins.int, builtins.int], builtins.list[builtins.int]]" + a, b, c = x +[builtins fixtures/len.pyi] + [case testNarrowingLenTypeVarTupleEquals] from typing import Tuple from typing_extensions import TypeVarTuple, Unpack From 34d2dd4301b4ebff86f1d0d3303c882233dbb1f3 Mon Sep 17 00:00:00 2001 From: Ivan Levkivskyi Date: Sun, 8 Oct 2023 18:59:20 +0100 Subject: [PATCH 11/18] Partially fix another TODO --- mypy/checker.py | 13 ++++++--- mypy/checkexpr.py | 45 ++++++++++++++++++++--------- mypy/meet.py | 6 ++-- mypy/subtypes.py | 2 +- mypy/types.py | 14 +++++++-- test-data/unit/check-narrowing.test | 11 +++++++ test-data/unit/fixtures/len.pyi | 1 + 7 files changed, 69 insertions(+), 23 deletions(-) diff --git a/mypy/checker.py b/mypy/checker.py index a679cb09e10c6..ad92d79700223 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -6354,11 +6354,11 @@ def refine_tuple_type_with_len( unpack = typ.items[unpack_index] assert isinstance(unpack, UnpackType) unpacked = get_proper_type(unpack.type) - min_len = typ.length() - 1 if isinstance(unpacked, TypeVarTupleType): # For tuples involving TypeVarTuple unpack we can't do much except - # inferring reachability, since we cannot really split a TypeVarTuple. - # TODO: support some cases by adding a min_len attribute to TypeVarTupleType. + # inferring reachability, and recording the restrictions on TypeVarTuple + # for further "manual" use elsewhere. + min_len = typ.length() - 1 + unpacked.min_len if op in ("==", "is"): if min_len <= size: return typ, typ @@ -6367,7 +6367,11 @@ def refine_tuple_type_with_len( if op == "<=": size += 1 if min_len < size: - return typ, typ + prefix = typ.items[:unpack_index] + suffix = typ.items[unpack_index + 1 :] + # TODO: also record max_len to avoid false negatives? + unpack = UnpackType(unpacked.copy_modified(min_len=size - typ.length() + 1)) + return typ, typ.copy_modified(items=prefix + [unpack] + suffix) return None, typ else: yes_type, no_type = self.refine_tuple_type_with_len(typ, neg_ops[op], size) @@ -6375,6 +6379,7 @@ def refine_tuple_type_with_len( # Homogeneous variadic item is the case where we are most flexible. Essentially, # we adjust the variadic item by "eating away" from it to satisfy the restriction. assert isinstance(unpacked, Instance) and unpacked.type.fullname == "builtins.tuple" + min_len = typ.length() - 1 arg = unpacked.args[0] prefix = typ.items[:unpack_index] suffix = typ.items[unpack_index + 1 :] diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index a2141680b6cbb..08aaab908fbaf 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -4180,9 +4180,8 @@ def visit_index_with_type( else: self.chk.fail(message_registry.TUPLE_INDEX_OUT_OF_RANGE, e) if any(isinstance(t, UnpackType) for t in left_type.items): - self.chk.note( - f"Variadic tuple can have length {left_type.length() - 1}", e - ) + min_len = self.min_tuple_length(left_type) + self.chk.note(f"Variadic tuple can have length {min_len}", e) return AnyType(TypeOfAny.from_error) return make_simplified_union(out) else: @@ -4206,6 +4205,16 @@ def visit_index_with_type( e.method_type = method_type return result + def min_tuple_length(self, left: TupleType) -> int: + unpack_index = find_unpack_in_list(left.items) + if unpack_index is None: + return left.length() + unpack = left.items[unpack_index] + assert isinstance(unpack, UnpackType) + if isinstance(unpack.type, TypeVarTupleType): + return left.length() - 1 + unpack.type.min_len + return left.length() - 1 + def visit_tuple_index_helper(self, left: TupleType, n: int) -> Type | None: unpack_index = find_unpack_in_list(left.items) if unpack_index is None: @@ -4219,31 +4228,39 @@ def visit_tuple_index_helper(self, left: TupleType, n: int) -> Type | None: unpacked = get_proper_type(unpack.type) if isinstance(unpacked, TypeVarTupleType): # Usually we say that TypeVarTuple can't be split, be in case of - # indexing it seems benign to just return the fallback item, similar + # indexing it seems benign to just return the upper bound item, similar # to what we do when indexing a regular TypeVar. - middle = unpacked.tuple_fallback.args[0] + bound = get_proper_type(unpacked.upper_bound) + assert isinstance(bound, Instance) + assert bound.type.fullname == "builtins.tuple" + middle = bound.args[0] else: assert isinstance(unpacked, Instance) assert unpacked.type.fullname == "builtins.tuple" middle = unpacked.args[0] + + extra_items = self.min_tuple_length(left) - left.length() + 1 if n >= 0: - if n < unpack_index: - return left.items[n] - if n >= len(left.items) - 1: + if n >= self.min_tuple_length(left): # For tuple[int, *tuple[str, ...], int] we allow either index 0 or 1, # since variadic item may have zero items. return None + if n < unpack_index: + return left.items[n] return UnionType.make_union( - [middle] + left.items[unpack_index + 1 : n + 2], left.line, left.column + [middle] + + left.items[unpack_index + 1 : max(n - extra_items + 2, unpack_index + 1)], + left.line, + left.column, ) - n += len(left.items) - if n <= 0: + n += self.min_tuple_length(left) + if n < 0: # Similar to above, we only allow -1, and -2 for tuple[int, *tuple[str, ...], int] return None - if n > unpack_index: - return left.items[n] + if n >= unpack_index + extra_items: + return left.items[n - extra_items + 1] return UnionType.make_union( - left.items[n - 1 : unpack_index] + [middle], left.line, left.column + left.items[min(n, unpack_index) : unpack_index] + [middle], left.line, left.column ) def visit_tuple_slice_helper(self, left_type: TupleType, slic: SliceExpr) -> Type: diff --git a/mypy/meet.py b/mypy/meet.py index 0fa500d32c303..e3645c7b58799 100644 --- a/mypy/meet.py +++ b/mypy/meet.py @@ -221,6 +221,8 @@ def get_possible_variants(typ: Type) -> list[Type]: return [typ.upper_bound] elif isinstance(typ, ParamSpecType): return [typ.upper_bound] + elif isinstance(typ, TypeVarTupleType): + return [typ.upper_bound] elif isinstance(typ, UnionType): return list(typ.items) elif isinstance(typ, Overloaded): @@ -694,8 +696,8 @@ def visit_param_spec(self, t: ParamSpecType) -> ProperType: return self.default(self.s) def visit_type_var_tuple(self, t: TypeVarTupleType) -> ProperType: - if self.s == t: - return self.s + if isinstance(self.s, TypeVarTupleType) and self.s.id == t.id: + return self.s if self.s.min_len > t.min_len else t else: return self.default(self.s) diff --git a/mypy/subtypes.py b/mypy/subtypes.py index 822c4b0ebf327..753ea5687fe11 100644 --- a/mypy/subtypes.py +++ b/mypy/subtypes.py @@ -638,7 +638,7 @@ def visit_param_spec(self, left: ParamSpecType) -> bool: def visit_type_var_tuple(self, left: TypeVarTupleType) -> bool: right = self.right if isinstance(right, TypeVarTupleType) and right.id == left.id: - return True + return left.min_len >= right.min_len return self._is_subtype(left.upper_bound, self.right) def visit_unpack_type(self, left: UnpackType) -> bool: diff --git a/mypy/types.py b/mypy/types.py index 34ea96be25ee3..7e619e4df5ecd 100644 --- a/mypy/types.py +++ b/mypy/types.py @@ -805,6 +805,8 @@ class TypeVarTupleType(TypeVarLikeType): See PEP646 for more information. """ + __slots__ = ("tuple_fallback", "min_len") + def __init__( self, name: str, @@ -816,9 +818,13 @@ def __init__( *, line: int = -1, column: int = -1, + min_len: int = 0, ) -> None: super().__init__(name, fullname, id, upper_bound, default, line=line, column=column) self.tuple_fallback = tuple_fallback + # This value is not settable by a user. It is an internal-only thing to support + # len()-narrowing of variadic tuples. + self.min_len = min_len def serialize(self) -> JsonDict: assert not self.id.is_meta_var() @@ -830,6 +836,7 @@ def serialize(self) -> JsonDict: "upper_bound": self.upper_bound.serialize(), "tuple_fallback": self.tuple_fallback.serialize(), "default": self.default.serialize(), + "min_len": self.min_len, } @classmethod @@ -842,18 +849,19 @@ def deserialize(cls, data: JsonDict) -> TypeVarTupleType: deserialize_type(data["upper_bound"]), Instance.deserialize(data["tuple_fallback"]), deserialize_type(data["default"]), + min_len=data["min_len"], ) def accept(self, visitor: TypeVisitor[T]) -> T: return visitor.visit_type_var_tuple(self) def __hash__(self) -> int: - return hash(self.id) + return hash((self.id, self.min_len)) def __eq__(self, other: object) -> bool: if not isinstance(other, TypeVarTupleType): return NotImplemented - return self.id == other.id + return self.id == other.id and self.min_len == other.min_len def copy_modified( self, @@ -861,6 +869,7 @@ def copy_modified( id: Bogus[TypeVarId | int] = _dummy, upper_bound: Bogus[Type] = _dummy, default: Bogus[Type] = _dummy, + min_len: Bogus[int] = _dummy, **kwargs: Any, ) -> TypeVarTupleType: return TypeVarTupleType( @@ -872,6 +881,7 @@ def copy_modified( self.default if default is _dummy else default, line=self.line, column=self.column, + min_len=self.min_len if min_len is _dummy else min_len, ) diff --git a/test-data/unit/check-narrowing.test b/test-data/unit/check-narrowing.test index 8e6e6f86ac2a7..3896eed8b216e 100644 --- a/test-data/unit/check-narrowing.test +++ b/test-data/unit/check-narrowing.test @@ -1562,6 +1562,9 @@ Ts = TypeVarTuple("Ts") def foo(x: Tuple[int, Unpack[Ts], str]) -> None: if len(x) > 5: reveal_type(x) # N: Revealed type is "Tuple[builtins.int, Unpack[Ts`-1], builtins.str]" + reveal_type(x[5]) # N: Revealed type is "builtins.object" + reveal_type(x[-6]) # N: Revealed type is "builtins.object" + reveal_type(x[-1]) # N: Revealed type is "builtins.str" else: reveal_type(x) # N: Revealed type is "Tuple[builtins.int, Unpack[Ts`-1], builtins.str]" @@ -1569,6 +1572,14 @@ def foo(x: Tuple[int, Unpack[Ts], str]) -> None: reveal_type(x) # N: Revealed type is "Tuple[builtins.int, Unpack[Ts`-1], builtins.str]" else: reveal_type(x) # N: Revealed type is "Tuple[builtins.int, Unpack[Ts`-1], builtins.str]" + x[5] # E: Tuple index out of range \ + # N: Variadic tuple can have length 5 + x[-6] # E: Tuple index out of range \ + # N: Variadic tuple can have length 5 + x[2] # E: Tuple index out of range \ + # N: Variadic tuple can have length 2 + x[-3] # E: Tuple index out of range \ + # N: Variadic tuple can have length 2 [builtins fixtures/len.pyi] [case testNarrowingLenTypeVarTupleUnreachable] diff --git a/test-data/unit/fixtures/len.pyi b/test-data/unit/fixtures/len.pyi index 13fc8829651ac..c725966618589 100644 --- a/test-data/unit/fixtures/len.pyi +++ b/test-data/unit/fixtures/len.pyi @@ -32,6 +32,7 @@ class int: def __gt__(self, n: int) -> bool: pass def __le__(self, n: int) -> bool: pass def __ge__(self, n: int) -> bool: pass + def __neg__(self) -> int: pass class float: pass class bool(int): pass class str(Sequence[str]): pass From 8844f69e61c5d39795eff9ff896a76769f016d11 Mon Sep 17 00:00:00 2001 From: Ivan Levkivskyi Date: Sun, 8 Oct 2023 23:48:28 +0100 Subject: [PATCH 12/18] Fix primer --- mypy/checker.py | 27 ++++++++++++++++++++------- test-data/unit/check-namedtuple.test | 4 ++-- test-data/unit/check-narrowing.test | 8 ++++++++ 3 files changed, 30 insertions(+), 9 deletions(-) diff --git a/mypy/checker.py b/mypy/checker.py index 4ade9676e8f00..14c69ffe54f72 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -5909,6 +5909,15 @@ def has_no_custom_eq_checks(t: Type) -> bool: elif isinstance(node, UnaryExpr) and node.op == "not": left, right = self.find_isinstance_check(node.expr) return right, left + elif ( + literal(node) == LITERAL_TYPE + and self.has_type(node) + and self.can_be_narrowed_with_len(self.lookup_type(node)) + ): + yes_type, no_type = self.narrow_with_len(self.lookup_type(node), ">", 0) + yes_map = None if yes_type is None else {node: yes_type} + no_map = None if no_type is None else {node: no_type} + return yes_map, no_map # Restrict the type of the variable to True-ish/False-ish in the if and else branches # respectively @@ -6277,10 +6286,10 @@ def can_be_narrowed_with_len(self, typ: Type) -> bool: unions involving such types. """ p_typ = get_proper_type(typ) + # TODO: support tuple subclasses as well? if isinstance(p_typ, TupleType): - return True + return p_typ.partial_fallback.type.fullname == "builtins.tuple" if isinstance(p_typ, Instance): - # TODO: support tuple subclasses as well? return p_typ.type.fullname == "builtins.tuple" if isinstance(p_typ, UnionType): return any(self.can_be_narrowed_with_len(t) for t in p_typ.items) @@ -6508,13 +6517,17 @@ def refine_instance_type_with_len( size += 1 if TYPE_VAR_TUPLE in self.options.enable_incomplete_feature: unpack = UnpackType(self.named_generic_type("builtins.tuple", [arg])) - no_type: Type | None = TupleType(items=[arg] * size + [unpack], fallback=typ) + no_type: Type = TupleType(items=[arg] * size + [unpack], fallback=typ) else: no_type = typ - items = [] - for n in range(size): - items.append(TupleType([arg] * n, fallback=typ)) - return UnionType.make_union(items, typ.line, typ.column), no_type + if TYPE_VAR_TUPLE in self.options.enable_incomplete_feature: + items = [] + for n in range(size): + items.append(TupleType([arg] * n, fallback=typ)) + yes_type: Type = UnionType.make_union(items, typ.line, typ.column) + else: + yes_type = typ + return yes_type, no_type else: yes_type, no_type = self.refine_instance_type_with_len(typ, neg_ops[op], size) return no_type, yes_type diff --git a/test-data/unit/check-namedtuple.test b/test-data/unit/check-namedtuple.test index 6e3628060617d..c839fd6caa1c1 100644 --- a/test-data/unit/check-namedtuple.test +++ b/test-data/unit/check-namedtuple.test @@ -878,7 +878,7 @@ reveal_type(Child.class_method()) # N: Revealed type is "Tuple[builtins.str, fa [builtins fixtures/classmethod.pyi] [case testNamedTupleAsConditionalStrictOptionalDisabled] -# flags: --no-strict-optional +# flags: --no-strict-optional --warn-unreachable from typing import NamedTuple class C(NamedTuple): @@ -890,7 +890,7 @@ if not a: 1() # E: "int" not callable b = (1, 2) if not b: - ''() # E: "str" not callable + ''() # E: Statement is unreachable [builtins fixtures/tuple.pyi] [case testNamedTupleDoubleForward] diff --git a/test-data/unit/check-narrowing.test b/test-data/unit/check-narrowing.test index 3896eed8b216e..9437011ad24c4 100644 --- a/test-data/unit/check-narrowing.test +++ b/test-data/unit/check-narrowing.test @@ -1672,3 +1672,11 @@ def bar(x: Tuple[int, Unpack[Tuple[float, ...]], str]) -> None: else: reveal_type(x) # N: Revealed type is "Tuple[builtins.int, Unpack[builtins.tuple[builtins.float, ...]], builtins.str]" [builtins fixtures/len.pyi] + +[case testNarrowingLenBareExpression] +from typing import Tuple + +x: Tuple[int, ...] +assert x +reveal_type(x) # N: Revealed type is "Tuple[builtins.int, Unpack[builtins.tuple[builtins.int, ...]]]" +[builtins fixtures/len.pyi] From 12052ea65e0ccc07653b49d9d6551654cf9ea324 Mon Sep 17 00:00:00 2001 From: Ivan Levkivskyi Date: Mon, 9 Oct 2023 00:39:13 +0100 Subject: [PATCH 13/18] Fix the fix --- mypy/checker.py | 19 ++++++++++++++----- mypy/types.py | 13 ++++++++++++- test-data/unit/check-narrowing.test | 10 ++++++++++ 3 files changed, 36 insertions(+), 6 deletions(-) diff --git a/mypy/checker.py b/mypy/checker.py index 14c69ffe54f72..8d22205060405 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -5914,10 +5914,19 @@ def has_no_custom_eq_checks(t: Type) -> bool: and self.has_type(node) and self.can_be_narrowed_with_len(self.lookup_type(node)) ): + # Combine a `len(x) > 0` check with the default logic below. yes_type, no_type = self.narrow_with_len(self.lookup_type(node), ">", 0) - yes_map = None if yes_type is None else {node: yes_type} - no_map = None if no_type is None else {node: no_type} - return yes_map, no_map + if yes_type is not None: + yes_type = true_only(yes_type) + else: + yes_type = UninhabitedType() + if no_type is not None: + no_type = false_only(no_type) + else: + no_type = UninhabitedType() + if_map = {node: yes_type} if not isinstance(yes_type, UninhabitedType) else None + else_map = {node: no_type} if not isinstance(no_type, UninhabitedType) else None + return if_map, else_map # Restrict the type of the variable to True-ish/False-ish in the if and else branches # respectively @@ -6517,14 +6526,14 @@ def refine_instance_type_with_len( size += 1 if TYPE_VAR_TUPLE in self.options.enable_incomplete_feature: unpack = UnpackType(self.named_generic_type("builtins.tuple", [arg])) - no_type: Type = TupleType(items=[arg] * size + [unpack], fallback=typ) + no_type: Type | None = TupleType(items=[arg] * size + [unpack], fallback=typ) else: no_type = typ if TYPE_VAR_TUPLE in self.options.enable_incomplete_feature: items = [] for n in range(size): items.append(TupleType([arg] * n, fallback=typ)) - yes_type: Type = UnionType.make_union(items, typ.line, typ.column) + yes_type: Type | None = UnionType.make_union(items, typ.line, typ.column) else: yes_type = typ return yes_type, no_type diff --git a/mypy/types.py b/mypy/types.py index 7e619e4df5ecd..8f20c1232677e 100644 --- a/mypy/types.py +++ b/mypy/types.py @@ -2377,7 +2377,18 @@ def can_be_false_default(self) -> bool: # Corner case: it is a `NamedTuple` with `__bool__` method defined. # It can be anything: both `True` and `False`. return True - return self.length() == 0 + if self.length() == 0: + return True + if self.length() > 1: + return False + # Special case tuple[*Ts] may or may not be false. + item = self.items[0] + if not isinstance(item, UnpackType): + return False + if not isinstance(item.type, TypeVarTupleType): + # Non-normalized tuple[int, ...] can be false. + return True + return item.type.min_len == 0 def can_be_any_bool(self) -> bool: return bool( diff --git a/test-data/unit/check-narrowing.test b/test-data/unit/check-narrowing.test index 9437011ad24c4..904044fc758e3 100644 --- a/test-data/unit/check-narrowing.test +++ b/test-data/unit/check-narrowing.test @@ -1680,3 +1680,13 @@ x: Tuple[int, ...] assert x reveal_type(x) # N: Revealed type is "Tuple[builtins.int, Unpack[builtins.tuple[builtins.int, ...]]]" [builtins fixtures/len.pyi] + +[case testNarrowingLenBareExpressionWithNone] +from typing import Tuple, Optional + +x: Optional[Tuple[int, ...]] +if x: + reveal_type(x) # N: Revealed type is "Tuple[builtins.int, Unpack[builtins.tuple[builtins.int, ...]]]" +else: + reveal_type(x) # N: Revealed type is "Union[Tuple[()], None]" +[builtins fixtures/len.pyi] From 9c21e44155e0c84054b1ad86d763dd50b9fadcb6 Mon Sep 17 00:00:00 2001 From: Ivan Levkivskyi Date: Mon, 9 Oct 2023 15:20:05 +0100 Subject: [PATCH 14/18] More consistent handling for Any in narrowing --- mypy/binder.py | 9 +++++ mypy/checker.py | 50 +++++++++-------------- mypy/operators.py | 23 +++++++++++ test-data/unit/check-narrowing.test | 61 +++++++++++++++++++++++++++++ 4 files changed, 111 insertions(+), 32 deletions(-) diff --git a/mypy/binder.py b/mypy/binder.py index 09ee4a92aa40e..3b67d09f16c3b 100644 --- a/mypy/binder.py +++ b/mypy/binder.py @@ -15,6 +15,7 @@ Instance, NoneType, PartialType, + ProperType, TupleType, Type, TypeOfAny, @@ -227,6 +228,14 @@ def update_from_options(self, frames: list[Frame]) -> bool: # still equivalent to such type). if isinstance(type, UnionType): type = collapse_variadic_union(type) + if isinstance(type, ProperType) and isinstance(type, UnionType): + # Simplify away any extra Any's that were added to the declared + # type when popping a frame. + simplified = UnionType.make_union( + [t for t in type.items if not isinstance(get_proper_type(t), AnyType)] + ) + if simplified == self.declarations[key]: + type = simplified if current_value is None or not is_same_type(type, current_value): self._put(key, type) changed = True diff --git a/mypy/checker.py b/mypy/checker.py index 8d22205060405..3d719b6a3e423 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -134,6 +134,7 @@ YieldExpr, is_final_node, ) +from mypy.operators import flip_ops, int_op_to_method, neg_ops from mypy.options import TYPE_VAR_TUPLE, Options from mypy.patterns import AsPattern, StarredPattern from mypy.plugin import CheckerPluginInterface, Plugin @@ -231,29 +232,6 @@ # Maximum length of fixed tuple types inferred when narrowing from variadic tuples. MAX_PRECISE_TUPLE_SIZE: Final = 15 -int_op_to_method: Final = { - "==": int.__eq__, - "is": int.__eq__, - "<": int.__lt__, - "<=": int.__le__, - "!=": int.__ne__, - "is not": int.__ne__, - ">": int.__gt__, - ">=": int.__ge__, -} - -flip_ops: Final = {"<": ">", "<=": ">=", ">": "<", ">=": "<="} -neg_ops: Final = { - "==": "!=", - "!=": "==", - "is": "is not", - "is not": "is", - "<": ">=", - "<=": ">", - ">": "<=", - ">=": "<", -} - DeferredNodeType: _TypeAlias = Union[FuncDef, LambdaExpr, OverloadedFuncDef, Decorator] FineGrainedDeferredNodeType: _TypeAlias = Union[FuncDef, MypyFile, OverloadedFuncDef] @@ -5894,7 +5872,10 @@ def has_no_custom_eq_checks(t: Type) -> bool: # and false if at least one of e1 and e2 is false. return ( and_conditional_maps(left_if_vars, right_if_vars), - or_conditional_maps(left_else_vars, right_else_vars), + # Note that if left else type is Any, we can't add any additional + # types to it, since the right maps were computed assuming + # the left is True, which may be not the case in the else branch. + or_conditional_maps(left_else_vars, right_else_vars, coalesce_any=True), ) elif isinstance(node, OpExpr) and node.op == "or": left_if_vars, left_else_vars = self.find_isinstance_check(node.left) @@ -6351,8 +6332,9 @@ def find_tuple_len_narrowing(self, node: ComparisonExpr) -> list[tuple[TypeMap, last_group.add(left) last_group.add(right) else: - chained.append(("==", list(last_group))) - last_group = set() + if last_group: + chained.append(("==", list(last_group))) + last_group = set() if op in {"is not", "!=", "<", "<=", ">", ">="}: chained.append((op, [left, right])) if last_group: @@ -7500,12 +7482,12 @@ def and_conditional_maps(m1: TypeMap, m2: TypeMap, use_meet: bool = False) -> Ty # Both conditions can be true; combine the information. Anything # we learn from either conditions' truth is valid. If the same # expression's type is refined by both conditions, we somewhat - # arbitrarily give precedence to m2. (In the future, we could use - # an intersection type or meet_types().) + # arbitrarily give precedence to m2 unless m1 value is Any. + # In the future, we could use an intersection type or meet_types(). result = m2.copy() m2_keys = {literal_hash(n2) for n2 in m2} for n1 in m1: - if literal_hash(n1) not in m2_keys: + if literal_hash(n1) not in m2_keys or isinstance(get_proper_type(m1[n1]), AnyType): result[n1] = m1[n1] if use_meet: # For now, meet common keys only if specifically requested. @@ -7518,10 +7500,11 @@ def and_conditional_maps(m1: TypeMap, m2: TypeMap, use_meet: bool = False) -> Ty return result -def or_conditional_maps(m1: TypeMap, m2: TypeMap) -> TypeMap: +def or_conditional_maps(m1: TypeMap, m2: TypeMap, coalesce_any: bool = False) -> TypeMap: """Calculate what information we can learn from the truth of (e1 or e2) in terms of the information that we can learn from the truth of e1 and - the truth of e2. + the truth of e2. If coalesce_any is True, consider Any a supertype when + joining restrictions. """ if m1 is None: @@ -7536,7 +7519,10 @@ def or_conditional_maps(m1: TypeMap, m2: TypeMap) -> TypeMap: for n1 in m1: for n2 in m2: if literal_hash(n1) == literal_hash(n2): - result[n1] = make_simplified_union([m1[n1], m2[n2]]) + if coalesce_any and isinstance(get_proper_type(m1[n1]), AnyType): + result[n1] = m1[n1] + else: + result[n1] = make_simplified_union([m1[n1], m2[n2]]) return result diff --git a/mypy/operators.py b/mypy/operators.py index 07ec5a24fa77c..d1f050b58faeb 100644 --- a/mypy/operators.py +++ b/mypy/operators.py @@ -101,3 +101,26 @@ reverse_op_method_set: Final = set(reverse_op_methods.values()) unary_op_methods: Final = {"-": "__neg__", "+": "__pos__", "~": "__invert__"} + +int_op_to_method: Final = { + "==": int.__eq__, + "is": int.__eq__, + "<": int.__lt__, + "<=": int.__le__, + "!=": int.__ne__, + "is not": int.__ne__, + ">": int.__gt__, + ">=": int.__ge__, +} + +flip_ops: Final = {"<": ">", "<=": ">=", ">": "<", ">=": "<="} +neg_ops: Final = { + "==": "!=", + "!=": "==", + "is": "is not", + "is not": "is", + "<": ">=", + "<=": ">", + ">": "<=", + ">=": "<", +} diff --git a/test-data/unit/check-narrowing.test b/test-data/unit/check-narrowing.test index 904044fc758e3..655fa50be1975 100644 --- a/test-data/unit/check-narrowing.test +++ b/test-data/unit/check-narrowing.test @@ -1335,6 +1335,49 @@ else: reveal_type(some) # N: Revealed type is "Union[builtins.int, __main__.Base]" [builtins fixtures/dict.pyi] +[case testNarrowingWithAnyOps] +from typing import Any + +class C: ... +class D(C): ... +tp: Any + +c: C +if isinstance(c, tp) or isinstance(c, D): + reveal_type(c) # N: Revealed type is "Union[Any, __main__.D]" +else: + reveal_type(c) # N: Revealed type is "__main__.C" +reveal_type(c) # N: Revealed type is "__main__.C" + +c1: C +if isinstance(c1, tp) and isinstance(c1, D): + reveal_type(c1) # N: Revealed type is "Any" +else: + reveal_type(c1) # N: Revealed type is "__main__.C" +reveal_type(c1) # N: Revealed type is "__main__.C" + +c2: C +if isinstance(c2, D) or isinstance(c2, tp): + reveal_type(c2) # N: Revealed type is "Union[__main__.D, Any]" +else: + reveal_type(c2) # N: Revealed type is "__main__.C" +reveal_type(c2) # N: Revealed type is "__main__.C" + +c3: C +if isinstance(c3, D) and isinstance(c3, tp): + reveal_type(c3) # N: Revealed type is "Any" +else: + reveal_type(c3) # N: Revealed type is "__main__.C" +reveal_type(c3) # N: Revealed type is "__main__.C" + +t: Any +if isinstance(t, (list, tuple)) and isinstance(t, tuple): + reveal_type(t) # N: Revealed type is "builtins.tuple[Any, ...]" +else: + reveal_type(t) # N: Revealed type is "Any" +reveal_type(t) # N: Revealed type is "Any" +[builtins fixtures/isinstancelist.pyi] + [case testNarrowingLenItemAndLenCompare] from typing import Any @@ -1690,3 +1733,21 @@ if x: else: reveal_type(x) # N: Revealed type is "Union[Tuple[()], None]" [builtins fixtures/len.pyi] + +[case testNarrowingLenMixWithAny] +from typing import Any + +x: Any +if isinstance(x, (list, tuple)) and len(x) == 0: + reveal_type(x) # N: Revealed type is "Union[Tuple[()], builtins.list[Any]]" +else: + reveal_type(x) # N: Revealed type is "Any" +reveal_type(x) # N: Revealed type is "Any" + +x1: Any +if isinstance(x1, (list, tuple)) and len(x1) > 1: + reveal_type(x1) # N: Revealed type is "Union[Tuple[Any, Any, Unpack[builtins.tuple[Any, ...]]], builtins.list[Any]]" +else: + reveal_type(x1) # N: Revealed type is "Any" +reveal_type(x1) # N: Revealed type is "Any" +[builtins fixtures/len.pyi] From 25396cfab9b7dcb65407987cc4d8561c483b3814 Mon Sep 17 00:00:00 2001 From: Ivan Levkivskyi Date: Wed, 11 Oct 2023 00:30:02 +0100 Subject: [PATCH 15/18] Gate precise tuple types behind a separate feature --- mypy/checker.py | 6 ++-- mypy/options.py | 3 +- mypy/test/testcheck.py | 2 +- mypy_self_check.ini | 1 + test-data/unit/check-narrowing.test | 49 +++++++++++++++++++++++++++-- 5 files changed, 53 insertions(+), 8 deletions(-) diff --git a/mypy/checker.py b/mypy/checker.py index 3d719b6a3e423..e633a0e9a2519 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -135,7 +135,7 @@ is_final_node, ) from mypy.operators import flip_ops, int_op_to_method, neg_ops -from mypy.options import TYPE_VAR_TUPLE, Options +from mypy.options import PRECISE_TUPLE_TYPES, Options from mypy.patterns import AsPattern, StarredPattern from mypy.plugin import CheckerPluginInterface, Plugin from mypy.plugins import dataclasses as dataclasses_plugin @@ -6506,12 +6506,12 @@ def refine_instance_type_with_len( elif op in ("<", "<="): if op == "<=": size += 1 - if TYPE_VAR_TUPLE in self.options.enable_incomplete_feature: + if PRECISE_TUPLE_TYPES in self.options.enable_incomplete_feature: unpack = UnpackType(self.named_generic_type("builtins.tuple", [arg])) no_type: Type | None = TupleType(items=[arg] * size + [unpack], fallback=typ) else: no_type = typ - if TYPE_VAR_TUPLE in self.options.enable_incomplete_feature: + if PRECISE_TUPLE_TYPES in self.options.enable_incomplete_feature: items = [] for n in range(size): items.append(TupleType([arg] * n, fallback=typ)) diff --git a/mypy/options.py b/mypy/options.py index 007ae0a78aa14..04f745d053f3a 100644 --- a/mypy/options.py +++ b/mypy/options.py @@ -72,7 +72,8 @@ class BuildType: # Features that are currently incomplete/experimental TYPE_VAR_TUPLE: Final = "TypeVarTuple" UNPACK: Final = "Unpack" -INCOMPLETE_FEATURES: Final = frozenset((TYPE_VAR_TUPLE, UNPACK)) +PRECISE_TUPLE_TYPES: Final = "PreciseTupleTypes" +INCOMPLETE_FEATURES: Final = frozenset((TYPE_VAR_TUPLE, UNPACK, PRECISE_TUPLE_TYPES)) class Options: diff --git a/mypy/test/testcheck.py b/mypy/test/testcheck.py index 85fbe5dc2990d..591421465a971 100644 --- a/mypy/test/testcheck.py +++ b/mypy/test/testcheck.py @@ -126,7 +126,7 @@ def run_case_once( options = parse_options(original_program_text, testcase, incremental_step) options.use_builtins_fixtures = True if not testcase.name.endswith("_no_incomplete"): - options.enable_incomplete_feature = [TYPE_VAR_TUPLE, UNPACK] + options.enable_incomplete_feature += [TYPE_VAR_TUPLE, UNPACK] options.show_traceback = True # Enable some options automatically based on test file name. diff --git a/mypy_self_check.ini b/mypy_self_check.ini index 6e1ad8187b7a7..093926d4c4155 100644 --- a/mypy_self_check.ini +++ b/mypy_self_check.ini @@ -10,6 +10,7 @@ python_version = 3.8 exclude = mypy/typeshed/|mypyc/test-data/|mypyc/lib-rt/ new_type_inference = True enable_error_code = ignore-without-code,redundant-expr +enable_incomplete_feature = PreciseTupleTypes show_error_code_links = True [mypy-mypy.visitor] diff --git a/test-data/unit/check-narrowing.test b/test-data/unit/check-narrowing.test index 655fa50be1975..a4f805fc7b812 100644 --- a/test-data/unit/check-narrowing.test +++ b/test-data/unit/check-narrowing.test @@ -1509,6 +1509,7 @@ else: [builtins fixtures/len.pyi] [case testNarrowingLenGreaterThanHomogeneousTupleShort] +# flags: --enable-incomplete-feature=PreciseTupleTypes from typing import Tuple VarTuple = Tuple[int, ...] @@ -1534,6 +1535,7 @@ else: [builtins fixtures/len.pyi] [case testNarrowingLenBothSidesHomogeneousTuple] +# flags: --enable-incomplete-feature=PreciseTupleTypes from typing import Tuple x: Tuple[int, ...] @@ -1716,7 +1718,8 @@ def bar(x: Tuple[int, Unpack[Tuple[float, ...]], str]) -> None: reveal_type(x) # N: Revealed type is "Tuple[builtins.int, Unpack[builtins.tuple[builtins.float, ...]], builtins.str]" [builtins fixtures/len.pyi] -[case testNarrowingLenBareExpression] +[case testNarrowingLenBareExpressionPrecise] +# flags: --enable-incomplete-feature=PreciseTupleTypes from typing import Tuple x: Tuple[int, ...] @@ -1724,7 +1727,18 @@ assert x reveal_type(x) # N: Revealed type is "Tuple[builtins.int, Unpack[builtins.tuple[builtins.int, ...]]]" [builtins fixtures/len.pyi] -[case testNarrowingLenBareExpressionWithNone] +[case testNarrowingLenBareExpressionTypeVarTuple] +from typing import Tuple +from typing_extensions import TypeVarTuple, Unpack + +Ts = TypeVarTuple("Ts") +def test(*xs: Unpack[Ts]) -> None: + assert xs + xs[0] # OK +[builtins fixtures/len.pyi] + +[case testNarrowingLenBareExpressionWithNonePrecise] +# flags: --enable-incomplete-feature=PreciseTupleTypes from typing import Tuple, Optional x: Optional[Tuple[int, ...]] @@ -1734,7 +1748,18 @@ else: reveal_type(x) # N: Revealed type is "Union[Tuple[()], None]" [builtins fixtures/len.pyi] -[case testNarrowingLenMixWithAny] +[case testNarrowingLenBareExpressionWithNoneImprecise] +from typing import Tuple, Optional + +x: Optional[Tuple[int, ...]] +if x: + reveal_type(x) # N: Revealed type is "builtins.tuple[builtins.int, ...]" +else: + reveal_type(x) # N: Revealed type is "Union[builtins.tuple[builtins.int, ...], None]" +[builtins fixtures/len.pyi] + +[case testNarrowingLenMixWithAnyPrecise] +# flags: --enable-incomplete-feature=PreciseTupleTypes from typing import Any x: Any @@ -1751,3 +1776,21 @@ else: reveal_type(x1) # N: Revealed type is "Any" reveal_type(x1) # N: Revealed type is "Any" [builtins fixtures/len.pyi] + +[case testNarrowingLenMixWithAnyImprecise] +from typing import Any + +x: Any +if isinstance(x, (list, tuple)) and len(x) == 0: + reveal_type(x) # N: Revealed type is "Union[Tuple[()], builtins.list[Any]]" +else: + reveal_type(x) # N: Revealed type is "Any" +reveal_type(x) # N: Revealed type is "Any" + +x1: Any +if isinstance(x1, (list, tuple)) and len(x1) > 1: + reveal_type(x1) # N: Revealed type is "Union[builtins.tuple[Any, ...], builtins.list[Any]]" +else: + reveal_type(x1) # N: Revealed type is "Any" +reveal_type(x1) # N: Revealed type is "Any" +[builtins fixtures/len.pyi] From d84f38f752bda7643fb8115edf90384eb456e390 Mon Sep 17 00:00:00 2001 From: Ivan Levkivskyi Date: Sat, 21 Oct 2023 11:27:27 +0100 Subject: [PATCH 16/18] Address some feedback --- mypy/checker.py | 1 + test-data/unit/check-narrowing.test | 25 +++++++++++++++++++++++++ 2 files changed, 26 insertions(+) diff --git a/mypy/checker.py b/mypy/checker.py index e633a0e9a2519..6481d2abaab66 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -6343,6 +6343,7 @@ def find_tuple_len_narrowing(self, node: ComparisonExpr) -> list[tuple[TypeMap, # Second step: infer type restrictions from each group found above. type_maps = [] for op, items in chained: + # TODO: support unions of literal types as len() comparison targets. if not any(self.literal_int_expr(it) is not None for it in items): continue if not any(self.is_len_of_tuple(it) for it in items): diff --git a/test-data/unit/check-narrowing.test b/test-data/unit/check-narrowing.test index a4f805fc7b812..5018390b2ca8f 100644 --- a/test-data/unit/check-narrowing.test +++ b/test-data/unit/check-narrowing.test @@ -1523,6 +1523,7 @@ reveal_type(x) # N: Revealed type is "builtins.tuple[builtins.int, ...]" [builtins fixtures/len.pyi] [case testNarrowingLenBiggerThanHomogeneousTupleLong] +# flags: --enable-incomplete-feature=PreciseTupleTypes from typing import Tuple VarTuple = Tuple[int, ...] @@ -1794,3 +1795,27 @@ else: reveal_type(x1) # N: Revealed type is "Any" reveal_type(x1) # N: Revealed type is "Any" [builtins fixtures/len.pyi] + +[case testNarrowingLenExplicitLiteralTypes] +from typing import Tuple, Union +from typing_extensions import Literal + +VarTuple = Union[ + Tuple[int], + Tuple[int, int], + Tuple[int, int, int], +] +x: VarTuple + +supported: Literal[2] +if len(x) == supported: + reveal_type(x) # N: Revealed type is "Tuple[builtins.int, builtins.int]" +else: + reveal_type(x) # N: Revealed type is "Union[Tuple[builtins.int], Tuple[builtins.int, builtins.int, builtins.int]]" + +not_supported_yet: Literal[2, 3] +if len(x) == not_supported_yet: + reveal_type(x) # N: Revealed type is "Union[Tuple[builtins.int], Tuple[builtins.int, builtins.int], Tuple[builtins.int, builtins.int, builtins.int]]" +else: + reveal_type(x) # N: Revealed type is "Union[Tuple[builtins.int], Tuple[builtins.int, builtins.int], Tuple[builtins.int, builtins.int, builtins.int]]" +[builtins fixtures/len.pyi] From ba196fda17774a2e48d1a68bbd78077107d5ca6e Mon Sep 17 00:00:00 2001 From: Ivan Levkivskyi Date: Sat, 21 Oct 2023 12:46:52 +0100 Subject: [PATCH 17/18] More feedback; more corner cases; more tests --- mypy/checker.py | 30 ++++++--- mypy/typeops.py | 2 +- test-data/unit/check-namedtuple.test | 2 +- test-data/unit/check-narrowing.test | 91 ++++++++++++++++++++++++++++ 4 files changed, 115 insertions(+), 10 deletions(-) diff --git a/mypy/checker.py b/mypy/checker.py index 6481d2abaab66..02bab37aa13f1 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -230,7 +230,7 @@ DEFAULT_LAST_PASS: Final = 1 # Pass numbers start at 0 # Maximum length of fixed tuple types inferred when narrowing from variadic tuples. -MAX_PRECISE_TUPLE_SIZE: Final = 15 +MAX_PRECISE_TUPLE_SIZE: Final = 8 DeferredNodeType: _TypeAlias = Union[FuncDef, LambdaExpr, OverloadedFuncDef, Decorator] FineGrainedDeferredNodeType: _TypeAlias = Union[FuncDef, MypyFile, OverloadedFuncDef] @@ -5894,6 +5894,9 @@ def has_no_custom_eq_checks(t: Type) -> bool: literal(node) == LITERAL_TYPE and self.has_type(node) and self.can_be_narrowed_with_len(self.lookup_type(node)) + # Only translate `if x` to `if len(x) > 0` when possible. + and not custom_special_method(self.lookup_type(node), "__bool__") + and self.options.strict_optional ): # Combine a `len(x) > 0` check with the default logic below. yes_type, no_type = self.narrow_with_len(self.lookup_type(node), ">", 0) @@ -6275,12 +6278,18 @@ def can_be_narrowed_with_len(self, typ: Type) -> bool: Currently supported types are TupleTypes, Instances of builtins.tuple, and unions involving such types. """ + if custom_special_method(typ, "__len__"): + # If user overrides builtin behavior, we can't do anything. + return False p_typ = get_proper_type(typ) - # TODO: support tuple subclasses as well? + # Note: we are conservative about tuple subclasses, because some code may rely on + # the fact that tuple_type of fallback TypeInfo matches the original TupleType. if isinstance(p_typ, TupleType): - return p_typ.partial_fallback.type.fullname == "builtins.tuple" + if any(isinstance(t, UnpackType) for t in p_typ.items): + return p_typ.partial_fallback.type.fullname == "builtins.tuple" + return True if isinstance(p_typ, Instance): - return p_typ.type.fullname == "builtins.tuple" + return p_typ.type.has_base("builtins.tuple") if isinstance(p_typ, UnionType): return any(self.can_be_narrowed_with_len(t) for t in p_typ.items) return False @@ -6399,7 +6408,7 @@ def narrow_with_len(self, typ: Type, op: str, size: int) -> tuple[Type | None, T typ = get_proper_type(typ) if isinstance(typ, TupleType): return self.refine_tuple_type_with_len(typ, op, size) - elif isinstance(typ, Instance) and typ.type.fullname == "builtins.tuple": + elif isinstance(typ, Instance): return self.refine_instance_type_with_len(typ, op, size) elif isinstance(typ, UnionType): yes_types = [] @@ -6500,19 +6509,24 @@ def refine_instance_type_with_len( self, typ: Instance, op: str, size: int ) -> tuple[Type | None, Type | None]: """Narrow a homogeneous tuple using length restrictions.""" - arg = typ.args[0] + base = map_instance_to_supertype(typ, self.lookup_typeinfo("builtins.tuple")) + arg = base.args[0] + # Again, we are conservative about subclasses until we gain more confidence. + allow_precise = ( + PRECISE_TUPLE_TYPES in self.options.enable_incomplete_feature + ) and typ.type.fullname == "builtins.tuple" if op in ("==", "is"): # TODO: return fixed union + prefixed variadic tuple for no_type? return TupleType(items=[arg] * size, fallback=typ), typ elif op in ("<", "<="): if op == "<=": size += 1 - if PRECISE_TUPLE_TYPES in self.options.enable_incomplete_feature: + if allow_precise: unpack = UnpackType(self.named_generic_type("builtins.tuple", [arg])) no_type: Type | None = TupleType(items=[arg] * size + [unpack], fallback=typ) else: no_type = typ - if PRECISE_TUPLE_TYPES in self.options.enable_incomplete_feature: + if allow_precise: items = [] for n in range(size): items.append(TupleType([arg] * n, fallback=typ)) diff --git a/mypy/typeops.py b/mypy/typeops.py index 37817933a3970..dff43775fe3d4 100644 --- a/mypy/typeops.py +++ b/mypy/typeops.py @@ -981,7 +981,7 @@ def custom_special_method(typ: Type, name: str, check_all: bool = False) -> bool method = typ.type.get(name) if method and isinstance(method.node, (SYMBOL_FUNCBASE_TYPES, Decorator, Var)): if method.node.info: - return not method.node.info.fullname.startswith("builtins.") + return not method.node.info.fullname.startswith(("builtins.", "typing.")) return False if isinstance(typ, UnionType): if check_all: diff --git a/test-data/unit/check-namedtuple.test b/test-data/unit/check-namedtuple.test index c839fd6caa1c1..9fa098b28dee9 100644 --- a/test-data/unit/check-namedtuple.test +++ b/test-data/unit/check-namedtuple.test @@ -890,7 +890,7 @@ if not a: 1() # E: "int" not callable b = (1, 2) if not b: - ''() # E: Statement is unreachable + ''() # E: "str" not callable [builtins fixtures/tuple.pyi] [case testNamedTupleDoubleForward] diff --git a/test-data/unit/check-narrowing.test b/test-data/unit/check-narrowing.test index 5018390b2ca8f..5b7fadf41c793 100644 --- a/test-data/unit/check-narrowing.test +++ b/test-data/unit/check-narrowing.test @@ -1819,3 +1819,94 @@ if len(x) == not_supported_yet: else: reveal_type(x) # N: Revealed type is "Union[Tuple[builtins.int], Tuple[builtins.int, builtins.int], Tuple[builtins.int, builtins.int, builtins.int]]" [builtins fixtures/len.pyi] + +[case testNarrowingLenUnionOfVariadicTuples] +from typing import Tuple, Union + +x: Union[Tuple[int, ...], Tuple[str, ...]] +if len(x) == 2: + reveal_type(x) # N: Revealed type is "Union[Tuple[builtins.int, builtins.int], Tuple[builtins.str, builtins.str]]" +else: + reveal_type(x) # N: Revealed type is "Union[builtins.tuple[builtins.int, ...], builtins.tuple[builtins.str, ...]]" +[builtins fixtures/len.pyi] + +[case testNarrowingLenUnionOfNamedTuples] +from typing import NamedTuple, Union + +class Point2D(NamedTuple): + x: int + y: int +class Point3D(NamedTuple): + x: int + y: int + z: int + +x: Union[Point2D, Point3D] +if len(x) == 2: + reveal_type(x) # N: Revealed type is "Tuple[builtins.int, builtins.int, fallback=__main__.Point2D]" +else: + reveal_type(x) # N: Revealed type is "Tuple[builtins.int, builtins.int, builtins.int, fallback=__main__.Point3D]" +[builtins fixtures/len.pyi] + +[case testNarrowingLenTupleSubclass] +from typing import Tuple + +class Ints(Tuple[int, ...]): + size: int + +x: Ints +if len(x) == 2: + reveal_type(x) # N: Revealed type is "Tuple[builtins.int, builtins.int, fallback=__main__.Ints]" + reveal_type(x.size) # N: Revealed type is "builtins.int" +else: + reveal_type(x) # N: Revealed type is "__main__.Ints" + reveal_type(x.size) # N: Revealed type is "builtins.int" + +reveal_type(x) # N: Revealed type is "__main__.Ints" +[builtins fixtures/len.pyi] + +[case testNarrowingLenTupleSubclassCustomNotAllowed] +from typing import Tuple + +class Ints(Tuple[int, ...]): + def __len__(self) -> int: + return 0 + +x: Ints +if len(x) > 2: + reveal_type(x) # N: Revealed type is "__main__.Ints" +else: + reveal_type(x) # N: Revealed type is "__main__.Ints" +[builtins fixtures/len.pyi] + +[case testNarrowingLenTupleSubclassPreciseNotAllowed] +# flags: --enable-incomplete-feature=PreciseTupleTypes +from typing import Tuple + +class Ints(Tuple[int, ...]): + size: int + +x: Ints +if len(x) > 2: + reveal_type(x) # N: Revealed type is "__main__.Ints" +else: + reveal_type(x) # N: Revealed type is "__main__.Ints" +[builtins fixtures/len.pyi] + +[case testNarrowingLenUnknownLen] +from typing import Any, Tuple, Union + +x: Union[Tuple[int, int], Tuple[int, int, int]] + +n: int +if len(x) == n: + reveal_type(x) # N: Revealed type is "Union[Tuple[builtins.int, builtins.int], Tuple[builtins.int, builtins.int, builtins.int]]" +else: + reveal_type(x) # N: Revealed type is "Union[Tuple[builtins.int, builtins.int], Tuple[builtins.int, builtins.int, builtins.int]]" + +a: Any +if len(x) == a: + reveal_type(x) # N: Revealed type is "Union[Tuple[builtins.int, builtins.int], Tuple[builtins.int, builtins.int, builtins.int]]" +else: + reveal_type(x) # N: Revealed type is "Union[Tuple[builtins.int, builtins.int], Tuple[builtins.int, builtins.int, builtins.int]]" +[builtins fixtures/len.pyi] From d676b235bdfe4f5c6652dfc50ae85376122f2973 Mon Sep 17 00:00:00 2001 From: Ivan Levkivskyi Date: Sat, 21 Oct 2023 13:46:23 +0100 Subject: [PATCH 18/18] Handle unrelated corner case to avoid new false positives --- mypy/checkexpr.py | 8 ++++++++ test-data/unit/check-expressions.test | 13 +++++++++++++ 2 files changed, 21 insertions(+) diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index a1ed0609134bf..2dc5a93a1de96 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -3643,6 +3643,14 @@ def dangerous_comparison( left = map_instance_to_supertype(left, abstract_set) right = map_instance_to_supertype(right, abstract_set) return self.dangerous_comparison(left.args[0], right.args[0]) + elif left.type.has_base("typing.Mapping") and right.type.has_base("typing.Mapping"): + # Similar to above: Mapping ignores the classes, it just compares items. + abstract_map = self.chk.lookup_typeinfo("typing.Mapping") + left = map_instance_to_supertype(left, abstract_map) + right = map_instance_to_supertype(right, abstract_map) + return self.dangerous_comparison( + left.args[0], right.args[0] + ) or self.dangerous_comparison(left.args[1], right.args[1]) elif left_name in ("builtins.list", "builtins.tuple") and right_name == left_name: return self.dangerous_comparison(left.args[0], right.args[0]) elif left_name in OVERLAPPING_BYTES_ALLOWLIST and right_name in ( diff --git a/test-data/unit/check-expressions.test b/test-data/unit/check-expressions.test index a3c1bc8795f2a..4ac5512580d2c 100644 --- a/test-data/unit/check-expressions.test +++ b/test-data/unit/check-expressions.test @@ -2365,6 +2365,19 @@ b"x" in data [builtins fixtures/primitives.pyi] [typing fixtures/typing-full.pyi] +[case testStrictEqualityWithDifferentMapTypes] +# flags: --strict-equality +from typing import Mapping + +class A(Mapping[int, str]): ... +class B(Mapping[int, str]): ... + +a: A +b: B +assert a == b +[builtins fixtures/dict.pyi] +[typing fixtures/typing-full.pyi] + [case testUnimportedHintAny] def f(x: Any) -> None: # E: Name "Any" is not defined \ # N: Did you forget to import it from "typing"? (Suggestion: "from typing import Any")