From fd8301bae1b2f5b6b5f4611e6a41f62dbad93fb1 Mon Sep 17 00:00:00 2001 From: Keenan Gugeler Date: Wed, 2 Mar 2022 12:52:03 -0500 Subject: [PATCH 1/5] infer types from base if not explicitly typed Fixes #12268. If the type of an lvalue is not given at definition time, attempt to infer it from the type of a base member. --- mypy/checker.py | 23 +++++++++++++++++++-- test-data/unit/check-classes.test | 33 +++++++++++++++++++++++++++++++ 2 files changed, 54 insertions(+), 2 deletions(-) diff --git a/mypy/checker.py b/mypy/checker.py index 851f23185f4f..9eb7813fef91 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -3091,8 +3091,13 @@ def check_lvalue(self, lvalue: Lvalue) -> Tuple[Optional[Type], not isinstance(lvalue, NameExpr) or isinstance(lvalue.node, Var) ): if isinstance(lvalue, NameExpr): - inferred = cast(Var, lvalue.node) - assert isinstance(inferred, Var) + # If the attribute is defined in a base, use that type. + base_type = self.try_setting_type_from_baseclass(lvalue) + if base_type: + lvalue_type = base_type + else: + inferred = cast(Var, lvalue.node) + assert isinstance(inferred, Var) else: assert isinstance(lvalue, MemberExpr) self.expr_checker.accept(lvalue.expr) @@ -3135,6 +3140,20 @@ def is_definition(self, s: Lvalue) -> bool: return s.is_inferred_def return False + # Attempt to find a type declaration in a base class + def try_setting_type_from_baseclass(self, lvalue: NameExpr) -> Optional[Type]: + lvalue_node = lvalue.node + if (isinstance(lvalue_node, Var) and lvalue.kind in (MDEF, None) + and len(lvalue_node.info.bases) > 0): # None for Vars defined via self + for base in lvalue_node.info.mro[1:]: + base_type, base_node = self.lvalue_type_from_base(lvalue_node, base) + if base_type: + lvalue_node.type = base_type + self.store_type(lvalue, base_type) + return base_type + + return None + def infer_variable_type(self, name: Var, lvalue: Lvalue, init_type: Type, context: Context) -> None: """Infer the type of initialized variables from initializer type.""" diff --git a/test-data/unit/check-classes.test b/test-data/unit/check-classes.test index 65b0e8d69cb5..4166ec529c3f 100644 --- a/test-data/unit/check-classes.test +++ b/test-data/unit/check-classes.test @@ -4033,6 +4033,39 @@ class B(A): a = 1 [out] +[case testInferBaseTypeInSubclass] +from typing import Union, List + +class A: pass +class B: pass + +class Base: + variable1 : List[Union[A, B]] = [] + variable2 : List[Union[A, B]] = [] + +class Derived(Base): + variable1 = [A()] + variable2 = variable1 +[out] + +[case testInferBaseTypeFailAssignment] +from typing import Union, List + +class A: pass +class B: pass + +class Base: + variable1 : List[A] = [] + variable2 : List[Union[A, B]] = [] + +class Derived(Base): + variable1 = [A()] + variable2 = variable1 +[out] +main:12: error: Incompatible types in assignment (expression has type "List[A]", variable has type "List[Union[A, B]]") +main:12: note: "List" is invariant -- see https://mypy.readthedocs.io/en/stable/common_issues.html#variance +main:12: note: Consider using "Sequence" instead, which is covariant + [case testVariableSubclassAssignMismatch] class A: a = 1 # type: int From 1e610754d78d688f30ca8a378bb98ed3848c778f Mon Sep 17 00:00:00 2001 From: Keenan Gugeler Date: Wed, 2 Mar 2022 19:55:47 -0500 Subject: [PATCH 2/5] move test to check-inference Also change test output for one of the tests. --- mypy/checker.py | 43 +++++++++++++---------------- test-data/unit/check-classes.test | 33 ---------------------- test-data/unit/check-inference.test | 37 +++++++++++++++++++++++-- 3 files changed, 54 insertions(+), 59 deletions(-) diff --git a/mypy/checker.py b/mypy/checker.py index 9eb7813fef91..6c8f15046f82 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -3091,13 +3091,8 @@ def check_lvalue(self, lvalue: Lvalue) -> Tuple[Optional[Type], not isinstance(lvalue, NameExpr) or isinstance(lvalue.node, Var) ): if isinstance(lvalue, NameExpr): - # If the attribute is defined in a base, use that type. - base_type = self.try_setting_type_from_baseclass(lvalue) - if base_type: - lvalue_type = base_type - else: - inferred = cast(Var, lvalue.node) - assert isinstance(inferred, Var) + inferred = cast(Var, lvalue.node) + assert isinstance(inferred, Var) else: assert isinstance(lvalue, MemberExpr) self.expr_checker.accept(lvalue.expr) @@ -3140,20 +3135,6 @@ def is_definition(self, s: Lvalue) -> bool: return s.is_inferred_def return False - # Attempt to find a type declaration in a base class - def try_setting_type_from_baseclass(self, lvalue: NameExpr) -> Optional[Type]: - lvalue_node = lvalue.node - if (isinstance(lvalue_node, Var) and lvalue.kind in (MDEF, None) - and len(lvalue_node.info.bases) > 0): # None for Vars defined via self - for base in lvalue_node.info.mro[1:]: - base_type, base_node = self.lvalue_type_from_base(lvalue_node, base) - if base_type: - lvalue_node.type = base_type - self.store_type(lvalue, base_type) - return base_type - - return None - def infer_variable_type(self, name: Var, lvalue: Lvalue, init_type: Type, context: Context) -> None: """Infer the type of initialized variables from initializer type.""" @@ -3176,11 +3157,15 @@ def infer_variable_type(self, name: Var, lvalue: Lvalue, name.type = AnyType(TypeOfAny.from_error) else: # Infer type of the target. + base_type = self.try_infer_lvalue_var_type_from_bases(name, lvalue) + if base_type: + self.set_inferred_type(name, lvalue, strip_type(base_type)) - # Make the type more general (strip away function names etc.). - init_type = strip_type(init_type) + else: + # Make the type more general (strip away function names etc.). + init_type = strip_type(init_type) - self.set_inferred_type(name, lvalue, init_type) + self.set_inferred_type(name, lvalue, init_type) def infer_partial_type(self, name: Var, lvalue: Lvalue, init_type: Type) -> bool: init_type = get_proper_type(init_type) @@ -3215,6 +3200,16 @@ def infer_partial_type(self, name: Var, lvalue: Lvalue, init_type: Type) -> bool self.partial_types[-1].map[name] = lvalue return True + def try_infer_lvalue_var_type_from_bases(self, name: Var, lvalue: Lvalue) -> Optional[Type]: + if (isinstance(lvalue, RefExpr) and lvalue.kind in (MDEF, None) + and len(name.info.bases) > 0): # None for Vars defined via self + for base in name.info.mro[1:]: + base_type, base_node = self.lvalue_type_from_base(name, base) + if base_type: + return base_type + + return None + def is_valid_defaultdict_partial_value_type(self, t: ProperType) -> bool: """Check if t can be used as the basis for a partial defaultdict value type. diff --git a/test-data/unit/check-classes.test b/test-data/unit/check-classes.test index 4166ec529c3f..65b0e8d69cb5 100644 --- a/test-data/unit/check-classes.test +++ b/test-data/unit/check-classes.test @@ -4033,39 +4033,6 @@ class B(A): a = 1 [out] -[case testInferBaseTypeInSubclass] -from typing import Union, List - -class A: pass -class B: pass - -class Base: - variable1 : List[Union[A, B]] = [] - variable2 : List[Union[A, B]] = [] - -class Derived(Base): - variable1 = [A()] - variable2 = variable1 -[out] - -[case testInferBaseTypeFailAssignment] -from typing import Union, List - -class A: pass -class B: pass - -class Base: - variable1 : List[A] = [] - variable2 : List[Union[A, B]] = [] - -class Derived(Base): - variable1 = [A()] - variable2 = variable1 -[out] -main:12: error: Incompatible types in assignment (expression has type "List[A]", variable has type "List[Union[A, B]]") -main:12: note: "List" is invariant -- see https://mypy.readthedocs.io/en/stable/common_issues.html#variance -main:12: note: Consider using "Sequence" instead, which is covariant - [case testVariableSubclassAssignMismatch] class A: a = 1 # type: int diff --git a/test-data/unit/check-inference.test b/test-data/unit/check-inference.test index 4de6e4a76f92..3b3817e4b6de 100644 --- a/test-data/unit/check-inference.test +++ b/test-data/unit/check-inference.test @@ -2813,8 +2813,8 @@ class C(A): x = ['12'] reveal_type(A.x) # N: Revealed type is "builtins.list[Any]" -reveal_type(B.x) # N: Revealed type is "builtins.list[builtins.int]" -reveal_type(C.x) # N: Revealed type is "builtins.list[builtins.str]" +reveal_type(B.x) # N: Revealed type is "builtins.list[Any]" +reveal_type(C.x) # N: Revealed type is "builtins.list[Any]" [builtins fixtures/list.pyi] @@ -3255,3 +3255,36 @@ reveal_type(x) # N: Revealed type is "builtins.bytes*" if x: reveal_type(x) # N: Revealed type is "builtins.bytes*" [builtins fixtures/dict.pyi] + +[case testInferBaseTypeInSubclass] +from typing import Union, List + +class A: pass +class B: pass + +class Base: + variable1 : List[Union[A, B]] = [] + variable2 : List[Union[A, B]] = [] + +class Derived(Base): + variable1 = [A()] + variable2 = variable1 +[out] + +[case testInferBaseTypeFailAssignment] +from typing import Union, List + +class A: pass +class B: pass + +class Base: + variable1 : List[A] = [] + variable2 : List[Union[A, B]] = [] + +class Derived(Base): + variable1 = [A()] + variable2 = variable1 +[out] +main:12: error: Incompatible types in assignment (expression has type "List[A]", base class "Base" defined the type as "List[Union[A, B]]") +main:12: note: "List" is invariant -- see https://mypy.readthedocs.io/en/stable/common_issues.html#variance +main:12: note: Consider using "Sequence" instead, which is covariant From edf0154e6612a8638dbf74b7f927fcf272ee3c3a Mon Sep 17 00:00:00 2001 From: Keenan Gugeler Date: Wed, 2 Mar 2022 22:01:45 -0500 Subject: [PATCH 3/5] only take base type if it's explicitly given --- mypy/checker.py | 6 +++--- test-data/unit/check-inference.test | 20 ++++++++++++++++++-- 2 files changed, 21 insertions(+), 5 deletions(-) diff --git a/mypy/checker.py b/mypy/checker.py index 6c8f15046f82..b329339bafd1 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -3157,7 +3157,7 @@ def infer_variable_type(self, name: Var, lvalue: Lvalue, name.type = AnyType(TypeOfAny.from_error) else: # Infer type of the target. - base_type = self.try_infer_lvalue_var_type_from_bases(name, lvalue) + base_type = self.try_infer_lvalue_var_type_from_explicit_base_member(name, lvalue) if base_type: self.set_inferred_type(name, lvalue, strip_type(base_type)) @@ -3200,12 +3200,12 @@ def infer_partial_type(self, name: Var, lvalue: Lvalue, init_type: Type) -> bool self.partial_types[-1].map[name] = lvalue return True - def try_infer_lvalue_var_type_from_bases(self, name: Var, lvalue: Lvalue) -> Optional[Type]: + def try_infer_lvalue_var_type_from_explicit_base_member(self, name: Var, lvalue: Lvalue) -> Optional[Type]: if (isinstance(lvalue, RefExpr) and lvalue.kind in (MDEF, None) and len(name.info.bases) > 0): # None for Vars defined via self for base in name.info.mro[1:]: base_type, base_node = self.lvalue_type_from_base(name, base) - if base_type: + if isinstance(base_node, Var) and not base_node.is_inferred and base_type: return base_type return None diff --git a/test-data/unit/check-inference.test b/test-data/unit/check-inference.test index 3b3817e4b6de..ef70fb5c076e 100644 --- a/test-data/unit/check-inference.test +++ b/test-data/unit/check-inference.test @@ -2813,8 +2813,8 @@ class C(A): x = ['12'] reveal_type(A.x) # N: Revealed type is "builtins.list[Any]" -reveal_type(B.x) # N: Revealed type is "builtins.list[Any]" -reveal_type(C.x) # N: Revealed type is "builtins.list[Any]" +reveal_type(B.x) # N: Revealed type is "builtins.list[builtins.int]" +reveal_type(C.x) # N: Revealed type is "builtins.list[builtins.str]" [builtins fixtures/list.pyi] @@ -3269,6 +3269,9 @@ class Base: class Derived(Base): variable1 = [A()] variable2 = variable1 + +reveal_type(Derived.variable1) # N: Revealed type is "builtins.list[Union[__main__.A, __main__.B]]" +reveal_type(Derived.variable2) # N: Revealed type is "builtins.list[Union[__main__.A, __main__.B]]" [out] [case testInferBaseTypeFailAssignment] @@ -3288,3 +3291,16 @@ class Derived(Base): main:12: error: Incompatible types in assignment (expression has type "List[A]", base class "Base" defined the type as "List[Union[A, B]]") main:12: note: "List" is invariant -- see https://mypy.readthedocs.io/en/stable/common_issues.html#variance main:12: note: Consider using "Sequence" instead, which is covariant + +[case testDontInferNonExplicitBase] +class Base: pass +class Derived: pass + +class A: + x = Base() + +class B: + x = Derived() + +reveal_type(B.x) # N: Revealed type is "__main__.Derived" +[out] From 4b1056144ed53c6823008ff1f69e45678aae7905 Mon Sep 17 00:00:00 2001 From: Keenan Gugeler Date: Thu, 3 Mar 2022 00:09:55 -0500 Subject: [PATCH 4/5] fix lint --- mypy/checker.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/mypy/checker.py b/mypy/checker.py index b329339bafd1..ed53ec787846 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -3200,7 +3200,9 @@ def infer_partial_type(self, name: Var, lvalue: Lvalue, init_type: Type) -> bool self.partial_types[-1].map[name] = lvalue return True - def try_infer_lvalue_var_type_from_explicit_base_member(self, name: Var, lvalue: Lvalue) -> Optional[Type]: + def try_infer_lvalue_var_type_from_explicit_base_member( + self, name: Var, lvalue: Lvalue + ) -> Optional[Type]: if (isinstance(lvalue, RefExpr) and lvalue.kind in (MDEF, None) and len(name.info.bases) > 0): # None for Vars defined via self for base in name.info.mro[1:]: From 0ad182792752bcc32b9de23cb94cbb9747ebbac0 Mon Sep 17 00:00:00 2001 From: Keenan Gugeler Date: Thu, 3 Mar 2022 12:28:04 -0500 Subject: [PATCH 5/5] let's not infer the type if the base is `object` --- mypy/checker.py | 7 ++++--- test-data/unit/check-inference.test | 7 +++++++ 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/mypy/checker.py b/mypy/checker.py index ed53ec787846..8e8d7b414b9f 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -3206,9 +3206,10 @@ def try_infer_lvalue_var_type_from_explicit_base_member( if (isinstance(lvalue, RefExpr) and lvalue.kind in (MDEF, None) and len(name.info.bases) > 0): # None for Vars defined via self for base in name.info.mro[1:]: - base_type, base_node = self.lvalue_type_from_base(name, base) - if isinstance(base_node, Var) and not base_node.is_inferred and base_type: - return base_type + if base.fullname != "builtins.object": + base_type, base_node = self.lvalue_type_from_base(name, base) + if isinstance(base_node, Var) and not base_node.is_inferred and base_type: + return base_type return None diff --git a/test-data/unit/check-inference.test b/test-data/unit/check-inference.test index ef70fb5c076e..6c8c26ffbe19 100644 --- a/test-data/unit/check-inference.test +++ b/test-data/unit/check-inference.test @@ -3304,3 +3304,10 @@ class B: reveal_type(B.x) # N: Revealed type is "__main__.Derived" [out] + +[case testDontInferFromObject] +class A: + __doc__ = __doc__ if __doc__ else "No docs" + +reveal_type(A.__doc__) # N: Revealed type is "builtins.str" +[out]