Skip to content

Commit

Permalink
MyPy and Pylint partition inputs via CoarsenedTarget (#15141)
Browse files Browse the repository at this point in the history
`CoarsenedTarget`s are structure shared, and because they preserve their internal structure, they can service requests for transitive targets for different roots from the same datastructure. Concretely: Mypy and Pylint can consume `CoarsenedTargets` to execute a single `@rule`-level graph walk, and then compute per-root closures from the resulting `CoarsenedTarget` instances.

This does not address #11270 in a general way (and it punts on #15241, which means that we still need per-root transitive walks), but it might provide a prototypical way to solve that problem on a case-by-case basis.

Performance wise, this moves cold `check ::` for ~1k files from:
* `main`: 32s total, and 26s spent in partitioning
*  `branch`: 19s total, and 13s spent in partitioning

The rest of the time is wrapped up in #15241.
  • Loading branch information
stuhood authored Apr 25, 2022
1 parent e8e22d5 commit 387635f
Show file tree
Hide file tree
Showing 7 changed files with 154 additions and 99 deletions.
81 changes: 40 additions & 41 deletions src/python/pants/backend/python/lint/pylint/rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,15 @@

from collections import defaultdict
from dataclasses import dataclass
from typing import Tuple
from typing import Mapping, Tuple

from pants.backend.python.lint.pylint.subsystem import (
Pylint,
PylintFieldSet,
PylintFirstPartyPlugins,
)
from pants.backend.python.subsystems.setup import PythonSetup
from pants.backend.python.target_types import (
InterpreterConstraintsField,
PythonResolveField,
PythonSourceField,
)
from pants.backend.python.target_types import InterpreterConstraintsField, PythonResolveField
from pants.backend.python.util_rules import pex_from_targets
from pants.backend.python.util_rules.interpreter_constraints import InterpreterConstraints
from pants.backend.python.util_rules.pex import (
Expand All @@ -39,7 +35,7 @@
from pants.engine.fs import CreateDigest, Digest, Directory, MergeDigests, RemovePrefix
from pants.engine.process import FallibleProcessResult
from pants.engine.rules import Get, MultiGet, collect_rules, rule
from pants.engine.target import Target, TransitiveTargets, TransitiveTargetsRequest
from pants.engine.target import CoarsenedTarget, CoarsenedTargets, CoarsenedTargetsRequest, Target
from pants.engine.unions import UnionRule
from pants.util.logging import LogLevel
from pants.util.ordered_set import FrozenOrderedSet, OrderedSet
Expand All @@ -48,7 +44,7 @@

@dataclass(frozen=True)
class PylintPartition:
root_targets: FrozenOrderedSet[Target]
root_field_sets: FrozenOrderedSet[PylintFieldSet]
closure: FrozenOrderedSet[Target]
resolve_description: str | None
interpreter_constraints: InterpreterConstraints
Expand Down Expand Up @@ -84,7 +80,7 @@ async def pylint_lint_partition(
requirements_pex_get = Get(
Pex,
RequirementsPexRequest(
(t.address for t in partition.root_targets),
(fs.address for fs in partition.root_field_sets),
# NB: These constraints must be identical to the other PEXes. Otherwise, we risk using
# a different version for the requirements than the other two PEXes, which can result
# in a PEX runtime error about missing dependencies.
Expand All @@ -103,7 +99,7 @@ async def pylint_lint_partition(

prepare_python_sources_get = Get(PythonSourceFiles, PythonSourceFilesRequest(partition.closure))
field_set_sources_get = Get(
SourceFiles, SourceFilesRequest(t[PythonSourceField] for t in partition.root_targets)
SourceFiles, SourceFilesRequest(fs.source for fs in partition.root_field_sets)
)
# Ensure that the empty report dir exists.
report_directory_digest_get = Get(Digest, CreateDigest([Directory(REPORT_DIR)]))
Expand Down Expand Up @@ -168,8 +164,8 @@ async def pylint_lint_partition(
input_digest=input_digest,
output_directories=(REPORT_DIR,),
extra_env={"PEX_EXTRA_SYS_PATH": ":".join(pythonpath)},
concurrency_available=len(partition.root_targets),
description=f"Run Pylint on {pluralize(len(partition.root_targets), 'file')}.",
concurrency_available=len(partition.root_field_sets),
description=f"Run Pylint on {pluralize(len(partition.root_field_sets), 'file')}.",
level=LogLevel.DEBUG,
),
)
Expand All @@ -181,7 +177,7 @@ async def pylint_lint_partition(
)


# TODO(#10863): Improve the performance of this, especially by not needing to calculate transitive
# TODO(#15241): Improve the performance of this, especially by not needing to calculate transitive
# targets per field set. Doing that would require changing how we calculate interpreter
# constraints to be more like how we determine resolves, i.e. only inspecting the root target
# (and later validating the closure is compatible).
Expand All @@ -195,49 +191,52 @@ async def pylint_determine_partitions(
# Note that Pylint uses the AST of the interpreter that runs it. So, we include any plugin
# targets in this interpreter constraints calculation. However, we don't have to consider the
# resolve of the plugin targets, per https://github.com/pantsbuild/pants/issues/14320.
transitive_targets_per_field_set = await MultiGet(
Get(TransitiveTargets, TransitiveTargetsRequest([field_set.address]))
for field_set in request.field_sets
coarsened_targets = await Get(
CoarsenedTargets,
CoarsenedTargetsRequest(
(field_set.address for field_set in request.field_sets), expanded_targets=True
),
)
coarsened_targets_by_address = coarsened_targets.by_address()

resolve_and_interpreter_constraints_to_transitive_targets = defaultdict(set)
for transitive_targets in transitive_targets_per_field_set:
resolve = transitive_targets.roots[0][PythonResolveField].normalized_value(python_setup)
resolve_and_interpreter_constraints_to_coarsened_targets: Mapping[
tuple[str, InterpreterConstraints],
tuple[OrderedSet[PylintFieldSet], OrderedSet[CoarsenedTarget]],
] = defaultdict(lambda: (OrderedSet(), OrderedSet()))
for root in request.field_sets:
ct = coarsened_targets_by_address[root.address]
# NB: If there is a cycle in the roots, we still only take the first resolve, as the other
# members will be validated when the partition is actually built.
resolve = ct.representative[PythonResolveField].normalized_value(python_setup)
# NB: We need to consume the entire un-memoized closure here. See the method comment.
interpreter_constraints = InterpreterConstraints.create_from_compatibility_fields(
(
*(
tgt[InterpreterConstraintsField]
for tgt in transitive_targets.closure
for tgt in ct.closure()
if tgt.has_field(InterpreterConstraintsField)
),
*first_party_plugins.interpreter_constraints_fields,
),
python_setup,
)
resolve_and_interpreter_constraints_to_transitive_targets[
roots, root_cts = resolve_and_interpreter_constraints_to_coarsened_targets[
(resolve, interpreter_constraints)
].add(transitive_targets)
]
roots.add(root)
root_cts.add(ct)

partitions = []
for (resolve, interpreter_constraints), all_transitive_targets in sorted(
resolve_and_interpreter_constraints_to_transitive_targets.items()
):
combined_roots: OrderedSet[Target] = OrderedSet()
combined_closure: OrderedSet[Target] = OrderedSet()
for transitive_targets in all_transitive_targets:
combined_roots.update(transitive_targets.roots)
combined_closure.update(transitive_targets.closure)
partitions.append(
# Note that we don't need to pass the resolve. pex_from_targets.py will already
# calculate it by inspecting the roots & validating that all dependees are valid.
PylintPartition(
FrozenOrderedSet(combined_roots),
FrozenOrderedSet(combined_closure),
resolve if len(python_setup.resolves) > 1 else None,
interpreter_constraints,
)
return PylintPartitions(
PylintPartition(
FrozenOrderedSet(roots),
FrozenOrderedSet(CoarsenedTargets(root_cts).closure()),
resolve if len(python_setup.resolves) > 1 else None,
interpreter_constraints,
)
return PylintPartitions(partitions)
for (resolve, interpreter_constraints), (roots, root_cts) in sorted(
resolve_and_interpreter_constraints_to_coarsened_targets.items()
)
)


@rule(desc="Lint using Pylint", level=LogLevel.DEBUG)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -522,7 +522,7 @@ def assert_partition(
resolve: str,
) -> None:
root_addresses = {t.address for t in roots}
assert {t.address for t in partition.root_targets} == root_addresses
assert {fs.address for fs in partition.root_field_sets} == root_addresses
assert {t.address for t in partition.closure} == {
*root_addresses,
*(t.address for t in deps),
Expand Down
84 changes: 47 additions & 37 deletions src/python/pants/backend/python/typecheck/mypy/rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import itertools
from collections import defaultdict
from dataclasses import dataclass
from typing import Iterable, Optional, Tuple
from typing import Iterable, Mapping, Optional, Tuple

from pants.backend.python.subsystems.setup import PythonSetup
from pants.backend.python.target_types import PythonResolveField, PythonSourceField
Expand All @@ -31,7 +31,13 @@
from pants.engine.fs import CreateDigest, Digest, FileContent, MergeDigests, RemovePrefix
from pants.engine.process import FallibleProcessResult
from pants.engine.rules import Get, MultiGet, collect_rules, rule
from pants.engine.target import FieldSet, Target, TransitiveTargets, TransitiveTargetsRequest
from pants.engine.target import (
CoarsenedTarget,
CoarsenedTargets,
CoarsenedTargetsRequest,
FieldSet,
Target,
)
from pants.engine.unions import UnionRule
from pants.util.logging import LogLevel
from pants.util.ordered_set import FrozenOrderedSet, OrderedSet
Expand All @@ -51,7 +57,7 @@ def opt_out(cls, tgt: Target) -> bool:

@dataclass(frozen=True)
class MyPyPartition:
root_targets: FrozenOrderedSet[Target]
root_field_sets: FrozenOrderedSet[MyPyFieldSet]
closure: FrozenOrderedSet[Target]
resolve_description: str | None
interpreter_constraints: InterpreterConstraints
Expand Down Expand Up @@ -132,14 +138,14 @@ async def mypy_typecheck_partition(
closure_sources_get = Get(PythonSourceFiles, PythonSourceFilesRequest(partition.closure))
roots_sources_get = Get(
SourceFiles,
SourceFilesRequest(tgt.get(PythonSourceField) for tgt in partition.root_targets),
SourceFilesRequest(fs.sources for fs in partition.root_field_sets),
)

# See `requirements_venv_pex` for how this will get wrapped in a `VenvPex`.
requirements_pex_get = Get(
Pex,
RequirementsPexRequest(
(tgt.address for tgt in partition.root_targets),
(fs.address for fs in partition.root_field_sets),
hardcoded_interpreter_constraints=partition.interpreter_constraints,
),
)
Expand Down Expand Up @@ -252,7 +258,7 @@ async def mypy_typecheck_partition(
)


# TODO(#10863): Improve the performance of this, especially by not needing to calculate transitive
# TODO(#15241): Improve the performance of this, especially by not needing to calculate transitive
# targets per field set. Doing that would require changing how we calculate interpreter
# constraints to be more like how we determine resolves, i.e. only inspecting the root target
# (and later validating the closure is compatible).
Expand All @@ -261,43 +267,47 @@ async def mypy_determine_partitions(
request: MyPyRequest, mypy: MyPy, python_setup: PythonSetup
) -> MyPyPartitions:
# When determining how to batch by interpreter constraints, we must consider the entire
# transitive closure to get the final resulting constraints.
transitive_targets_per_field_set = await MultiGet(
Get(TransitiveTargets, TransitiveTargetsRequest([field_set.address]))
for field_set in request.field_sets
# transitive closure _per-root_ to get the final resulting constraints. See the method
# comment.
coarsened_targets = await Get(
CoarsenedTargets,
CoarsenedTargetsRequest(
(field_set.address for field_set in request.field_sets), expanded_targets=True
),
)

resolve_and_interpreter_constraints_to_transitive_targets = defaultdict(set)
for transitive_targets in transitive_targets_per_field_set:
resolve = transitive_targets.roots[0][PythonResolveField].normalized_value(python_setup)
coarsened_targets_by_address = coarsened_targets.by_address()

resolve_and_interpreter_constraints_to_coarsened_targets: Mapping[
tuple[str, InterpreterConstraints],
tuple[OrderedSet[MyPyFieldSet], OrderedSet[CoarsenedTarget]],
] = defaultdict(lambda: (OrderedSet(), OrderedSet()))
for root in request.field_sets:
ct = coarsened_targets_by_address[root.address]
# NB: If there is a cycle in the roots, we still only take the first resolve, as the other
# members will be validated when the partition is actually built.
resolve = ct.representative[PythonResolveField].normalized_value(python_setup)
# NB: We need to consume the entire un-memoized closure here. See the method comment.
interpreter_constraints = (
InterpreterConstraints.create_from_targets(transitive_targets.closure, python_setup)
InterpreterConstraints.create_from_targets(ct.closure(), python_setup)
or mypy.interpreter_constraints
)
resolve_and_interpreter_constraints_to_transitive_targets[
roots, root_cts = resolve_and_interpreter_constraints_to_coarsened_targets[
(resolve, interpreter_constraints)
].add(transitive_targets)

partitions = []
for (resolve, interpreter_constraints), all_transitive_targets in sorted(
resolve_and_interpreter_constraints_to_transitive_targets.items()
):
combined_roots: OrderedSet[Target] = OrderedSet()
combined_closure: OrderedSet[Target] = OrderedSet()
for transitive_targets in all_transitive_targets:
combined_roots.update(transitive_targets.roots)
combined_closure.update(transitive_targets.closure)
partitions.append(
# Note that we don't need to pass the resolve. pex_from_targets.py will already
# calculate it by inspecting the roots & validating that all dependees are valid.
MyPyPartition(
FrozenOrderedSet(combined_roots),
FrozenOrderedSet(combined_closure),
resolve if len(python_setup.resolves) > 1 else None,
interpreter_constraints,
)
]
roots.add(root)
root_cts.add(ct)

return MyPyPartitions(
MyPyPartition(
FrozenOrderedSet(roots),
FrozenOrderedSet(CoarsenedTargets(root_cts).closure()),
resolve if len(python_setup.resolves) > 1 else None,
interpreter_constraints,
)
for (resolve, interpreter_constraints), (roots, root_cts) in sorted(
resolve_and_interpreter_constraints_to_coarsened_targets.items()
)
return MyPyPartitions(partitions)
)


# TODO(#10864): Improve performance, e.g. by leveraging the MyPy cache.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -760,7 +760,7 @@ def assert_partition(
resolve: str,
) -> None:
root_addresses = {t.address for t in roots}
assert {t.address for t in partition.root_targets} == root_addresses
assert {fs.address for fs in partition.root_field_sets} == root_addresses
assert {t.address for t in partition.closure} == {
*root_addresses,
*(t.address for t in deps),
Expand Down
20 changes: 12 additions & 8 deletions src/python/pants/engine/internals/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
AllUnexpandedTargets,
CoarsenedTarget,
CoarsenedTargets,
CoarsenedTargetsRequest,
Dependencies,
DependenciesRequest,
ExplicitlyProvidedDependencies,
Expand Down Expand Up @@ -542,16 +543,19 @@ async def transitive_targets(request: TransitiveTargetsRequest) -> TransitiveTar


@rule
async def coarsened_targets(addresses: Addresses) -> CoarsenedTargets:
def coarsened_targets_request(addresses: Addresses) -> CoarsenedTargetsRequest:
return CoarsenedTargetsRequest(addresses)


@rule
async def coarsened_targets(request: CoarsenedTargetsRequest) -> CoarsenedTargets:
dependency_mapping = await Get(
_DependencyMapping,
_DependencyMappingRequest(
# NB: We set include_special_cased_deps=True because although computing CoarsenedTargets
# requires a transitive graph walk (to ensure that all cycles are actually detected),
# the resulting CoarsenedTargets instance is not itself transitive: everything not directly
# involved in a cycle with one of the input Addresses is discarded in the output.
TransitiveTargetsRequest(addresses, include_special_cased_deps=True),
expanded_targets=False,
TransitiveTargetsRequest(
request.roots, include_special_cased_deps=request.include_special_cased_deps
),
expanded_targets=request.expanded_targets,
),
)
addresses_to_targets = {
Expand All @@ -568,7 +572,7 @@ async def coarsened_targets(addresses: Addresses) -> CoarsenedTargets:

coarsened_targets: dict[Address, CoarsenedTarget] = {}
root_coarsened_targets = []
root_addresses_set = set(addresses)
root_addresses_set = set(request.roots)
for component in components:
component = sorted(component)
component_set = set(component)
Expand Down
Loading

0 comments on commit 387635f

Please sign in to comment.