From fe9abdba19959500e70d6d57b1f4a6b90d868e9d Mon Sep 17 00:00:00 2001 From: Danny McClanahan <1305167+cosmicexplorer@users.noreply.github.com> Date: Fri, 6 Dec 2019 16:10:27 -0800 Subject: [PATCH 1/6] add scalafmt native-image NativeTool --- .../pants/backend/jvm/tasks/scalafmt.py | 115 ++++++++++++++++-- src/python/pants/process/xargs.py | 6 + 2 files changed, 108 insertions(+), 13 deletions(-) diff --git a/src/python/pants/backend/jvm/tasks/scalafmt.py b/src/python/pants/backend/jvm/tasks/scalafmt.py index 493811df681..6a0f0e19d50 100644 --- a/src/python/pants/backend/jvm/tasks/scalafmt.py +++ b/src/python/pants/backend/jvm/tasks/scalafmt.py @@ -1,14 +1,79 @@ # Copyright 2016 Pants project contributors (see CONTRIBUTORS.md). # Licensed under the Apache License, Version 2.0 (see LICENSE). +import os +import subprocess from abc import abstractmethod +from typing import List, cast +from pants.backend.jvm.subsystems.jvm_tool_mixin import JvmToolMixin from pants.backend.jvm.tasks.rewrite_base import RewriteBase from pants.base.exceptions import TaskError +from pants.base.workunit import WorkUnitLabel +from pants.binaries.binary_tool import NativeTool +from pants.binaries.binary_util import BinaryToolUrlGenerator +from pants.engine.platform import Platform from pants.java.jar.jar_dependency import JarDependency from pants.option.custom_types import file_option from pants.task.fmt_task_mixin import FmtTaskMixin from pants.task.lint_task_mixin import LintTaskMixin +from pants.util.dirutil import chmod_plus_x +from pants.util.memo import memoized_method + + +class ScalaFmtNativeUrlGenerator(BinaryToolUrlGenerator): + + _DIST_URL_FMT = 'https://github.com/scalameta/scalafmt/releases/download/v{version}/scalafmt-{system_id}.zip' + + _SYSTEM_ID = { + 'mac': 'macos', + 'linux': 'linux', + } + + def generate_urls(self, version, host_platform): + system_id = self._SYSTEM_ID[host_platform.os_name] + return [self._DIST_URL_FMT.format(version=version, system_id=system_id)] + + +class ScalaFmtSubsystem(JvmToolMixin, NativeTool): + options_scope = 'scalafmt' + default_version = '2.3.1' + archive_type = 'zip' + + def get_external_url_generator(self): + return ScalaFmtNativeUrlGenerator() + + @memoized_method + def select(self): + """Reach into the unzipped directory and return the scalafmt executable. + + Also make sure to chmod +x the scalafmt executable, since the release zip doesn't do that. + """ + extracted_dir = super().select() + inner_dir_name = Platform.current.match({ + Platform.darwin: 'scalafmt-macos', + Platform.linux: 'scalafmt-linux', + }) + output_file = os.path.join(extracted_dir, inner_dir_name, 'scalafmt') + chmod_plus_x(output_file) + return output_file + + @property + def use_native_image(self) -> bool: + return cast(bool, self.get_options().use_native_image) + + @classmethod + def register_options(cls, register): + super().register_options(register) + register('--use-native-image', type=bool, advanced=True, fingerprint=False, + help='Use a pre-compiled native-image for scalafmt.') + + cls.register_jvm_tool(register, + 'scalafmt', + classpath=[ + JarDependency(org='com.geirsson', + name='scalafmt-cli_2.11', + rev='1.5.1')]) class ScalaFmt(RewriteBase): @@ -18,20 +83,18 @@ class ScalaFmt(RewriteBase): process_result to run different scalafmt commands. """ + @classmethod + def subsystem_dependencies(cls): + return super().subsystem_dependencies() + ( + ScalaFmtSubsystem.scoped(cls), + ) + @classmethod def register_options(cls, register): super().register_options(register) register('--configuration', advanced=True, type=file_option, fingerprint=True, help='Path to scalafmt config file, if not specified default scalafmt config used') - cls.register_jvm_tool(register, - 'scalafmt', - classpath=[ - JarDependency(org='com.geirsson', - name='scalafmt-cli_2.11', - rev='1.5.1') - ]) - @classmethod def target_types(cls): return ['scala_library', 'junit_tests'] @@ -44,7 +107,23 @@ def source_extension(cls): def implementation_version(cls): return super().implementation_version() + [('ScalaFmt', 5)] + @property + def _use_native_image(self) -> bool: + return ScalaFmtSubsystem.scoped_instance(self).use_native_image + + def _native_image_path(self) -> str: + return ScalaFmtSubsystem.scoped_instance(self).select() + + def _tool_classpath(self) -> List[str]: + subsystem = ScalaFmtSubsystem.scoped_instance(self) + return subsystem.tool_classpath_from_products( + self.context.products, + key='scalafmt', + scope=subsystem.options_scope) + def invoke_tool(self, absolute_root, target_sources): + self.context.log.debug(f'scalafmt called with sources: {target_sources}') + # If no config file is specified use default scalafmt config. config_file = self.get_options().configuration args = list(self.additional_args) @@ -52,11 +131,21 @@ def invoke_tool(self, absolute_root, target_sources): args.extend(['--config', config_file]) args.extend([source for _target, source in target_sources]) - return self.runjava(classpath=self.tool_classpath('scalafmt'), - main='org.scalafmt.cli.Cli', - args=args, - workunit_name='scalafmt', - jvm_options=self.get_options().jvm_options) + if self._use_native_image: + with self.context.new_workunit(name='scalafmt', labels=[WorkUnitLabel.COMPILER]) as workunit: + args = [self._native_image_path(), *args] + self.context.log.debug(f'executing scalafmt with native image with args: {args}') + return subprocess.run( + args=args, + stdout=workunit.output('stdout'), + stderr=workunit.output('stderr'), + ).returncode + else: + return self.runjava(classpath=self._tool_classpath(), + main='org.scalafmt.cli.Cli', + args=args, + workunit_name='scalafmt', + jvm_options=self.get_options().jvm_options) @property @abstractmethod diff --git a/src/python/pants/process/xargs.py b/src/python/pants/process/xargs.py index 5a2ed7bfecc..abbd8aa7a03 100644 --- a/src/python/pants/process/xargs.py +++ b/src/python/pants/process/xargs.py @@ -2,9 +2,13 @@ # Licensed under the Apache License, Version 2.0 (see LICENSE). import errno +import logging import subprocess +logger = logging.getLogger(__name__) + + class Xargs: """A subprocess execution wrapper in the spirit of the xargs command line tool. @@ -40,11 +44,13 @@ def execute(self, args): :param list args: Extra arguments to pass to cmd. """ all_args = list(args) + logger.debug(f'xargs all_args: {all_args}') try: return self._cmd(all_args) except OSError as e: if errno.E2BIG == e.errno: args1, args2 = self._split_args(all_args) + logger.debug(f'xargs split cmd line:\nargs1={args1},\nargs2={args2}!') result = self.execute(args1) if result != 0: return result From fea5e44938e9e9ac295f652834f9c5e7d614a708 Mon Sep 17 00:00:00 2001 From: Danny McClanahan <1305167+cosmicexplorer@users.noreply.github.com> Date: Fri, 6 Dec 2019 17:32:48 -0800 Subject: [PATCH 2/6] add generic parallelism options to RewriteBase! --- .../pants/backend/jvm/tasks/rewrite_base.py | 65 +++++++++++++++++-- .../pants/backend/jvm/tasks/scalafmt.py | 8 ++- src/python/pants/goal/run_tracker.py | 12 +++- src/python/pants/process/xargs.py | 6 +- 4 files changed, 78 insertions(+), 13 deletions(-) diff --git a/src/python/pants/backend/jvm/tasks/rewrite_base.py b/src/python/pants/backend/jvm/tasks/rewrite_base.py index ae4c8e7426a..c3dc0edf3d3 100644 --- a/src/python/pants/backend/jvm/tasks/rewrite_base.py +++ b/src/python/pants/backend/jvm/tasks/rewrite_base.py @@ -3,7 +3,9 @@ import os import shutil +import threading from abc import ABCMeta, abstractmethod +from typing import Any, List from pants.backend.jvm.tasks.nailgun_task import NailgunTask from pants.base.build_environment import get_buildroot @@ -29,6 +31,13 @@ def register_options(cls, register): help='Path to output directory. Any updated files will be written here. ' 'If not specified, files will be modified in-place.') + register('--files-per-process', type=int, fingerprint=False, + default=None, + help='Number of files to use per individual process execution for native-image.') + register('--total-number-parallel-processes', type=int, fingerprint=False, + default=None, + help='Total number of parallel scalafmt processes to run.') + @classmethod def target_types(cls): """Returns a list of target type names (e.g.: `scala_library`) this rewriter operates on.""" @@ -63,17 +72,59 @@ def execute(self): with self.invalidated(relevant_targets) as invalidation_check: self._execute_for([vt.target for vt in invalidation_check.invalid_vts]) + def _split_by_threads(self, inputs_list_of_lists: List[List[Any]], invoke_fn): + parent_workunit = self.context.run_tracker.get_current_workunit() + + def hacked_thread_excepthook(args): + raise args.exc_value + + threading.excepthook = hacked_thread_excepthook # type: ignore[attr-defined] + all_threads = [ + threading.Thread( + name=f'scalafmt invocation thread #{idx}/{len(inputs_list_of_lists)}', + target=invoke_fn, + args=[parent_workunit, inputs_single_list], + ) + for idx, inputs_single_list in enumerate(inputs_list_of_lists) + ] + for thread in all_threads: + thread.start() + for thread in all_threads: + thread.join() + def _execute_for(self, targets): target_sources = self._calculate_sources(targets) if not target_sources: return - result = Xargs(self._invoke_tool).execute(target_sources) - if result != 0: - raise TaskError('{} is improperly implemented: a failed process ' - 'should raise an exception earlier.'.format(type(self).__name__)) - - def _invoke_tool(self, target_sources): + if self.get_options().files_per_process is not None: + # If --files-per-process is specified, split the target sources and run in separate threads! + n = self.get_options().files_per_process + inputs_list_of_lists = [ + target_sources[i:i + n] + for i in range(0, len(target_sources), n) + ] + self._split_by_threads(inputs_list_of_lists=inputs_list_of_lists, invoke_fn=self._invoke_tool) + elif self.get_options().total_number_parallel_processes is not None: + # If --total-number-parallel-processes is specified, split the target sources into that many + # threads, and run in separate threads! + num_processes = self.get_options().total_number_parallel_processes + n = len(target_sources) // (num_processes - 1) + inputs_list_of_lists = [ + target_sources[i:i + n] + for i in range(0, len(target_sources), n) + ] + self._split_by_threads(inputs_list_of_lists=inputs_list_of_lists, invoke_fn=self._invoke_tool) + else: + # Otherwise, pass in the parent workunit to Xargs, which is passed into self._invoke_tool. + parent_workunit = self.context.run_tracker.get_current_workunit() + result = Xargs(self._invoke_tool, constant_args=[parent_workunit]).execute(target_sources) + if result != 0: + raise TaskError('{} is improperly implemented: a failed process ' + 'should raise an exception earlier.'.format(type(self).__name__)) + + def _invoke_tool(self, parent_workunit, target_sources): + self.context.run_tracker.register_thread(parent_workunit) buildroot = get_buildroot() toolroot = buildroot if self.sideeffecting and self.get_options().output_dir: @@ -88,7 +139,7 @@ def _invoke_tool(self, target_sources): for old, new in zip(old_file_paths, new_file_paths): shutil.copyfile(old, new) target_sources = new_sources - result = self.invoke_tool(toolroot, target_sources) + result = self.invoke_tool(parent_workunit, toolroot, target_sources) self.process_result(result) return result diff --git a/src/python/pants/backend/jvm/tasks/scalafmt.py b/src/python/pants/backend/jvm/tasks/scalafmt.py index 6a0f0e19d50..8a6f1c2d706 100644 --- a/src/python/pants/backend/jvm/tasks/scalafmt.py +++ b/src/python/pants/backend/jvm/tasks/scalafmt.py @@ -121,7 +121,8 @@ def _tool_classpath(self) -> List[str]: key='scalafmt', scope=subsystem.options_scope) - def invoke_tool(self, absolute_root, target_sources): + def invoke_tool(self, current_workunit, absolute_root, target_sources): + self.context.run_tracker._threadlocal.current_workunit = current_workunit self.context.log.debug(f'scalafmt called with sources: {target_sources}') # If no config file is specified use default scalafmt config. @@ -132,7 +133,10 @@ def invoke_tool(self, absolute_root, target_sources): args.extend([source for _target, source in target_sources]) if self._use_native_image: - with self.context.new_workunit(name='scalafmt', labels=[WorkUnitLabel.COMPILER]) as workunit: + with self.context.run_tracker.new_workunit( + name='scalafmt', + labels=[WorkUnitLabel.COMPILER], + ) as workunit: args = [self._native_image_path(), *args] self.context.log.debug(f'executing scalafmt with native image with args: {args}') return subprocess.run( diff --git a/src/python/pants/goal/run_tracker.py b/src/python/pants/goal/run_tracker.py index c1c71483d6a..417d0c4c424 100644 --- a/src/python/pants/goal/run_tracker.py +++ b/src/python/pants/goal/run_tracker.py @@ -187,6 +187,14 @@ def set_sorted_goal_infos(self, sorted_goal_infos): def set_v2_console_rule_names(self, v2_console_rule_names): self._v2_console_rule_names = v2_console_rule_names + def get_current_workunit(self): + """Return the workunit associated with the current thread. + + This can be used along with .register_thread() to ensure spawned threads have access to the + parent workunit that spawned them! + """ + return self._threadlocal.current_workunit + def register_thread(self, parent_workunit): """Register the parent workunit for all work in the calling thread. @@ -383,7 +391,7 @@ def error(msg): if stats_version not in cls.SUPPORTED_STATS_VERSIONS: raise ValueError("Invalid stats version") - + auth_data = BasicAuth.global_instance().get_auth_for_provider(auth_provider) headers = cls._get_headers(stats_version=stats_version) @@ -462,7 +470,7 @@ def _stats(self): else: stats.update({ 'self_timings': self.self_timings.get_all(), - 'critical_path_timings': self.get_critical_path_timings().get_all(), + 'critical_path_timings': self.get_critical_path_timings().get_all(), 'outcomes': self.outcomes, }) return stats diff --git a/src/python/pants/process/xargs.py b/src/python/pants/process/xargs.py index abbd8aa7a03..5346550ba37 100644 --- a/src/python/pants/process/xargs.py +++ b/src/python/pants/process/xargs.py @@ -26,13 +26,15 @@ def call(args): return subprocess.call(cmd + args, **kwargs) return cls(call) - def __init__(self, cmd): + def __init__(self, cmd, constant_args=None): """Creates an xargs engine that calls cmd with argument chunks. :param cmd: A function that can execute a command line in the form of a list of strings passed as its sole argument. + :param constant_args: Any positional arguments to be added to each invocation. """ self._cmd = cmd + self._constant_args = constant_args or [] def _split_args(self, args): half = len(args) // 2 @@ -46,7 +48,7 @@ def execute(self, args): all_args = list(args) logger.debug(f'xargs all_args: {all_args}') try: - return self._cmd(all_args) + return self._cmd(*(*self._constant_args, all_args)) except OSError as e: if errno.E2BIG == e.errno: args1, args2 = self._split_args(all_args) From be62cdaf1b67e39feb9ca25821044dac516f395c Mon Sep 17 00:00:00 2001 From: Danny McClanahan <1305167+cosmicexplorer@users.noreply.github.com> Date: Fri, 6 Dec 2019 17:37:27 -0800 Subject: [PATCH 3/6] propagate exceptions to the top-level and wrap in TaskError --- src/python/pants/backend/jvm/tasks/rewrite_base.py | 6 +++++- src/python/pants/backend/jvm/tasks/scalafmt.py | 1 - 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/src/python/pants/backend/jvm/tasks/rewrite_base.py b/src/python/pants/backend/jvm/tasks/rewrite_base.py index c3dc0edf3d3..5a0467d3e65 100644 --- a/src/python/pants/backend/jvm/tasks/rewrite_base.py +++ b/src/python/pants/backend/jvm/tasks/rewrite_base.py @@ -78,6 +78,7 @@ def _split_by_threads(self, inputs_list_of_lists: List[List[Any]], invoke_fn): def hacked_thread_excepthook(args): raise args.exc_value + # Propagate exceptions in threads to the toplevel. threading.excepthook = hacked_thread_excepthook # type: ignore[attr-defined] all_threads = [ threading.Thread( @@ -90,7 +91,10 @@ def hacked_thread_excepthook(args): for thread in all_threads: thread.start() for thread in all_threads: - thread.join() + try: + thread.join() + except Exception as e: + raise TaskError(str(e), exit_code=1) def _execute_for(self, targets): target_sources = self._calculate_sources(targets) diff --git a/src/python/pants/backend/jvm/tasks/scalafmt.py b/src/python/pants/backend/jvm/tasks/scalafmt.py index 8a6f1c2d706..52ca152a890 100644 --- a/src/python/pants/backend/jvm/tasks/scalafmt.py +++ b/src/python/pants/backend/jvm/tasks/scalafmt.py @@ -122,7 +122,6 @@ def _tool_classpath(self) -> List[str]: scope=subsystem.options_scope) def invoke_tool(self, current_workunit, absolute_root, target_sources): - self.context.run_tracker._threadlocal.current_workunit = current_workunit self.context.log.debug(f'scalafmt called with sources: {target_sources}') # If no config file is specified use default scalafmt config. From 1d950f28f9e6418d480370836c276679d144823c Mon Sep 17 00:00:00 2001 From: Danny McClanahan <1305167+cosmicexplorer@users.noreply.github.com> Date: Fri, 13 Dec 2019 13:52:32 -0800 Subject: [PATCH 4/6] fix process splitting calculation and add xargs to the thread-based solutions! --- .../pants/backend/jvm/tasks/rewrite_base.py | 31 +++++++++++-------- .../pants/backend/jvm/tasks/scalafmt.py | 30 +++++++++++------- src/python/pants/process/xargs.py | 9 +++--- 3 files changed, 42 insertions(+), 28 deletions(-) diff --git a/src/python/pants/backend/jvm/tasks/rewrite_base.py b/src/python/pants/backend/jvm/tasks/rewrite_base.py index 5a0467d3e65..60de473f213 100644 --- a/src/python/pants/backend/jvm/tasks/rewrite_base.py +++ b/src/python/pants/backend/jvm/tasks/rewrite_base.py @@ -1,6 +1,8 @@ # Copyright 2017 Pants project contributors (see CONTRIBUTORS.md). # Licensed under the Apache License, Version 2.0 (see LICENSE). +import itertools +import math import os import shutil import threading @@ -31,12 +33,12 @@ def register_options(cls, register): help='Path to output directory. Any updated files will be written here. ' 'If not specified, files will be modified in-place.') - register('--files-per-process', type=int, fingerprint=False, + register('--files-per-worker', type=int, fingerprint=False, default=None, - help='Number of files to use per individual process execution for native-image.') - register('--total-number-parallel-processes', type=int, fingerprint=False, + help='Number of files to use per each scalafmt execution.') + register('--worker-count', type=int, fingerprint=False, default=None, - help='Total number of parallel scalafmt processes to run.') + help='Total number of parallel scalafmt threads or processes to run.') @classmethod def target_types(cls): @@ -101,22 +103,22 @@ def _execute_for(self, targets): if not target_sources: return - if self.get_options().files_per_process is not None: - # If --files-per-process is specified, split the target sources and run in separate threads! - n = self.get_options().files_per_process + if self.get_options().files_per_worker is not None: + # If --files-per-worker is specified, split the target sources and run in separate threads! + n = self.get_options().files_per_worker inputs_list_of_lists = [ target_sources[i:i + n] for i in range(0, len(target_sources), n) ] self._split_by_threads(inputs_list_of_lists=inputs_list_of_lists, invoke_fn=self._invoke_tool) - elif self.get_options().total_number_parallel_processes is not None: - # If --total-number-parallel-processes is specified, split the target sources into that many + elif self.get_options().worker_count is not None: + # If --worker-count is specified, split the target sources into that many # threads, and run in separate threads! - num_processes = self.get_options().total_number_parallel_processes - n = len(target_sources) // (num_processes - 1) + num_processes = self.get_options().worker_count + sources_iterator = iter(target_sources) inputs_list_of_lists = [ - target_sources[i:i + n] - for i in range(0, len(target_sources), n) + list(itertools.islice(sources_iterator, 0, math.ceil(len(target_sources) / num_processes))) + for _ in range(0, num_processes) ] self._split_by_threads(inputs_list_of_lists=inputs_list_of_lists, invoke_fn=self._invoke_tool) else: @@ -128,6 +130,9 @@ def _execute_for(self, targets): 'should raise an exception earlier.'.format(type(self).__name__)) def _invoke_tool(self, parent_workunit, target_sources): + # We want to avoid executing anything if there are no sources to generate. + if not target_sources: + return 0 self.context.run_tracker.register_thread(parent_workunit) buildroot = get_buildroot() toolroot = buildroot diff --git a/src/python/pants/backend/jvm/tasks/scalafmt.py b/src/python/pants/backend/jvm/tasks/scalafmt.py index 52ca152a890..a564a6c722b 100644 --- a/src/python/pants/backend/jvm/tasks/scalafmt.py +++ b/src/python/pants/backend/jvm/tasks/scalafmt.py @@ -15,6 +15,7 @@ from pants.engine.platform import Platform from pants.java.jar.jar_dependency import JarDependency from pants.option.custom_types import file_option +from pants.process.xargs import Xargs from pants.task.fmt_task_mixin import FmtTaskMixin from pants.task.lint_task_mixin import LintTaskMixin from pants.util.dirutil import chmod_plus_x @@ -121,32 +122,39 @@ def _tool_classpath(self) -> List[str]: key='scalafmt', scope=subsystem.options_scope) + def _invoke_native_image_subprocess(self, prefix_args, workunit, all_source_paths): + return subprocess.run( + args=(prefix_args + all_source_paths), + stdout=workunit.output('stdout'), + stderr=workunit.output('stderr'), + ).returncode + def invoke_tool(self, current_workunit, absolute_root, target_sources): self.context.log.debug(f'scalafmt called with sources: {target_sources}') # If no config file is specified use default scalafmt config. config_file = self.get_options().configuration - args = list(self.additional_args) + prefix_args = list(self.additional_args) if config_file is not None: - args.extend(['--config', config_file]) - args.extend([source for _target, source in target_sources]) + prefix_args.extend(['--config', config_file]) + + all_source_paths = [source for _target, source in target_sources] if self._use_native_image: with self.context.run_tracker.new_workunit( name='scalafmt', labels=[WorkUnitLabel.COMPILER], ) as workunit: - args = [self._native_image_path(), *args] - self.context.log.debug(f'executing scalafmt with native image with args: {args}') - return subprocess.run( - args=args, - stdout=workunit.output('stdout'), - stderr=workunit.output('stderr'), - ).returncode + prefix_args = [self._native_image_path()] + prefix_args + self.context.log.debug(f'executing scalafmt with native image with prefix args: {prefix_args}') + return Xargs( + self._invoke_native_image_subprocess, + constant_args=[prefix_args, workunit], + ).execute(all_source_paths) else: return self.runjava(classpath=self._tool_classpath(), main='org.scalafmt.cli.Cli', - args=args, + args=(prefix_args + all_source_paths), workunit_name='scalafmt', jvm_options=self.get_options().jvm_options) diff --git a/src/python/pants/process/xargs.py b/src/python/pants/process/xargs.py index 5346550ba37..948e77cb060 100644 --- a/src/python/pants/process/xargs.py +++ b/src/python/pants/process/xargs.py @@ -45,13 +45,14 @@ def execute(self, args): :param list args: Extra arguments to pass to cmd. """ - all_args = list(args) - logger.debug(f'xargs all_args: {all_args}') + splittable_args = list(args) + all_args_for_command_function = self._constant_args + [splittable_args] + logger.debug(f'xargs all_args_for_command_function: {all_args_for_command_function}') try: - return self._cmd(*(*self._constant_args, all_args)) + return self._cmd(*all_args_for_command_function) except OSError as e: if errno.E2BIG == e.errno: - args1, args2 = self._split_args(all_args) + args1, args2 = self._split_args(splittable_args) logger.debug(f'xargs split cmd line:\nargs1={args1},\nargs2={args2}!') result = self.execute(args1) if result != 0: From 032d61e9ba1b79fcaea55778ef4846cc609b743f Mon Sep 17 00:00:00 2001 From: Danny McClanahan <1305167+cosmicexplorer@users.noreply.github.com> Date: Fri, 13 Dec 2019 15:24:31 -0800 Subject: [PATCH 5/6] add testing! --- .../pants/backend/jvm/tasks/rewrite_base.py | 19 +++++++++--- .../pants/backend/jvm/tasks/scalafmt.py | 24 ++++++++++----- .../pants/testutil/base/context_utils.py | 11 +++++++ .../backend/jvm/tasks/test_scalafmt.py | 30 ++++++++++++++++++- 4 files changed, 71 insertions(+), 13 deletions(-) diff --git a/src/python/pants/backend/jvm/tasks/rewrite_base.py b/src/python/pants/backend/jvm/tasks/rewrite_base.py index 60de473f213..a580f9be927 100644 --- a/src/python/pants/backend/jvm/tasks/rewrite_base.py +++ b/src/python/pants/backend/jvm/tasks/rewrite_base.py @@ -1,6 +1,7 @@ # Copyright 2017 Pants project contributors (see CONTRIBUTORS.md). # Licensed under the Apache License, Version 2.0 (see LICENSE). +import functools import itertools import math import os @@ -77,15 +78,22 @@ def execute(self): def _split_by_threads(self, inputs_list_of_lists: List[List[Any]], invoke_fn): parent_workunit = self.context.run_tracker.get_current_workunit() - def hacked_thread_excepthook(args): - raise args.exc_value + all_exceptions = [] + + def thread_exception_wrapper(fn): + @functools.wraps(fn) + def inner(*args, **kwargs): + try: + fn(*args, **kwargs) + except Exception as e: + all_exceptions.append(e) # Propagate exceptions in threads to the toplevel. threading.excepthook = hacked_thread_excepthook # type: ignore[attr-defined] all_threads = [ threading.Thread( name=f'scalafmt invocation thread #{idx}/{len(inputs_list_of_lists)}', - target=invoke_fn, + target=thread_exception_wrapper(invoke_fn), args=[parent_workunit, inputs_single_list], ) for idx, inputs_single_list in enumerate(inputs_list_of_lists) @@ -96,7 +104,10 @@ def hacked_thread_excepthook(args): try: thread.join() except Exception as e: - raise TaskError(str(e), exit_code=1) + raise TaskError(str(e)) from e + if all_exceptions: + joined_str = ', '.join(str(e) for e in all_exceptions) + raise TaskError(f'all errors: {joined_str}') from all_exceptions[0] def _execute_for(self, targets): target_sources = self._calculate_sources(targets) diff --git a/src/python/pants/backend/jvm/tasks/scalafmt.py b/src/python/pants/backend/jvm/tasks/scalafmt.py index a564a6c722b..263e361ddaf 100644 --- a/src/python/pants/backend/jvm/tasks/scalafmt.py +++ b/src/python/pants/backend/jvm/tasks/scalafmt.py @@ -4,7 +4,7 @@ import os import subprocess from abc import abstractmethod -from typing import List, cast +from typing import List from pants.backend.jvm.subsystems.jvm_tool_mixin import JvmToolMixin from pants.backend.jvm.tasks.rewrite_base import RewriteBase @@ -61,7 +61,7 @@ def select(self): @property def use_native_image(self) -> bool: - return cast(bool, self.get_options().use_native_image) + return bool(self.get_options().use_native_image) @classmethod def register_options(cls, register): @@ -84,6 +84,10 @@ class ScalaFmt(RewriteBase): process_result to run different scalafmt commands. """ + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._all_command_lines = [] + @classmethod def subsystem_dependencies(cls): return super().subsystem_dependencies() + ( @@ -123,16 +127,24 @@ def _tool_classpath(self) -> List[str]: scope=subsystem.options_scope) def _invoke_native_image_subprocess(self, prefix_args, workunit, all_source_paths): + self._all_command_lines.append((prefix_args, all_source_paths)) return subprocess.run( args=(prefix_args + all_source_paths), stdout=workunit.output('stdout'), stderr=workunit.output('stderr'), ).returncode + def _invoke_jvm_process(self, prefix_args, all_source_paths): + return self.runjava(classpath=self._tool_classpath(), + main='org.scalafmt.cli.Cli', + args=(prefix_args + all_source_paths), + workunit_name='scalafmt', + jvm_options=self.get_options().jvm_options) + def invoke_tool(self, current_workunit, absolute_root, target_sources): self.context.log.debug(f'scalafmt called with sources: {target_sources}') - # If no config file is specified use default scalafmt config. + # If no config file is specified, use default scalafmt config. config_file = self.get_options().configuration prefix_args = list(self.additional_args) if config_file is not None: @@ -152,11 +164,7 @@ def invoke_tool(self, current_workunit, absolute_root, target_sources): constant_args=[prefix_args, workunit], ).execute(all_source_paths) else: - return self.runjava(classpath=self._tool_classpath(), - main='org.scalafmt.cli.Cli', - args=(prefix_args + all_source_paths), - workunit_name='scalafmt', - jvm_options=self.get_options().jvm_options) + return Xargs(self._invoke_jvm_process, constant_args=[prefix_args]).execute(all_source_paths) @property @abstractmethod diff --git a/src/python/pants/testutil/base/context_utils.py b/src/python/pants/testutil/base/context_utils.py index 61a2e33b486..d538c319c10 100644 --- a/src/python/pants/testutil/base/context_utils.py +++ b/src/python/pants/testutil/base/context_utils.py @@ -49,6 +49,7 @@ class DummyRunTracker: def __init__(self): self.logger = RunTrackerLogger(self) + self._cur_workunit = TestContext.DummyWorkUnit() class DummyArtifactCacheStats: def add_hits(self, cache_name, targets): pass @@ -59,6 +60,16 @@ def add_misses(self, cache_name, targets, causes): pass def report_target_info(self, scope, target, keys, val): pass + def get_current_workunit(self): + return self._cur_workunit + + def register_thread(self, parent_workunit): + self._cur_workunit = parent_workunit + + @contextmanager + def new_workunit(self, *args, **kwargs): + yield self.get_current_workunit() + class TestLogger(logging.getLoggerClass()): # type: ignore[misc] # MyPy does't understand this dynamic base class """A logger that converts our structured records into flat ones. diff --git a/tests/python/pants_test/backend/jvm/tasks/test_scalafmt.py b/tests/python/pants_test/backend/jvm/tasks/test_scalafmt.py index 4acaa9be0a7..714754b6ba4 100644 --- a/tests/python/pants_test/backend/jvm/tasks/test_scalafmt.py +++ b/tests/python/pants_test/backend/jvm/tasks/test_scalafmt.py @@ -8,7 +8,7 @@ from pants.backend.jvm.subsystems.scoverage_platform import ScoveragePlatform from pants.backend.jvm.targets.junit_tests import JUnitTests from pants.backend.jvm.targets.scala_library import ScalaLibrary -from pants.backend.jvm.tasks.scalafmt import ScalaFmtCheckFormat, ScalaFmtFormat +from pants.backend.jvm.tasks.scalafmt import ScalaFmtCheckFormat, ScalaFmtFormat, ScalaFmtSubsystem from pants.base.exceptions import TaskError from pants.build_graph.build_file_aliases import BuildFileAliases from pants.build_graph.resources import Resources @@ -32,6 +32,7 @@ def setUp(self): init_subsystem(ScalaPlatform) init_subsystem(ScoveragePlatform) init_subsystem(SourceRootConfig) + init_subsystem(ScalaFmtSubsystem) self.configuration = self.create_file( relpath='build-support/scalafmt/config', @@ -165,3 +166,30 @@ def test_output_dir(self): relative_test_file = fast_relpath(self.test_file, self.build_root) with open(os.path.join(output_dir, relative_test_file), 'r') as fp: self.assertNotEqual(self.test_file_contents, fp.read()) + + def _execute_native_image(self, **kwargs): + self.set_options(skip=False, **kwargs) + self.set_options_for_scope('scalafmt', use_native_image=True) + + context = self.context(target_roots=self.library) + task = self.execute(context) + + # Assert that it ran successfully. + with open(self.test_file, 'r') as fp: + self.assertNotEqual(self.test_file_contents, fp.read()) + + # Assert that the native-image executable was most recently used. + scalafmt_native_image_basedir = ScalaFmtSubsystem.global_instance().select() + most_recent_command_line = task._all_command_lines[-1] + prefix_args, _workunit, _all_source_paths = most_recent_command_line + executable_file = prefix_args[0] + self.assertTrue(executable_file.startswith(scalafmt_native_image_basedir)) + + def test_native_image_execution(self): + self._execute_native_image() + + def test_native_image_threading_worker_count(self): + self._execute_native_image(worker_count=4) + + def test_native_image_threading_files_per_worker(self): + self._execute_native_image(files_per_worker=1) From c98d8cbc4280f30776851ce2bf0a515f64cb249b Mon Sep 17 00:00:00 2001 From: Danny McClanahan <1305167+cosmicexplorer@users.noreply.github.com> Date: Fri, 13 Dec 2019 17:18:01 -0800 Subject: [PATCH 6/6] fix testing! --- .../pants/backend/jvm/tasks/rewrite_base.py | 5 +- .../backend/jvm/tasks/test_scalafmt.py | 67 ++++++++++++++++--- 2 files changed, 60 insertions(+), 12 deletions(-) diff --git a/src/python/pants/backend/jvm/tasks/rewrite_base.py b/src/python/pants/backend/jvm/tasks/rewrite_base.py index a580f9be927..951489077f3 100644 --- a/src/python/pants/backend/jvm/tasks/rewrite_base.py +++ b/src/python/pants/backend/jvm/tasks/rewrite_base.py @@ -78,6 +78,8 @@ def execute(self): def _split_by_threads(self, inputs_list_of_lists: List[List[Any]], invoke_fn): parent_workunit = self.context.run_tracker.get_current_workunit() + # Propagate exceptions in threads to the toplevel by checking this variable after joining all + # the threads. all_exceptions = [] def thread_exception_wrapper(fn): @@ -87,9 +89,8 @@ def inner(*args, **kwargs): fn(*args, **kwargs) except Exception as e: all_exceptions.append(e) + return inner - # Propagate exceptions in threads to the toplevel. - threading.excepthook = hacked_thread_excepthook # type: ignore[attr-defined] all_threads = [ threading.Thread( name=f'scalafmt invocation thread #{idx}/{len(inputs_list_of_lists)}', diff --git a/tests/python/pants_test/backend/jvm/tasks/test_scalafmt.py b/tests/python/pants_test/backend/jvm/tasks/test_scalafmt.py index 714754b6ba4..3c6cff01615 100644 --- a/tests/python/pants_test/backend/jvm/tasks/test_scalafmt.py +++ b/tests/python/pants_test/backend/jvm/tasks/test_scalafmt.py @@ -3,6 +3,7 @@ import os from textwrap import dedent +from uuid import uuid4 from pants.backend.jvm.subsystems.scala_platform import ScalaPlatform from pants.backend.jvm.subsystems.scoverage_platform import ScoveragePlatform @@ -167,29 +168,75 @@ def test_output_dir(self): with open(os.path.join(output_dir, relative_test_file), 'r') as fp: self.assertNotEqual(self.test_file_contents, fp.read()) - def _execute_native_image(self, **kwargs): + def _execute_native_image(self, num_sources=1, **kwargs): self.set_options(skip=False, **kwargs) self.set_options_for_scope('scalafmt', use_native_image=True) - context = self.context(target_roots=self.library) + uuid = '_' + str(uuid4()).replace('-', '_') + + more_sources = [ + self.create_file(relpath=f'src/scala/org/pantsbuild/badscalastyle/BadScalaStyle{i}.scala', + contents=self.test_file_contents) + for i in range(0, num_sources) + ] + generated_library = self.make_target(spec=f'src/scala/org/pantsbuild/badscalastyle:{uuid}', + sources=[f'BadScalaStyle{i}.scala' + for i in range(0, num_sources)], + target_type=ScalaLibrary) + + context = self.context(target_roots=generated_library) task = self.execute(context) # Assert that it ran successfully. - with open(self.test_file, 'r') as fp: - self.assertNotEqual(self.test_file_contents, fp.read()) + for src in more_sources: + with open(src, 'r') as fp: + self.assertNotEqual(self.test_file_contents, fp.read()) # Assert that the native-image executable was most recently used. scalafmt_native_image_basedir = ScalaFmtSubsystem.global_instance().select() - most_recent_command_line = task._all_command_lines[-1] - prefix_args, _workunit, _all_source_paths = most_recent_command_line - executable_file = prefix_args[0] - self.assertTrue(executable_file.startswith(scalafmt_native_image_basedir)) + + all_source_paths = [] + for prefix_args, source_paths in task._all_command_lines: + executable_file = prefix_args[0] + self.assertTrue(executable_file.startswith(scalafmt_native_image_basedir)) + all_source_paths.append(source_paths) + + return all_source_paths + + def _assert_num_files_per_process_matches(self, all_source_paths, all_num_files): + self.assertEqual(len(all_source_paths), len(all_num_files)) + for paths, expected_num_files in zip(all_source_paths, all_num_files): + self.assertEqual(len(paths), expected_num_files) def test_native_image_execution(self): self._execute_native_image() def test_native_image_threading_worker_count(self): - self._execute_native_image(worker_count=4) + all_source_paths = self._execute_native_image(num_sources=4, worker_count=2) + self._assert_num_files_per_process_matches(all_source_paths, + all_num_files=[2, 2]) + + def test_native_image_threading_uneven_divisor(self): + all_source_paths = self._execute_native_image(num_sources=5, worker_count=3) + self._assert_num_files_per_process_matches(all_source_paths, + all_num_files=[2, 2, 1]) + + def test_native_image_threading_massive_process_limit(self): + all_source_paths = self._execute_native_image(num_sources=2, worker_count=3) + self._assert_num_files_per_process_matches(all_source_paths, + all_num_files=[1, 1]) def test_native_image_threading_files_per_worker(self): - self._execute_native_image(files_per_worker=1) + all_source_paths = self._execute_native_image(num_sources=2, files_per_worker=1) + self._assert_num_files_per_process_matches(all_source_paths, + all_num_files=[1, 1]) + + def test_native_image_threading_uneven_files_per_worker(self): + all_source_paths = self._execute_native_image(num_sources=5, files_per_worker=3) + self._assert_num_files_per_process_matches(all_source_paths, + all_num_files=[3, 2]) + + def test_native_image_threading_massive_files_per_process(self): + all_source_paths = self._execute_native_image(num_sources=2, files_per_worker=5) + self._assert_num_files_per_process_matches(all_source_paths, + all_num_files=[2])