diff --git a/libcst/codemod/_cli.py b/libcst/codemod/_cli.py index 45ae05e74..4726a34f5 100644 --- a/libcst/codemod/_cli.py +++ b/libcst/codemod/_cli.py @@ -16,8 +16,8 @@ import traceback from dataclasses import dataclass, replace from multiprocessing import cpu_count, Pool -from pathlib import Path, PurePath -from typing import Any, AnyStr, cast, Dict, List, Optional, Sequence, Tuple, Union +from pathlib import Path +from typing import Any, AnyStr, cast, Dict, List, Optional, Sequence, Union from libcst import parse_module, PartialParserConfig from libcst.codemod._codemod import Codemod @@ -32,6 +32,7 @@ TransformSkip, TransformSuccess, ) +from libcst.helpers import calculate_module_and_package from libcst.metadata import FullRepoManager _DEFAULT_GENERATED_CODE_MARKER: str = f"@gen{''}erated" @@ -184,35 +185,6 @@ def exec_transform_with_prettyprint( return maybe_code -def _calculate_module_and_package( - repo_root: Optional[str], filename: str -) -> Optional[Tuple[str, str]]: - # Given an absolute repo_root and an absolute filename, calculate the - # python module name for the file. - if repo_root is None: - # We don't have a repo root, so this is impossible to calculate. - return None - - try: - relative_filename = PurePath(filename).relative_to(repo_root) - except ValueError: - # This file seems to be out of the repo root. - return None - - # get rid of extension - relative_filename = relative_filename.with_suffix("") - - # handle special cases - if relative_filename.stem in ["__init__", "__main__"]: - relative_filename = relative_filename.parent - package = module = ".".join(relative_filename.parts) - else: - module = ".".join(relative_filename.parts) - package = ".".join(relative_filename.parts[:-1]) - - return module, package - - @dataclass(frozen=True) class ExecutionResult: # File we have results for @@ -269,11 +241,6 @@ def _execute_transform( # noqa: C901 ), ) - # attempt to work out the module and package name for this file - full_module_name, full_package_name = _calculate_module_and_package( - config.repo_root, filename - ) or (None, None) - # Somewhat gross hack to provide the filename in the transform's context. # We do this after the fork so that a context that was initialized with # some defaults before calling parallel_exec_transform_with_prettyprint @@ -281,11 +248,20 @@ def _execute_transform( # noqa: C901 transformer.context = replace( transformer.context, filename=filename, - full_module_name=full_module_name, - full_package_name=full_package_name, scratch={}, ) + # attempt to work out the module and package name for this file + module_name_and_package = calculate_module_and_package( + config.repo_root, filename + ) + if module_name_and_package is not None: + transformer.context = replace( + transformer.context, + full_module_name=module_name_and_package.name, + full_package_name=module_name_and_package.package, + ) + # Run the transform, bail if we failed or if we aren't formatting code try: input_tree = parse_module( diff --git a/libcst/codemod/tests/test_cli.py b/libcst/codemod/tests/test_cli.py deleted file mode 100644 index adc00e736..000000000 --- a/libcst/codemod/tests/test_cli.py +++ /dev/null @@ -1,71 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. -# -from typing import Optional, Tuple - -from libcst.codemod._cli import _calculate_module_and_package -from libcst.testing.utils import data_provider, UnitTest - - -class TestPackageCalculation(UnitTest): - @data_provider( - ( - # Providing no root should give back no module. - (None, "/some/dummy/file.py", None), - # Providing a file outside the root should give back no module. - ("/home/username/root", "/some/dummy/file.py", None), - ("/home/username/root/", "/some/dummy/file.py", None), - ("/home/username/root", "/home/username/file.py", None), - # Various files inside the root should give back valid modules. - ("/home/username/root", "/home/username/root/file.py", ("file", "")), - ("/home/username/root/", "/home/username/root/file.py", ("file", "")), - ( - "/home/username/root/", - "/home/username/root/some/dir/file.py", - ("some.dir.file", "some.dir"), - ), - # Various special files inside the root should give back valid modules. - ( - "/home/username/root/", - "/home/username/root/some/dir/__init__.py", - ("some.dir", "some.dir"), - ), - ( - "/home/username/root/", - "/home/username/root/some/dir/__main__.py", - ("some.dir", "some.dir"), - ), - # some windows tests - ( - "c:/Program Files/", - "d:/Program Files/some/dir/file.py", - None, - ), - ( - "c:/Program Files/other/", - "c:/Program Files/some/dir/file.py", - None, - ), - ( - "c:/Program Files/", - "c:/Program Files/some/dir/file.py", - ("some.dir.file", "some.dir"), - ), - ( - "c:/Program Files/", - "c:/Program Files/some/dir/__main__.py", - ("some.dir", "some.dir"), - ), - ), - ) - def test_calculate_module_and_package( - self, - repo_root: Optional[str], - filename: str, - module_and_package: Optional[Tuple[str, str]], - ) -> None: - self.assertEqual( - _calculate_module_and_package(repo_root, filename), module_and_package - ) diff --git a/libcst/helpers/__init__.py b/libcst/helpers/__init__.py index ccd12c728..6f0db041f 100644 --- a/libcst/helpers/__init__.py +++ b/libcst/helpers/__init__.py @@ -4,11 +4,6 @@ # LICENSE file in the root directory of this source tree. # -from libcst.helpers._statement import ( - get_absolute_module, - get_absolute_module_for_import, - get_absolute_module_for_import_or_raise, -) from libcst.helpers._template import ( parse_template_expression, parse_template_module, @@ -19,9 +14,17 @@ get_full_name_for_node, get_full_name_for_node_or_raise, ) -from libcst.helpers.module import insert_header_comments +from libcst.helpers.module import ( + calculate_module_and_package, + get_absolute_module, + get_absolute_module_for_import, + get_absolute_module_for_import_or_raise, + insert_header_comments, + ModuleNameAndPackage, +) __all__ = [ + "calculate_module_and_package", "get_absolute_module", "get_absolute_module_for_import", "get_absolute_module_for_import_or_raise", @@ -32,4 +35,5 @@ "parse_template_module", "parse_template_statement", "parse_template_expression", + "ModuleNameAndPackage", ] diff --git a/libcst/helpers/_statement.py b/libcst/helpers/_statement.py deleted file mode 100644 index f62a5eb87..000000000 --- a/libcst/helpers/_statement.py +++ /dev/null @@ -1,58 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. -# -from typing import Optional - -import libcst as cst -from libcst.helpers.expression import get_full_name_for_node - - -def get_absolute_module( - current_module: Optional[str], module_name: Optional[str], num_dots: int -) -> Optional[str]: - if num_dots == 0: - # This is an absolute import, so the module is correct. - return module_name - if current_module is None: - # We don't actually have the current module available, so we can't compute - # the absolute module from relative. - return None - # We have the current module, as well as the relative, let's compute the base. - modules = current_module.split(".") - if len(modules) < num_dots: - # This relative import goes past the base of the repository, so we can't calculate it. - return None - base_module = ".".join(modules[:-num_dots]) - # Finally, if the module name was supplied, append it to the end. - if module_name is not None: - # If we went all the way to the top, the base module should be empty, so we - # should return the relative bit as absolute. Otherwise, combine the base - # module and module name using a dot separator. - base_module = ( - f"{base_module}.{module_name}" if len(base_module) > 0 else module_name - ) - # If they tried to import all the way to the root, return None. Otherwise, - # return the module itself. - return base_module if len(base_module) > 0 else None - - -def get_absolute_module_for_import( - current_module: Optional[str], import_node: cst.ImportFrom -) -> Optional[str]: - # First, let's try to grab the module name, regardless of relative status. - module = import_node.module - module_name = get_full_name_for_node(module) if module is not None else None - # Now, get the relative import location if it exists. - num_dots = len(import_node.relative) - return get_absolute_module(current_module, module_name, num_dots) - - -def get_absolute_module_for_import_or_raise( - current_module: Optional[str], import_node: cst.ImportFrom -) -> str: - module = get_absolute_module_for_import(current_module, import_node) - if module is None: - raise Exception(f"Unable to compute absolute module for {import_node}") - return module diff --git a/libcst/helpers/module.py b/libcst/helpers/module.py index 50e42ff7a..f9ba41aaf 100644 --- a/libcst/helpers/module.py +++ b/libcst/helpers/module.py @@ -3,13 +3,16 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. # +from dataclasses import dataclass from itertools import islice -from typing import List +from pathlib import PurePath +from typing import List, Optional -import libcst +from libcst import Comment, EmptyLine, ImportFrom, Module +from libcst.helpers.expression import get_full_name_for_node -def insert_header_comments(node: libcst.Module, comments: List[str]) -> libcst.Module: +def insert_header_comments(node: Module, comments: List[str]) -> Module: """ Insert comments after last non-empty line in header. Use this to insert one or more comments after any copyright preamble in a :class:`~libcst.Module`. Each comment in @@ -25,9 +28,91 @@ def insert_header_comments(node: libcst.Module, comments: List[str]) -> libcst.M comment_lines = islice(node.header, last_comment_index + 1) empty_lines = islice(node.header, last_comment_index + 1, None) - inserted_lines = [ - libcst.EmptyLine(comment=libcst.Comment(value=comment)) for comment in comments - ] + inserted_lines = [EmptyLine(comment=Comment(value=comment)) for comment in comments] # pyre-fixme[60]: Concatenation not yet support for multiple variadic tuples: # `*comment_lines, *inserted_lines, *empty_lines`. return node.with_changes(header=(*comment_lines, *inserted_lines, *empty_lines)) + + +def get_absolute_module( + current_module: Optional[str], module_name: Optional[str], num_dots: int +) -> Optional[str]: + if num_dots == 0: + # This is an absolute import, so the module is correct. + return module_name + if current_module is None: + # We don't actually have the current module available, so we can't compute + # the absolute module from relative. + return None + # We have the current module, as well as the relative, let's compute the base. + modules = current_module.split(".") + if len(modules) < num_dots: + # This relative import goes past the base of the repository, so we can't calculate it. + return None + base_module = ".".join(modules[:-num_dots]) + # Finally, if the module name was supplied, append it to the end. + if module_name is not None: + # If we went all the way to the top, the base module should be empty, so we + # should return the relative bit as absolute. Otherwise, combine the base + # module and module name using a dot separator. + base_module = ( + f"{base_module}.{module_name}" if len(base_module) > 0 else module_name + ) + # If they tried to import all the way to the root, return None. Otherwise, + # return the module itself. + return base_module if len(base_module) > 0 else None + + +def get_absolute_module_for_import( + current_module: Optional[str], import_node: ImportFrom +) -> Optional[str]: + # First, let's try to grab the module name, regardless of relative status. + module = import_node.module + module_name = get_full_name_for_node(module) if module is not None else None + # Now, get the relative import location if it exists. + num_dots = len(import_node.relative) + return get_absolute_module(current_module, module_name, num_dots) + + +def get_absolute_module_for_import_or_raise( + current_module: Optional[str], import_node: ImportFrom +) -> str: + module = get_absolute_module_for_import(current_module, import_node) + if module is None: + raise Exception(f"Unable to compute absolute module for {import_node}") + return module + + +@dataclass(frozen=True) +class ModuleNameAndPackage: + name: str + package: str + + +def calculate_module_and_package( + repo_root: Optional[str], filename: str +) -> Optional[ModuleNameAndPackage]: + # Given an absolute repo_root and an absolute filename, calculate the + # python module name for the file. + if repo_root is None: + # We don't have a repo root, so this is impossible to calculate. + return None + + try: + relative_filename = PurePath(filename).relative_to(repo_root) + except ValueError: + # This file seems to be out of the repo root. + return None + + # get rid of extension + relative_filename = relative_filename.with_suffix("") + + # handle special cases + if relative_filename.stem in ["__init__", "__main__"]: + relative_filename = relative_filename.parent + package = name = ".".join(relative_filename.parts) + else: + name = ".".join(relative_filename.parts) + package = ".".join(relative_filename.parts[:-1]) + + return ModuleNameAndPackage(name, package) diff --git a/libcst/helpers/tests/test_module.py b/libcst/helpers/tests/test_module.py index 687e0260a..da9dab71a 100644 --- a/libcst/helpers/tests/test_module.py +++ b/libcst/helpers/tests/test_module.py @@ -3,9 +3,18 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. # -import libcst -from libcst.helpers import insert_header_comments -from libcst.testing.utils import UnitTest +from typing import Optional + +import libcst as cst +from libcst.helpers.common import ensure_type +from libcst.helpers.module import ( + calculate_module_and_package, + get_absolute_module_for_import, + get_absolute_module_for_import_or_raise, + insert_header_comments, + ModuleNameAndPackage, +) +from libcst.testing.utils import data_provider, UnitTest class ModuleTest(UnitTest): @@ -18,7 +27,7 @@ def test_insert_header_comments(self) -> None: expected_code = "\n".join( comment_lines + inserted_comments + empty_lines + non_header_line ) - node = libcst.parse_module(original_code) + node = cst.parse_module(original_code) self.assertEqual( insert_header_comments(node, inserted_comments).code, expected_code ) @@ -26,7 +35,7 @@ def test_insert_header_comments(self) -> None: # No comment case original_code = "\n".join(empty_lines + non_header_line) expected_code = "\n".join(inserted_comments + empty_lines + non_header_line) - node = libcst.parse_module(original_code) + node = cst.parse_module(original_code) self.assertEqual( insert_header_comments(node, inserted_comments).code, expected_code ) @@ -34,7 +43,7 @@ def test_insert_header_comments(self) -> None: # No empty lines case original_code = "\n".join(comment_lines + non_header_line) expected_code = "\n".join(comment_lines + inserted_comments + non_header_line) - node = libcst.parse_module(original_code) + node = cst.parse_module(original_code) self.assertEqual( insert_header_comments(node, inserted_comments).code, expected_code ) @@ -45,7 +54,7 @@ def test_insert_header_comments(self) -> None: expected_code = "\n".join( comment_lines + inserted_comments + empty_lines + non_header_line ) - node = libcst.parse_module(original_code) + node = cst.parse_module(original_code) self.assertEqual( insert_header_comments(node, inserted_comments).code, expected_code ) @@ -53,7 +62,146 @@ def test_insert_header_comments(self) -> None: # No header case original_code = "\n".join(non_header_line) expected_code = "\n".join(inserted_comments + non_header_line) - node = libcst.parse_module(original_code) + node = cst.parse_module(original_code) self.assertEqual( insert_header_comments(node, inserted_comments).code, expected_code ) + + @data_provider( + ( + # Simple imports that are already absolute. + (None, "from a.b import c", "a.b"), + ("x.y.z", "from a.b import c", "a.b"), + # Relative import that can't be resolved due to missing module. + (None, "from ..w import c", None), + # Relative import that goes past the module level. + ("x", "from ...y import z", None), + ("x.y.z", "from .....w import c", None), + ("x.y.z", "from ... import c", None), + # Correct resolution of absolute from relative modules. + ("x.y.z", "from . import c", "x.y"), + ("x.y.z", "from .. import c", "x"), + ("x.y.z", "from .w import c", "x.y.w"), + ("x.y.z", "from ..w import c", "x.w"), + ("x.y.z", "from ...w import c", "w"), + ) + ) + def test_get_absolute_module( + self, + module: Optional[str], + importfrom: str, + output: Optional[str], + ) -> None: + node = ensure_type(cst.parse_statement(importfrom), cst.SimpleStatementLine) + assert len(node.body) == 1, "Unexpected number of statements!" + import_node = ensure_type(node.body[0], cst.ImportFrom) + + self.assertEqual(get_absolute_module_for_import(module, import_node), output) + if output is None: + with self.assertRaises(Exception): + get_absolute_module_for_import_or_raise(module, import_node) + else: + self.assertEqual( + get_absolute_module_for_import_or_raise(module, import_node), output + ) + + @data_provider( + ( + # Nodes without an asname + (cst.ImportAlias(name=cst.Name("foo")), "foo", None), + ( + cst.ImportAlias(name=cst.Attribute(cst.Name("foo"), cst.Name("bar"))), + "foo.bar", + None, + ), + # Nodes with an asname + ( + cst.ImportAlias( + name=cst.Name("foo"), asname=cst.AsName(name=cst.Name("baz")) + ), + "foo", + "baz", + ), + ( + cst.ImportAlias( + name=cst.Attribute(cst.Name("foo"), cst.Name("bar")), + asname=cst.AsName(name=cst.Name("baz")), + ), + "foo.bar", + "baz", + ), + ) + ) + def test_importalias_helpers( + self, alias_node: cst.ImportAlias, full_name: str, alias: Optional[str] + ) -> None: + self.assertEqual(alias_node.evaluated_name, full_name) + self.assertEqual(alias_node.evaluated_alias, alias) + + @data_provider( + ( + # Providing no root should give back no module. + (None, "/some/dummy/file.py", None), + # Providing a file outside the root should give back no module. + ("/home/username/root", "/some/dummy/file.py", None), + ("/home/username/root/", "/some/dummy/file.py", None), + ("/home/username/root", "/home/username/file.py", None), + # Various files inside the root should give back valid modules. + ( + "/home/username/root", + "/home/username/root/file.py", + ModuleNameAndPackage("file", ""), + ), + ( + "/home/username/root/", + "/home/username/root/file.py", + ModuleNameAndPackage("file", ""), + ), + ( + "/home/username/root/", + "/home/username/root/some/dir/file.py", + ModuleNameAndPackage("some.dir.file", "some.dir"), + ), + # Various special files inside the root should give back valid modules. + ( + "/home/username/root/", + "/home/username/root/some/dir/__init__.py", + ModuleNameAndPackage("some.dir", "some.dir"), + ), + ( + "/home/username/root/", + "/home/username/root/some/dir/__main__.py", + ModuleNameAndPackage("some.dir", "some.dir"), + ), + # some windows tests + ( + "c:/Program Files/", + "d:/Program Files/some/dir/file.py", + None, + ), + ( + "c:/Program Files/other/", + "c:/Program Files/some/dir/file.py", + None, + ), + ( + "c:/Program Files/", + "c:/Program Files/some/dir/file.py", + ModuleNameAndPackage("some.dir.file", "some.dir"), + ), + ( + "c:/Program Files/", + "c:/Program Files/some/dir/__main__.py", + ModuleNameAndPackage("some.dir", "some.dir"), + ), + ), + ) + def test_calculate_module_and_package( + self, + repo_root: Optional[str], + filename: str, + module_and_package: Optional[ModuleNameAndPackage], + ) -> None: + self.assertEqual( + calculate_module_and_package(repo_root, filename), module_and_package + ) diff --git a/libcst/helpers/tests/test_statement.py b/libcst/helpers/tests/test_statement.py deleted file mode 100644 index f26900bdd..000000000 --- a/libcst/helpers/tests/test_statement.py +++ /dev/null @@ -1,87 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. -# -from typing import Optional - -import libcst as cst -from libcst.helpers import ( - ensure_type, - get_absolute_module_for_import, - get_absolute_module_for_import_or_raise, -) -from libcst.testing.utils import data_provider, UnitTest - - -class StatementTest(UnitTest): - @data_provider( - ( - # Simple imports that are already absolute. - (None, "from a.b import c", "a.b"), - ("x.y.z", "from a.b import c", "a.b"), - # Relative import that can't be resolved due to missing module. - (None, "from ..w import c", None), - # Relative import that goes past the module level. - ("x", "from ...y import z", None), - ("x.y.z", "from .....w import c", None), - ("x.y.z", "from ... import c", None), - # Correct resolution of absolute from relative modules. - ("x.y.z", "from . import c", "x.y"), - ("x.y.z", "from .. import c", "x"), - ("x.y.z", "from .w import c", "x.y.w"), - ("x.y.z", "from ..w import c", "x.w"), - ("x.y.z", "from ...w import c", "w"), - ) - ) - def test_get_absolute_module( - self, - module: Optional[str], - importfrom: str, - output: Optional[str], - ) -> None: - node = ensure_type(cst.parse_statement(importfrom), cst.SimpleStatementLine) - assert len(node.body) == 1, "Unexpected number of statements!" - import_node = ensure_type(node.body[0], cst.ImportFrom) - - self.assertEqual(get_absolute_module_for_import(module, import_node), output) - if output is None: - with self.assertRaises(Exception): - get_absolute_module_for_import_or_raise(module, import_node) - else: - self.assertEqual( - get_absolute_module_for_import_or_raise(module, import_node), output - ) - - @data_provider( - ( - # Nodes without an asname - (cst.ImportAlias(name=cst.Name("foo")), "foo", None), - ( - cst.ImportAlias(name=cst.Attribute(cst.Name("foo"), cst.Name("bar"))), - "foo.bar", - None, - ), - # Nodes with an asname - ( - cst.ImportAlias( - name=cst.Name("foo"), asname=cst.AsName(name=cst.Name("baz")) - ), - "foo", - "baz", - ), - ( - cst.ImportAlias( - name=cst.Attribute(cst.Name("foo"), cst.Name("bar")), - asname=cst.AsName(name=cst.Name("baz")), - ), - "foo.bar", - "baz", - ), - ) - ) - def test_importalias_helpers( - self, alias_node: cst.ImportAlias, full_name: str, alias: Optional[str] - ) -> None: - self.assertEqual(alias_node.evaluated_name, full_name) - self.assertEqual(alias_node.evaluated_alias, alias)