Skip to content

Commit

Permalink
Add TraverseIfNotPackageTarget deps traversal predicate for use in …
Browse files Browse the repository at this point in the history
…plugins (#19306)

This builds on #19272, adding another `should_traverse_deps_predicate`
that stops dependency traversal at any package targets.

This is mostly extracted from #19155.

`TraverseIfNotPackageTarget` will be useful whenever a
`TransitiveTargetsRequest`, `CoarsenedTargetsRequest`, or
`DependenciesRequest` could benefit from treating package targets as
leaves. This PR does not change any `TransitiveTargetsRequest`s because
that is probably a change in user-facing behavior (even if it counts as
a bugfix) and needs to be documented as such. This PR merely adds the
feature, which, on its own, does not impact anything else.

Related:
- #18254
- #17368
- #15855
- #15082

---------

Co-authored-by: Andreas Stenius <andreas.stenius@imanage.com>
  • Loading branch information
cognifloyd and kaos authored Jun 17, 2023
1 parent c7b6187 commit d5812af
Show file tree
Hide file tree
Showing 2 changed files with 109 additions and 5 deletions.
37 changes: 37 additions & 0 deletions src/python/pants/core/goals/package.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,29 +7,36 @@
import os
from abc import ABCMeta
from dataclasses import dataclass
from typing import Iterable

from pants.core.util_rules import distdir
from pants.core.util_rules.distdir import DistDir
from pants.core.util_rules.environments import EnvironmentNameRequest
from pants.engine.addresses import Address
from pants.engine.environment import EnvironmentName
from pants.engine.fs import Digest, MergeDigests, Workspace
from pants.engine.goal import Goal, GoalSubsystem
from pants.engine.rules import Get, MultiGet, collect_rules, goal_rule, rule
from pants.engine.target import (
AllTargets,
AsyncFieldMixin,
Dependencies,
FieldSet,
FieldSetsPerTarget,
FieldSetsPerTargetRequest,
NoApplicableTargetsBehavior,
ShouldTraverseDepsPredicate,
SpecialCasedDependencies,
StringField,
Target,
TargetRootsToFieldSets,
TargetRootsToFieldSetsRequest,
Targets,
)
from pants.engine.unions import UnionMembership, union
from pants.util.docutil import bin_name
from pants.util.logging import LogLevel
from pants.util.ordered_set import FrozenOrderedSet
from pants.util.strutil import help_text

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -178,5 +185,35 @@ async def package_asset(workspace: Workspace, dist_dir: DistDir) -> Package:
return Package(exit_code=0)


@dataclass(frozen=True)
class TraverseIfNotPackageTarget(ShouldTraverseDepsPredicate):
package_field_set_types: FrozenOrderedSet[PackageFieldSet]
roots: FrozenOrderedSet[Address]
always_traverse_roots: bool = True # traverse roots even if they are package targets

def __init__(
self,
*,
union_membership: UnionMembership,
roots: Iterable[Address],
always_traverse_roots: bool = True,
) -> None:
object.__setattr__(self, "package_field_set_types", union_membership.get(PackageFieldSet))
object.__setattr__(self, "roots", FrozenOrderedSet(roots))
object.__setattr__(self, "always_traverse_roots", always_traverse_roots)
super().__init__()

def __call__(self, target: Target, field: Dependencies | SpecialCasedDependencies) -> bool:
if isinstance(field, SpecialCasedDependencies):
return False
if self.always_traverse_roots and target.address in self.roots:
return True
for field_set_type in self.package_field_set_types:
if field_set_type.is_applicable(target):
# False means do not traverse dependencies of this target
return False
return True


def rules():
return (*collect_rules(), *distdir.rules())
77 changes: 72 additions & 5 deletions src/python/pants/core/goals/package_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,29 @@
import pytest

from pants.core.goals import package
from pants.core.goals.package import BuiltPackage, BuiltPackageArtifact, Package, PackageFieldSet
from pants.core.goals.package import (
BuiltPackage,
BuiltPackageArtifact,
Package,
PackageFieldSet,
TraverseIfNotPackageTarget,
)
from pants.engine.addresses import Address
from pants.engine.fs import CreateDigest, Digest, FileContent
from pants.engine.internals.selectors import Get
from pants.engine.rules import rule
from pants.engine.target import StringField, Target
from pants.engine.unions import UnionRule
from pants.testutil.rule_runner import RuleRunner
from pants.engine.target import (
Dependencies,
DependenciesRequest,
StringField,
Target,
Targets,
TransitiveTargets,
TransitiveTargetsRequest,
)
from pants.engine.unions import UnionMembership, UnionRule
from pants.testutil.rule_runner import QueryRule, RuleRunner
from pants.util.ordered_set import FrozenOrderedSet


class MockTypeField(StringField):
Expand Down Expand Up @@ -52,9 +68,13 @@ def synth(self, base_path: Path) -> tuple[CreateDigest, tuple[Path, ...]]:
raise ValueError(f"don't understand {self.value}")


class MockDependenciesField(Dependencies):
pass


class MockTarget(Target):
alias = "mock"
core_fields = (MockTypeField,)
core_fields = (MockTypeField, MockDependenciesField)


@dataclass(frozen=True)
Expand All @@ -81,6 +101,8 @@ def rule_runner() -> RuleRunner:
*package.rules(),
package_mock_target,
UnionRule(PackageFieldSet, MockPackageFieldSet),
QueryRule(Targets, [DependenciesRequest]),
QueryRule(TransitiveTargets, [TransitiveTargetsRequest]),
],
target_types=[MockTarget],
)
Expand Down Expand Up @@ -185,3 +207,48 @@ def test_package_replace_existing(
assert set((dist_base / "x").iterdir()) == {a, b}
assert a.read_text() == "directory: a"
assert b.read_text() == "directory: b"


def test_transitive_targets_without_traversing_packages(rule_runner: RuleRunner) -> None:
rule_runner.write_files(
{
"src/BUILD": dedent(
"""\
mock(name='w', type='single_file')
mock(name='x', type='single_file')
mock(name='y', type='single_file', dependencies=[':w', ':x'])
mock(name='z', type='single_file', dependencies=[':y'])
"""
)
}
)
w = rule_runner.get_target(Address("src", target_name="w"))
x = rule_runner.get_target(Address("src", target_name="x"))
y = rule_runner.get_target(Address("src", target_name="y"))
z = rule_runner.get_target(Address("src", target_name="z"))

direct_deps = rule_runner.request(Targets, [DependenciesRequest(z[MockDependenciesField])])
assert direct_deps == Targets([y])

union_membership = rule_runner.request(UnionMembership, ())
transitive_targets = rule_runner.request(
TransitiveTargets,
[
TransitiveTargetsRequest(
[z.address],
should_traverse_deps_predicate=TraverseIfNotPackageTarget(
roots=[z.address],
union_membership=union_membership,
),
)
],
)
assert transitive_targets.roots == (z,)
# deps: z -> y -> x,w
# z should not see w or x as a transitive dep because y is also a package.
assert w not in transitive_targets.dependencies
assert x not in transitive_targets.dependencies
assert w not in transitive_targets.closure
assert x not in transitive_targets.closure
assert transitive_targets.dependencies == FrozenOrderedSet([y])
assert transitive_targets.closure == FrozenOrderedSet([z, y])

0 comments on commit d5812af

Please sign in to comment.