diff --git a/libcst/codemod/visitors/_apply_type_annotations.py b/libcst/codemod/visitors/_apply_type_annotations.py index d56477989..934e1b402 100644 --- a/libcst/codemod/visitors/_apply_type_annotations.py +++ b/libcst/codemod/visitors/_apply_type_annotations.py @@ -3,6 +3,7 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from collections import defaultdict from dataclasses import dataclass from typing import Dict, List, Optional, Sequence, Set, Tuple, Union @@ -14,6 +15,7 @@ from libcst.codemod.visitors._add_imports import AddImportsVisitor from libcst.codemod.visitors._gather_global_names import GatherGlobalNamesVisitor from libcst.codemod.visitors._gather_imports import GatherImportsVisitor +from libcst.codemod.visitors._imports import ImportItem from libcst.helpers import get_full_name_for_node from libcst.metadata import PositionProvider, QualifiedNameProvider @@ -29,6 +31,41 @@ ] +def _module_and_target(qualified_name: str) -> Tuple[str, str]: + relative_prefix = "" + while qualified_name.startswith("."): + relative_prefix += "." + qualified_name = qualified_name[1:] + split = qualified_name.rsplit(".", 1) + if len(split) == 1: + qualifier, target = "", split[0] + else: + qualifier, target = split + return (relative_prefix + qualifier, target) + + +def _get_unique_qualified_name( + visitor: m.MatcherDecoratableVisitor, node: cst.CSTNode +) -> str: + name = None + names = [q.name for q in visitor.get_metadata(QualifiedNameProvider, node)] + if len(names) == 0: + # we hit this branch if the stub is directly using a fully + # qualified name, which is not technically valid python but is + # convenient to allow. + name = get_full_name_for_node(node) + elif len(names) == 1 and isinstance(names[0], str): + name = names[0] + if name is None: + start = visitor.get_metadata(PositionProvider, node).start + raise ValueError( + "Could not resolve a unique qualified name for type " + + f"{get_full_name_for_node(node)} at {start.line}:{start.column}. " + + f"Candidate names were: {names!r}" + ) + return name + + def _get_import_alias_names( import_aliases: Sequence[cst.ImportAlias], ) -> Set[str]: @@ -186,6 +223,130 @@ def finish(self) -> None: self.typevars = {k: v for k, v in self.typevars.items() if k in self.names} +@dataclass(frozen=True) +class ImportedSymbol: + """Import of foo.Bar, where both foo and Bar are potentially aliases.""" + + module_name: str + module_alias: Optional[str] = None + target_name: Optional[str] = None + target_alias: Optional[str] = None + + @property + def symbol(self) -> Optional[str]: + return self.target_alias or self.target_name + + @property + def module_symbol(self) -> str: + return self.module_alias or self.module_name + + +class ImportedSymbolCollector(m.MatcherDecoratableVisitor): + """ + Collect imported symbols from a stub module. + """ + + METADATA_DEPENDENCIES = ( + PositionProvider, + QualifiedNameProvider, + ) + + def __init__(self, existing_imports: Set[str], context: CodemodContext) -> None: + super().__init__() + self.existing_imports: Set[str] = existing_imports + self.imported_symbols: Dict[str, Set[ImportedSymbol]] = defaultdict(set) + + def visit_ClassDef(self, node: cst.ClassDef) -> None: + for base in node.bases: + value = base.value + if isinstance(value, NAME_OR_ATTRIBUTE): + self._handle_NameOrAttribute(value) + elif isinstance(value, cst.Subscript): + self._handle_Subscript(value) + + def visit_FunctionDef(self, node: cst.FunctionDef) -> bool: + if node.returns is not None: + self._handle_Annotation(annotation=node.returns) + self._handle_Parameters(node.params) + + # pyi files don't support inner functions, return False to stop the traversal. + return False + + def visit_AnnAssign(self, node: cst.AnnAssign) -> None: + self._handle_Annotation(annotation=node.annotation) + + # Handler functions. + # + # These ultimately all call _handle_NameOrAttribute, which adds the + # qualified name to the list of imported symbols + + def _handle_NameOrAttribute( + self, + node: NameOrAttribute, + ) -> None: + obj = sym = None # keep pyre happy + if isinstance(node, cst.Name): + obj = None + sym = node.value + elif isinstance(node, cst.Attribute): + obj = node.value.value # pyre-ignore[16] + sym = node.attr.value + qualified_name = _get_unique_qualified_name(self, node) + module, target = _module_and_target(qualified_name) + if module in ("", "builtins"): + return + elif qualified_name not in self.existing_imports: + mod = ImportedSymbol( + module_name=module, + module_alias=obj if obj != module else None, + target_name=target, + target_alias=sym if sym != target else None, + ) + self.imported_symbols[sym].add(mod) + + def _handle_Index(self, slice: cst.Index) -> None: + value = slice.value + if isinstance(value, cst.Subscript): + self._handle_Subscript(value) + elif isinstance(value, cst.Attribute): + self._handle_NameOrAttribute(value) + + def _handle_Subscript(self, node: cst.Subscript) -> None: + value = node.value + if isinstance(value, NAME_OR_ATTRIBUTE): + self._handle_NameOrAttribute(value) + else: + raise ValueError("Expected any indexed type to have") + if _get_unique_qualified_name(self, node) in ("Type", "typing.Type"): + return + slice = node.slice + if isinstance(slice, tuple): + for item in slice: + if isinstance(item.slice.value, NAME_OR_ATTRIBUTE): + self._handle_NameOrAttribute(item.slice.value) + else: + if isinstance(item.slice, cst.Index): + self._handle_Index(item.slice) + elif isinstance(slice, cst.Index): + self._handle_Index(slice) + + def _handle_Annotation(self, annotation: cst.Annotation) -> None: + node = annotation.annotation + if isinstance(node, cst.Subscript): + self._handle_Subscript(node) + elif isinstance(node, NAME_OR_ATTRIBUTE): + self._handle_NameOrAttribute(node) + elif isinstance(node, cst.SimpleString): + pass + else: + raise ValueError(f"Unexpected annotation node: {node}") + + def _handle_Parameters(self, parameters: cst.Parameters) -> None: + for parameter in list(parameters.params): + if parameter.annotation is not None: + self._handle_Annotation(annotation=parameter.annotation) + + class TypeCollector(m.MatcherDecoratableVisitor): """ Collect type annotations from a stub module. @@ -201,6 +362,7 @@ class TypeCollector(m.MatcherDecoratableVisitor): def __init__( self, existing_imports: Set[str], + module_imports: Dict[str, ImportItem], context: CodemodContext, ) -> None: super().__init__() @@ -212,6 +374,9 @@ def __init__( # as well as module names, although downstream we effectively ignore # the module names as of the current implementation. self.existing_imports: Set[str] = existing_imports + # Module imports, gathered by prescanning the stub file to determine + # which modules need to be imported directly to qualify their symbols. + self.module_imports: Dict[str, ImportItem] = module_imports # Fields that help us track temporary state as we recurse self.qualifier: List[str] = [] self.current_assign: Optional[cst.Assign] = None # used to collect typevars @@ -323,36 +488,6 @@ def leave_Module( ) -> None: self.annotations.finish() - def _get_unique_qualified_name( - self, - node: cst.CSTNode, - ) -> str: - name = None - names = [q.name for q in self.get_metadata(QualifiedNameProvider, node)] - if len(names) == 0: - # we hit this branch if the stub is directly using a fully - # qualified name, which is not technically valid python but is - # convenient to allow. - name = get_full_name_for_node(node) - elif len(names) == 1 and isinstance(names[0], str): - name = names[0] - if name is None: - start = self.get_metadata(PositionProvider, node).start - raise ValueError( - "Could not resolve a unique qualified name for type " - + f"{get_full_name_for_node(node)} at {start.line}:{start.column}. " - + f"Candidate names were: {names!r}" - ) - return name - - def _get_qualified_name_and_dequalified_node( - self, - node: Union[cst.Name, cst.Attribute], - ) -> Tuple[str, Union[cst.Name, cst.Attribute]]: - qualified_name = self._get_unique_qualified_name(node) - dequalified_node = node.attr if isinstance(node, cst.Attribute) else node - return qualified_name, dequalified_node - def _module_and_target( self, qualified_name: str, @@ -382,6 +517,16 @@ def _handle_qualification_and_should_qualify( elif qualified_name not in self.existing_imports: if module in self.existing_imports: return True + elif module in self.module_imports: + m = self.module_imports[module] + if m.obj_name is None: + asname = m.alias + else: + asname = None + AddImportsVisitor.add_needed_import( + self.context, m.module_name, asname=asname + ) + return True else: if node and isinstance(node, cst.Name) and node.value != target: asname = node.value @@ -407,17 +552,18 @@ def _handle_NameOrAttribute( self, node: NameOrAttribute, ) -> Union[cst.Name, cst.Attribute]: - ( - qualified_name, - dequalified_node, - ) = self._get_qualified_name_and_dequalified_node(node) + qualified_name = _get_unique_qualified_name(self, node) should_qualify = self._handle_qualification_and_should_qualify( qualified_name, node ) self.annotations.names.add(qualified_name) if should_qualify: - return node + qualified_node = ( + cst.parse_module(qualified_name) if isinstance(node, cst.Name) else node + ) + return qualified_node # pyre-ignore[7] else: + dequalified_node = node.attr if isinstance(node, cst.Attribute) else node return dequalified_node def _handle_Index( @@ -443,7 +589,7 @@ def _handle_Subscript( new_node = node.with_changes(value=self._handle_NameOrAttribute(value)) else: raise ValueError("Expected any indexed type to have") - if self._get_unique_qualified_name(node) in ("Type", "typing.Type"): + if _get_unique_qualified_name(self, node) in ("Type", "typing.Type"): # Note: we are intentionally not handling qualification of # anything inside `Type` because it's common to have nested # classes, which we cannot currently distinguish from classes @@ -679,7 +825,8 @@ def transform_module_impl( self.strict_annotation_matching = ( self.strict_annotation_matching or strict_annotation_matching ) - visitor = TypeCollector(existing_import_names, self.context) + module_imports = self._get_module_imports(stub, import_gatherer) + visitor = TypeCollector(existing_import_names, module_imports, self.context) cst.MetadataWrapper(stub).visit(visitor) self.annotations.update(visitor.annotations) @@ -697,6 +844,70 @@ def transform_module_impl( else: return tree + # helpers for collecting type information from the stub files + + def _get_module_imports( + self, stub: cst.Module, existing_import_gatherer: GatherImportsVisitor + ) -> Dict[str, ImportItem]: + """Returns a dict of modules that need to be imported to qualify symbols.""" + # We correlate all imported symbols, e.g. foo.bar.Baz, with a list of module + # and from imports. If the same unqualified symbol is used from different + # modules, we give preference to an explicit from-import if any, and qualify + # everything else by importing the module. + # + # e.g. the following stub: + # import foo as quux + # from bar import Baz as X + # def f(x: X) -> quux.X: ... + # will return {'foo': ImportItem("foo", "quux")}. When the apply type + # annotation visitor hits `quux.X` it will retrieve the canonical name + # `foo.X` and then note that `foo` is in the module imports map, so it will + # leave the symbol qualified. + import_gatherer = GatherImportsVisitor(CodemodContext()) + stub.visit(import_gatherer) + symbol_map = import_gatherer.symbol_mapping + existing_import_names = _get_imported_names( + existing_import_gatherer.all_imports + ) + symbol_collector = ImportedSymbolCollector(existing_import_names, self.context) + cst.MetadataWrapper(stub).visit(symbol_collector) + module_imports = {} + for sym, imported_symbols in symbol_collector.imported_symbols.items(): + existing = existing_import_gatherer.symbol_mapping.get(sym) + if existing and any( + s.module_name != existing.module_name for s in imported_symbols + ): + # If a symbol is imported in the main file, we have to qualify + # it when imported from a different module in the stub file. + used = True + elif len(imported_symbols) == 1: + # If we have a single use of a new symbol we can from-import it + continue + else: + # There are multiple occurrences in the stub file and none in + # the main file. At least one can be from-imported. + used = False + for imp_sym in imported_symbols: + if not imp_sym.symbol: + continue + imp = symbol_map.get(imp_sym.symbol) + if not used and imp and imp.module_name == imp_sym.module_name: + # We can only import a symbol directly once. + used = True + elif sym in existing_import_names: + if imp: + module_imports[imp.module_name] = imp + else: + imp = symbol_map.get(imp_sym.module_symbol) + if imp: + # imp will be None in corner cases like + # import foo.bar as Baz + # x: Baz + # which is technically valid python but nonsensical as a + # type annotation. Dropping it on the floor for now. + module_imports[imp.module_name] = imp + return module_imports + # helpers for processing annotation nodes def _quote_future_annotations(self, annotation: cst.Annotation) -> cst.Annotation: # TODO: We probably want to make sure references to classes defined in the current diff --git a/libcst/codemod/visitors/_gather_imports.py b/libcst/codemod/visitors/_gather_imports.py index 147607980..e62e374a6 100644 --- a/libcst/codemod/visitors/_gather_imports.py +++ b/libcst/codemod/visitors/_gather_imports.py @@ -8,6 +8,7 @@ import libcst from libcst.codemod._context import CodemodContext from libcst.codemod._visitor import ContextAwareVisitor +from libcst.codemod.visitors._imports import ImportItem from libcst.helpers import get_absolute_module_for_import @@ -60,6 +61,8 @@ def __init__(self, context: CodemodContext) -> None: self.alias_mapping: Dict[str, List[Tuple[str, str]]] = {} # Track all of the imports found in this transform self.all_imports: List[Union[libcst.Import, libcst.ImportFrom]] = [] + # Track the import for every symbol introduced into the module + self.symbol_mapping: Dict[str, ImportItem] = {} def visit_Import(self, node: libcst.Import) -> None: # Track this import statement for later analysis. @@ -67,12 +70,15 @@ def visit_Import(self, node: libcst.Import) -> None: for name in node.names: alias = name.evaluated_alias + imp = ImportItem(name.evaluated_name, alias=alias) if alias is not None: # Track this as an aliased module self.module_aliases[name.evaluated_name] = alias + self.symbol_mapping[alias] = imp else: # Get the module we're importing as a string. self.module_imports.add(name.evaluated_name) + self.symbol_mapping[name.evaluated_name] = imp def visit_ImportFrom(self, node: libcst.ImportFrom) -> None: # Track this import statement for later analysis. @@ -114,3 +120,9 @@ def visit_ImportFrom(self, node: libcst.ImportFrom) -> None: return self.object_mapping[module].update(new_objects) + for ia in nodenames: + imp = ImportItem( + module, obj_name=ia.evaluated_name, alias=ia.evaluated_alias + ) + key = ia.evaluated_alias or ia.evaluated_name + self.symbol_mapping[key] = imp diff --git a/libcst/codemod/visitors/tests/test_apply_type_annotations.py b/libcst/codemod/visitors/tests/test_apply_type_annotations.py index 196cb383c..385ab7fad 100644 --- a/libcst/codemod/visitors/tests/test_apply_type_annotations.py +++ b/libcst/codemod/visitors/tests/test_apply_type_annotations.py @@ -309,6 +309,50 @@ def foo(x: B): pass """, ), + "with_conflicting_imported_symbols": ( + """ + import a.foo as bar + from b.c import Baz as B + import d + + def f(a: d.A, b: B) -> bar.B: ... + """, + """ + def f(a, b): + pass + """, + """ + import a.foo as bar + from b.c import Baz as B + from d import A + + def f(a: A, b: B) -> bar.B: + pass + """, + ), + "with_conflicts_between_imported_and_existing_symbols": ( + """ + from a import A + from b import B + + def f(x: A, y: B) -> None: ... + """, + """ + from b import A, B + + def f(x, y): + y = A(x) + z = B(y) + """, + """ + from b import A, B + import a + + def f(x: a.A, y: B) -> None: + y = A(x) + z = B(y) + """, + ), "with_nested_import": ( """ def foo(x: django.http.response.HttpResponse) -> str: