From fd85e8753eeb6906c2dbfcfb9d627e02abbfe912 Mon Sep 17 00:00:00 2001 From: Jordandev678 <20153053+Jordandev678@users.noreply.github.com> Date: Fri, 7 Jun 2024 15:42:25 +0000 Subject: [PATCH 1/6] Enable narrowing types using the "in" operator Enables the narrowing of variable types when checking a variable is "in" a collection, and the collection type is a subtype of the variable type. --- mypy/checker.py | 16 ++-- test-data/unit/check-narrowing.test | 102 +++++++++++++++++++++++++- test-data/unit/fixtures/narrowing.pyi | 8 +- 3 files changed, 118 insertions(+), 8 deletions(-) diff --git a/mypy/checker.py b/mypy/checker.py index 179ff6e0b4b6..508f39072f0f 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -5946,11 +5946,17 @@ def has_no_custom_eq_checks(t: Type) -> bool: if_map, else_map = {}, {} if left_index in narrowable_operand_index_to_hash: - # We only try and narrow away 'None' for now - if is_overlapping_none(item_type): - collection_item_type = get_proper_type( - builtin_item_type(iterable_type) - ) + collection_item_type = get_proper_type( + builtin_item_type(iterable_type) + ) + # Narrow if the collection is a subtype + if ( + collection_item_type is not None + and is_subtype(collection_item_type, item_type) + ): + if_map[operands[left_index]] = collection_item_type + # Try and narrow away 'None' + elif is_overlapping_none(item_type): if ( collection_item_type is not None and not is_overlapping_none(collection_item_type) diff --git a/test-data/unit/check-narrowing.test b/test-data/unit/check-narrowing.test index 8612df9bc663..f6815b2689bf 100644 --- a/test-data/unit/check-narrowing.test +++ b/test-data/unit/check-narrowing.test @@ -1376,13 +1376,13 @@ else: reveal_type(val) # N: Revealed type is "Union[__main__.A, None]" if val in (None,): - reveal_type(val) # N: Revealed type is "Union[__main__.A, None]" + reveal_type(val) # N: Revealed type is "None" else: reveal_type(val) # N: Revealed type is "Union[__main__.A, None]" if val not in (None,): reveal_type(val) # N: Revealed type is "Union[__main__.A, None]" else: - reveal_type(val) # N: Revealed type is "Union[__main__.A, None]" + reveal_type(val) # N: Revealed type is "None" [builtins fixtures/primitives.pyi] [case testNarrowingWithTupleOfTypes] @@ -2114,3 +2114,101 @@ else: [typing fixtures/typing-medium.pyi] [builtins fixtures/ops.pyi] + + +[case testTypeNarrowingStringInLiteralUnion] +from typing import Literal, Tuple +typ: Tuple[Literal['a', 'b'], ...] = ('a', 'b') +x: str = "hi!" +if x in typ: + reveal_type(x) # N: Revealed type is "Union[Literal['a'], Literal['b']]" +else: + reveal_type(x) # N: Revealed type is "builtins.str" +[builtins fixtures/tuple.pyi] +[typing fixtures/typing-medium.pyi] + +[case testTypeNarrowingStringInLiteralUnionSubset] +from typing import Literal, Tuple +typeAlpha: Tuple[Literal['a', 'b', 'c'], ...] = ('a', 'b') +strIn: str = "b" +strOut: str = "c" +if strIn in typeAlpha: + reveal_type(strIn) # N: Revealed type is "Union[Literal['a'], Literal['b'], Literal['c']]" +else: + reveal_type(strIn) # N: Revealed type is "builtins.str" +if strOut in typeAlpha: + reveal_type(strOut) # N: Revealed type is "Union[Literal['a'], Literal['b'], Literal['c']]" +else: + reveal_type(strOut) # N: Revealed type is "builtins.str" +[builtins fixtures/primitives.pyi] +[typing fixtures/typing-medium.pyi] + +[case testNarrowingStringNotInLiteralUnion] +from typing import Literal, Tuple +typeAlpha: Tuple[Literal['a', 'b', 'c'],...] = ('a', 'b', 'c') +strIn: str = "c" +strOut: str = "d" +if strIn not in typeAlpha: + reveal_type(strIn) # N: Revealed type is "builtins.str" +else: + reveal_type(strIn) # N: Revealed type is "Union[Literal['a'], Literal['b'], Literal['c']]" +if strOut in typeAlpha: + reveal_type(strOut) # N: Revealed type is "Union[Literal['a'], Literal['b'], Literal['c']]" +else: + reveal_type(strOut) # N: Revealed type is "builtins.str" +[builtins fixtures/primitives.pyi] +[typing fixtures/typing-medium.pyi] + +[case testNarrowingStringInLiteralUnionDontExpand] +from typing import Literal, Tuple +typeAlpha: Tuple[Literal['a', 'b', 'c'], ...] = ('a', 'b', 'c') +strIn: Literal['c'] = "c" +reveal_type(strIn) # N: Revealed type is "Literal['c']" +#Check we don't expand a Literal into the Union type +if strIn not in typeAlpha: + reveal_type(strIn) # N: Revealed type is "Literal['c']" +else: + reveal_type(strIn) # N: Revealed type is "Literal['c']" +[builtins fixtures/primitives.pyi] +[typing fixtures/typing-medium.pyi] + +[case testTypeNarrowingStringInMixedUnion] +from typing import Literal, Tuple +typ: Tuple[Literal['a', 'b'], ...] = ('a', 'b') +x: str = "hi!" +if x in typ: + reveal_type(x) # N: Revealed type is "Union[Literal['a'], Literal['b']]" +else: + reveal_type(x) # N: Revealed type is "builtins.str" +[builtins fixtures/tuple.pyi] +[typing fixtures/typing-medium.pyi] + +[case testTypeNarrowingStringInSet] +from typing import Literal, Set +typ: Set[Literal['a', 'b']] = {'a', 'b'} +x: str = "hi!" +if x in typ: + reveal_type(x) # N: Revealed type is "Union[Literal['a'], Literal['b']]" +else: + reveal_type(x) # N: Revealed type is "builtins.str" +if x not in typ: + reveal_type(x) # N: Revealed type is "builtins.str" +else: + reveal_type(x) # N: Revealed type is "Union[Literal['a'], Literal['b']]" +[builtins fixtures/narrowing.pyi] +[typing fixtures/typing-medium.pyi] + +[case testTypeNarrowingStringInList] +from typing import Literal, List +typ: List[Literal['a', 'b']] = ['a', 'b'] +x: str = "hi!" +if x in typ: + reveal_type(x) # N: Revealed type is "Union[Literal['a'], Literal['b']]" +else: + reveal_type(x) # N: Revealed type is "builtins.str" +if x not in typ: + reveal_type(x) # N: Revealed type is "builtins.str" +else: + reveal_type(x) # N: Revealed type is "Union[Literal['a'], Literal['b']]" +[builtins fixtures/narrowing.pyi] +[typing fixtures/typing-medium.pyi] \ No newline at end of file diff --git a/test-data/unit/fixtures/narrowing.pyi b/test-data/unit/fixtures/narrowing.pyi index 89ee011c1c80..3bf603d7d5ed 100644 --- a/test-data/unit/fixtures/narrowing.pyi +++ b/test-data/unit/fixtures/narrowing.pyi @@ -1,5 +1,5 @@ # Builtins stub used in check-narrowing test cases. -from typing import Generic, Sequence, Tuple, Type, TypeVar, Union +from typing import Generic, Sequence, Tuple, Type, TypeVar, Union, Iterable Tco = TypeVar('Tco', covariant=True) @@ -15,6 +15,12 @@ class function: pass class ellipsis: pass class int: pass class str: pass +class float: pass class dict(Generic[KT, VT]): pass def isinstance(x: object, t: Union[Type[object], Tuple[Type[object], ...]]) -> bool: pass + + +class set(Iterable[Tco], Generic[Tco]): + def __init__(self, iterable: Iterable[Tco] = ...) -> None: ... + def __contains__(self, item: object) -> bool: pass \ No newline at end of file From 3783f1d07d253425b4d8090e56656688433a696b Mon Sep 17 00:00:00 2001 From: Jordandev678 <20153053+Jordandev678@users.noreply.github.com> Date: Fri, 7 Jun 2024 19:38:38 +0000 Subject: [PATCH 2/6] Add test case from #3229 --- test-data/unit/check-narrowing.test | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/test-data/unit/check-narrowing.test b/test-data/unit/check-narrowing.test index f6815b2689bf..6be9578db555 100644 --- a/test-data/unit/check-narrowing.test +++ b/test-data/unit/check-narrowing.test @@ -2211,4 +2211,14 @@ if x not in typ: else: reveal_type(x) # N: Revealed type is "Union[Literal['a'], Literal['b']]" [builtins fixtures/narrowing.pyi] +[typing fixtures/typing-medium.pyi] + +[case testTypeNarrowingUnionStringFloat] +from typing import Union +def foobar(foo: Union[str, float]): + if foo in ['a', 'b']: + reveal_type(foo) # N: Revealed type is "builtins.str" + else: + reveal_type(foo) # N: Revealed type is "Union[builtins.str, builtins.float]" +[builtins fixtures/primitives.pyi] [typing fixtures/typing-medium.pyi] \ No newline at end of file From dae707d92357c13fef85af49796d341cfbbaabc8 Mon Sep 17 00:00:00 2001 From: Jordandev678 <20153053+Jordandev678@users.noreply.github.com> Date: Fri, 7 Jun 2024 22:07:22 +0000 Subject: [PATCH 3/6] Add list to narrowing.pyi for testTypeNarrowingStringInList --- test-data/unit/fixtures/narrowing.pyi | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/test-data/unit/fixtures/narrowing.pyi b/test-data/unit/fixtures/narrowing.pyi index 3bf603d7d5ed..24457d2f77a9 100644 --- a/test-data/unit/fixtures/narrowing.pyi +++ b/test-data/unit/fixtures/narrowing.pyi @@ -20,7 +20,8 @@ class dict(Generic[KT, VT]): pass def isinstance(x: object, t: Union[Type[object], Tuple[Type[object], ...]]) -> bool: pass - +class list(Sequence[Tco]): + def __contains__(self, other: object) -> bool: pass class set(Iterable[Tco], Generic[Tco]): def __init__(self, iterable: Iterable[Tco] = ...) -> None: ... def __contains__(self, item: object) -> bool: pass \ No newline at end of file From 0c10969283124465e9c3a50d8f0ef8a68e4502fe Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 7 Jun 2024 22:34:13 +0000 Subject: [PATCH 4/6] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- mypy/checker.py | 9 +++------ test-data/unit/check-narrowing.test | 2 +- test-data/unit/fixtures/narrowing.pyi | 2 +- 3 files changed, 5 insertions(+), 8 deletions(-) diff --git a/mypy/checker.py b/mypy/checker.py index 508f39072f0f..45837ce235bb 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -5946,13 +5946,10 @@ def has_no_custom_eq_checks(t: Type) -> bool: if_map, else_map = {}, {} if left_index in narrowable_operand_index_to_hash: - collection_item_type = get_proper_type( - builtin_item_type(iterable_type) - ) + collection_item_type = get_proper_type(builtin_item_type(iterable_type)) # Narrow if the collection is a subtype - if ( - collection_item_type is not None - and is_subtype(collection_item_type, item_type) + if collection_item_type is not None and is_subtype( + collection_item_type, item_type ): if_map[operands[left_index]] = collection_item_type # Try and narrow away 'None' diff --git a/test-data/unit/check-narrowing.test b/test-data/unit/check-narrowing.test index 6be9578db555..e142fdd5d060 100644 --- a/test-data/unit/check-narrowing.test +++ b/test-data/unit/check-narrowing.test @@ -2221,4 +2221,4 @@ def foobar(foo: Union[str, float]): else: reveal_type(foo) # N: Revealed type is "Union[builtins.str, builtins.float]" [builtins fixtures/primitives.pyi] -[typing fixtures/typing-medium.pyi] \ No newline at end of file +[typing fixtures/typing-medium.pyi] diff --git a/test-data/unit/fixtures/narrowing.pyi b/test-data/unit/fixtures/narrowing.pyi index 24457d2f77a9..a36ac7f29bd2 100644 --- a/test-data/unit/fixtures/narrowing.pyi +++ b/test-data/unit/fixtures/narrowing.pyi @@ -24,4 +24,4 @@ class list(Sequence[Tco]): def __contains__(self, other: object) -> bool: pass class set(Iterable[Tco], Generic[Tco]): def __init__(self, iterable: Iterable[Tco] = ...) -> None: ... - def __contains__(self, item: object) -> bool: pass \ No newline at end of file + def __contains__(self, item: object) -> bool: pass From 800741a38f10c10dda98cac3b3f1de477bcf1cfd Mon Sep 17 00:00:00 2001 From: Jordandev678 <20153053+Jordandev678@users.noreply.github.com> Date: Sat, 8 Jun 2024 09:20:21 +0000 Subject: [PATCH 5/6] Don't update if_map if item_type == collection_item_type --- mypy/checker.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/mypy/checker.py b/mypy/checker.py index 45837ce235bb..a45464509ea3 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -5948,8 +5948,10 @@ def has_no_custom_eq_checks(t: Type) -> bool: if left_index in narrowable_operand_index_to_hash: collection_item_type = get_proper_type(builtin_item_type(iterable_type)) # Narrow if the collection is a subtype - if collection_item_type is not None and is_subtype( - collection_item_type, item_type + if ( + collection_item_type is not None + and collection_item_type != item_type + and is_subtype(collection_item_type, item_type) ): if_map[operands[left_index]] = collection_item_type # Try and narrow away 'None' From 0bb56898805e16a7e00cc483eefa5250f289ef5d Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 8 Jun 2024 11:27:29 +0000 Subject: [PATCH 6/6] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- mypy/checker.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mypy/checker.py b/mypy/checker.py index 376d02a50fe7..ee3bc468dbe8 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -5951,7 +5951,7 @@ def has_no_custom_eq_checks(t: Type) -> bool: collection_item_type = get_proper_type(builtin_item_type(iterable_type)) # Narrow if the collection is a subtype if ( - collection_item_type is not None + collection_item_type is not None and collection_item_type != item_type and is_subtype(collection_item_type, item_type) ):