Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add --use-native-image option to scalafmt and add thread parallelism to RewriteBase #8772

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 79 additions & 7 deletions src/python/pants/backend/jvm/tasks/rewrite_base.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
# Copyright 2017 Pants project contributors (see CONTRIBUTORS.md).
# Licensed under the Apache License, Version 2.0 (see LICENSE).

import functools
import itertools
import math
import os
import shutil
import threading
from abc import ABCMeta, abstractmethod
from typing import Any, List

from pants.backend.jvm.tasks.nailgun_task import NailgunTask
from pants.base.build_environment import get_buildroot
Expand All @@ -29,6 +34,13 @@ def register_options(cls, register):
help='Path to output directory. Any updated files will be written here. '
'If not specified, files will be modified in-place.')

register('--files-per-worker', type=int, fingerprint=False,
default=None,
help='Number of files to use per each scalafmt execution.')
register('--worker-count', type=int, fingerprint=False,
default=None,
help='Total number of parallel scalafmt threads or processes to run.')

@classmethod
def target_types(cls):
"""Returns a list of target type names (e.g.: `scala_library`) this rewriter operates on."""
Expand Down Expand Up @@ -63,17 +75,77 @@ def execute(self):
with self.invalidated(relevant_targets) as invalidation_check:
self._execute_for([vt.target for vt in invalidation_check.invalid_vts])

def _split_by_threads(self, inputs_list_of_lists: List[List[Any]], invoke_fn):
parent_workunit = self.context.run_tracker.get_current_workunit()

# Propagate exceptions in threads to the toplevel by checking this variable after joining all
# the threads.
all_exceptions = []

def thread_exception_wrapper(fn):
@functools.wraps(fn)
def inner(*args, **kwargs):
try:
fn(*args, **kwargs)
except Exception as e:
all_exceptions.append(e)
return inner

all_threads = [
threading.Thread(
name=f'scalafmt invocation thread #{idx}/{len(inputs_list_of_lists)}',
target=thread_exception_wrapper(invoke_fn),
args=[parent_workunit, inputs_single_list],
)
for idx, inputs_single_list in enumerate(inputs_list_of_lists)
]
for thread in all_threads:
thread.start()
for thread in all_threads:
try:
thread.join()
except Exception as e:
raise TaskError(str(e)) from e
if all_exceptions:
joined_str = ', '.join(str(e) for e in all_exceptions)
raise TaskError(f'all errors: {joined_str}') from all_exceptions[0]

def _execute_for(self, targets):
target_sources = self._calculate_sources(targets)
if not target_sources:
return

result = Xargs(self._invoke_tool).execute(target_sources)
if result != 0:
raise TaskError('{} is improperly implemented: a failed process '
'should raise an exception earlier.'.format(type(self).__name__))

def _invoke_tool(self, target_sources):
if self.get_options().files_per_worker is not None:
# If --files-per-worker is specified, split the target sources and run in separate threads!
n = self.get_options().files_per_worker
inputs_list_of_lists = [
target_sources[i:i + n]
for i in range(0, len(target_sources), n)
]
self._split_by_threads(inputs_list_of_lists=inputs_list_of_lists, invoke_fn=self._invoke_tool)
elif self.get_options().worker_count is not None:
# If --worker-count is specified, split the target sources into that many
# threads, and run in separate threads!
num_processes = self.get_options().worker_count
sources_iterator = iter(target_sources)
inputs_list_of_lists = [
list(itertools.islice(sources_iterator, 0, math.ceil(len(target_sources) / num_processes)))
for _ in range(0, num_processes)
]
self._split_by_threads(inputs_list_of_lists=inputs_list_of_lists, invoke_fn=self._invoke_tool)
else:
# Otherwise, pass in the parent workunit to Xargs, which is passed into self._invoke_tool.
parent_workunit = self.context.run_tracker.get_current_workunit()
result = Xargs(self._invoke_tool, constant_args=[parent_workunit]).execute(target_sources)
if result != 0:
raise TaskError('{} is improperly implemented: a failed process '
'should raise an exception earlier.'.format(type(self).__name__))

def _invoke_tool(self, parent_workunit, target_sources):
# We want to avoid executing anything if there are no sources to generate.
if not target_sources:
return 0
self.context.run_tracker.register_thread(parent_workunit)
buildroot = get_buildroot()
toolroot = buildroot
if self.sideeffecting and self.get_options().output_dir:
Expand All @@ -88,7 +160,7 @@ def _invoke_tool(self, target_sources):
for old, new in zip(old_file_paths, new_file_paths):
shutil.copyfile(old, new)
target_sources = new_sources
result = self.invoke_tool(toolroot, target_sources)
result = self.invoke_tool(parent_workunit, toolroot, target_sources)
self.process_result(result)
return result

