Skip to content

Commit

Permalink
Fix type narrowing of == None and in (None,) conditions (#15760)
Browse files Browse the repository at this point in the history
  • Loading branch information
intgr authored Aug 2, 2023
1 parent 54bc37c commit 2b613e5
Show file tree
Hide file tree
Showing 7 changed files with 48 additions and 16 deletions.
10 changes: 5 additions & 5 deletions mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@
is_literal_type,
is_named_instance,
)
from mypy.types_utils import is_optional, remove_optional, store_argument_type, strip_type
from mypy.types_utils import is_overlapping_none, remove_optional, store_argument_type, strip_type
from mypy.typetraverser import TypeTraverserVisitor
from mypy.typevars import fill_typevars, fill_typevars_with_any, has_no_typevars
from mypy.util import is_dunder, is_sunder, is_typeshed_file
Expand Down Expand Up @@ -5660,13 +5660,13 @@ def has_no_custom_eq_checks(t: Type) -> bool:

if left_index in narrowable_operand_index_to_hash:
# We only try and narrow away 'None' for now
if is_optional(item_type):
if is_overlapping_none(item_type):
collection_item_type = get_proper_type(
builtin_item_type(iterable_type)
)
if (
collection_item_type is not None
and not is_optional(collection_item_type)
and not is_overlapping_none(collection_item_type)
and not (
isinstance(collection_item_type, Instance)
and collection_item_type.type.fullname == "builtins.object"
Expand Down Expand Up @@ -6073,7 +6073,7 @@ def refine_away_none_in_comparison(
non_optional_types = []
for i in chain_indices:
typ = operand_types[i]
if not is_optional(typ):
if not is_overlapping_none(typ):
non_optional_types.append(typ)

# Make sure we have a mixture of optional and non-optional types.
Expand All @@ -6083,7 +6083,7 @@ def refine_away_none_in_comparison(
if_map = {}
for i in narrowable_operand_indices:
expr_type = operand_types[i]
if not is_optional(expr_type):
if not is_overlapping_none(expr_type):
continue
if any(is_overlapping_erased_types(expr_type, t) for t in non_optional_types):
if_map[operands[i]] = remove_optional(expr_type)
Expand Down
9 changes: 7 additions & 2 deletions mypy/checkexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,12 @@
is_named_instance,
split_with_prefix_and_suffix,
)
from mypy.types_utils import is_generic_instance, is_optional, is_self_type_like, remove_optional
from mypy.types_utils import (
is_generic_instance,
is_overlapping_none,
is_self_type_like,
remove_optional,
)
from mypy.typestate import type_state
from mypy.typevars import fill_typevars
from mypy.typevartuples import find_unpack_in_list
Expand Down Expand Up @@ -1809,7 +1814,7 @@ def infer_function_type_arguments_using_context(
# valid results.
erased_ctx = replace_meta_vars(ctx, ErasedType())
ret_type = callable.ret_type
if is_optional(ret_type) and is_optional(ctx):
if is_overlapping_none(ret_type) and is_overlapping_none(ctx):
# If both the context and the return type are optional, unwrap the optional,
# since in 99% cases this is what a user expects. In other words, we replace
# Optional[T] <: Optional[int]
Expand Down
4 changes: 2 additions & 2 deletions mypy/plugins/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
deserialize_type,
get_proper_type,
)
from mypy.types_utils import is_optional
from mypy.types_utils import is_overlapping_none
from mypy.typevars import fill_typevars
from mypy.util import get_unique_redefinition_name

Expand Down Expand Up @@ -141,7 +141,7 @@ def find_shallow_matching_overload_item(overload: Overloaded, call: CallExpr) ->
break
elif (
arg_none
and not is_optional(arg_type)
and not is_overlapping_none(arg_type)
and not (
isinstance(arg_type, Instance)
and arg_type.type.fullname == "builtins.object"
Expand Down
6 changes: 3 additions & 3 deletions mypy/suggestions.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@
UnionType,
get_proper_type,
)
from mypy.types_utils import is_optional, remove_optional
from mypy.types_utils import is_overlapping_none, remove_optional
from mypy.util import split_target


Expand Down Expand Up @@ -752,7 +752,7 @@ def score_type(self, t: Type, arg_pos: bool) -> int:
return 20
if any(has_any_type(x) for x in t.items):
return 15
if not is_optional(t):
if not is_overlapping_none(t):
return 10
if isinstance(t, CallableType) and (has_any_type(t) or is_tricky_callable(t)):
return 10
Expand Down Expand Up @@ -868,7 +868,7 @@ def visit_typeddict_type(self, t: TypedDictType) -> str:
return t.fallback.accept(self)

def visit_union_type(self, t: UnionType) -> str:
if len(t.items) == 2 and is_optional(t):
if len(t.items) == 2 and is_overlapping_none(t):
return f"Optional[{remove_optional(t).accept(self)}]"
else:
return super().visit_union_type(t)
Expand Down
6 changes: 3 additions & 3 deletions mypy/types_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,10 +101,10 @@ def is_generic_instance(tp: Type) -> bool:
return isinstance(tp, Instance) and bool(tp.args)


def is_optional(t: Type) -> bool:
def is_overlapping_none(t: Type) -> bool:
t = get_proper_type(t)
return isinstance(t, UnionType) and any(
isinstance(get_proper_type(e), NoneType) for e in t.items
return isinstance(t, NoneType) or (
isinstance(t, UnionType) and any(isinstance(get_proper_type(e), NoneType) for e in t.items)
)


Expand Down
26 changes: 26 additions & 0 deletions test-data/unit/check-narrowing.test
Original file line number Diff line number Diff line change
Expand Up @@ -1263,6 +1263,32 @@ def g() -> None:
[builtins fixtures/dict.pyi]


[case testNarrowingOptionalEqualsNone]
from typing import Optional

class A: ...

val: Optional[A]

if val == None:
reveal_type(val) # N: Revealed type is "Union[__main__.A, None]"
else:
reveal_type(val) # N: Revealed type is "Union[__main__.A, None]"
if val != None:
reveal_type(val) # N: Revealed type is "Union[__main__.A, None]"
else:
reveal_type(val) # N: Revealed type is "Union[__main__.A, None]"

if val in (None,):
reveal_type(val) # N: Revealed type is "Union[__main__.A, None]"
else:
reveal_type(val) # N: Revealed type is "Union[__main__.A, None]"
if val not in (None,):
reveal_type(val) # N: Revealed type is "Union[__main__.A, None]"
else:
reveal_type(val) # N: Revealed type is "Union[__main__.A, None]"
[builtins fixtures/primitives.pyi]

[case testNarrowingWithTupleOfTypes]
from typing import Tuple, Type

Expand Down
3 changes: 2 additions & 1 deletion test-data/unit/fixtures/primitives.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ class memoryview(Sequence[int]):
def __iter__(self) -> Iterator[int]: pass
def __contains__(self, other: object) -> bool: pass
def __getitem__(self, item: int) -> int: pass
class tuple(Generic[T]): pass
class tuple(Generic[T]):
def __contains__(self, other: object) -> bool: pass
class list(Sequence[T]):
def __iter__(self) -> Iterator[T]: pass
def __contains__(self, other: object) -> bool: pass
Expand Down

0 comments on commit 2b613e5

Please sign in to comment.