Skip to content

Commit

Permalink
Use supertype context for variable type inference (#13494)
Browse files Browse the repository at this point in the history
  • Loading branch information
ilevkivskyi committed Aug 24, 2022
1 parent 61a9b92 commit 57de8db
Show file tree
Hide file tree
Showing 6 changed files with 146 additions and 7 deletions.
30 changes: 28 additions & 2 deletions mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -1624,6 +1624,8 @@ def check_slots_definition(self, typ: Type, context: Context) -> None:

def check_match_args(self, var: Var, typ: Type, context: Context) -> None:
"""Check that __match_args__ contains literal strings"""
if not self.scope.active_class():
return
typ = get_proper_type(typ)
if not isinstance(typ, TupleType) or not all(
[is_string_literal(item) for item in typ.items]
Expand Down Expand Up @@ -2686,7 +2688,8 @@ def check_assignment(
self.check_indexed_assignment(index_lvalue, rvalue, lvalue)

if inferred:
rvalue_type = self.expr_checker.accept(rvalue)
type_context = self.get_variable_type_context(inferred)
rvalue_type = self.expr_checker.accept(rvalue, type_context=type_context)
if not (
inferred.is_final
or (isinstance(lvalue, NameExpr) and lvalue.name == "__match_args__")
Expand All @@ -2698,6 +2701,27 @@ def check_assignment(
# (type, operator) tuples for augmented assignments supported with partial types
partial_type_augmented_ops: Final = {("builtins.list", "+"), ("builtins.set", "|")}

def get_variable_type_context(self, inferred: Var) -> Type | None:
type_contexts = []
if inferred.info:
for base in inferred.info.mro[1:]:
base_type, base_node = self.lvalue_type_from_base(inferred, base)
if base_type and not (
isinstance(base_node, Var) and base_node.invalid_partial_type
):
type_contexts.append(base_type)
# Use most derived supertype as type context if available.
if not type_contexts:
return None
candidate = type_contexts[0]
for other in type_contexts:
if is_proper_subtype(other, candidate):
candidate = other
elif not is_subtype(candidate, other):
# Multiple incompatible candidates, cannot use any of them as context.
return None
return candidate

def try_infer_partial_generic_type_from_assignment(
self, lvalue: Lvalue, rvalue: Expression, op: str
) -> None:
Expand Down Expand Up @@ -5870,7 +5894,9 @@ def enter_partial_types(
self.msg.need_annotation_for_var(var, context, self.options.python_version)
self.partial_reported.add(var)
if var.type:
var.type = self.fixup_partial_type(var.type)
fixed = self.fixup_partial_type(var.type)
var.invalid_partial_type = fixed != var.type
var.type = fixed

def handle_partial_var_type(
self, typ: PartialType, is_lvalue: bool, node: Var, context: Context
Expand Down
5 changes: 5 additions & 0 deletions mypy/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -939,6 +939,7 @@ def deserialize(cls, data: JsonDict) -> Decorator:
"explicit_self_type",
"is_ready",
"is_inferred",
"invalid_partial_type",
"from_module_getattr",
"has_explicit_value",
"allow_incompatible_override",
Expand Down Expand Up @@ -975,6 +976,7 @@ class Var(SymbolNode):
"from_module_getattr",
"has_explicit_value",
"allow_incompatible_override",
"invalid_partial_type",
)

def __init__(self, name: str, type: mypy.types.Type | None = None) -> None:
Expand Down Expand Up @@ -1024,6 +1026,9 @@ def __init__(self, name: str, type: mypy.types.Type | None = None) -> None:
self.has_explicit_value = False
# If True, subclasses can override this with an incompatible type.
self.allow_incompatible_override = False
# If True, this means we didn't manage to infer full type and fall back to
# something like list[Any]. We may decide to not use such types as context.
self.invalid_partial_type = False

@property
def name(self) -> str:
Expand Down
6 changes: 4 additions & 2 deletions mypy/semanal.py
Original file line number Diff line number Diff line change
Expand Up @@ -3007,7 +3007,10 @@ def process_type_annotation(self, s: AssignmentStmt) -> None:
):
self.fail("All protocol members must have explicitly declared types", s)
# Set the type if the rvalue is a simple literal (even if the above error occurred).
if len(s.lvalues) == 1 and isinstance(s.lvalues[0], RefExpr):
# We skip this step for type scope because it messes up with class attribute
# inference for literal types (also annotated and non-annotated variables at class
# scope are semantically different, so we should not souch statement type).
if len(s.lvalues) == 1 and isinstance(s.lvalues[0], RefExpr) and not self.type:
if s.lvalues[0].is_inferred_def:
s.type = self.analyze_simple_literal_type(s.rvalue, s.is_final_def)
if s.type:
Expand All @@ -3026,7 +3029,6 @@ def is_annotated_protocol_member(self, s: AssignmentStmt) -> bool:

def analyze_simple_literal_type(self, rvalue: Expression, is_final: bool) -> Type | None:
"""Return builtins.int if rvalue is an int literal, etc.
If this is a 'Final' context, we return "Literal[...]" instead."""
if self.options.semantic_analysis_only or self.function_stack:
# Skip this if we're only doing the semantic analysis pass.
Expand Down
6 changes: 3 additions & 3 deletions test-data/unit/check-classes.test
Original file line number Diff line number Diff line change
Expand Up @@ -4317,7 +4317,7 @@ class C(B):
x = object()
[out]
main:4: error: Incompatible types in assignment (expression has type "str", base class "A" defined the type as "int")
main:6: error: Incompatible types in assignment (expression has type "object", base class "B" defined the type as "str")
main:6: error: Incompatible types in assignment (expression has type "object", base class "A" defined the type as "int")

[case testClassOneErrorPerLine]
class A:
Expand All @@ -4327,15 +4327,15 @@ class B(A):
x = 1.0
[out]
main:4: error: Incompatible types in assignment (expression has type "str", base class "A" defined the type as "int")
main:5: error: Incompatible types in assignment (expression has type "str", base class "A" defined the type as "int")
main:5: error: Incompatible types in assignment (expression has type "float", base class "A" defined the type as "int")

[case testClassIgnoreType_RedefinedAttributeAndGrandparentAttributeTypesNotIgnored]
class A:
x = 0
class B(A):
x = '' # type: ignore
class C(B):
x = ''
x = '' # E: Incompatible types in assignment (expression has type "str", base class "A" defined the type as "int")
[out]

[case testClassIgnoreType_RedefinedAttributeTypeIgnoredInChildren]
Expand Down
65 changes: 65 additions & 0 deletions test-data/unit/check-inference.test
Original file line number Diff line number Diff line change
Expand Up @@ -3263,3 +3263,68 @@ from typing import Dict, Iterable, Tuple, Union
def foo(x: Union[Tuple[str, Dict[str, int], str], Iterable[object]]) -> None: ...
foo(("a", {"a": "b"}, "b"))
[builtins fixtures/dict.pyi]

[case testUseSupertypeAsInferenceContext]
# flags: --strict-optional
from typing import List, Optional

class B:
x: List[Optional[int]]

class C(B):
x = [1]

reveal_type(C().x) # N: Revealed type is "builtins.list[Union[builtins.int, None]]"
[builtins fixtures/list.pyi]

[case testUseSupertypeAsInferenceContextInvalidType]
from typing import List
class P:
x: List[int]
class C(P):
x = ['a'] # E: List item 0 has incompatible type "str"; expected "int"
[builtins fixtures/list.pyi]

[case testUseSupertypeAsInferenceContextPartial]
from typing import List

class A:
x: List[str]

class B(A):
x = []

reveal_type(B().x) # N: Revealed type is "builtins.list[builtins.str]"
[builtins fixtures/list.pyi]

[case testUseSupertypeAsInferenceContextPartialError]
class A:
x = ['a', 'b']

class B(A):
x = []
x.append(2) # E: Argument 1 to "append" of "list" has incompatible type "int"; expected "str"
[builtins fixtures/list.pyi]

[case testUseSupertypeAsInferenceContextPartialErrorProperty]
from typing import List

class P:
@property
def x(self) -> List[int]: ...
class C(P):
x = []

C.x.append("no") # E: Argument 1 to "append" of "list" has incompatible type "str"; expected "int"
[builtins fixtures/list.pyi]

[case testUseSupertypeAsInferenceContextConflict]
from typing import List
class P:
x: List[int]
class M:
x: List[str]
class C(P, M):
x = [] # E: Need type annotation for "x" (hint: "x: List[<type>] = ...")
reveal_type(C.x) # N: Revealed type is "builtins.list[Any]"
[builtins fixtures/list.pyi]
41 changes: 41 additions & 0 deletions test-data/unit/check-literal.test
Original file line number Diff line number Diff line change
Expand Up @@ -2918,3 +2918,44 @@ def incorrect_return2() -> Union[Tuple[Literal[True], int], Tuple[Literal[False]
else:
return (bool(), 'oops') # E: Incompatible return value type (got "Tuple[bool, str]", expected "Union[Tuple[Literal[True], int], Tuple[Literal[False], str]]")
[builtins fixtures/bool.pyi]

[case testLiteralSubtypeContext]
from typing_extensions import Literal

class A:
foo: Literal['bar', 'spam']
class B(A):
foo = 'spam'

reveal_type(B().foo) # N: Revealed type is "Literal['spam']"
[builtins fixtures/tuple.pyi]

[case testLiteralSubtypeContextNested]
from typing import List
from typing_extensions import Literal

class A:
foo: List[Literal['bar', 'spam']]
class B(A):
foo = ['spam']

reveal_type(B().foo) # N: Revealed type is "builtins.list[Union[Literal['bar'], Literal['spam']]]"
[builtins fixtures/tuple.pyi]

[case testLiteralSubtypeContextGeneric]
from typing_extensions import Literal
from typing import Generic, List, TypeVar

T = TypeVar("T", bound=str)

class B(Generic[T]):
collection: List[T]
word: T

class C(B[Literal["word"]]):
collection = ["word"]
word = "word"

reveal_type(C().collection) # N: Revealed type is "builtins.list[Literal['word']]"
reveal_type(C().word) # N: Revealed type is "Literal['word']"
[builtins fixtures/tuple.pyi]

0 comments on commit 57de8db

Please sign in to comment.