Expand Down
144 changes: 126 additions & 18 deletions src/python/pants/backend/jvm/tasks/scalafmt.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,80 @@
# Copyright 2016 Pants project contributors (see CONTRIBUTORS.md).
# Licensed under the Apache License, Version 2.0 (see LICENSE).

import os
import subprocess
from abc import abstractmethod
from typing import List

from pants.backend.jvm.subsystems.jvm_tool_mixin import JvmToolMixin
from pants.backend.jvm.tasks.rewrite_base import RewriteBase
from pants.base.exceptions import TaskError
from pants.base.workunit import WorkUnitLabel
from pants.binaries.binary_tool import NativeTool
from pants.binaries.binary_util import BinaryToolUrlGenerator
from pants.engine.platform import Platform
from pants.java.jar.jar_dependency import JarDependency
from pants.option.custom_types import file_option
from pants.process.xargs import Xargs
from pants.task.fmt_task_mixin import FmtTaskMixin
from pants.task.lint_task_mixin import LintTaskMixin
from pants.util.dirutil import chmod_plus_x
from pants.util.memo import memoized_method


class ScalaFmtNativeUrlGenerator(BinaryToolUrlGenerator):

_DIST_URL_FMT = 'https://github.com/scalameta/scalafmt/releases/download/v{version}/scalafmt-{system_id}.zip'

_SYSTEM_ID = {
'mac': 'macos',
'linux': 'linux',
}

def generate_urls(self, version, host_platform):
system_id = self._SYSTEM_ID[host_platform.os_name]
return [self._DIST_URL_FMT.format(version=version, system_id=system_id)]


class ScalaFmtSubsystem(JvmToolMixin, NativeTool):
options_scope = 'scalafmt'
default_version = '2.3.1'
archive_type = 'zip'

def get_external_url_generator(self):
return ScalaFmtNativeUrlGenerator()

@memoized_method
def select(self):
"""Reach into the unzipped directory and return the scalafmt executable.

Also make sure to chmod +x the scalafmt executable, since the release zip doesn't do that.
"""
extracted_dir = super().select()
inner_dir_name = Platform.current.match({
Platform.darwin: 'scalafmt-macos',
Platform.linux: 'scalafmt-linux',
})
output_file = os.path.join(extracted_dir, inner_dir_name, 'scalafmt')
chmod_plus_x(output_file)
return output_file

@property
def use_native_image(self) -> bool:
return bool(self.get_options().use_native_image)

@classmethod
def register_options(cls, register):
super().register_options(register)
register('--use-native-image', type=bool, advanced=True, fingerprint=False,
help='Use a pre-compiled native-image for scalafmt.')

cls.register_jvm_tool(register,
'scalafmt',
classpath=[
JarDependency(org='com.geirsson',
name='scalafmt-cli_2.11',
rev='1.5.1')])


