diff --git a/mypy/checker.py b/mypy/checker.py index 851f23185f4f..8e8d7b414b9f 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -3157,11 +3157,15 @@ def infer_variable_type(self, name: Var, lvalue: Lvalue, name.type = AnyType(TypeOfAny.from_error) else: # Infer type of the target. + base_type = self.try_infer_lvalue_var_type_from_explicit_base_member(name, lvalue) + if base_type: + self.set_inferred_type(name, lvalue, strip_type(base_type)) - # Make the type more general (strip away function names etc.). - init_type = strip_type(init_type) + else: + # Make the type more general (strip away function names etc.). + init_type = strip_type(init_type) - self.set_inferred_type(name, lvalue, init_type) + self.set_inferred_type(name, lvalue, init_type) def infer_partial_type(self, name: Var, lvalue: Lvalue, init_type: Type) -> bool: init_type = get_proper_type(init_type) @@ -3196,6 +3200,19 @@ def infer_partial_type(self, name: Var, lvalue: Lvalue, init_type: Type) -> bool self.partial_types[-1].map[name] = lvalue return True + def try_infer_lvalue_var_type_from_explicit_base_member( + self, name: Var, lvalue: Lvalue + ) -> Optional[Type]: + if (isinstance(lvalue, RefExpr) and lvalue.kind in (MDEF, None) + and len(name.info.bases) > 0): # None for Vars defined via self + for base in name.info.mro[1:]: + if base.fullname != "builtins.object": + base_type, base_node = self.lvalue_type_from_base(name, base) + if isinstance(base_node, Var) and not base_node.is_inferred and base_type: + return base_type + + return None + def is_valid_defaultdict_partial_value_type(self, t: ProperType) -> bool: """Check if t can be used as the basis for a partial defaultdict value type. diff --git a/test-data/unit/check-inference.test b/test-data/unit/check-inference.test index 4de6e4a76f92..6c8c26ffbe19 100644 --- a/test-data/unit/check-inference.test +++ b/test-data/unit/check-inference.test @@ -3255,3 +3255,59 @@ reveal_type(x) # N: Revealed type is "builtins.bytes*" if x: reveal_type(x) # N: Revealed type is "builtins.bytes*" [builtins fixtures/dict.pyi] + +[case testInferBaseTypeInSubclass] +from typing import Union, List + +class A: pass +class B: pass + +class Base: + variable1 : List[Union[A, B]] = [] + variable2 : List[Union[A, B]] = [] + +class Derived(Base): + variable1 = [A()] + variable2 = variable1 + +reveal_type(Derived.variable1) # N: Revealed type is "builtins.list[Union[__main__.A, __main__.B]]" +reveal_type(Derived.variable2) # N: Revealed type is "builtins.list[Union[__main__.A, __main__.B]]" +[out] + +[case testInferBaseTypeFailAssignment] +from typing import Union, List + +class A: pass +class B: pass + +class Base: + variable1 : List[A] = [] + variable2 : List[Union[A, B]] = [] + +class Derived(Base): + variable1 = [A()] + variable2 = variable1 +[out] +main:12: error: Incompatible types in assignment (expression has type "List[A]", base class "Base" defined the type as "List[Union[A, B]]") +main:12: note: "List" is invariant -- see https://mypy.readthedocs.io/en/stable/common_issues.html#variance +main:12: note: Consider using "Sequence" instead, which is covariant + +[case testDontInferNonExplicitBase] +class Base: pass +class Derived: pass + +class A: + x = Base() + +class B: + x = Derived() + +reveal_type(B.x) # N: Revealed type is "__main__.Derived" +[out] + +[case testDontInferFromObject] +class A: + __doc__ = __doc__ if __doc__ else "No docs" + +reveal_type(A.__doc__) # N: Revealed type is "builtins.str" +[out]