diff --git a/src/python/pants/backend/jvm/tasks/rewrite_base.py b/src/python/pants/backend/jvm/tasks/rewrite_base.py index ae4c8e7426a..951489077f3 100644 --- a/src/python/pants/backend/jvm/tasks/rewrite_base.py +++ b/src/python/pants/backend/jvm/tasks/rewrite_base.py @@ -1,9 +1,14 @@ # Copyright 2017 Pants project contributors (see CONTRIBUTORS.md). # Licensed under the Apache License, Version 2.0 (see LICENSE). +import functools +import itertools +import math import os import shutil +import threading from abc import ABCMeta, abstractmethod +from typing import Any, List from pants.backend.jvm.tasks.nailgun_task import NailgunTask from pants.base.build_environment import get_buildroot @@ -29,6 +34,13 @@ def register_options(cls, register): help='Path to output directory. Any updated files will be written here. ' 'If not specified, files will be modified in-place.') + register('--files-per-worker', type=int, fingerprint=False, + default=None, + help='Number of files to use per each scalafmt execution.') + register('--worker-count', type=int, fingerprint=False, + default=None, + help='Total number of parallel scalafmt threads or processes to run.') + @classmethod def target_types(cls): """Returns a list of target type names (e.g.: `scala_library`) this rewriter operates on.""" @@ -63,17 +75,77 @@ def execute(self): with self.invalidated(relevant_targets) as invalidation_check: self._execute_for([vt.target for vt in invalidation_check.invalid_vts]) + def _split_by_threads(self, inputs_list_of_lists: List[List[Any]], invoke_fn): + parent_workunit = self.context.run_tracker.get_current_workunit() + + # Propagate exceptions in threads to the toplevel by checking this variable after joining all + # the threads. + all_exceptions = [] + + def thread_exception_wrapper(fn): + @functools.wraps(fn) + def inner(*args, **kwargs): + try: + fn(*args, **kwargs) + except Exception as e: + all_exceptions.append(e) + return inner + + all_threads = [ + threading.Thread( + name=f'scalafmt invocation thread #{idx}/{len(inputs_list_of_lists)}', + target=thread_exception_wrapper(invoke_fn), + args=[parent_workunit, inputs_single_list], + ) + for idx, inputs_single_list in enumerate(inputs_list_of_lists) + ] + for thread in all_threads: + thread.start() + for thread in all_threads: + try: + thread.join() + except Exception as e: + raise TaskError(str(e)) from e + if all_exceptions: + joined_str = ', '.join(str(e) for e in all_exceptions) + raise TaskError(f'all errors: {joined_str}') from all_exceptions[0] + def _execute_for(self, targets): target_sources = self._calculate_sources(targets) if not target_sources: return - result = Xargs(self._invoke_tool).execute(target_sources) - if result != 0: - raise TaskError('{} is improperly implemented: a failed process ' - 'should raise an exception earlier.'.format(type(self).__name__)) - - def _invoke_tool(self, target_sources): + if self.get_options().files_per_worker is not None: + # If --files-per-worker is specified, split the target sources and run in separate threads! + n = self.get_options().files_per_worker + inputs_list_of_lists = [ + target_sources[i:i + n] + for i in range(0, len(target_sources), n) + ] + self._split_by_threads(inputs_list_of_lists=inputs_list_of_lists, invoke_fn=self._invoke_tool) + elif self.get_options().worker_count is not None: + # If --worker-count is specified, split the target sources into that many + # threads, and run in separate threads! + num_processes = self.get_options().worker_count + sources_iterator = iter(target_sources) + inputs_list_of_lists = [ + list(itertools.islice(sources_iterator, 0, math.ceil(len(target_sources) / num_processes))) + for _ in range(0, num_processes) + ] + self._split_by_threads(inputs_list_of_lists=inputs_list_of_lists, invoke_fn=self._invoke_tool) + else: + # Otherwise, pass in the parent workunit to Xargs, which is passed into self._invoke_tool. + parent_workunit = self.context.run_tracker.get_current_workunit() + result = Xargs(self._invoke_tool, constant_args=[parent_workunit]).execute(target_sources) + if result != 0: + raise TaskError('{} is improperly implemented: a failed process ' + 'should raise an exception earlier.'.format(type(self).__name__)) + + def _invoke_tool(self, parent_workunit, target_sources): + # We want to avoid executing anything if there are no sources to generate. + if not target_sources: + return 0 + self.context.run_tracker.register_thread(parent_workunit) buildroot = get_buildroot() toolroot = buildroot if self.sideeffecting and self.get_options().output_dir: @@ -88,7 +160,7 @@ def _invoke_tool(self, target_sources): for old, new in zip(old_file_paths, new_file_paths): shutil.copyfile(old, new) target_sources = new_sources - result = self.invoke_tool(toolroot, target_sources) + result = self.invoke_tool(parent_workunit, toolroot, target_sources) self.process_result(result) return result diff --git a/src/python/pants/backend/jvm/tasks/scalafmt.py b/src/python/pants/backend/jvm/tasks/scalafmt.py index 493811df681..263e361ddaf 100644 --- a/src/python/pants/backend/jvm/tasks/scalafmt.py +++ b/src/python/pants/backend/jvm/tasks/scalafmt.py @@ -1,14 +1,80 @@ # Copyright 2016 Pants project contributors (see CONTRIBUTORS.md). # Licensed under the Apache License, Version 2.0 (see LICENSE). +import os +import subprocess from abc import abstractmethod +from typing import List +from pants.backend.jvm.subsystems.jvm_tool_mixin import JvmToolMixin from pants.backend.jvm.tasks.rewrite_base import RewriteBase from pants.base.exceptions import TaskError +from pants.base.workunit import WorkUnitLabel +from pants.binaries.binary_tool import NativeTool +from pants.binaries.binary_util import BinaryToolUrlGenerator +from pants.engine.platform import Platform from pants.java.jar.jar_dependency import JarDependency from pants.option.custom_types import file_option +from pants.process.xargs import Xargs from pants.task.fmt_task_mixin import FmtTaskMixin from pants.task.lint_task_mixin import LintTaskMixin +from pants.util.dirutil import chmod_plus_x +from pants.util.memo import memoized_method + + +class ScalaFmtNativeUrlGenerator(BinaryToolUrlGenerator): + + _DIST_URL_FMT = 'https://github.com/scalameta/scalafmt/releases/download/v{version}/scalafmt-{system_id}.zip' + + _SYSTEM_ID = { + 'mac': 'macos', + 'linux': 'linux', + } + + def generate_urls(self, version, host_platform): + system_id = self._SYSTEM_ID[host_platform.os_name] + return [self._DIST_URL_FMT.format(version=version, system_id=system_id)] + + +class ScalaFmtSubsystem(JvmToolMixin, NativeTool): + options_scope = 'scalafmt' + default_version = '2.3.1' + archive_type = 'zip' + + def get_external_url_generator(self): + return ScalaFmtNativeUrlGenerator() + + @memoized_method + def select(self): + """Reach into the unzipped directory and return the scalafmt executable. + + Also make sure to chmod +x the scalafmt executable, since the release zip doesn't do that. + """ + extracted_dir = super().select() + inner_dir_name = Platform.current.match({ + Platform.darwin: 'scalafmt-macos', + Platform.linux: 'scalafmt-linux', + }) + output_file = os.path.join(extracted_dir, inner_dir_name, 'scalafmt') + chmod_plus_x(output_file) + return output_file + + @property + def use_native_image(self) -> bool: + return bool(self.get_options().use_native_image) + + @classmethod + def register_options(cls, register): + super().register_options(register) + register('--use-native-image', type=bool, advanced=True, fingerprint=False, + help='Use a pre-compiled native-image for scalafmt.') + + cls.register_jvm_tool(register, + 'scalafmt', + classpath=[ + JarDependency(org='com.geirsson', + name='scalafmt-cli_2.11', + rev='1.5.1')]) class ScalaFmt(RewriteBase): @@ -18,20 +84,22 @@ class ScalaFmt(RewriteBase): process_result to run different scalafmt commands. """ + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._all_command_lines = [] + + @classmethod + def subsystem_dependencies(cls): + return super().subsystem_dependencies() + ( + ScalaFmtSubsystem.scoped(cls), + ) + @classmethod def register_options(cls, register): super().register_options(register) register('--configuration', advanced=True, type=file_option, fingerprint=True, help='Path to scalafmt config file, if not specified default scalafmt config used') - cls.register_jvm_tool(register, - 'scalafmt', - classpath=[ - JarDependency(org='com.geirsson', - name='scalafmt-cli_2.11', - rev='1.5.1') - ]) - @classmethod def target_types(cls): return ['scala_library', 'junit_tests'] @@ -44,20 +112,60 @@ def source_extension(cls): def implementation_version(cls): return super().implementation_version() + [('ScalaFmt', 5)] - def invoke_tool(self, absolute_root, target_sources): - # If no config file is specified use default scalafmt config. - config_file = self.get_options().configuration - args = list(self.additional_args) - if config_file is not None: - args.extend(['--config', config_file]) - args.extend([source for _target, source in target_sources]) - - return self.runjava(classpath=self.tool_classpath('scalafmt'), + @property + def _use_native_image(self) -> bool: + return ScalaFmtSubsystem.scoped_instance(self).use_native_image + + def _native_image_path(self) -> str: + return ScalaFmtSubsystem.scoped_instance(self).select() + + def _tool_classpath(self) -> List[str]: + subsystem = ScalaFmtSubsystem.scoped_instance(self) + return subsystem.tool_classpath_from_products( + self.context.products, + key='scalafmt', + scope=subsystem.options_scope) + + def _invoke_native_image_subprocess(self, prefix_args, workunit, all_source_paths): + self._all_command_lines.append((prefix_args, all_source_paths)) + return subprocess.run( + args=(prefix_args + all_source_paths), + stdout=workunit.output('stdout'), + stderr=workunit.output('stderr'), + ).returncode + + def _invoke_jvm_process(self, prefix_args, all_source_paths): + return self.runjava(classpath=self._tool_classpath(), main='org.scalafmt.cli.Cli', - args=args, + args=(prefix_args + all_source_paths), workunit_name='scalafmt', jvm_options=self.get_options().jvm_options) + def invoke_tool(self, current_workunit, absolute_root, target_sources): + self.context.log.debug(f'scalafmt called with sources: {target_sources}') + + # If no config file is specified, use default scalafmt config. + config_file = self.get_options().configuration + prefix_args = list(self.additional_args) + if config_file is not None: + prefix_args.extend(['--config', config_file]) + + all_source_paths = [source for _target, source in target_sources] + + if self._use_native_image: + with self.context.run_tracker.new_workunit( + name='scalafmt', + labels=[WorkUnitLabel.COMPILER], + ) as workunit: + prefix_args = [self._native_image_path()] + prefix_args + self.context.log.debug(f'executing scalafmt with native image with prefix args: {prefix_args}') + return Xargs( + self._invoke_native_image_subprocess, + constant_args=[prefix_args, workunit], + ).execute(all_source_paths) + else: + return Xargs(self._invoke_jvm_process, constant_args=[prefix_args]).execute(all_source_paths) + @property @abstractmethod def additional_args(self): diff --git a/src/python/pants/goal/run_tracker.py b/src/python/pants/goal/run_tracker.py index c1c71483d6a..417d0c4c424 100644 --- a/src/python/pants/goal/run_tracker.py +++ b/src/python/pants/goal/run_tracker.py @@ -187,6 +187,14 @@ def set_sorted_goal_infos(self, sorted_goal_infos): def set_v2_console_rule_names(self, v2_console_rule_names): self._v2_console_rule_names = v2_console_rule_names + def get_current_workunit(self): + """Return the workunit associated with the current thread. + + This can be used along with .register_thread() to ensure spawned threads have access to the + parent workunit that spawned them! + """ + return self._threadlocal.current_workunit + def register_thread(self, parent_workunit): """Register the parent workunit for all work in the calling thread. @@ -383,7 +391,7 @@ def error(msg): if stats_version not in cls.SUPPORTED_STATS_VERSIONS: raise ValueError("Invalid stats version") - + auth_data = BasicAuth.global_instance().get_auth_for_provider(auth_provider) headers = cls._get_headers(stats_version=stats_version) @@ -462,7 +470,7 @@ def _stats(self): else: stats.update({ 'self_timings': self.self_timings.get_all(), - 'critical_path_timings': self.get_critical_path_timings().get_all(), + 'critical_path_timings': self.get_critical_path_timings().get_all(), 'outcomes': self.outcomes, }) return stats diff --git a/src/python/pants/process/xargs.py b/src/python/pants/process/xargs.py index 5a2ed7bfecc..948e77cb060 100644 --- a/src/python/pants/process/xargs.py +++ b/src/python/pants/process/xargs.py @@ -2,9 +2,13 @@ # Licensed under the Apache License, Version 2.0 (see LICENSE). import errno +import logging import subprocess +logger = logging.getLogger(__name__) + + class Xargs: """A subprocess execution wrapper in the spirit of the xargs command line tool. @@ -22,13 +26,15 @@ def call(args): return subprocess.call(cmd + args, **kwargs) return cls(call) - def __init__(self, cmd): + def __init__(self, cmd, constant_args=None): """Creates an xargs engine that calls cmd with argument chunks. :param cmd: A function that can execute a command line in the form of a list of strings passed as its sole argument. + :param constant_args: Any positional arguments to be added to each invocation. """ self._cmd = cmd + self._constant_args = constant_args or [] def _split_args(self, args): half = len(args) // 2 @@ -39,12 +45,15 @@ def execute(self, args): :param list args: Extra arguments to pass to cmd. """ - all_args = list(args) + splittable_args = list(args) + all_args_for_command_function = self._constant_args + [splittable_args] + logger.debug(f'xargs all_args_for_command_function: {all_args_for_command_function}') try: - return self._cmd(all_args) + return self._cmd(*all_args_for_command_function) except OSError as e: if errno.E2BIG == e.errno: - args1, args2 = self._split_args(all_args) + args1, args2 = self._split_args(splittable_args) + logger.debug(f'xargs split cmd line:\nargs1={args1},\nargs2={args2}!') result = self.execute(args1) if result != 0: return result diff --git a/src/python/pants/testutil/base/context_utils.py b/src/python/pants/testutil/base/context_utils.py index 61a2e33b486..d538c319c10 100644 --- a/src/python/pants/testutil/base/context_utils.py +++ b/src/python/pants/testutil/base/context_utils.py @@ -49,6 +49,7 @@ class DummyRunTracker: def __init__(self): self.logger = RunTrackerLogger(self) + self._cur_workunit = TestContext.DummyWorkUnit() class DummyArtifactCacheStats: def add_hits(self, cache_name, targets): pass @@ -59,6 +60,16 @@ def add_misses(self, cache_name, targets, causes): pass def report_target_info(self, scope, target, keys, val): pass + def get_current_workunit(self): + return self._cur_workunit + + def register_thread(self, parent_workunit): + self._cur_workunit = parent_workunit + + @contextmanager + def new_workunit(self, *args, **kwargs): + yield self.get_current_workunit() + class TestLogger(logging.getLoggerClass()): # type: ignore[misc] # MyPy does't understand this dynamic base class """A logger that converts our structured records into flat ones. diff --git a/tests/python/pants_test/backend/jvm/tasks/test_scalafmt.py b/tests/python/pants_test/backend/jvm/tasks/test_scalafmt.py index 4acaa9be0a7..3c6cff01615 100644 --- a/tests/python/pants_test/backend/jvm/tasks/test_scalafmt.py +++ b/tests/python/pants_test/backend/jvm/tasks/test_scalafmt.py @@ -3,12 +3,13 @@ import os from textwrap import dedent +from uuid import uuid4 from pants.backend.jvm.subsystems.scala_platform import ScalaPlatform from pants.backend.jvm.subsystems.scoverage_platform import ScoveragePlatform from pants.backend.jvm.targets.junit_tests import JUnitTests from pants.backend.jvm.targets.scala_library import ScalaLibrary -from pants.backend.jvm.tasks.scalafmt import ScalaFmtCheckFormat, ScalaFmtFormat +from pants.backend.jvm.tasks.scalafmt import ScalaFmtCheckFormat, ScalaFmtFormat, ScalaFmtSubsystem from pants.base.exceptions import TaskError from pants.build_graph.build_file_aliases import BuildFileAliases from pants.build_graph.resources import Resources @@ -32,6 +33,7 @@ def setUp(self): init_subsystem(ScalaPlatform) init_subsystem(ScoveragePlatform) init_subsystem(SourceRootConfig) + init_subsystem(ScalaFmtSubsystem) self.configuration = self.create_file( relpath='build-support/scalafmt/config', @@ -165,3 +167,76 @@ def test_output_dir(self): relative_test_file = fast_relpath(self.test_file, self.build_root) with open(os.path.join(output_dir, relative_test_file), 'r') as fp: self.assertNotEqual(self.test_file_contents, fp.read()) + + def _execute_native_image(self, num_sources=1, **kwargs): + self.set_options(skip=False, **kwargs) + self.set_options_for_scope('scalafmt', use_native_image=True) + + uuid = '_' + str(uuid4()).replace('-', '_') + + more_sources = [ + self.create_file(relpath=f'src/scala/org/pantsbuild/badscalastyle/BadScalaStyle{i}.scala', + contents=self.test_file_contents) + for i in range(0, num_sources) + ] + generated_library = self.make_target(spec=f'src/scala/org/pantsbuild/badscalastyle:{uuid}', + sources=[f'BadScalaStyle{i}.scala' + for i in range(0, num_sources)], + target_type=ScalaLibrary) + + context = self.context(target_roots=generated_library) + task = self.execute(context) + + # Assert that it ran successfully. + for src in more_sources: + with open(src, 'r') as fp: + self.assertNotEqual(self.test_file_contents, fp.read()) + + # Assert that the native-image executable was most recently used. + scalafmt_native_image_basedir = ScalaFmtSubsystem.global_instance().select() + + all_source_paths = [] + for prefix_args, source_paths in task._all_command_lines: + executable_file = prefix_args[0] + self.assertTrue(executable_file.startswith(scalafmt_native_image_basedir)) + all_source_paths.append(source_paths) + + return all_source_paths + + def _assert_num_files_per_process_matches(self, all_source_paths, all_num_files): + self.assertEqual(len(all_source_paths), len(all_num_files)) + for paths, expected_num_files in zip(all_source_paths, all_num_files): + self.assertEqual(len(paths), expected_num_files) + + def test_native_image_execution(self): + self._execute_native_image() + + def test_native_image_threading_worker_count(self): + all_source_paths = self._execute_native_image(num_sources=4, worker_count=2) + self._assert_num_files_per_process_matches(all_source_paths, + all_num_files=[2, 2]) + + def test_native_image_threading_uneven_divisor(self): + all_source_paths = self._execute_native_image(num_sources=5, worker_count=3) + self._assert_num_files_per_process_matches(all_source_paths, + all_num_files=[2, 2, 1]) + + def test_native_image_threading_massive_process_limit(self): + all_source_paths = self._execute_native_image(num_sources=2, worker_count=3) + self._assert_num_files_per_process_matches(all_source_paths, + all_num_files=[1, 1]) + + def test_native_image_threading_files_per_worker(self): + all_source_paths = self._execute_native_image(num_sources=2, files_per_worker=1) + self._assert_num_files_per_process_matches(all_source_paths, + all_num_files=[1, 1]) + + def test_native_image_threading_uneven_files_per_worker(self): + all_source_paths = self._execute_native_image(num_sources=5, files_per_worker=3) + self._assert_num_files_per_process_matches(all_source_paths, + all_num_files=[3, 2]) + + def test_native_image_threading_massive_files_per_process(self): + all_source_paths = self._execute_native_image(num_sources=2, files_per_worker=5) + self._assert_num_files_per_process_matches(all_source_paths, + all_num_files=[2])