Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AddImportsVisitor: add imports before the first non-import statement #1024

Merged
merged 5 commits into from
Oct 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 88 additions & 32 deletions libcst/codemod/visitors/_add_imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,51 @@

import libcst
from libcst import matchers as m, parse_statement
from libcst._nodes.statement import Import, ImportFrom, SimpleStatementLine
from libcst.codemod._context import CodemodContext
from libcst.codemod._visitor import ContextAwareTransformer
from libcst.codemod.visitors._gather_imports import GatherImportsVisitor
from libcst.codemod.visitors._gather_imports import _GatherImportsMixin
from libcst.codemod.visitors._imports import ImportItem
from libcst.helpers import get_absolute_module_from_package_for_import
from libcst.helpers.common import ensure_type


class _GatherTopImportsBeforeStatements(_GatherImportsMixin):
"""
Works similarly to GatherImportsVisitor, but only considers imports
declared before any other statements of the module with the exception
of docstrings and __strict__ flag.
"""

def __init__(self, context: CodemodContext) -> None:
super().__init__(context)
# Track all of the imports found in this transform
self.all_imports: List[Union[libcst.Import, libcst.ImportFrom]] = []

def leave_Module(self, original_node: libcst.Module) -> None:
start = 1 if _skip_first(original_node) else 0
for stmt in original_node.body[start:]:
if m.matches(
stmt,
m.SimpleStatementLine(body=[m.ImportFrom() | m.Import()]),
):
stmt = ensure_type(stmt, SimpleStatementLine)
# Workaround for python 3.8 and 3.9, won't accept Union for isinstance
if m.matches(stmt.body[0], m.ImportFrom()):
imp = ensure_type(stmt.body[0], ImportFrom)
self.all_imports.append(imp)
if m.matches(stmt.body[0], m.Import()):
imp = ensure_type(stmt.body[0], Import)
self.all_imports.append(imp)
else:
break
for imp in self.all_imports:
if m.matches(imp, m.Import()):
imp = ensure_type(imp, Import)
self._handle_Import(imp)
else:
imp = ensure_type(imp, ImportFrom)
self._handle_ImportFrom(imp)


class AddImportsVisitor(ContextAwareTransformer):
Expand Down Expand Up @@ -169,12 +209,12 @@ def __init__(
for module in sorted(from_imports_aliases)
}

# Track the list of imports found in the file
# Track the list of imports found at the top of the file
self.all_imports: List[Union[libcst.Import, libcst.ImportFrom]] = []

def visit_Module(self, node: libcst.Module) -> None:
# Do a preliminary pass to gather the imports we already have
gatherer = GatherImportsVisitor(self.context)
# Do a preliminary pass to gather the imports we already have at the top
gatherer = _GatherTopImportsBeforeStatements(self.context)
node.visit(gatherer)
self.all_imports = gatherer.all_imports

