diff --git a/mypy/checker.py b/mypy/checker.py index 076f9e3763d9..0498887acc87 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -1624,6 +1624,8 @@ def check_slots_definition(self, typ: Type, context: Context) -> None: def check_match_args(self, var: Var, typ: Type, context: Context) -> None: """Check that __match_args__ contains literal strings""" + if not self.scope.active_class(): + return typ = get_proper_type(typ) if not isinstance(typ, TupleType) or not all( [is_string_literal(item) for item in typ.items] @@ -2686,7 +2688,8 @@ def check_assignment( self.check_indexed_assignment(index_lvalue, rvalue, lvalue) if inferred: - rvalue_type = self.expr_checker.accept(rvalue) + type_context = self.get_variable_type_context(inferred) + rvalue_type = self.expr_checker.accept(rvalue, type_context=type_context) if not ( inferred.is_final or (isinstance(lvalue, NameExpr) and lvalue.name == "__match_args__") @@ -2698,6 +2701,27 @@ def check_assignment( # (type, operator) tuples for augmented assignments supported with partial types partial_type_augmented_ops: Final = {("builtins.list", "+"), ("builtins.set", "|")} + def get_variable_type_context(self, inferred: Var) -> Type | None: + type_contexts = [] + if inferred.info: + for base in inferred.info.mro[1:]: + base_type, base_node = self.lvalue_type_from_base(inferred, base) + if base_type and not ( + isinstance(base_node, Var) and base_node.invalid_partial_type + ): + type_contexts.append(base_type) + # Use most derived supertype as type context if available. + if not type_contexts: + return None + candidate = type_contexts[0] + for other in type_contexts: + if is_proper_subtype(other, candidate): + candidate = other + elif not is_subtype(candidate, other): + # Multiple incompatible candidates, cannot use any of them as context. + return None + return candidate + def try_infer_partial_generic_type_from_assignment( self, lvalue: Lvalue, rvalue: Expression, op: str ) -> None: @@ -5870,7 +5894,9 @@ def enter_partial_types( self.msg.need_annotation_for_var(var, context, self.options.python_version) self.partial_reported.add(var) if var.type: - var.type = self.fixup_partial_type(var.type) + fixed = self.fixup_partial_type(var.type) + var.invalid_partial_type = fixed != var.type + var.type = fixed def handle_partial_var_type( self, typ: PartialType, is_lvalue: bool, node: Var, context: Context diff --git a/mypy/nodes.py b/mypy/nodes.py index 2b32d5f4f25c..4856ce3035e8 100644 --- a/mypy/nodes.py +++ b/mypy/nodes.py @@ -939,6 +939,7 @@ def deserialize(cls, data: JsonDict) -> Decorator: "explicit_self_type", "is_ready", "is_inferred", + "invalid_partial_type", "from_module_getattr", "has_explicit_value", "allow_incompatible_override", @@ -975,6 +976,7 @@ class Var(SymbolNode): "from_module_getattr", "has_explicit_value", "allow_incompatible_override", + "invalid_partial_type", ) def __init__(self, name: str, type: mypy.types.Type | None = None) -> None: @@ -1024,6 +1026,9 @@ def __init__(self, name: str, type: mypy.types.Type | None = None) -> None: self.has_explicit_value = False # If True, subclasses can override this with an incompatible type. self.allow_incompatible_override = False + # If True, this means we didn't manage to infer full type and fall back to + # something like list[Any]. We may decide to not use such types as context. + self.invalid_partial_type = False @property def name(self) -> str: diff --git a/mypy/semanal.py b/mypy/semanal.py index 4f62d3010a3b..2946880b783e 100644 --- a/mypy/semanal.py +++ b/mypy/semanal.py @@ -3007,7 +3007,10 @@ def process_type_annotation(self, s: AssignmentStmt) -> None: ): self.fail("All protocol members must have explicitly declared types", s) # Set the type if the rvalue is a simple literal (even if the above error occurred). - if len(s.lvalues) == 1 and isinstance(s.lvalues[0], RefExpr): + # We skip this step for type scope because it messes up with class attribute + # inference for literal types (also annotated and non-annotated variables at class + # scope are semantically different, so we should not souch statement type). + if len(s.lvalues) == 1 and isinstance(s.lvalues[0], RefExpr) and not self.type: if s.lvalues[0].is_inferred_def: s.type = self.analyze_simple_literal_type(s.rvalue, s.is_final_def) if s.type: @@ -3026,7 +3029,6 @@ def is_annotated_protocol_member(self, s: AssignmentStmt) -> bool: def analyze_simple_literal_type(self, rvalue: Expression, is_final: bool) -> Type | None: """Return builtins.int if rvalue is an int literal, etc. - If this is a 'Final' context, we return "Literal[...]" instead.""" if self.options.semantic_analysis_only or self.function_stack: # Skip this if we're only doing the semantic analysis pass. diff --git a/test-data/unit/check-classes.test b/test-data/unit/check-classes.test index 20a0c4ae80ea..53f4d6280311 100644 --- a/test-data/unit/check-classes.test +++ b/test-data/unit/check-classes.test @@ -4317,7 +4317,7 @@ class C(B): x = object() [out] main:4: error: Incompatible types in assignment (expression has type "str", base class "A" defined the type as "int") -main:6: error: Incompatible types in assignment (expression has type "object", base class "B" defined the type as "str") +main:6: error: Incompatible types in assignment (expression has type "object", base class "A" defined the type as "int") [case testClassOneErrorPerLine] class A: @@ -4327,7 +4327,7 @@ class B(A): x = 1.0 [out] main:4: error: Incompatible types in assignment (expression has type "str", base class "A" defined the type as "int") -main:5: error: Incompatible types in assignment (expression has type "str", base class "A" defined the type as "int") +main:5: error: Incompatible types in assignment (expression has type "float", base class "A" defined the type as "int") [case testClassIgnoreType_RedefinedAttributeAndGrandparentAttributeTypesNotIgnored] class A: @@ -4335,7 +4335,7 @@ class A: class B(A): x = '' # type: ignore class C(B): - x = '' + x = '' # E: Incompatible types in assignment (expression has type "str", base class "A" defined the type as "int") [out] [case testClassIgnoreType_RedefinedAttributeTypeIgnoredInChildren] diff --git a/test-data/unit/check-inference.test b/test-data/unit/check-inference.test index fc6cb6fc456a..ffcd6d8d94dd 100644 --- a/test-data/unit/check-inference.test +++ b/test-data/unit/check-inference.test @@ -3263,3 +3263,68 @@ from typing import Dict, Iterable, Tuple, Union def foo(x: Union[Tuple[str, Dict[str, int], str], Iterable[object]]) -> None: ... foo(("a", {"a": "b"}, "b")) [builtins fixtures/dict.pyi] + +[case testUseSupertypeAsInferenceContext] +# flags: --strict-optional +from typing import List, Optional + +class B: + x: List[Optional[int]] + +class C(B): + x = [1] + +reveal_type(C().x) # N: Revealed type is "builtins.list[Union[builtins.int, None]]" +[builtins fixtures/list.pyi] + +[case testUseSupertypeAsInferenceContextInvalidType] +from typing import List +class P: + x: List[int] +class C(P): + x = ['a'] # E: List item 0 has incompatible type "str"; expected "int" +[builtins fixtures/list.pyi] + +[case testUseSupertypeAsInferenceContextPartial] +from typing import List + +class A: + x: List[str] + +class B(A): + x = [] + +reveal_type(B().x) # N: Revealed type is "builtins.list[builtins.str]" +[builtins fixtures/list.pyi] + +[case testUseSupertypeAsInferenceContextPartialError] +class A: + x = ['a', 'b'] + +class B(A): + x = [] + x.append(2) # E: Argument 1 to "append" of "list" has incompatible type "int"; expected "str" +[builtins fixtures/list.pyi] + +[case testUseSupertypeAsInferenceContextPartialErrorProperty] +from typing import List + +class P: + @property + def x(self) -> List[int]: ... +class C(P): + x = [] + +C.x.append("no") # E: Argument 1 to "append" of "list" has incompatible type "str"; expected "int" +[builtins fixtures/list.pyi] + +[case testUseSupertypeAsInferenceContextConflict] +from typing import List +class P: + x: List[int] +class M: + x: List[str] +class C(P, M): + x = [] # E: Need type annotation for "x" (hint: "x: List[] = ...") +reveal_type(C.x) # N: Revealed type is "builtins.list[Any]" +[builtins fixtures/list.pyi] diff --git a/test-data/unit/check-literal.test b/test-data/unit/check-literal.test index b6eae1da7d84..da8f1570a4f4 100644 --- a/test-data/unit/check-literal.test +++ b/test-data/unit/check-literal.test @@ -2918,3 +2918,44 @@ def incorrect_return2() -> Union[Tuple[Literal[True], int], Tuple[Literal[False] else: return (bool(), 'oops') # E: Incompatible return value type (got "Tuple[bool, str]", expected "Union[Tuple[Literal[True], int], Tuple[Literal[False], str]]") [builtins fixtures/bool.pyi] + +[case testLiteralSubtypeContext] +from typing_extensions import Literal + +class A: + foo: Literal['bar', 'spam'] +class B(A): + foo = 'spam' + +reveal_type(B().foo) # N: Revealed type is "Literal['spam']" +[builtins fixtures/tuple.pyi] + +[case testLiteralSubtypeContextNested] +from typing import List +from typing_extensions import Literal + +class A: + foo: List[Literal['bar', 'spam']] +class B(A): + foo = ['spam'] + +reveal_type(B().foo) # N: Revealed type is "builtins.list[Union[Literal['bar'], Literal['spam']]]" +[builtins fixtures/tuple.pyi] + +[case testLiteralSubtypeContextGeneric] +from typing_extensions import Literal +from typing import Generic, List, TypeVar + +T = TypeVar("T", bound=str) + +class B(Generic[T]): + collection: List[T] + word: T + +class C(B[Literal["word"]]): + collection = ["word"] + word = "word" + +reveal_type(C().collection) # N: Revealed type is "builtins.list[Literal['word']]" +reveal_type(C().word) # N: Revealed type is "Literal['word']" +[builtins fixtures/tuple.pyi]