diff --git a/mypy/exprtotype.py b/mypy/exprtotype.py index f08cac825e4d..348b73c8d46c 100644 --- a/mypy/exprtotype.py +++ b/mypy/exprtotype.py @@ -12,7 +12,7 @@ class TypeTranslationError(Exception): def expr_to_unanalyzed_type(expr: Node) -> Type: - """Translate an expression to the corresonding type. + """Translate an expression to the corresponding type. The result is not semantically analyzed. It can be UnboundType or TypeList. Raise TypeTranslationError if the expression cannot represent a type. @@ -32,7 +32,7 @@ def expr_to_unanalyzed_type(expr: Node) -> Type: if base.args: raise TypeTranslationError() if isinstance(expr.index, TupleExpr): - args = cast(TupleExpr, expr.index).items + args = expr.index.items else: args = [expr.index] base.args = [expr_to_unanalyzed_type(arg) for arg in args] @@ -54,15 +54,15 @@ def expr_to_unanalyzed_type(expr: Node) -> Type: def get_member_expr_fullname(expr: MemberExpr) -> str: - """Return the qualified name represention of a member expression. + """Return the qualified name representation of a member expression. Return a string of form foo.bar, foo.bar.baz, or similar, or None if the argument cannot be represented in this form. """ if isinstance(expr.expr, NameExpr): - initial = cast(NameExpr, expr.expr).name + initial = expr.expr.name elif isinstance(expr.expr, MemberExpr): - initial = get_member_expr_fullname(cast(MemberExpr, expr.expr)) + initial = get_member_expr_fullname(expr.expr) else: return None return '{}.{}'.format(initial, expr.name) diff --git a/mypy/nodes.py b/mypy/nodes.py index a1372ba7bfed..727e7cb4900f 100644 --- a/mypy/nodes.py +++ b/mypy/nodes.py @@ -39,7 +39,7 @@ def get_line(self) -> int: pass # (1) a generic class that uses the type variable as a type argument or # (2) a generic function that refers to the type variable in its signature. UNBOUND_TVAR = 4 # type: int -TVAR = 5 # type: int +BOUND_TVAR = 5 # type: int TYPE_ALIAS = 6 # type: int @@ -53,7 +53,7 @@ def get_line(self) -> int: pass MDEF: 'Mdef', MODULE_REF: 'ModuleRef', UNBOUND_TVAR: 'UnboundTvar', - TVAR: 'Tvar', + BOUND_TVAR: 'Tvar', } diff --git a/mypy/semanal.py b/mypy/semanal.py index bccfab11fe63..d1d051b71d49 100644 --- a/mypy/semanal.py +++ b/mypy/semanal.py @@ -53,7 +53,7 @@ ForStmt, BreakStmt, ContinueStmt, IfStmt, TryStmt, WithStmt, DelStmt, GlobalDecl, SuperExpr, DictExpr, CallExpr, RefExpr, OpExpr, UnaryExpr, SliceExpr, CastExpr, TypeApplication, Context, SymbolTable, - SymbolTableNode, TVAR, UNBOUND_TVAR, ListComprehension, GeneratorExpr, + SymbolTableNode, BOUND_TVAR, UNBOUND_TVAR, ListComprehension, GeneratorExpr, FuncExpr, MDEF, FuncBase, Decorator, SetExpr, UndefinedExpr, TypeVarExpr, StrExpr, PrintStmt, ConditionalExpr, PromoteExpr, ComparisonExpr, StarExpr, ARG_POS, ARG_NAMED, MroError, type_aliases, @@ -118,7 +118,12 @@ class SemanticAnalyzer(NodeVisitor): # TypeInfo of directly enclosing class (or None) type = Undefined(TypeInfo) # Stack of outer classes (the second tuple item contains tvars). - type_stack = Undefined(List[Tuple[TypeInfo, List[SymbolTableNode]]]) + type_stack = Undefined(List[TypeInfo]) + # Type variables that are bound by the directly enclosing class + bound_tvars = Undefined(List[SymbolTableNode]) + # Stack of type varialbes that were bound by outer classess + tvar_stack = Undefined(List[List[SymbolTableNode]]) + # Stack of functions being analyzed function_stack = Undefined(List[FuncItem]) @@ -138,6 +143,8 @@ def __init__(self, lib_path: List[str], errors: Errors, self.imports = set() self.type = None self.type_stack = [] + self.bound_tvars = None + self.tvar_stack = [] self.function_stack = [] self.block_depth = [0] self.loop_depth = 0 @@ -260,7 +267,7 @@ def find_type_variables_in_type( return result def is_defined_type_var(self, tvar: str, context: Node) -> bool: - return self.lookup_qualified(tvar, context).kind == TVAR + return self.lookup_qualified(tvar, context).kind == BOUND_TVAR def visit_overloaded_func_def(self, defn: OverloadedFuncDef) -> None: t = [] # type: List[CallableType] @@ -351,7 +358,7 @@ def add_func_type_variables_to_symbol_table( name = item.name if name in names: self.name_already_defined(name, defn) - node = self.add_type_var(name, -i - 1, defn) + node = self.bind_type_var(name, -i - 1, defn) nodes.append(node) names.add(name) return nodes @@ -362,10 +369,10 @@ def type_var_names(self) -> Set[str]: else: return set(self.type.type_vars) - def add_type_var(self, fullname: str, id: int, + def bind_type_var(self, fullname: str, id: int, context: Context) -> SymbolTableNode: node = self.lookup_qualified(fullname, context) - node.kind = TVAR + node.kind = BOUND_TVAR node.tvar_id = id return node @@ -379,12 +386,17 @@ def check_function_signature(self, fdef: FuncItem) -> None: def visit_class_def(self, defn: ClassDef) -> None: self.clean_up_bases_and_infer_type_variables(defn) self.setup_class_def_analysis(defn) + + self.bind_class_type_vars(defn) + self.analyze_base_classes(defn) self.analyze_metaclass(defn) for decorator in defn.decorators: self.analyze_class_decorator(defn, decorator) + self.enter_class(defn) + self.setup_is_builtinclass(defn) # Analyze class body. @@ -393,14 +405,39 @@ def visit_class_def(self, defn: ClassDef) -> None: self.calculate_abstract_status(defn.info) self.setup_type_promotion(defn) - # Restore analyzer state. + self.leave_class() + self.unbind_class_type_vars() + + def enter_class(self, defn: ClassDef) -> None: + # Remember previous active class + self.type_stack.append(self.type) + self.locals.append(None) # Add class scope + self.block_depth.append(-1) # The class body increments this to 0 + self.type = defn.info + + def leave_class(self) -> None: + """ Restore analyzer state. """ self.block_depth.pop() self.locals.pop() - self.type, tvarnodes = self.type_stack.pop() - disable_typevars(tvarnodes) - if self.type_stack: - # Enable type variables of the enclosing class again. - enable_typevars(self.type_stack[-1][1]) + self.type = self.type_stack.pop() + + def bind_class_type_vars(self, defn: ClassDef) -> None: + """ Unbind type variables of previously active class and bind + the type variables for the active class. + """ + if self.bound_tvars: + disable_typevars(self.bound_tvars) + self.tvar_stack.append(self.bound_tvars) + self.bound_tvars = self.bind_class_type_variables_in_symbol_table(defn.info) + + def unbind_class_type_vars(self) -> None: + """ Unbind the active class' type vars and rebind the + type vars of the previously active class. + """ + disable_typevars(self.bound_tvars) + self.bound_tvars = self.tvar_stack.pop() + if self.bound_tvars: + enable_typevars(self.bound_tvars) def analyze_class_decorator(self, defn: ClassDef, decorator: Node) -> None: decorator.accept(self) @@ -547,15 +584,6 @@ def setup_class_def_analysis(self, defn: ClassDef) -> None: if self.is_func_scope(): kind = LDEF self.add_symbol(defn.name, SymbolTableNode(kind, defn.info), defn) - if self.type_stack: - # Disable type variables of the enclosing class. - disable_typevars(self.type_stack[-1][1]) - tvarnodes = self.add_class_type_variables_to_symbol_table(defn.info) - # Remember previous active class and type vars of *this* class. - self.type_stack.append((self.type, tvarnodes)) - self.locals.append(None) # Add class scope - self.block_depth.append(-1) # The class body increments this to 0 - self.type = defn.info def analyze_base_classes(self, defn: ClassDef) -> None: """Analyze and set up base classes.""" @@ -662,13 +690,13 @@ def named_type_or_none(self, qualified_name: str) -> Instance: def is_instance_type(self, t: Type) -> bool: return isinstance(t, Instance) - def add_class_type_variables_to_symbol_table( + def bind_class_type_variables_in_symbol_table( self, info: TypeInfo) -> List[SymbolTableNode]: vars = info.type_vars nodes = [] # type: List[SymbolTableNode] if vars: for i in range(len(vars)): - node = self.add_type_var(vars[i], i + 1, info) + node = self.bind_type_var(vars[i], i + 1, info) nodes.append(node) return nodes @@ -1410,7 +1438,7 @@ def visit_print_stmt(self, s: PrintStmt) -> None: def visit_name_expr(self, expr: NameExpr) -> None: n = self.lookup(expr.name, expr) if n: - if n.kind == TVAR: + if n.kind == BOUND_TVAR: self.fail("'{}' is a type variable and only valid in type " "context".format(expr.name), expr) else: @@ -2076,14 +2104,14 @@ def find_duplicate(list: List[T]) -> T: def disable_typevars(nodes: List[SymbolTableNode]) -> None: for node in nodes: - assert node.kind in (TVAR, UNBOUND_TVAR) + assert node.kind in (BOUND_TVAR, UNBOUND_TVAR) node.kind = UNBOUND_TVAR def enable_typevars(nodes: List[SymbolTableNode]) -> None: for node in nodes: - assert node.kind in (TVAR, UNBOUND_TVAR) - node.kind = TVAR + assert node.kind in (BOUND_TVAR, UNBOUND_TVAR) + node.kind = BOUND_TVAR def remove_imported_names_from_symtable(names: SymbolTable, diff --git a/mypy/test/data/semanal-classes.test b/mypy/test/data/semanal-classes.test index e3bc8650a904..6de222fc016c 100644 --- a/mypy/test/data/semanal-classes.test +++ b/mypy/test/data/semanal-classes.test @@ -331,6 +331,23 @@ MypyFile:1( ExpressionStmt:3( NameExpr(B [m])))) +[case testClassWithBaseClassWithinClass] +class A: + class B: pass + class C(B): pass +[out] +MypyFile:1( + ClassDef:1( + A + ClassDef:2( + B + PassStmt:2()) + ClassDef:3( + C + BaseType( + B) + PassStmt:3()))) + [case testDeclarationReferenceToNestedClass] def f() -> None: class A: pass diff --git a/mypy/test/data/semanal-errors.test b/mypy/test/data/semanal-errors.test index 09173a89efcd..22d874f5cba4 100644 --- a/mypy/test/data/semanal-errors.test +++ b/mypy/test/data/semanal-errors.test @@ -1379,3 +1379,10 @@ def f() -> None: [out] main: In function "f": main, line 3: Invalid assignment target + +[case testInvalidReferenceToAttributeOfOuterClass] +class A: + class X: pass + class B: + y = X # E: Name 'X' is not defined +[out] diff --git a/mypy/typeanal.py b/mypy/typeanal.py index b6b4c01b7585..33f2ccbb4873 100644 --- a/mypy/typeanal.py +++ b/mypy/typeanal.py @@ -7,7 +7,7 @@ Void, NoneTyp, TypeList, TypeVarDef, TypeVisitor, StarType ) from mypy.nodes import ( - GDEF, TYPE_ALIAS, TypeInfo, Context, SymbolTableNode, TVAR, TypeVarExpr, Var, Node, + GDEF, TYPE_ALIAS, TypeInfo, Context, SymbolTableNode, BOUND_TVAR, TypeVarExpr, Var, Node, IndexExpr, NameExpr, TupleExpr, RefExpr ) from mypy.sametypes import is_same_type @@ -56,7 +56,10 @@ def analyse_type_alias(node: Node, class TypeAnalyser(TypeVisitor[Type]): - """Semantic analyzer for types (semantic analysis pass 2).""" + """Semantic analyzer for types (semantic analysis pass 2). + + Converts unbound types into bound types. + """ def __init__(self, lookup_func: Callable[[str, Context], SymbolTableNode], @@ -70,7 +73,7 @@ def visit_unbound_type(self, t: UnboundType) -> Type: sym = self.lookup(t.name, t) if sym is not None: fullname = sym.node.fullname() - if sym.kind == TVAR: + if sym.kind == BOUND_TVAR: if len(t.args) > 0: self.fail('Type variable "{}" used with arguments'.format( t.name), t)