Expand Down Expand Up @@ -213,6 +253,10 @@ def leave_ImportFrom(
# There's nothing to do here!
return updated_node

# Ensure this is one of the imports at the top
if original_node not in self.all_imports:
return updated_node

# Get the module we're importing as a string, see if we have work to do.
module = get_absolute_module_from_package_for_import(
self.context.full_package_name, updated_node
Expand Down Expand Up @@ -260,39 +304,26 @@ def _split_module(
statement_before_import_location = 0
import_add_location = 0

# never insert an import before initial __strict__ flag
if m.matches(
orig_module,
m.Module(
body=[
m.SimpleStatementLine(
body=[
m.Assign(
targets=[m.AssignTarget(target=m.Name("__strict__"))]
)
]
),
m.ZeroOrMore(),
]
),
):
statement_before_import_location = import_add_location = 1

# This works under the principle that while we might modify node contents,
# we have yet to modify the number of statements. So we can match on the
# original tree but break up the statements of the modified tree. If we
# change this assumption in this visitor, we will have to change this code.
for i, statement in enumerate(orig_module.body):
if i == 0 and m.matches(
statement, m.SimpleStatementLine(body=[m.Expr(value=m.SimpleString())])

# Finds the location to add imports. It is the end of the first import block that occurs before any other statement (save for docstrings)

# Never insert an import before initial __strict__ flag or docstring
if _skip_first(orig_module):
statement_before_import_location = import_add_location = 1

for i, statement in enumerate(
orig_module.body[statement_before_import_location:]
):
if m.matches(
statement, m.SimpleStatementLine(body=[m.ImportFrom() | m.Import()])
):
statement_before_import_location = import_add_location = 1
elif isinstance(statement, libcst.SimpleStatementLine):
for possible_import in statement.body:
for last_import in self.all_imports:
if possible_import is last_import:
import_add_location = i + 1
break
import_add_location = i + statement_before_import_location + 1
else:
break

return (
list(updated_module.body[:statement_before_import_location]),
Expand Down Expand Up @@ -414,3 +445,28 @@ def leave_Module(
*statements_after_imports,
)
)


def _skip_first(orig_module: libcst.Module) -> bool:
# Is there a __strict__ flag or docstring at the top?
if m.matches(
orig_module,
m.Module(
body=[
m.SimpleStatementLine(
body=[
m.Assign(targets=[m.AssignTarget(target=m.Name("__strict__"))])
]
),
m.ZeroOrMore(),
]
)
| m.Module(
body=[
m.SimpleStatementLine(body=[m.Expr(value=m.SimpleString())]),
m.ZeroOrMore(),
]
),
):
return True
return False
105 changes: 59 additions & 46 deletions libcst/codemod/visitors/_gather_imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,43 +12,9 @@
from libcst.helpers import get_absolute_module_from_package_for_import


class GatherImportsVisitor(ContextAwareVisitor):
class _GatherImportsMixin(ContextAwareVisitor):
"""
Gathers all imports in a module and stores them as attributes on the instance.
Intended to be instantiated and passed to a :class:`~libcst.Module`
:meth:`~libcst.CSTNode.visit` method in order to gather up information about
imports on a module. Note that this is not a substitute for scope analysis or
qualified name support. Please see :ref:`libcst-scope-tutorial` for a more
robust way of determining the qualified name and definition for an arbitrary
node.

After visiting a module the following attributes will be populated:

module_imports
A sequence of strings representing modules that were imported directly, such as
in the case of ``import typing``. Each module directly imported but not aliased
will be included here.
object_mapping
A mapping of strings to sequences of strings representing modules where we
imported objects from, such as in the case of ``from typing import Optional``.
Each from import that was not aliased will be included here, where the keys of
the mapping are the module we are importing from, and the value is a
sequence of objects we are importing from the module.
module_aliases
A mapping of strings representing modules that were imported and aliased,
such as in the case of ``import typing as t``. Each module imported this
way will be represented as a key in this mapping, and the value will be
the local alias of the module.
alias_mapping
A mapping of strings to sequences of tuples representing modules where we
imported objects from and aliased using ``as`` syntax, such as in the case
of ``from typing import Optional as opt``. Each from import that was aliased
will be included here, where the keys of the mapping are the module we are
importing from, and the value is a tuple representing the original object
name and the alias.
all_imports
A collection of all :class:`~libcst.Import` and :class:`~libcst.ImportFrom`
statements that were encountered in the module.
A Mixin class for tracking visited imports.
"""

def __init__(self, context: CodemodContext) -> None:
Expand All @@ -59,15 +25,10 @@ def __init__(self, context: CodemodContext) -> None:
# Track the aliased imports in this transform
self.module_aliases: Dict[str, str] = {}
self.alias_mapping: Dict[str, List[Tuple[str, str]]] = {}
# Track all of the imports found in this transform
self.all_imports: List[Union[libcst.Import, libcst.ImportFrom]] = []
# Track the import for every symbol introduced into the module
self.symbol_mapping: Dict[str, ImportItem] = {}

def visit_Import(self, node: libcst.Import) -> None:
# Track this import statement for later analysis.
self.all_imports.append(node)

def _handle_Import(self, node: libcst.Import) -> None:
for name in node.names:
alias = name.evaluated_alias
imp = ImportItem(name.evaluated_name, alias=alias)
Expand All @@ -80,10 +41,7 @@ def visit_Import(self, node: libcst.Import) -> None:
self.module_imports.add(name.evaluated_name)
self.symbol_mapping[name.evaluated_name] = imp

def visit_ImportFrom(self, node: libcst.ImportFrom) -> None:
# Track this import statement for later analysis.
self.all_imports.append(node)

def _handle_ImportFrom(self, node: libcst.ImportFrom) -> None:
# Get the module we're importing as a string.
module = get_absolute_module_from_package_for_import(
self.context.full_package_name, node
Expand Down Expand Up @@ -128,3 +86,58 @@ def visit_ImportFrom(self, node: libcst.ImportFrom) -> None:
)
key = ia.evaluated_alias or ia.evaluated_name
self.symbol_mapping[key] = imp


class GatherImportsVisitor(_GatherImportsMixin):
"""
Gathers all imports in a module and stores them as attributes on the instance.
Intended to be instantiated and passed to a :class:`~libcst.Module`
:meth:`~libcst.CSTNode.visit` method in order to gather up information about
imports on a module. Note that this is not a substitute for scope analysis or
qualified name support. Please see :ref:`libcst-scope-tutorial` for a more
robust way of determining the qualified name and definition for an arbitrary
node.

After visiting a module the following attributes will be populated:

module_imports
A sequence of strings representing modules that were imported directly, such as
in the case of ``import typing``. Each module directly imported but not aliased
will be included here.
object_mapping
A mapping of strings to sequences of strings representing modules where we
imported objects from, such as in the case of ``from typing import Optional``.
Each from import that was not aliased will be included here, where the keys of
the mapping are the module we are importing from, and the value is a
sequence of objects we are importing from the module.
module_aliases
A mapping of strings representing modules that were imported and aliased,
such as in the case of ``import typing as t``. Each module imported this
way will be represented as a key in this mapping, and the value will be
the local alias of the module.
alias_mapping
A mapping of strings to sequences of tuples representing modules where we
imported objects from and aliased using ``as`` syntax, such as in the case
of ``from typing import Optional as opt``. Each from import that was aliased
will be included here, where the keys of the mapping are the module we are
importing from, and the value is a tuple representing the original object
name and the alias.
all_imports
A collection of all :class:`~libcst.Import` and :class:`~libcst.ImportFrom`
statements that were encountered in the module.
"""

def __init__(self, context: CodemodContext) -> None:
super().__init__(context)
# Track all of the imports found in this transform
self.all_imports: List[Union[libcst.Import, libcst.ImportFrom]] = []

def visit_Import(self, node: libcst.Import) -> None:
# Track this import statement for later analysis.
self.all_imports.append(node)
self._handle_Import(node)

def visit_ImportFrom(self, node: libcst.ImportFrom) -> None:
# Track this import statement for later analysis.
self.all_imports.append(node)
self._handle_ImportFrom(node)
Loading
Loading