class ScalaFmt(RewriteBase):
Expand All @@ -18,20 +84,22 @@ class ScalaFmt(RewriteBase):
process_result to run different scalafmt commands.
"""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._all_command_lines = []
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rather than modifying the actual production code to also act like a spy, could we in the test code use a fake binary for scalafmt which logs its command line or something? Or unit test the thread-splitting logic specifically?


@classmethod
def subsystem_dependencies(cls):
return super().subsystem_dependencies() + (
ScalaFmtSubsystem.scoped(cls),
)

@classmethod
def register_options(cls, register):
super().register_options(register)
register('--configuration', advanced=True, type=file_option, fingerprint=True,
help='Path to scalafmt config file, if not specified default scalafmt config used')

cls.register_jvm_tool(register,
'scalafmt',
classpath=[
JarDependency(org='com.geirsson',
name='scalafmt-cli_2.11',
rev='1.5.1')
])

@classmethod
def target_types(cls):
return ['scala_library', 'junit_tests']
Expand All @@ -44,20 +112,60 @@ def source_extension(cls):
def implementation_version(cls):
return super().implementation_version() + [('ScalaFmt', 5)]

def invoke_tool(self, absolute_root, target_sources):
# If no config file is specified use default scalafmt config.
config_file = self.get_options().configuration
args = list(self.additional_args)
if config_file is not None:
args.extend(['--config', config_file])
args.extend([source for _target, source in target_sources])

return self.runjava(classpath=self.tool_classpath('scalafmt'),
@property
def _use_native_image(self) -> bool:
return ScalaFmtSubsystem.scoped_instance(self).use_native_image

def _native_image_path(self) -> str:
return ScalaFmtSubsystem.scoped_instance(self).select()

def _tool_classpath(self) -> List[str]:
subsystem = ScalaFmtSubsystem.scoped_instance(self)
return subsystem.tool_classpath_from_products(
self.context.products,
key='scalafmt',
scope=subsystem.options_scope)

def _invoke_native_image_subprocess(self, prefix_args, workunit, all_source_paths):
self._all_command_lines.append((prefix_args, all_source_paths))
return subprocess.run(
args=(prefix_args + all_source_paths),
stdout=workunit.output('stdout'),
stderr=workunit.output('stderr'),
).returncode

def _invoke_jvm_process(self, prefix_args, all_source_paths):
return self.runjava(classpath=self._tool_classpath(),
main='org.scalafmt.cli.Cli',
args=args,
args=(prefix_args + all_source_paths),
workunit_name='scalafmt',
jvm_options=self.get_options().jvm_options)

def invoke_tool(self, current_workunit, absolute_root, target_sources):
self.context.log.debug(f'scalafmt called with sources: {target_sources}')

# If no config file is specified, use default scalafmt config.
config_file = self.get_options().configuration
prefix_args = list(self.additional_args)
if config_file is not None:
prefix_args.extend(['--config', config_file])

all_source_paths = [source for _target, source in target_sources]

if self._use_native_image:
with self.context.run_tracker.new_workunit(
name='scalafmt',
labels=[WorkUnitLabel.COMPILER],
) as workunit:
prefix_args = [self._native_image_path()] + prefix_args
self.context.log.debug(f'executing scalafmt with native image with prefix args: {prefix_args}')
return Xargs(
self._invoke_native_image_subprocess,
constant_args=[prefix_args, workunit],
).execute(all_source_paths)
else:
return Xargs(self._invoke_jvm_process, constant_args=[prefix_args]).execute(all_source_paths)

@property
@abstractmethod
def additional_args(self):
Expand Down
12 changes: 10 additions & 2 deletions src/python/pants/goal/run_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,14 @@ def set_sorted_goal_infos(self, sorted_goal_infos):
def set_v2_console_rule_names(self, v2_console_rule_names):
self._v2_console_rule_names = v2_console_rule_names

def get_current_workunit(self):
"""Return the workunit associated with the current thread.

This can be used along with .register_thread() to ensure spawned threads have access to the
parent workunit that spawned them!
"""
return self._threadlocal.current_workunit

def register_thread(self, parent_workunit):
"""Register the parent workunit for all work in the calling thread.

Expand Down Expand Up @@ -383,7 +391,7 @@ def error(msg):

if stats_version not in cls.SUPPORTED_STATS_VERSIONS:
raise ValueError("Invalid stats version")


auth_data = BasicAuth.global_instance().get_auth_for_provider(auth_provider)
headers = cls._get_headers(stats_version=stats_version)
Expand Down Expand Up @@ -462,7 +470,7 @@ def _stats(self):
else:
stats.update({
'self_timings': self.self_timings.get_all(),
'critical_path_timings': self.get_critical_path_timings().get_all(),
'critical_path_timings': self.get_critical_path_timings().get_all(),
'outcomes': self.outcomes,
})
return stats
Expand Down
17 changes: 13 additions & 4 deletions src/python/pants/process/xargs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,13 @@
# Licensed under the Apache License, Version 2.0 (see LICENSE).

import errno
import logging
import subprocess


logger = logging.getLogger(__name__)


class Xargs:
"""A subprocess execution wrapper in the spirit of the xargs command line tool.

Expand All @@ -22,13 +26,15 @@ def call(args):
return subprocess.call(cmd + args, **kwargs)
return cls(call)

def __init__(self, cmd):
def __init__(self, cmd, constant_args=None):
"""Creates an xargs engine that calls cmd with argument chunks.

:param cmd: A function that can execute a command line in the form of a list of strings
passed as its sole argument.
:param constant_args: Any positional arguments to be added to each invocation.
"""
self._cmd = cmd
self._constant_args = constant_args or []

def _split_args(self, args):
half = len(args) // 2
Expand All @@ -39,12 +45,15 @@ def execute(self, args):

:param list args: Extra arguments to pass to cmd.
"""
all_args = list(args)
splittable_args = list(args)
all_args_for_command_function = self._constant_args + [splittable_args]
logger.debug(f'xargs all_args_for_command_function: {all_args_for_command_function}')
try:
return self._cmd(all_args)
return self._cmd(*all_args_for_command_function)
except OSError as e:
if errno.E2BIG == e.errno:
args1, args2 = self._split_args(all_args)
args1, args2 = self._split_args(splittable_args)
logger.debug(f'xargs split cmd line:\nargs1={args1},\nargs2={args2}!')
result = self.execute(args1)
if result != 0:
return result
Expand Down
Loading