diff --git a/mypy/checker.py b/mypy/checker.py index 05a5e91552e0..e24beb31c85e 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -2874,6 +2874,39 @@ def remove_optional(typ: Type) -> Type: return typ +def builtin_item_type(tp: Type) -> Optional[Type]: + """Get the item type of a builtin container. + + If 'tp' is not one of the built containers (these includes NamedTuple and TypedDict) + or if the container is not parameterized (like List or List[Any]) + return None. This function is used to narrow optional types in situations like this: + + x: Optional[int] + if x in (1, 2, 3): + x + 42 # OK + + Note: this is only OK for built-in containers, where we know the behavior + of __contains__. + """ + if isinstance(tp, Instance): + if tp.type.fullname() in ['builtins.list', 'builtins.tuple', 'builtins.dict', + 'builtins.set', 'builtins.frozenset']: + if not tp.args: + # TODO: fix tuple in lib-stub/builtins.pyi (it should be generic). + return None + if not isinstance(tp.args[0], AnyType): + return tp.args[0] + elif isinstance(tp, TupleType) and all(not isinstance(it, AnyType) for it in tp.items): + return UnionType.make_simplified_union(tp.items) # this type is not externally visible + elif isinstance(tp, TypedDictType): + # TypedDict always has non-optional string keys. + if tp.fallback.type.fullname() == 'typing.Mapping': + return tp.fallback.args[0] + elif tp.fallback.type.bases[0].type.fullname() == 'typing.Mapping': + return tp.fallback.type.bases[0].args[0] + return None + + def and_conditional_maps(m1: TypeMap, m2: TypeMap) -> TypeMap: """Calculate what information we can learn from the truth of (e1 and e2) in terms of the information that we can learn from the truth of e1 and @@ -3020,6 +3053,20 @@ def find_isinstance_check(node: Expression, optional_expr = node.operands[1] if is_overlapping_types(optional_type, comp_type): return {optional_expr: remove_optional(optional_type)}, {} + elif node.operators in [['in'], ['not in']]: + expr = node.operands[0] + left_type = type_map[expr] + right_type = builtin_item_type(type_map[node.operands[1]]) + right_ok = right_type and (not is_optional(right_type) and + (not isinstance(right_type, Instance) or + right_type.type.fullname() != 'builtins.object')) + if (right_type and right_ok and is_optional(left_type) and + literal(expr) == LITERAL_TYPE and not is_literal_none(expr) and + is_overlapping_types(left_type, right_type)): + if node.operators == ['in']: + return {expr: remove_optional(left_type)}, {} + if node.operators == ['not in']: + return {}, {expr: remove_optional(left_type)} elif isinstance(node, RefExpr): # Restrict the type of the variable to True-ish/False-ish in the if and else branches # respectively diff --git a/test-data/unit/check-isinstance.test b/test-data/unit/check-isinstance.test index 8279e1aeafd4..81a1069d8832 100644 --- a/test-data/unit/check-isinstance.test +++ b/test-data/unit/check-isinstance.test @@ -1757,7 +1757,6 @@ if isinstance(x, str, 1): # E: Too many arguments for "isinstance" reveal_type(x) # E: Revealed type is 'builtins.int' [builtins fixtures/isinstancelist.pyi] - [case testIsinstanceNarrowAny] from typing import Any @@ -1770,3 +1769,209 @@ def narrow_any_to_str_then_reassign_to_int() -> None: reveal_type(v) # E: Revealed type is 'Any' [builtins fixtures/isinstance.pyi] + +[case testNarrowTypeAfterInList] +# flags: --strict-optional +from typing import List, Optional + +x: List[int] +y: Optional[int] + +if y in x: + reveal_type(y) # E: Revealed type is 'builtins.int' +else: + reveal_type(y) # E: Revealed type is 'Union[builtins.int, builtins.None]' +if y not in x: + reveal_type(y) # E: Revealed type is 'Union[builtins.int, builtins.None]' +else: + reveal_type(y) # E: Revealed type is 'builtins.int' +[builtins fixtures/list.pyi] +[out] + +[case testNarrowTypeAfterInListOfOptional] +# flags: --strict-optional +from typing import List, Optional + +x: List[Optional[int]] +y: Optional[int] + +if y not in x: + reveal_type(y) # E: Revealed type is 'Union[builtins.int, builtins.None]' +else: + reveal_type(y) # E: Revealed type is 'Union[builtins.int, builtins.None]' +[builtins fixtures/list.pyi] +[out] + +[case testNarrowTypeAfterInListNonOverlapping] +# flags: --strict-optional +from typing import List, Optional + +x: List[str] +y: Optional[int] + +if y in x: + reveal_type(y) # E: Revealed type is 'Union[builtins.int, builtins.None]' +else: + reveal_type(y) # E: Revealed type is 'Union[builtins.int, builtins.None]' +[builtins fixtures/list.pyi] +[out] + +[case testNarrowTypeAfterInListNested] +# flags: --strict-optional +from typing import List, Optional, Any + +x: Optional[int] +lst: Optional[List[int]] +nested_any: List[List[Any]] + +if lst in nested_any: + reveal_type(lst) # E: Revealed type is 'builtins.list[builtins.int]' +if x in nested_any: + reveal_type(x) # E: Revealed type is 'Union[builtins.int, builtins.None]' +[builtins fixtures/list.pyi] +[out] + +[case testNarrowTypeAfterInTuple] +# flags: --strict-optional +from typing import Optional +class A: pass +class B(A): pass +class C(A): pass + +y: Optional[B] +if y in (B(), C()): + reveal_type(y) # E: Revealed type is '__main__.B' +else: + reveal_type(y) # E: Revealed type is 'Union[__main__.B, builtins.None]' +[builtins fixtures/tuple.pyi] +[out] + +[case testNarrowTypeAfterInNamedTuple] +# flags: --strict-optional +from typing import NamedTuple, Optional +class NT(NamedTuple): + x: int + y: int +nt: NT + +y: Optional[int] +if y not in nt: + reveal_type(y) # E: Revealed type is 'Union[builtins.int, builtins.None]' +else: + reveal_type(y) # E: Revealed type is 'builtins.int' +[builtins fixtures/tuple.pyi] +[out] + +[case testNarrowTypeAfterInDict] +# flags: --strict-optional +from typing import Dict, Optional +x: Dict[str, int] +y: Optional[str] + +if y in x: + reveal_type(y) # E: Revealed type is 'builtins.str' +else: + reveal_type(y) # E: Revealed type is 'Union[builtins.str, builtins.None]' +if y not in x: + reveal_type(y) # E: Revealed type is 'Union[builtins.str, builtins.None]' +else: + reveal_type(y) # E: Revealed type is 'builtins.str' +[builtins fixtures/dict.pyi] +[out] + +[case testNarrowTypeAfterInList_python2] +# flags: --strict-optional +from typing import List, Optional + +x = [] # type: List[int] +y = None # type: Optional[int] + +# TODO: Fix running tests on Python 2: "Iterator[int]" has no attribute "next" +if y in x: # type: ignore + reveal_type(y) # E: Revealed type is 'builtins.int' +else: + reveal_type(y) # E: Revealed type is 'Union[builtins.int, builtins.None]' +if y not in x: # type: ignore + reveal_type(y) # E: Revealed type is 'Union[builtins.int, builtins.None]' +else: + reveal_type(y) # E: Revealed type is 'builtins.int' + +[builtins_py2 fixtures/python2.pyi] +[out] + +[case testNarrowTypeAfterInNoAnyOrObject] +# flags: --strict-optional +from typing import Any, List, Optional +x: List[Any] +z: List[object] + +y: Optional[int] +if y in x: + reveal_type(y) # E: Revealed type is 'Union[builtins.int, builtins.None]' +else: + reveal_type(y) # E: Revealed type is 'Union[builtins.int, builtins.None]' + +if y not in z: + reveal_type(y) # E: Revealed type is 'Union[builtins.int, builtins.None]' +else: + reveal_type(y) # E: Revealed type is 'Union[builtins.int, builtins.None]' +[typing fixtures/typing-full.pyi] +[builtins fixtures/list.pyi] +[out] + +[case testNarrowTypeAfterInUserDefined] +# flags: --strict-optional +from typing import Container, Optional + +class C(Container[int]): + def __contains__(self, item: object) -> bool: + return item is 'surprise' + +y: Optional[int] +# We never trust user defined types +if y in C(): + reveal_type(y) # E: Revealed type is 'Union[builtins.int, builtins.None]' +else: + reveal_type(y) # E: Revealed type is 'Union[builtins.int, builtins.None]' +if y not in C(): + reveal_type(y) # E: Revealed type is 'Union[builtins.int, builtins.None]' +else: + reveal_type(y) # E: Revealed type is 'Union[builtins.int, builtins.None]' +[typing fixtures/typing-full.pyi] +[builtins fixtures/list.pyi] +[out] + +[case testNarrowTypeAfterInSet] +# flags: --strict-optional +from typing import Optional, Set +s: Set[str] + +y: Optional[str] +if y in {'a', 'b', 'c'}: + reveal_type(y) # E: Revealed type is 'builtins.str' +else: + reveal_type(y) # E: Revealed type is 'Union[builtins.str, builtins.None]' +if y not in s: + reveal_type(y) # E: Revealed type is 'Union[builtins.str, builtins.None]' +else: + reveal_type(y) # E: Revealed type is 'builtins.str' +[builtins fixtures/set.pyi] +[out] + +[case testNarrowTypeAfterInTypedDict] +# flags: --strict-optional +from typing import Optional +from mypy_extensions import TypedDict +class TD(TypedDict): + a: int + b: str +td: TD + +def f() -> None: + x: Optional[str] + if x not in td: + return + reveal_type(x) # E: Revealed type is 'builtins.str' +[typing fixtures/typing-full.pyi] +[builtins fixtures/dict.pyi] +[out] diff --git a/test-data/unit/fixtures/dict.pyi b/test-data/unit/fixtures/dict.pyi index cf8b61f9397a..d8fc59f60008 100644 --- a/test-data/unit/fixtures/dict.pyi +++ b/test-data/unit/fixtures/dict.pyi @@ -19,6 +19,7 @@ class dict(Generic[KT, VT]): def __getitem__(self, key: KT) -> VT: pass def __setitem__(self, k: KT, v: VT) -> None: pass def __iter__(self) -> Iterator[KT]: pass + def __contains__(self, item: object) -> bool: pass def update(self, a: Mapping[KT, VT]) -> None: pass @overload def get(self, k: KT) -> Optional[VT]: pass diff --git a/test-data/unit/fixtures/list.pyi b/test-data/unit/fixtures/list.pyi index 7b6d1dbd127b..b6e54577fb91 100644 --- a/test-data/unit/fixtures/list.pyi +++ b/test-data/unit/fixtures/list.pyi @@ -16,6 +16,7 @@ class list(Generic[T]): @overload def __init__(self, x: Iterable[T]) -> None: pass def __iter__(self) -> Iterator[T]: pass + def __contains__(self, item: object) -> bool: pass def __add__(self, x: list[T]) -> list[T]: pass def __mul__(self, x: int) -> list[T]: pass def __getitem__(self, x: int) -> T: pass diff --git a/test-data/unit/fixtures/python2.pyi b/test-data/unit/fixtures/python2.pyi index 61e48be4510e..283ba1895a97 100644 --- a/test-data/unit/fixtures/python2.pyi +++ b/test-data/unit/fixtures/python2.pyi @@ -11,6 +11,7 @@ class function: pass class int: pass class str: pass class unicode: pass +class bool: pass T = TypeVar('T') class list(Iterable[T], Generic[T]): pass diff --git a/test-data/unit/fixtures/set.pyi b/test-data/unit/fixtures/set.pyi index 79d53e832291..9de7bdaa8096 100644 --- a/test-data/unit/fixtures/set.pyi +++ b/test-data/unit/fixtures/set.pyi @@ -13,9 +13,11 @@ class function: pass class int: pass class str: pass +class bool: pass class set(Iterable[T], Generic[T]): def __iter__(self) -> Iterator[T]: pass + def __contains__(self, item: object) -> bool: pass def add(self, x: T) -> None: pass def discard(self, x: T) -> None: pass def update(self, x: Set[T]) -> None: pass diff --git a/test-data/unit/fixtures/tuple.pyi b/test-data/unit/fixtures/tuple.pyi index 4e53d12f76e6..e231900cfa20 100644 --- a/test-data/unit/fixtures/tuple.pyi +++ b/test-data/unit/fixtures/tuple.pyi @@ -12,6 +12,7 @@ class type: def __call__(self, *a) -> object: pass class tuple(Sequence[Tco], Generic[Tco]): def __iter__(self) -> Iterator[Tco]: pass + def __contains__(self, item: object) -> bool: pass def __getitem__(self, x: int) -> Tco: pass def count(self, obj: Any) -> int: pass class function: pass diff --git a/test-data/unit/fixtures/typing-full.pyi b/test-data/unit/fixtures/typing-full.pyi index 62fac70034c0..fb6b1d3e596d 100644 --- a/test-data/unit/fixtures/typing-full.pyi +++ b/test-data/unit/fixtures/typing-full.pyi @@ -126,6 +126,7 @@ class Mapping(Iterable[T], Protocol[T, T_co]): def get(self, k: T, default: Union[T_co, V]) -> Union[T_co, V]: pass def values(self) -> Iterable[T_co]: pass # Approximate return type def __len__(self) -> int: ... + def __contains__(self, arg: object) -> int: pass @runtime class MutableMapping(Mapping[T, U], Protocol):