Skip to content

Commit

Permalink
Relative imports are relative to the containing package. (#11181)
Browse files Browse the repository at this point in the history
The existing code assumed they were relative to the module
(i.e., that relative imports in foo.py and __init__.py should
behave differently), but that is not the case.

See, e.g., https://docs.python.org/3/reference/import.html.
  • Loading branch information
benjyw authored Nov 14, 2020
1 parent 91e82e2 commit 43bb218
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 27 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import re
import sys
from dataclasses import dataclass
from pathlib import PurePath
from typing import Optional, Set, Tuple

from typed_ast import ast27
Expand Down Expand Up @@ -56,15 +57,16 @@ def parse_file(*, filename: str, content: str) -> Optional[Tuple]:
return None


def find_python_imports(*, filename: str, content: str, module_name: str) -> ParsedPythonImports:
def find_python_imports(*, filename: str, content: str) -> ParsedPythonImports:
package_parts = PurePath(filename).parts[0:-1]
parse_result = parse_file(filename=filename, content=content)
# If there were syntax errors, gracefully early return. This is more user friendly than
# propagating the exception. Dependency inference simply won't be used for that file, and
# it'll be up to the tool actually being run (e.g. Pytest or Flake8) to error.
if parse_result is None:
return ParsedPythonImports(FrozenOrderedSet(), FrozenOrderedSet())
tree, ast_visitor_cls = parse_result
ast_visitor = ast_visitor_cls(module_name)
ast_visitor = ast_visitor_cls(package_parts)
ast_visitor.visit(tree)
return ParsedPythonImports(
explicit_imports=FrozenOrderedSet(sorted(ast_visitor.explicit_imports)),
Expand All @@ -78,8 +80,8 @@ def find_python_imports(*, filename: str, content: str, module_name: str) -> Par


class _BaseAstVisitor:
def __init__(self, module_name: str) -> None:
self._module_parts = module_name.split(".")
def __init__(self, package_parts: Tuple[str, ...]) -> None:
self._package_parts = package_parts
self.explicit_imports: Set[str] = set()
self.inferred_imports: Set[str] = set()

Expand All @@ -92,10 +94,15 @@ def visit_Import(self, node) -> None:
self.explicit_imports.add(alias.name)

def visit_ImportFrom(self, node) -> None:
rel_module = node.module
abs_module = ".".join(
self._module_parts[0 : -node.level] + ([] if rel_module is None else [rel_module])
)
if node.level:
# Relative import.
rel_module = node.module
abs_module = ".".join(
self._package_parts[0 : len(self._package_parts) - node.level + 1]
+ (tuple() if rel_module is None else (rel_module,))
)
else:
abs_module = node.module
for alias in node.names:
self.explicit_imports.add(f"{abs_module}.{alias.name}")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

def test_normal_imports() -> None:
imports = find_python_imports(
filename="foo.py",
filename="project/foo.py",
content=dedent(
"""\
from __future__ import print_function
Expand All @@ -35,7 +35,6 @@ def test_normal_imports() -> None:
import subprocess23 as subprocess
"""
),
module_name="project.app",
)
assert set(imports.explicit_imports) == {
"__future__.print_function",
Expand All @@ -53,20 +52,22 @@ def test_normal_imports() -> None:
assert not imports.inferred_imports


def test_relative_imports() -> None:
@pytest.mark.parametrize("basename", ["foo.py", "__init__.py"])
def test_relative_imports(basename: str) -> None:
imports = find_python_imports(
filename="foo.py",
filename=f"project/util/{basename}",
content=dedent(
"""\
from . import sibling
from .sibling import Nibling
from .subdir.child import Child
from ..parent import Parent
"""
),
module_name="project.util.test_utils",
)
assert set(imports.explicit_imports) == {
"project.util.sibling",
"project.util.sibling.Nibling",
"project.util.subdir.child.Child",
"project.parent.Parent",
}
Expand All @@ -75,7 +76,7 @@ def test_relative_imports() -> None:

def test_imports_from_strings() -> None:
imports = find_python_imports(
filename="foo.py",
filename="project/foo.py",
content=dedent(
"""\
modules = [
Expand Down Expand Up @@ -105,7 +106,6 @@ def test_imports_from_strings() -> None:
importlib.import_module(module)
"""
),
module_name="project.app",
)
assert not imports.explicit_imports
assert set(imports.inferred_imports) == {
Expand All @@ -121,14 +121,14 @@ def test_imports_from_strings() -> None:


def test_gracefully_handle_syntax_errors() -> None:
imports = find_python_imports(filename="foo.py", content="x =", module_name="project.app")
imports = find_python_imports(filename="project/foo.py", content="x =")
assert not imports.explicit_imports
assert not imports.inferred_imports


def test_works_with_python2() -> None:
imports = find_python_imports(
filename="foo.py",
filename="project/foo.py",
content=dedent(
"""\
print "Python 2 lives on."
Expand All @@ -140,7 +140,6 @@ def test_works_with_python2() -> None:
importlib.import_module(u"dep.from.str")
"""
),
module_name="project.app",
)
assert set(imports.explicit_imports) == {"demo", "project.demo.Demo"}
assert set(imports.inferred_imports) == {"dep.from.bytes", "dep.from.str"}
Expand All @@ -152,7 +151,7 @@ def test_works_with_python2() -> None:
)
def test_works_with_python38() -> None:
imports = find_python_imports(
filename="foo.py",
filename="project/foo.py",
content=dedent(
"""\
is_py38 = True
Expand All @@ -165,7 +164,6 @@ def test_works_with_python38() -> None:
importlib.import_module("dep.from.str")
"""
),
module_name="project.app",
)
assert set(imports.explicit_imports) == {"demo", "project.demo.Demo"}
assert set(imports.inferred_imports) == {"dep.from.str"}
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
# Licensed under the Apache License, Version 2.0 (see LICENSE).

import itertools
from pathlib import PurePath
from typing import List, cast

from pants.backend.python.dependency_inference import module_mapper
Expand Down Expand Up @@ -108,18 +107,13 @@ async def infer_python_dependencies(
return InferredDependencies([], sibling_dependencies_inferrable=False)

stripped_sources = await Get(StrippedSourceFiles, SourceFilesRequest([request.sources_field]))
modules = tuple(
PythonModule.create_from_stripped_path(PurePath(fp))
for fp in stripped_sources.snapshot.files
)
digest_contents = await Get(DigestContents, Digest, stripped_sources.snapshot.digest)

owners_requests: List[Get[PythonModuleOwners, PythonModule]] = []
for file_content, module in zip(digest_contents, modules):
for file_content in digest_contents:
file_imports_obj = find_python_imports(
filename=file_content.path,
content=file_content.content.decode(),
module_name=module.module,
)
detected_imports = (
file_imports_obj.all_imports
Expand Down

0 comments on commit 43bb218

Please sign in to comment.