diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index 9a32dc52fcdb..9e44142302fa 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -21,8 +21,7 @@ DictionaryComprehension, ComplexExpr, EllipsisExpr, StarExpr, AwaitExpr, YieldExpr, YieldFromExpr, TypedDictExpr, PromoteExpr, NewTypeExpr, NamedTupleExpr, TypeVarExpr, TypeAliasExpr, BackquoteExpr, EnumCallExpr, - ARG_POS, ARG_NAMED, ARG_STAR, ARG_STAR2, MODULE_REF, - UNBOUND_TVAR, BOUND_TVAR, LITERAL_TYPE + ARG_POS, ARG_NAMED, ARG_STAR, ARG_STAR2, MODULE_REF, TVAR, LITERAL_TYPE, ) from mypy import nodes import mypy.checker @@ -1623,7 +1622,7 @@ def replace_tvars_any(self, tp: Type) -> Type: sym = self.chk.lookup_qualified(arg.name) except KeyError: pass - if sym and (sym.kind == UNBOUND_TVAR or sym.kind == BOUND_TVAR): + if sym and (sym.kind == TVAR): new_args[i] = AnyType() else: new_args[i] = self.replace_tvars_any(arg) diff --git a/mypy/nodes.py b/mypy/nodes.py index 43adc6d2d62b..0f4ae878cf60 100644 --- a/mypy/nodes.py +++ b/mypy/nodes.py @@ -39,12 +39,12 @@ def get_column(self) -> int: pass GDEF = 1 # type: int MDEF = 2 # type: int MODULE_REF = 3 # type: int -# Type variable declared using TypeVar(...) has kind UNBOUND_TVAR. It's not -# valid as a type. A type variable is valid as a type (kind BOUND_TVAR) within +# Type variable declared using TypeVar(...) has kind TVAR. It's not +# valid as a type unless bound in a TypeVarScope. That happens within: # (1) a generic class that uses the type variable as a type argument or # (2) a generic function that refers to the type variable in its signature. -UNBOUND_TVAR = 4 # type: int -BOUND_TVAR = 5 # type: int +TVAR = 4 # type: int + TYPE_ALIAS = 6 # type: int # Placeholder for a name imported via 'from ... import'. Second phase of # semantic will replace this the actual imported reference. This is @@ -65,8 +65,7 @@ def get_column(self) -> int: pass GDEF: 'Gdef', MDEF: 'Mdef', MODULE_REF: 'ModuleRef', - UNBOUND_TVAR: 'UnboundTvar', - BOUND_TVAR: 'Tvar', + TVAR: 'Tvar', TYPE_ALIAS: 'TypeAlias', UNBOUND_IMPORTED: 'UnboundImported', } @@ -2211,8 +2210,7 @@ class SymbolTableNode: # - LDEF: local definition (of any kind) # - GDEF: global (module-level) definition # - MDEF: class member definition - # - UNBOUND_TVAR: TypeVar(...) definition, not bound - # - TVAR: type variable in a bound scope (generic function / generic clas) + # - TVAR: TypeVar(...) definition # - MODULE_REF: reference to a module # - TYPE_ALIAS: type alias # - UNBOUND_IMPORTED: temporary kind for imported names @@ -2220,8 +2218,6 @@ class SymbolTableNode: # AST node of definition (FuncDef/Var/TypeInfo/Decorator/TypeVarExpr, # or None for a bound type variable). node = None # type: Optional[SymbolNode] - # Type variable definition (for bound type variables only) - tvar_def = None # type: Optional[mypy.types.TypeVarDef] # Module id (e.g. "foo.bar") or None mod_id = '' # If this not None, override the type of the 'node' attribute. @@ -2237,13 +2233,11 @@ class SymbolTableNode: def __init__(self, kind: int, node: Optional[SymbolNode], mod_id: str = None, typ: 'mypy.types.Type' = None, - tvar_def: 'mypy.types.TypeVarDef' = None, module_public: bool = True, normalized: bool = False) -> None: self.kind = kind self.node = node self.type_override = typ self.mod_id = mod_id - self.tvar_def = tvar_def self.module_public = module_public self.normalized = normalized @@ -2287,8 +2281,6 @@ def serialize(self, prefix: str, name: str) -> JsonDict: data = {'.class': 'SymbolTableNode', 'kind': node_kinds[self.kind], } # type: JsonDict - if self.tvar_def: - data['tvar_def'] = self.tvar_def.serialize() if not self.module_public: data['module_public'] = False if self.kind == MODULE_REF: @@ -2323,8 +2315,6 @@ def deserialize(cls, data: JsonDict) -> 'SymbolTableNode': if 'type_override' in data: typ = mypy.types.deserialize_type(data['type_override']) stnode = SymbolTableNode(kind, node, typ=typ) - if 'tvar_def' in data: - stnode.tvar_def = mypy.types.TypeVarDef.deserialize(data['tvar_def']) if 'module_public' in data: stnode.module_public = data['module_public'] return stnode diff --git a/mypy/semanal.py b/mypy/semanal.py index c59448bf62e4..25e6f30f8c22 100644 --- a/mypy/semanal.py +++ b/mypy/semanal.py @@ -45,8 +45,9 @@ from collections import OrderedDict from contextlib import contextmanager + from typing import ( - List, Dict, Set, Tuple, cast, TypeVar, Union, Optional, Callable, Iterator + List, Dict, Set, Tuple, cast, TypeVar, Union, Optional, Callable, Iterator, ) from mypy.nodes import ( @@ -58,7 +59,7 @@ ForStmt, BreakStmt, ContinueStmt, IfStmt, TryStmt, WithStmt, DelStmt, PassStmt, GlobalDecl, SuperExpr, DictExpr, CallExpr, RefExpr, OpExpr, UnaryExpr, SliceExpr, CastExpr, RevealTypeExpr, TypeApplication, Context, SymbolTable, - SymbolTableNode, BOUND_TVAR, UNBOUND_TVAR, ListComprehension, GeneratorExpr, + SymbolTableNode, TVAR, ListComprehension, GeneratorExpr, LambdaExpr, MDEF, FuncBase, Decorator, SetExpr, TypeVarExpr, NewTypeExpr, StrExpr, BytesExpr, PrintStmt, ConditionalExpr, PromoteExpr, ComparisonExpr, StarExpr, ARG_POS, ARG_NAMED, ARG_NAMED_OPT, MroError, type_aliases, @@ -69,6 +70,7 @@ COVARIANT, CONTRAVARIANT, INVARIANT, UNBOUND_IMPORTED, LITERAL_YES, ARG_OPT, nongen_builtins, collections_type_aliases, get_member_expr_fullname, ) +from mypy.tvar_scope import TypeVarScope from mypy.typevars import has_no_typevars, fill_typevars from mypy.visitor import NodeVisitor from mypy.traverser import TraverserVisitor @@ -78,10 +80,12 @@ NoneTyp, CallableType, Overloaded, Instance, Type, TypeVarType, AnyType, FunctionLike, UnboundType, TypeList, TypeVarDef, TypeType, TupleType, UnionType, StarType, EllipsisType, function_type, TypedDictType, + TypeQuery ) from mypy.nodes import implicit_module_attrs from mypy.typeanal import ( TypeAnalyser, TypeAnalyserPass3, analyze_type_alias, no_subscript_builtin_alias, + TypeVariableQuery, TypeVarList, remove_dups, ) from mypy.exprtotype import expr_to_unanalyzed_type, TypeTranslationError from mypy.sametypes import is_same_type @@ -190,15 +194,13 @@ class SemanticAnalyzer(NodeVisitor): type_stack = None # type: List[TypeInfo] # Type variables that are bound by the directly enclosing class bound_tvars = None # type: List[SymbolTableNode] - # Stack of type variables that were bound by outer classess - tvar_stack = None # type: List[List[SymbolTableNode]] + # Type variables bound by the current scope, be it class or function + tvar_scope = None # type: TypeVarScope # Per-module options options = None # type: Options # Stack of functions being analyzed function_stack = None # type: List[FuncItem] - # Stack of next available function type variable ids - next_function_tvar_id_stack = None # type: List[int] # Status of postponing analysis of nested function bodies. By using this we # can have mutually recursive nested functions. Values are FUNCTION_x @@ -227,10 +229,8 @@ def __init__(self, self.imports = set() self.type = None self.type_stack = [] - self.bound_tvars = None - self.tvar_stack = [] + self.tvar_scope = TypeVarScope() self.function_stack = [] - self.next_function_tvar_id_stack = [-1] self.block_depth = [0] self.loop_depth = 0 self.lib_path = lib_path @@ -322,12 +322,15 @@ def file_context(self, file_node: MypyFile, fnam: str, options: Options, del self.options def visit_func_def(self, defn: FuncDef) -> None: + phase_info = self.postpone_nested_functions_stack[-1] if phase_info != FUNCTION_SECOND_PHASE: self.function_stack.append(defn) # First phase of analysis for function. self.errors.push_function(defn.name()) - self.update_function_type_variables(defn) + if defn.type: + assert isinstance(defn.type, CallableType) + self.update_function_type_variables(defn.type, defn) self.errors.pop_function() self.function_stack.pop() @@ -421,72 +424,15 @@ def set_original_def(self, previous: Node, new: FuncDef) -> bool: else: return False - def update_function_type_variables(self, defn: FuncDef) -> None: + def update_function_type_variables(self, fun_type: CallableType, defn: FuncItem) -> None: """Make any type variables in the signature of defn explicit. Update the signature of defn to contain type variable definitions if defn is generic. """ - if defn.type: - assert isinstance(defn.type, CallableType) - typevars = self.infer_type_variables(defn.type) - # Do not define a new type variable if already defined in scope. - typevars = [(name, tvar) for name, tvar in typevars - if not self.is_defined_type_var(name, defn)] - if typevars: - next_tvar_id = self.next_function_tvar_id() - defs = [TypeVarDef(tvar[0], next_tvar_id - i, - tvar[1].values, tvar[1].upper_bound, - tvar[1].variance) - for i, tvar in enumerate(typevars)] - defn.type.variables = defs - - def infer_type_variables(self, - type: CallableType) -> List[Tuple[str, TypeVarExpr]]: - """Return list of unique type variables referred to in a callable.""" - names = [] # type: List[str] - tvars = [] # type: List[TypeVarExpr] - for arg in type.arg_types + [type.ret_type]: - for name, tvar_expr in self.find_type_variables_in_type(arg): - if name not in names: - names.append(name) - tvars.append(tvar_expr) - return list(zip(names, tvars)) - - def find_type_variables_in_type(self, type: Type) -> List[Tuple[str, TypeVarExpr]]: - """Return a list of all unique type variable references in type. - - This effectively does partial name binding, results of which are mostly thrown away. - """ - result = [] # type: List[Tuple[str, TypeVarExpr]] - if isinstance(type, UnboundType): - name = type.name - node = self.lookup_qualified(name, type) - if node and node.kind == UNBOUND_TVAR: - assert isinstance(node.node, TypeVarExpr) - result.append((name, node.node)) - for arg in type.args: - result.extend(self.find_type_variables_in_type(arg)) - elif isinstance(type, TypeList): - for item in type.items: - result.extend(self.find_type_variables_in_type(item)) - elif isinstance(type, UnionType): - for item in type.items: - result.extend(self.find_type_variables_in_type(item)) - elif isinstance(type, AnyType): - pass - elif isinstance(type, (EllipsisType, TupleType)): - # TODO: Need to process tuple items? - pass - elif isinstance(type, Instance): - for arg in type.args: - result.extend(self.find_type_variables_in_type(arg)) - else: - assert False, 'Unsupported type %s' % type - return result - - def is_defined_type_var(self, tvar: str, context: Context) -> bool: - return self.lookup_qualified(tvar, context).kind == BOUND_TVAR + with self.tvar_scope_frame(self.tvar_scope.method_frame()): + a = self.type_analyzer() + fun_type.variables = a.bind_function_type_variables(fun_type, defn) def visit_overloaded_func_def(self, defn: OverloadedFuncDef) -> None: # OverloadedFuncDef refers to any legitimate situation where you have @@ -600,60 +546,54 @@ def analyze_property_with_multi_part_definition(self, defn: OverloadedFuncDef) - else: self.fail("Decorated property not supported", item) - def next_function_tvar_id(self) -> int: - return self.next_function_tvar_id_stack[-1] - def analyze_function(self, defn: FuncItem) -> None: is_method = self.is_class_scope() - - tvarnodes = self.add_func_type_variables_to_symbol_table(defn) - next_function_tvar_id = min([self.next_function_tvar_id()] + - [n.tvar_def.id.raw_id - 1 for n in tvarnodes]) - self.next_function_tvar_id_stack.append(next_function_tvar_id) - - if defn.type: - # Signature must be analyzed in the surrounding scope so that - # class-level imported names and type variables are in scope. - self.check_classvar_in_signature(defn.type) - defn.type = self.anal_type(defn.type) - self.check_function_signature(defn) - if isinstance(defn, FuncDef): - defn.type = set_callable_name(defn.type, defn) - for arg in defn.arguments: - if arg.initializer: - arg.initializer.accept(self) - self.function_stack.append(defn) - self.enter() - for arg in defn.arguments: - self.add_local(arg.variable, defn) - for arg in defn.arguments: - if arg.initialization_statement: - lvalue = arg.initialization_statement.lvalues[0] - lvalue.accept(self) - - # The first argument of a non-static, non-class method is like 'self' - # (though the name could be different), having the enclosing class's - # instance type. - if is_method and not defn.is_static and not defn.is_class and defn.arguments: - defn.arguments[0].variable.is_self = True - - # First analyze body of the function but ignore nested functions. - self.postpone_nested_functions_stack.append(FUNCTION_FIRST_PHASE_POSTPONE_SECOND) - self.postponed_functions_stack.append([]) - defn.body.accept(self) - - # Analyze nested functions (if any) as a second phase. - self.postpone_nested_functions_stack[-1] = FUNCTION_SECOND_PHASE - for postponed in self.postponed_functions_stack[-1]: - postponed.accept(self) - self.postpone_nested_functions_stack.pop() - self.postponed_functions_stack.pop() - - self.next_function_tvar_id_stack.pop() - disable_typevars(tvarnodes) - - self.leave() - self.function_stack.pop() + with self.tvar_scope_frame(self.tvar_scope.method_frame()): + if defn.type: + self.check_classvar_in_signature(defn.type) + assert isinstance(defn.type, CallableType) + # Signature must be analyzed in the surrounding scope so that + # class-level imported names and type variables are in scope. + defn.type = self.type_analyzer().visit_callable_type(defn.type, nested=False) + self.check_function_signature(defn) + if isinstance(defn, FuncDef): + defn.type = set_callable_name(defn.type, defn) + for arg in defn.arguments: + if arg.initializer: + arg.initializer.accept(self) + # Bind the type variables again to visit the body. + if defn.type: + a = self.type_analyzer() + a.bind_function_type_variables(cast(CallableType, defn.type), defn) + self.function_stack.append(defn) + self.enter() + for arg in defn.arguments: + self.add_local(arg.variable, defn) + for arg in defn.arguments: + if arg.initialization_statement: + lvalue = arg.initialization_statement.lvalues[0] + lvalue.accept(self) + + # The first argument of a non-static, non-class method is like 'self' + # (though the name could be different), having the enclosing class's + # instance type. + if is_method and not defn.is_static and not defn.is_class and defn.arguments: + defn.arguments[0].variable.is_self = True + + # First analyze body of the function but ignore nested functions. + self.postpone_nested_functions_stack.append(FUNCTION_FIRST_PHASE_POSTPONE_SECOND) + self.postponed_functions_stack.append([]) + defn.body.accept(self) + + # Analyze nested functions (if any) as a second phase. + self.postpone_nested_functions_stack[-1] = FUNCTION_SECOND_PHASE + for postponed in self.postponed_functions_stack[-1]: + postponed.accept(self) + self.postpone_nested_functions_stack.pop() + self.postponed_functions_stack.pop() + + self.leave() + self.function_stack.pop() def check_classvar_in_signature(self, typ: Type) -> None: t = None # type: Type @@ -669,36 +609,6 @@ def check_classvar_in_signature(self, typ: Type) -> None: # Show only one error per signature break - def add_func_type_variables_to_symbol_table( - self, defn: FuncItem) -> List[SymbolTableNode]: - nodes = [] # type: List[SymbolTableNode] - if defn.type: - tt = defn.type - assert isinstance(tt, CallableType) - items = tt.variables - names = self.type_var_names() - for item in items: - name = item.name - if name in names: - self.name_already_defined(name, defn) - node = self.bind_type_var(name, item, defn) - nodes.append(node) - names.add(name) - return nodes - - def type_var_names(self) -> Set[str]: - if not self.type: - return set() - else: - return set(self.type.type_vars) - - def bind_type_var(self, fullname: str, tvar_def: TypeVarDef, - context: Context) -> SymbolTableNode: - node = self.lookup_qualified(fullname, context) - node.kind = BOUND_TVAR - node.tvar_def = tvar_def - return node - def check_function_signature(self, fdef: FuncItem) -> None: sig = fdef.type assert isinstance(sig, CallableType) @@ -718,36 +628,31 @@ def visit_class_def(self, defn: ClassDef) -> None: @contextmanager def analyze_class_body(self, defn: ClassDef) -> Iterator[bool]: - self.clean_up_bases_and_infer_type_variables(defn) - if self.analyze_typeddict_classdef(defn): - yield False - return - if self.analyze_namedtuple_classdef(defn): - # just analyze the class body so we catch type errors in default values - self.enter_class(defn) - yield True - self.leave_class() - else: - self.setup_class_def_analysis(defn) - - self.bind_class_type_vars(defn) - - self.analyze_base_classes(defn) - self.analyze_metaclass(defn) - - for decorator in defn.decorators: - self.analyze_class_decorator(defn, decorator) - - self.enter_class(defn) + with self.tvar_scope_frame(self.tvar_scope.class_frame()): + self.clean_up_bases_and_infer_type_variables(defn) + if self.analyze_typeddict_classdef(defn): + yield False + return + if self.analyze_namedtuple_classdef(defn): + # just analyze the class body so we catch type errors in default values + self.enter_class(defn) + yield True + self.leave_class() + else: + self.setup_class_def_analysis(defn) + self.analyze_base_classes(defn) + self.analyze_metaclass(defn) - yield True + for decorator in defn.decorators: + self.analyze_class_decorator(defn, decorator) - self.calculate_abstract_status(defn.info) - self.setup_type_promotion(defn) + self.enter_class(defn) + yield True - self.leave_class() + self.calculate_abstract_status(defn.info) + self.setup_type_promotion(defn) - self.unbind_class_type_vars() + self.leave_class() def enter_class(self, defn: ClassDef) -> None: # Remember previous active class @@ -764,24 +669,6 @@ def leave_class(self) -> None: self.locals.pop() self.type = self.type_stack.pop() - def bind_class_type_vars(self, defn: ClassDef) -> None: - """ Unbind type variables of previously active class and bind - the type variables for the active class. - """ - if self.bound_tvars: - disable_typevars(self.bound_tvars) - self.tvar_stack.append(self.bound_tvars) - self.bound_tvars = self.bind_class_type_variables_in_symbol_table(defn.info) - - def unbind_class_type_vars(self) -> None: - """ Unbind the active class' type vars and rebind the - type vars of the previously active class. - """ - disable_typevars(self.bound_tvars) - self.bound_tvars = self.tvar_stack.pop() - if self.bound_tvars: - enable_typevars(self.bound_tvars) - def analyze_class_decorator(self, defn: ClassDef, decorator: Expression) -> None: decorator.accept(self) @@ -844,8 +731,7 @@ def clean_up_bases_and_infer_type_variables(self, defn: ClassDef) -> None: Note that this is performed *before* semantic analysis. """ removed = [] # type: List[int] - declared_tvars = [] # type: List[Tuple[str, TypeVarExpr]] - type_vars = [] # type: List[TypeVarDef] + declared_tvars = [] # type: TypeVarList for i, base_expr in enumerate(defn.base_type_exprs): try: base = expr_to_unanalyzed_type(base_expr) @@ -861,26 +747,26 @@ def clean_up_bases_and_infer_type_variables(self, defn: ClassDef) -> None: all_tvars = self.get_all_bases_tvars(defn, removed) if declared_tvars: - if len(self.remove_dups(declared_tvars)) < len(declared_tvars): + if len(remove_dups(declared_tvars)) < len(declared_tvars): self.fail("Duplicate type variables in Generic[...]", defn) - declared_tvars = self.remove_dups(declared_tvars) + declared_tvars = remove_dups(declared_tvars) if not set(all_tvars).issubset(set(declared_tvars)): self.fail("If Generic[...] is present it should list all type variables", defn) # In case of error, Generic tvars will go first - declared_tvars = self.remove_dups(declared_tvars + all_tvars) + declared_tvars = remove_dups(declared_tvars + all_tvars) else: declared_tvars = all_tvars - for j, (name, tvar_expr) in enumerate(declared_tvars): - type_vars.append(TypeVarDef(name, j + 1, tvar_expr.values, - tvar_expr.upper_bound, tvar_expr.variance)) - if type_vars: - defn.type_vars = type_vars + if declared_tvars: if defn.info: - defn.info.type_vars = [tv.name for tv in type_vars] + defn.info.type_vars = [name for name, _ in declared_tvars] for i in reversed(removed): del defn.base_type_exprs[i] + tvar_defs = [] # type: List[TypeVarDef] + for name, tvar_expr in declared_tvars: + tvar_defs.append(self.tvar_scope.bind(name, tvar_expr)) + defn.type_vars = tvar_defs - def analyze_typevar_declaration(self, t: Type) -> Optional[List[Tuple[str, TypeVarExpr]]]: + def analyze_typevar_declaration(self, t: Type) -> Optional[TypeVarList]: if not isinstance(t, UnboundType): return None unbound = t @@ -888,7 +774,7 @@ def analyze_typevar_declaration(self, t: Type) -> Optional[List[Tuple[str, TypeV if sym is None or sym.node is None: return None if sym.node.fullname() == 'typing.Generic': - tvars = [] # type: List[Tuple[str, TypeVarExpr]] + tvars = [] # type: TypeVarList for arg in unbound.args: tvar = self.analyze_unbound_tvar(arg) if tvar: @@ -904,14 +790,17 @@ def analyze_unbound_tvar(self, t: Type) -> Tuple[str, TypeVarExpr]: return None unbound = t sym = self.lookup_qualified(unbound.name, unbound) - if sym is not None and sym.kind == UNBOUND_TVAR: + if sym is None or sym.kind != TVAR: + return None + elif not self.tvar_scope.allow_binding(sym.fullname): + # It's bound by our type variable scope + return None + else: assert isinstance(sym.node, TypeVarExpr) return unbound.name, sym.node - return None - def get_all_bases_tvars(self, defn: ClassDef, - removed: List[int]) -> List[Tuple[str, TypeVarExpr]]: - tvars = [] # type: List[Tuple[str, TypeVarExpr]] + def get_all_bases_tvars(self, defn: ClassDef, removed: List[int]) -> TypeVarList: + tvars = [] # type: TypeVarList for i, base_expr in enumerate(defn.base_type_exprs): if i not in removed: try: @@ -919,34 +808,9 @@ def get_all_bases_tvars(self, defn: ClassDef, except TypeTranslationError: # This error will be caught later. continue - tvars.extend(self.get_tvars(base)) - return self.remove_dups(tvars) - - def get_tvars(self, tp: Type) -> List[Tuple[str, TypeVarExpr]]: - tvars = [] # type: List[Tuple[str, TypeVarExpr]] - if isinstance(tp, UnboundType): - tp_args = tp.args - elif isinstance(tp, TypeList): - tp_args = tp.items - else: - return tvars - for arg in tp_args: - tvar = self.analyze_unbound_tvar(arg) - if tvar: - tvars.append(tvar) - else: - tvars.extend(self.get_tvars(arg)) - return self.remove_dups(tvars) - - def remove_dups(self, tvars: List[T]) -> List[T]: - # Get unique elements in order of appearance - all_tvars = set(tvars) - new_tvars = [] # type: List[T] - for t in tvars: - if t in all_tvars: - new_tvars.append(t) - all_tvars.remove(t) - return new_tvars + base_tvars = base.accept(TypeVariableQuery(self.lookup_qualified, self.tvar_scope)) + tvars.extend(base_tvars) + return remove_dups(tvars) def analyze_namedtuple_classdef(self, defn: ClassDef) -> bool: # special case for NamedTuple @@ -1216,14 +1080,6 @@ def named_type_or_none(self, qualified_name: str, args: List[Type] = None) -> In assert isinstance(sym.node, TypeInfo) return Instance(sym.node, args or []) - def bind_class_type_variables_in_symbol_table( - self, info: TypeInfo) -> List[SymbolTableNode]: - nodes = [] # type: List[SymbolTableNode] - for var, binder in zip(info.type_vars, info.defn.type_vars): - node = self.bind_type_var(var, binder, info) - nodes.append(node) - return nodes - def is_typeddict(self, expr: Expression) -> bool: return (isinstance(expr, RefExpr) and isinstance(expr.node, TypeInfo) and expr.node.typeddict_type is not None) @@ -1516,16 +1372,31 @@ def visit_block_maybe(self, b: Block) -> None: if b: self.visit_block(b) - def anal_type(self, t: Type, allow_tuple_literal: bool = False, + def type_analyzer(self, *, + tvar_scope: Optional[TypeVarScope] = None, + allow_tuple_literal: bool = False, + aliasing: bool = False) -> TypeAnalyser: + if tvar_scope is None: + tvar_scope = self.tvar_scope + return TypeAnalyser(self.lookup_qualified, + self.lookup_fully_qualified, + tvar_scope, + self.fail, + aliasing=aliasing, + allow_tuple_literal=allow_tuple_literal, + allow_unnormalized=self.is_stub_file) + + def anal_type(self, t: Type, *, + tvar_scope: Optional[TypeVarScope] = None, + allow_tuple_literal: bool = False, aliasing: bool = False) -> Type: if t: - a = TypeAnalyser(self.lookup_qualified, - self.lookup_fully_qualified, - self.fail, - aliasing=aliasing, - allow_tuple_literal=allow_tuple_literal, - allow_unnormalized=self.is_stub_file) + a = self.type_analyzer( + tvar_scope=tvar_scope, + aliasing=aliasing, + allow_tuple_literal=allow_tuple_literal) return t.accept(a) + else: return None @@ -1536,7 +1407,7 @@ def visit_assignment_stmt(self, s: AssignmentStmt) -> None: s.rvalue.accept(self) if s.type: allow_tuple_literal = isinstance(s.lvalues[-1], (TupleExpr, ListExpr)) - s.type = self.anal_type(s.type, allow_tuple_literal) + s.type = self.anal_type(s.type, allow_tuple_literal=allow_tuple_literal) else: # For simple assignments, allow binding type aliases. # Also set the type if the rvalue is a simple literal. @@ -1547,6 +1418,7 @@ def visit_assignment_stmt(self, s: AssignmentStmt) -> None: res = analyze_type_alias(s.rvalue, self.lookup_qualified, self.lookup_fully_qualified, + self.tvar_scope, self.fail, allow_unnormalized=True) if res and (not isinstance(res, Instance) or res.args): # TODO: What if this gets reassigned? @@ -1929,7 +1801,7 @@ def process_typevar_declaration(self, s: AssignmentStmt) -> None: # Yes, it's a valid type variable definition! Add it to the symbol table. node = self.lookup(name, s) - node.kind = UNBOUND_TVAR + node.kind = TVAR TypeVar = TypeVarExpr(name, node.fullname, values, upper_bound, variance) TypeVar.line = call.line call.analyzed = TypeVar @@ -2637,7 +2509,7 @@ def visit_for_stmt(self, s: ForStmt) -> None: if self.is_classvar(s.index_type): self.fail_invalid_classvar(s.index) allow_tuple_literal = isinstance(s.index, (TupleExpr, ListExpr)) - s.index_type = self.anal_type(s.index_type, allow_tuple_literal) + s.index_type = self.anal_type(s.index_type, allow_tuple_literal=allow_tuple_literal) self.store_declared_types(s.index, s.index_type) self.loop_depth += 1 @@ -2714,7 +2586,7 @@ def visit_with_stmt(self, s: WithStmt) -> None: if self.is_classvar(t): self.fail_invalid_classvar(n) allow_tuple_literal = isinstance(n, (TupleExpr, ListExpr)) - t = self.anal_type(t, allow_tuple_literal) + t = self.anal_type(t, allow_tuple_literal=allow_tuple_literal) new_types.append(t) self.store_declared_types(n, t) @@ -2785,7 +2657,7 @@ def visit_exec_stmt(self, s: ExecStmt) -> None: def visit_name_expr(self, expr: NameExpr) -> None: n = self.lookup(expr.name, expr) if n: - if n.kind == BOUND_TVAR: + if n.kind == TVAR and self.tvar_scope.get_binding(n): self.fail("'{}' is a type variable and only valid in type " "context".format(expr.name), expr) else: @@ -3010,6 +2882,7 @@ def visit_index_expr(self, expr: IndexExpr) -> None: res = analyze_type_alias(expr, self.lookup_qualified, self.lookup_fully_qualified, + self.tvar_scope, self.fail, allow_unnormalized=self.is_stub_file) expr.analyzed = TypeAliasExpr(res, fallback=self.alias_fallback(res), in_runtime=True) @@ -3155,6 +3028,13 @@ def visit_await_expr(self, expr: AwaitExpr) -> None: # Helpers # + @contextmanager + def tvar_scope_frame(self, frame: TypeVarScope) -> Iterator[None]: + old_scope = self.tvar_scope + self.tvar_scope = frame + yield + self.tvar_scope = old_scope + def lookup(self, name: str, ctx: Context) -> SymbolTableNode: """Look up an unqualified name in all active namespaces.""" # 1a. Name declared using 'global x' takes precedence @@ -3860,18 +3740,6 @@ def find_duplicate(list: List[T]) -> T: return None -def disable_typevars(nodes: List[SymbolTableNode]) -> None: - for node in nodes: - assert node.kind in (BOUND_TVAR, UNBOUND_TVAR) - node.kind = UNBOUND_TVAR - - -def enable_typevars(nodes: List[SymbolTableNode]) -> None: - for node in nodes: - assert node.kind in (BOUND_TVAR, UNBOUND_TVAR) - node.kind = BOUND_TVAR - - def remove_imported_names_from_symtable(names: SymbolTable, module: str) -> None: """Remove all imported names from the symbol table of a module.""" diff --git a/mypy/tvar_scope.py b/mypy/tvar_scope.py new file mode 100644 index 000000000000..3cdb67bbf992 --- /dev/null +++ b/mypy/tvar_scope.py @@ -0,0 +1,82 @@ +from typing import Optional, Dict, Union +from mypy.types import TypeVarDef, TypeVarId +from mypy.nodes import TypeVarExpr, SymbolTableNode + + +class TypeVarScope: + """Scope that holds bindings for type variables. Node fullname -> TypeVarDef.""" + + def __init__(self, + parent: Optional['TypeVarScope'] = None, + is_class_scope: bool = False, + prohibited: Optional['TypeVarScope'] = None) -> None: + """Initializer for TypeVarScope + + Parameters: + parent: the outer scope for this scope + is_class_scope: True if this represents a generic class + prohibited: Type variables that aren't strictly in scope exactly, + but can't be bound because they're part of an outer class's scope. + """ + self.scope = {} # type: Dict[str, TypeVarDef] + self.parent = parent + self.func_id = 0 + self.class_id = 0 + self.is_class_scope = is_class_scope + self.prohibited = prohibited + if parent is not None: + self.func_id = parent.func_id + self.class_id = parent.class_id + + def get_function_scope(self) -> Optional['TypeVarScope']: + """Get the nearest parent that's a function scope, not a class scope""" + it = self + while it is not None and it.is_class_scope: + it = it.parent + return it + + def allow_binding(self, fullname: str) -> bool: + if fullname in self.scope: + return False + elif self.parent and not self.parent.allow_binding(fullname): + return False + elif self.prohibited and not self.prohibited.allow_binding(fullname): + return False + return True + + def method_frame(self) -> 'TypeVarScope': + """A new scope frame for binding a method""" + return TypeVarScope(self, False, None) + + def class_frame(self) -> 'TypeVarScope': + """A new scope frame for binding a class. Prohibits *this* class's tvars""" + return TypeVarScope(self.get_function_scope(), True, self) + + def bind(self, name: str, tvar_expr: TypeVarExpr) -> TypeVarDef: + if self.is_class_scope: + self.class_id += 1 + i = self.class_id + else: + self.func_id -= 1 + i = self.func_id + tvar_def = TypeVarDef( + name, i, values=tvar_expr.values, + upper_bound=tvar_expr.upper_bound, variance=tvar_expr.variance, + line=tvar_expr.line, column=tvar_expr.column) + self.scope[tvar_expr.fullname()] = tvar_def + return tvar_def + + def get_binding(self, item: Union[str, SymbolTableNode]) -> Optional[TypeVarDef]: + fullname = item.fullname if isinstance(item, SymbolTableNode) else item + if fullname in self.scope: + return self.scope[fullname] + elif self.parent is not None: + return self.parent.get_binding(fullname) + else: + return None + + def __str__(self) -> str: + me = ", ".join('{}: {}`{}'.format(k, v.name, v.id) for k, v in self.scope.items()) + if self.parent is None: + return me + return "{} <- {}".format(str(self.parent), me) diff --git a/mypy/typeanal.py b/mypy/typeanal.py index a3ad0eb36fea..877714581080 100644 --- a/mypy/typeanal.py +++ b/mypy/typeanal.py @@ -1,20 +1,24 @@ """Semantic analysis of types""" from collections import OrderedDict -from typing import Callable, List, Optional, Set +from typing import Callable, List, Optional, Set, Tuple, Iterator, TypeVar, Iterable +from itertools import chain + +from contextlib import contextmanager from mypy.types import ( Type, UnboundType, TypeVarType, TupleType, TypedDictType, UnionType, Instance, AnyType, CallableType, NoneTyp, DeletedType, TypeList, TypeVarDef, TypeVisitor, SyntheticTypeVisitor, StarType, PartialType, EllipsisType, UninhabitedType, TypeType, get_typ_args, set_typ_args, - get_type_vars, union_items + get_type_vars, TypeQuery, union_items, ) from mypy.nodes import ( - BOUND_TVAR, UNBOUND_TVAR, TYPE_ALIAS, UNBOUND_IMPORTED, + TVAR, TYPE_ALIAS, UNBOUND_IMPORTED, TypeInfo, Context, SymbolTableNode, Var, Expression, - IndexExpr, RefExpr, nongen_builtins, + IndexExpr, RefExpr, nongen_builtins, TypeVarExpr ) +from mypy.tvar_scope import TypeVarScope from mypy.sametypes import is_same_type from mypy.exprtotype import expr_to_unanalyzed_type, TypeTranslationError from mypy.subtypes import is_subtype @@ -22,6 +26,9 @@ from mypy import experiments +T = TypeVar('T') + + type_constructors = { 'typing.Callable', 'typing.Optional', @@ -34,6 +41,7 @@ def analyze_type_alias(node: Expression, lookup_func: Callable[[str, Context], SymbolTableNode], lookup_fqn_func: Callable[[str], SymbolTableNode], + tvar_scope: TypeVarScope, fail_func: Callable[[str, Context], None], allow_unnormalized: bool = False) -> Type: """Return type if node is valid as a type alias rvalue. @@ -47,7 +55,7 @@ def analyze_type_alias(node: Expression, # Note that this misses the case where someone tried to use a # class-referenced type variable as a type alias. It's easier to catch # that one in checkmember.py - if node.kind == UNBOUND_TVAR or node.kind == BOUND_TVAR: + if node.kind == TVAR: fail_func('Type variable "{}" is invalid as target for type alias'.format( node.fullname), node) return None @@ -77,7 +85,7 @@ def analyze_type_alias(node: Expression, except TypeTranslationError: fail_func('Invalid type alias', node) return None - analyzer = TypeAnalyser(lookup_func, lookup_fqn_func, fail_func, aliasing=True, + analyzer = TypeAnalyser(lookup_func, lookup_fqn_func, tvar_scope, fail_func, aliasing=True, allow_unnormalized=allow_unnormalized) return type.accept(analyzer) @@ -99,6 +107,7 @@ class TypeAnalyser(SyntheticTypeVisitor[Type]): def __init__(self, lookup_func: Callable[[str, Context], SymbolTableNode], lookup_fqn_func: Callable[[str], SymbolTableNode], + tvar_scope: TypeVarScope, fail_func: Callable[[str, Context], None], *, aliasing: bool = False, allow_tuple_literal: bool = False, @@ -106,6 +115,7 @@ def __init__(self, self.lookup = lookup_func self.lookup_fqn_func = lookup_fqn_func self.fail = fail_func + self.tvar_scope = tvar_scope self.aliasing = aliasing self.allow_tuple_literal = allow_tuple_literal # Positive if we are analyzing arguments of another (outer) type @@ -129,12 +139,12 @@ def visit_unbound_type(self, t: UnboundType) -> Type: if (fullname in nongen_builtins and t.args and not sym.normalized and not self.allow_unnormalized): self.fail(no_subscript_builtin_alias(fullname), t) - if sym.kind == BOUND_TVAR: + if sym.kind == TVAR and self.tvar_scope.get_binding(sym) is not None: + tvar_def = self.tvar_scope.get_binding(sym) if len(t.args) > 0: self.fail('Type variable "{}" used with arguments'.format( t.name), t) - assert sym.tvar_def is not None - return TypeVarType(sym.tvar_def, t.line) + return TypeVarType(tvar_def, t.line) elif fullname == 'builtins.None': return NoneTyp() elif fullname == 'typing.Any' or fullname == 'builtins.Any': @@ -212,7 +222,8 @@ def visit_unbound_type(self, t: UnboundType) -> Type: # is pretty minor. return AnyType() # Allow unbound type variables when defining an alias - if not (self.aliasing and sym.kind == UNBOUND_TVAR): + if not (self.aliasing and sym.kind == TVAR and + self.tvar_scope.get_binding(sym) is None): self.fail('Invalid type "{}"'.format(name), t) return t info = sym.node # type: TypeInfo @@ -254,30 +265,15 @@ def get_type_var_names(self, tp: Type) -> List[str]: """Get all type variable names that are present in a generic type alias in order of textual appearance (recursively, if needed). """ - tvars = [] # type: List[str] - typ_args = get_typ_args(tp) - for arg in typ_args: - tvar = self.get_tvar_name(arg) - if tvar: - tvars.append(tvar) - else: - subvars = self.get_type_var_names(arg) - if subvars: - tvars.extend(subvars) - # Get unique type variables in order of appearance - all_tvars = set(tvars) - new_tvars = [] - for t in tvars: - if t in all_tvars: - new_tvars.append(t) - all_tvars.remove(t) - return new_tvars + return [name for name, _ + in tp.accept(TypeVariableQuery(self.lookup, self.tvar_scope, + include_callables=True, include_bound_tvars=True))] def get_tvar_name(self, t: Type) -> Optional[str]: if not isinstance(t, UnboundType): return None sym = self.lookup(t.name, t) - if sym is not None and (sym.kind == UNBOUND_TVAR or sym.kind == BOUND_TVAR): + if sym is not None and sym.kind == TVAR: return t.name return None @@ -320,11 +316,18 @@ def visit_instance(self, t: Instance) -> Type: def visit_type_var(self, t: TypeVarType) -> Type: return t - def visit_callable_type(self, t: CallableType) -> Type: - return t.copy_modified(arg_types=self.anal_array(t.arg_types, nested=False), - ret_type=self.anal_type(t.ret_type, nested=False), - fallback=t.fallback or self.builtin_type('builtins.function'), - variables=self.anal_var_defs(t.variables)) + def visit_callable_type(self, t: CallableType, nested: bool = True) -> Type: + # Every Callable can bind its own type variables, if they're not in the outer scope + with self.tvar_scope_frame(): + if self.aliasing: + variables = t.variables + else: + variables = self.bind_function_type_variables(t, t) + ret = t.copy_modified(arg_types=self.anal_array(t.arg_types, nested=nested), + ret_type=self.anal_type(t.ret_type, nested=nested), + fallback=t.fallback or self.builtin_type('builtins.function'), + variables=self.anal_var_defs(variables)) + return ret def visit_tuple_type(self, t: TupleType) -> Type: # Types such as (t1, t2, ...) only allowed in assignment statements. They'll @@ -371,36 +374,91 @@ def analyze_callable_type(self, t: UnboundType) -> Type: fallback = self.builtin_type('builtins.function') if len(t.args) == 0: # Callable (bare). Treat as Callable[..., Any]. - return CallableType([AnyType(), AnyType()], - [nodes.ARG_STAR, nodes.ARG_STAR2], - [None, None], - ret_type=AnyType(), - fallback=fallback, - is_ellipsis_args=True) + ret = CallableType([AnyType(), AnyType()], + [nodes.ARG_STAR, nodes.ARG_STAR2], + [None, None], + ret_type=AnyType(), + fallback=fallback, + is_ellipsis_args=True) elif len(t.args) == 2: - ret_type = self.anal_type(t.args[1]) + ret_type = t.args[1] if isinstance(t.args[0], TypeList): # Callable[[ARG, ...], RET] (ordinary callable type) args = t.args[0].items - return CallableType(self.anal_array(args), - [nodes.ARG_POS] * len(args), - [None] * len(args), - ret_type=ret_type, - fallback=fallback) + ret = CallableType(args, + [nodes.ARG_POS] * len(args), + [None] * len(args), + ret_type=ret_type, + fallback=fallback) elif isinstance(t.args[0], EllipsisType): # Callable[..., RET] (with literal ellipsis; accept arbitrary arguments) - return CallableType([AnyType(), AnyType()], - [nodes.ARG_STAR, nodes.ARG_STAR2], - [None, None], - ret_type=ret_type, - fallback=fallback, - is_ellipsis_args=True) + ret = CallableType([AnyType(), AnyType()], + [nodes.ARG_STAR, nodes.ARG_STAR2], + [None, None], + ret_type=ret_type, + fallback=fallback, + is_ellipsis_args=True) else: self.fail('The first argument to Callable must be a list of types or "..."', t) return AnyType() - - self.fail('Invalid function type', t) - return AnyType() + else: + self.fail('Invalid function type', t) + return AnyType() + assert isinstance(ret, CallableType) + return ret.accept(self) + + @contextmanager + def tvar_scope_frame(self) -> Iterator[None]: + old_scope = self.tvar_scope + self.tvar_scope = self.tvar_scope.method_frame() + yield + self.tvar_scope = old_scope + + def infer_type_variables(self, + type: CallableType) -> List[Tuple[str, TypeVarExpr]]: + """Return list of unique type variables referred to in a callable.""" + names = [] # type: List[str] + tvars = [] # type: List[TypeVarExpr] + for arg in type.arg_types: + for name, tvar_expr in arg.accept(TypeVariableQuery(self.lookup, self.tvar_scope)): + if name not in names: + names.append(name) + tvars.append(tvar_expr) + # When finding type variables in the return type of a function, don't + # look inside Callable types. Type variables only appearing in + # functions in the return type belong to those functions, not the + # function we're currently analyzing. + for name, tvar_expr in type.ret_type.accept( + TypeVariableQuery(self.lookup, self.tvar_scope, include_callables=False)): + if name not in names: + names.append(name) + tvars.append(tvar_expr) + return list(zip(names, tvars)) + + def bind_function_type_variables(self, + fun_type: CallableType, defn: Context) -> List[TypeVarDef]: + """Find the type variables of the function type and bind them in our tvar_scope""" + if fun_type.variables: + for var in fun_type.variables: + var_expr = self.lookup(var.name, var).node + assert isinstance(var_expr, TypeVarExpr) + self.tvar_scope.bind(var.name, var_expr) + return fun_type.variables + typevars = self.infer_type_variables(fun_type) + # Do not define a new type variable if already defined in scope. + typevars = [(name, tvar) for name, tvar in typevars + if not self.is_defined_type_var(name, defn)] + defs = [] # type: List[TypeVarDef] + for name, tvar in typevars: + if not self.tvar_scope.allow_binding(tvar.fullname()): + self.fail("Type variable '{}' is bound by an outer class".format(name), defn) + self.tvar_scope.bind(name, tvar) + defs.append(self.tvar_scope.get_binding(tvar.fullname())) + + return defs + + def is_defined_type_var(self, tvar: str, context: Context) -> bool: + return self.tvar_scope.get_binding(self.lookup(tvar, context)) is not None def anal_array(self, a: List[Type], nested: bool = True) -> List[Type]: res = [] # type: List[Type] @@ -485,8 +543,8 @@ def visit_instance(self, t: Instance) -> None: t.invalid = True elif info.defn.type_vars: # Check type argument values. - for (i, arg), TypeVar in zip(enumerate(t.args), info.defn.type_vars): - if TypeVar.values: + for (i, arg), tvar in zip(enumerate(t.args), info.defn.type_vars): + if tvar.values: if isinstance(arg, TypeVarType): arg_values = arg.values if not arg_values: @@ -497,11 +555,11 @@ def visit_instance(self, t: Instance) -> None: else: arg_values = [arg] self.check_type_var_values(info, arg_values, - TypeVar.values, i + 1, t) - if not is_subtype(arg, TypeVar.upper_bound): + tvar.values, i + 1, t) + if not is_subtype(arg, tvar.upper_bound): self.fail('Type argument "{}" of "{}" must be ' 'a subtype of "{}"'.format( - arg, info.name(), TypeVar.upper_bound), t) + arg, info.name(), tvar.upper_bound), t) for arg in t.args: arg.accept(self) @@ -567,6 +625,64 @@ def visit_type_type(self, t: TypeType) -> None: pass +TypeVarList = List[Tuple[str, TypeVarExpr]] + + +def remove_dups(tvars: Iterable[T]) -> List[T]: + # Get unique elements in order of appearance + all_tvars = set() # type: Set[T] + new_tvars = [] # type: List[T] + for t in tvars: + if t not in all_tvars: + new_tvars.append(t) + all_tvars.add(t) + return new_tvars + + +def flatten_tvars(ll: Iterable[List[T]]) -> List[T]: + return remove_dups(chain.from_iterable(ll)) + + +class TypeVariableQuery(TypeQuery[TypeVarList]): + + def __init__(self, + lookup: Callable[[str, Context], SymbolTableNode], + scope: 'TypeVarScope', + *, + include_callables: bool = True, + include_bound_tvars: bool = False) -> None: + self.include_callables = include_callables + self.lookup = lookup + self.scope = scope + self.include_bound_tvars = include_bound_tvars + super().__init__(flatten_tvars) + + def _seems_like_callable(self, type: UnboundType) -> bool: + if not type.args: + return False + if isinstance(type.args[0], (EllipsisType, TypeList)): + return True + return False + + def visit_unbound_type(self, t: UnboundType) -> TypeVarList: + name = t.name + node = self.lookup(name, t) + if node and node.kind == TVAR and ( + self.include_bound_tvars or self.scope.get_binding(node) is None): + assert isinstance(node.node, TypeVarExpr) + return [(name, node.node)] + elif not self.include_callables and self._seems_like_callable(t): + return [] + else: + return super().visit_unbound_type(t) + + def visit_callable_type(self, t: CallableType) -> TypeVarList: + if self.include_callables: + return super().visit_callable_type(t) + else: + return [] + + def make_optional_type(t: Type) -> Type: """Return the type corresponding to Optional[t]. diff --git a/test-data/unit/check-functions.test b/test-data/unit/check-functions.test index 0099d501b9f6..33fc51d3b046 100644 --- a/test-data/unit/check-functions.test +++ b/test-data/unit/check-functions.test @@ -1703,3 +1703,81 @@ def f(a, (b, c), d): pass # flags: --python-version 2.7 def f(a, (b, c), d): pass + +-- Type variable shenanagins +-- ------------------------- + +[case testGenericFunctionTypeDecl] +from typing import Callable, TypeVar + +T = TypeVar('T') + +f: Callable[[T], T] +reveal_type(f) # E: Revealed type is 'def [T] (T`-1) -> T`-1' +def g(__x: T) -> T: pass +f = g +reveal_type(f) # E: Revealed type is 'def [T] (T`-1) -> T`-1' +i = f(3) +reveal_type(i) # E: Revealed type is 'builtins.int*' + +[case testFunctionReturningGenericFunction] +from typing import Callable, TypeVar + +T = TypeVar('T') +def deco() -> Callable[[T], T]: pass +reveal_type(deco) # E: Revealed type is 'def () -> def [T] (T`-1) -> T`-1' +f = deco() +reveal_type(f) # E: Revealed type is 'def [T] (T`-1) -> T`-1' +i = f(3) +reveal_type(i) # E: Revealed type is 'builtins.int*' + +[case testFunctionReturningGenericFunctionPartialBinding] +from typing import Callable, TypeVar + +T = TypeVar('T') +U = TypeVar('U') + +def deco(x: U) -> Callable[[T, U], T]: pass +reveal_type(deco) # E: Revealed type is 'def [U] (x: U`-1) -> def [T] (T`-2, U`-1) -> T`-2' +f = deco("foo") +reveal_type(f) # E: Revealed type is 'def [T] (T`-2, builtins.str*) -> T`-2' +i = f(3, "eggs") +reveal_type(i) # E: Revealed type is 'builtins.int*' + +[case testFunctionReturningGenericFunctionTwoLevelBinding] +from typing import Callable, TypeVar + +T = TypeVar('T') +R = TypeVar('R') +def deco() -> Callable[[T], Callable[[T, R], R]]: pass +f = deco() +reveal_type(f) # E: Revealed type is 'def [T] (T`-1) -> def [R] (T`-1, R`-2) -> R`-2' +g = f(3) +reveal_type(g) # E: Revealed type is 'def [R] (builtins.int*, R`-2) -> R`-2' +s = g(4, "foo") +reveal_type(s) # E: Revealed type is 'builtins.str*' + +[case testGenericFunctionReturnAsDecorator] +from typing import Callable, TypeVar + +T = TypeVar('T') +def deco(__i: int) -> Callable[[T], T]: pass + +@deco(3) +def lol(x: int) -> str: ... + +reveal_type(lol) # E: Revealed type is 'def (x: builtins.int) -> builtins.str' +s = lol(4) +reveal_type(s) # E: Revealed type is 'builtins.str' + +[case testGenericFunctionOnReturnTypeOnly] +from typing import TypeVar, List + +T = TypeVar('T') + +def make_list() -> List[T]: pass + +l: List[int] = make_list() + +bad = make_list() # E: Need type annotation for variable +[builtins fixtures/list.pyi] diff --git a/test-data/unit/check-generics.test b/test-data/unit/check-generics.test index 9d1a0ae8899e..4aa4bf49c2c1 100644 --- a/test-data/unit/check-generics.test +++ b/test-data/unit/check-generics.test @@ -1347,7 +1347,17 @@ class A(Generic[T]): self.a = a g(self.a) g(n) # E: Argument 1 to "g" has incompatible type "int"; expected "T" -[out] + +[case testFunctionInGenericInnerClassTypeVariable] +from typing import TypeVar, Generic + +T = TypeVar('T') +class Outer(Generic[T]): + class Inner: + x: T # E: Invalid type "__main__.T" + def f(self, x: T) -> T: ... # E: Type variable 'T' is bound by an outer class + def g(self) -> None: + y: T # E: Invalid type "__main__.T" -- Callable subtyping with generic functions