Skip to content

Commit

Permalink
Try empty context when assigning to union typed variables (#14151)
Browse files Browse the repository at this point in the history
Fixes #4805
Fixes #13936

It is known that mypy can overuse outer type context sometimes
(especially when it is a union). This prevents a common use case for
narrowing types (see issue and test cases). This is a somewhat major
semantic change, but I think it should match better what a user would
expect.
  • Loading branch information
ilevkivskyi authored Nov 22, 2022
1 parent 3c5f368 commit b83ac9c
Show file tree
Hide file tree
Showing 3 changed files with 102 additions and 1 deletion.
44 changes: 44 additions & 0 deletions mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@
AssignmentStmt,
Block,
BreakStmt,
BytesExpr,
CallExpr,
ClassDef,
ComparisonExpr,
Expand All @@ -86,6 +87,7 @@
EllipsisExpr,
Expression,
ExpressionStmt,
FloatExpr,
ForStmt,
FuncBase,
FuncDef,
Expand Down Expand Up @@ -115,6 +117,7 @@
ReturnStmt,
StarExpr,
Statement,
StrExpr,
SymbolNode,
SymbolTable,
SymbolTableNode,
Expand Down Expand Up @@ -3826,6 +3829,23 @@ def inference_error_fallback_type(self, type: Type) -> Type:
# we therefore need to erase them.
return erase_typevars(fallback)

def simple_rvalue(self, rvalue: Expression) -> bool:
"""Returns True for expressions for which inferred type should not depend on context.
Note that this function can still return False for some expressions where inferred type
does not depend on context. It only exists for performance optimizations.
"""
if isinstance(rvalue, (IntExpr, StrExpr, BytesExpr, FloatExpr, RefExpr)):
return True
if isinstance(rvalue, CallExpr):
if isinstance(rvalue.callee, RefExpr) and isinstance(rvalue.callee.node, FuncBase):
typ = rvalue.callee.node.type
if isinstance(typ, CallableType):
return not typ.variables
elif isinstance(typ, Overloaded):
return not any(item.variables for item in typ.items)
return False

def check_simple_assignment(
self,
lvalue_type: Type | None,
Expand All @@ -3847,6 +3867,30 @@ def check_simple_assignment(
rvalue_type = self.expr_checker.accept(
rvalue, lvalue_type, always_allow_any=always_allow_any
)
if (
isinstance(get_proper_type(lvalue_type), UnionType)
# Skip literal types, as they have special logic (for better errors).
and not isinstance(get_proper_type(rvalue_type), LiteralType)
and not self.simple_rvalue(rvalue)
):
# Try re-inferring r.h.s. in empty context, and use that if it
# results in a narrower type. We don't do this always because this
# may cause some perf impact, plus we want to partially preserve
# the old behavior. This helps with various practical examples, see
# e.g. testOptionalTypeNarrowedByGenericCall.
with self.msg.filter_errors() as local_errors, self.local_type_map() as type_map:
alt_rvalue_type = self.expr_checker.accept(
rvalue, None, always_allow_any=always_allow_any
)
if (
not local_errors.has_new_errors()
# Skip Any type, since it is special cased in binder.
and not isinstance(get_proper_type(alt_rvalue_type), AnyType)
and is_valid_inferred_type(alt_rvalue_type)
and is_proper_subtype(alt_rvalue_type, rvalue_type)
):
rvalue_type = alt_rvalue_type
self.store_types(type_map)
if isinstance(rvalue_type, DeletedType):
self.msg.deleted_as_rvalue(rvalue_type, context)
if isinstance(lvalue_type, DeletedType):
Expand Down
57 changes: 57 additions & 0 deletions test-data/unit/check-inference-context.test
Original file line number Diff line number Diff line change
Expand Up @@ -1419,3 +1419,60 @@ def bar(x: Union[Mapping[Any, Any], Dict[Any, Sequence[Any]]]) -> None:
...
bar({1: 2})
[builtins fixtures/dict.pyi]

[case testOptionalTypeNarrowedByGenericCall]
# flags: --strict-optional
from typing import Dict, Optional

d: Dict[str, str] = {}

def foo(arg: Optional[str] = None) -> None:
if arg is None:
arg = d.get("a", "b")
reveal_type(arg) # N: Revealed type is "builtins.str"
[builtins fixtures/dict.pyi]

[case testOptionalTypeNarrowedByGenericCall2]
# flags: --strict-optional
from typing import Dict, Optional

d: Dict[str, str] = {}
x: Optional[str]
if x:
reveal_type(x) # N: Revealed type is "builtins.str"
x = d.get(x, x)
reveal_type(x) # N: Revealed type is "builtins.str"
[builtins fixtures/dict.pyi]

[case testOptionalTypeNarrowedByGenericCall3]
# flags: --strict-optional
from typing import Generic, TypeVar, Union

T = TypeVar("T")
def bar(arg: Union[str, T]) -> Union[str, T]: ...

def foo(arg: Union[str, int]) -> None:
if isinstance(arg, int):
arg = bar("default")
reveal_type(arg) # N: Revealed type is "builtins.str"
[builtins fixtures/isinstance.pyi]

[case testOptionalTypeNarrowedByGenericCall4]
# flags: --strict-optional
from typing import Optional, List, Generic, TypeVar

T = TypeVar("T", covariant=True)
class C(Generic[T]): ...

x: Optional[C[int]] = None
y = x = C()
reveal_type(y) # N: Revealed type is "__main__.C[builtins.int]"

[case testOptionalTypeNarrowedByGenericCall5]
from typing import Any, Tuple, Union

i: Union[Tuple[Any, ...], int]
b: Any
i = i if isinstance(i, int) else b
reveal_type(i) # N: Revealed type is "Union[Any, builtins.int]"
[builtins fixtures/isinstance.pyi]
2 changes: 1 addition & 1 deletion test-data/unit/check-typeddict.test
Original file line number Diff line number Diff line change
Expand Up @@ -893,7 +893,7 @@ B = TypedDict('B', {'@type': Literal['b-type'], 'b': int})

c: Union[A, B] = {'@type': 'a-type', 'a': 'Test'}
reveal_type(c) # N: Revealed type is "Union[TypedDict('__main__.A', {'@type': Literal['a-type'], 'a': builtins.str}), TypedDict('__main__.B', {'@type': Literal['b-type'], 'b': builtins.int})]"
[builtins fixtures/tuple.pyi]
[builtins fixtures/dict.pyi]

[case testTypedDictUnionAmbiguousCase]
from typing import Union, Mapping, Any, cast
Expand Down

0 comments on commit b83ac9c

Please sign in to comment.