diff --git a/src/python/pants/core/goals/package.py b/src/python/pants/core/goals/package.py index 528322da2e7..aab58a2606b 100644 --- a/src/python/pants/core/goals/package.py +++ b/src/python/pants/core/goals/package.py @@ -7,10 +7,12 @@ import os from abc import ABCMeta from dataclasses import dataclass +from typing import Iterable from pants.core.util_rules import distdir from pants.core.util_rules.distdir import DistDir from pants.core.util_rules.environments import EnvironmentNameRequest +from pants.engine.addresses import Address from pants.engine.environment import EnvironmentName from pants.engine.fs import Digest, MergeDigests, Workspace from pants.engine.goal import Goal, GoalSubsystem @@ -18,11 +20,15 @@ from pants.engine.target import ( AllTargets, AsyncFieldMixin, + Dependencies, FieldSet, FieldSetsPerTarget, FieldSetsPerTargetRequest, NoApplicableTargetsBehavior, + ShouldTraverseDepsPredicate, + SpecialCasedDependencies, StringField, + Target, TargetRootsToFieldSets, TargetRootsToFieldSetsRequest, Targets, @@ -30,6 +36,7 @@ from pants.engine.unions import UnionMembership, union from pants.util.docutil import bin_name from pants.util.logging import LogLevel +from pants.util.ordered_set import FrozenOrderedSet from pants.util.strutil import help_text logger = logging.getLogger(__name__) @@ -178,5 +185,35 @@ async def package_asset(workspace: Workspace, dist_dir: DistDir) -> Package: return Package(exit_code=0) +@dataclass(frozen=True) +class TraverseIfNotPackageTarget(ShouldTraverseDepsPredicate): + package_field_set_types: FrozenOrderedSet[PackageFieldSet] + roots: FrozenOrderedSet[Address] + always_traverse_roots: bool = True # traverse roots even if they are package targets + + def __init__( + self, + *, + union_membership: UnionMembership, + roots: Iterable[Address], + always_traverse_roots: bool = True, + ) -> None: + object.__setattr__(self, "package_field_set_types", union_membership.get(PackageFieldSet)) + object.__setattr__(self, "roots", FrozenOrderedSet(roots)) + object.__setattr__(self, "always_traverse_roots", always_traverse_roots) + super().__init__() + + def __call__(self, target: Target, field: Dependencies | SpecialCasedDependencies) -> bool: + if isinstance(field, SpecialCasedDependencies): + return False + if self.always_traverse_roots and target.address in self.roots: + return True + for field_set_type in self.package_field_set_types: + if field_set_type.is_applicable(target): + # False means do not traverse dependencies of this target + return False + return True + + def rules(): return (*collect_rules(), *distdir.rules()) diff --git a/src/python/pants/core/goals/package_test.py b/src/python/pants/core/goals/package_test.py index f1547e4c1fc..0f936f09505 100644 --- a/src/python/pants/core/goals/package_test.py +++ b/src/python/pants/core/goals/package_test.py @@ -10,13 +10,29 @@ import pytest from pants.core.goals import package -from pants.core.goals.package import BuiltPackage, BuiltPackageArtifact, Package, PackageFieldSet +from pants.core.goals.package import ( + BuiltPackage, + BuiltPackageArtifact, + Package, + PackageFieldSet, + TraverseIfNotPackageTarget, +) +from pants.engine.addresses import Address from pants.engine.fs import CreateDigest, Digest, FileContent from pants.engine.internals.selectors import Get from pants.engine.rules import rule -from pants.engine.target import StringField, Target -from pants.engine.unions import UnionRule -from pants.testutil.rule_runner import RuleRunner +from pants.engine.target import ( + Dependencies, + DependenciesRequest, + StringField, + Target, + Targets, + TransitiveTargets, + TransitiveTargetsRequest, +) +from pants.engine.unions import UnionMembership, UnionRule +from pants.testutil.rule_runner import QueryRule, RuleRunner +from pants.util.ordered_set import FrozenOrderedSet class MockTypeField(StringField): @@ -52,9 +68,13 @@ def synth(self, base_path: Path) -> tuple[CreateDigest, tuple[Path, ...]]: raise ValueError(f"don't understand {self.value}") +class MockDependenciesField(Dependencies): + pass + + class MockTarget(Target): alias = "mock" - core_fields = (MockTypeField,) + core_fields = (MockTypeField, MockDependenciesField) @dataclass(frozen=True) @@ -81,6 +101,8 @@ def rule_runner() -> RuleRunner: *package.rules(), package_mock_target, UnionRule(PackageFieldSet, MockPackageFieldSet), + QueryRule(Targets, [DependenciesRequest]), + QueryRule(TransitiveTargets, [TransitiveTargetsRequest]), ], target_types=[MockTarget], ) @@ -185,3 +207,48 @@ def test_package_replace_existing( assert set((dist_base / "x").iterdir()) == {a, b} assert a.read_text() == "directory: a" assert b.read_text() == "directory: b" + + +def test_transitive_targets_without_traversing_packages(rule_runner: RuleRunner) -> None: + rule_runner.write_files( + { + "src/BUILD": dedent( + """\ + mock(name='w', type='single_file') + mock(name='x', type='single_file') + mock(name='y', type='single_file', dependencies=[':w', ':x']) + mock(name='z', type='single_file', dependencies=[':y']) + """ + ) + } + ) + w = rule_runner.get_target(Address("src", target_name="w")) + x = rule_runner.get_target(Address("src", target_name="x")) + y = rule_runner.get_target(Address("src", target_name="y")) + z = rule_runner.get_target(Address("src", target_name="z")) + + direct_deps = rule_runner.request(Targets, [DependenciesRequest(z[MockDependenciesField])]) + assert direct_deps == Targets([y]) + + union_membership = rule_runner.request(UnionMembership, ()) + transitive_targets = rule_runner.request( + TransitiveTargets, + [ + TransitiveTargetsRequest( + [z.address], + should_traverse_deps_predicate=TraverseIfNotPackageTarget( + roots=[z.address], + union_membership=union_membership, + ), + ) + ], + ) + assert transitive_targets.roots == (z,) + # deps: z -> y -> x,w + # z should not see w or x as a transitive dep because y is also a package. + assert w not in transitive_targets.dependencies + assert x not in transitive_targets.dependencies + assert w not in transitive_targets.closure + assert x not in transitive_targets.closure + assert transitive_targets.dependencies == FrozenOrderedSet([y]) + assert transitive_targets.closure == FrozenOrderedSet([z, y])