diff --git a/src/python/pants/backend/java/target_types.py b/src/python/pants/backend/java/target_types.py index 46811b19c3c..5ed1a5b4e33 100644 --- a/src/python/pants/backend/java/target_types.py +++ b/src/python/pants/backend/java/target_types.py @@ -22,7 +22,11 @@ generate_file_level_targets, ) from pants.engine.unions import UnionMembership, UnionRule -from pants.jvm.target_types import JvmCompatibleResolveNamesField, JvmResolveName +from pants.jvm.target_types import ( + JvmCompatibleResolveNamesField, + JvmProvidesTypesField, + JvmResolveName, +) class JavaSourceField(SingleSourceField): @@ -63,6 +67,7 @@ class JunitTestTarget(Target): JavaTestSourceField, Dependencies, JvmCompatibleResolveNamesField, + JvmProvidesTypesField, ) help = "A single Java test, run with JUnit." @@ -78,6 +83,7 @@ class JunitTestsGeneratorTarget(Target): JavaTestsGeneratorSourcesField, Dependencies, JvmCompatibleResolveNamesField, + JvmProvidesTypesField, ) help = "Generate a `junit_test` target for each file in the `sources` field." @@ -114,6 +120,7 @@ class JavaSourceTarget(Target): Dependencies, JavaSourceField, JvmCompatibleResolveNamesField, + JvmProvidesTypesField, ) help = "A single Java source file containing application or library code." @@ -129,6 +136,7 @@ class JavaSourcesGeneratorTarget(Target): Dependencies, JavaSourcesGeneratorSourcesField, JvmCompatibleResolveNamesField, + JvmProvidesTypesField, ) help = "Generate a `java_source` target for each file in the `sources` field." diff --git a/src/python/pants/backend/scala/target_types.py b/src/python/pants/backend/scala/target_types.py index 38905be010a..b8a416035ca 100644 --- a/src/python/pants/backend/scala/target_types.py +++ b/src/python/pants/backend/scala/target_types.py @@ -20,7 +20,7 @@ generate_file_level_targets, ) from pants.engine.unions import UnionMembership, UnionRule -from pants.jvm.target_types import JvmCompatibleResolveNamesField +from pants.jvm.target_types import JvmCompatibleResolveNamesField, JvmProvidesTypesField class ScalaSourceField(SingleSourceField): @@ -61,6 +61,7 @@ class ScalaJunitTestTarget(Target): Dependencies, ScalaTestSourceField, JvmCompatibleResolveNamesField, + JvmProvidesTypesField, ) help = "A single Scala test, run with JUnit." @@ -81,6 +82,7 @@ class ScalaJunitTestsGeneratorTarget(Target): ScalaTestsGeneratorSourcesField, Dependencies, JvmCompatibleResolveNamesField, + JvmProvidesTypesField, ) help = ( "Generate a `junit_test` target for each file in the `sources` field (defaults to " @@ -121,6 +123,7 @@ class ScalaSourceTarget(Target): Dependencies, ScalaSourceField, JvmCompatibleResolveNamesField, + JvmProvidesTypesField, ) help = "A single Scala source file containing application or library code." @@ -141,6 +144,7 @@ class ScalaSourcesGeneratorTarget(Target): Dependencies, ScalaSourcesGeneratorSourcesField, JvmCompatibleResolveNamesField, + JvmProvidesTypesField, ) help = ( "Generate a `scala_source` target for each file in the `sources` field (defaults to " diff --git a/src/python/pants/jvm/dependency_inference/artifact_mapper.py b/src/python/pants/jvm/dependency_inference/artifact_mapper.py index 56b05fec3d6..6c5af7a791f 100644 --- a/src/python/pants/jvm/dependency_inference/artifact_mapper.py +++ b/src/python/pants/jvm/dependency_inference/artifact_mapper.py @@ -16,6 +16,7 @@ JvmArtifactArtifactField, JvmArtifactGroupField, JvmArtifactPackagesField, + JvmProvidesTypesField, ) from pants.util.frozendict import FrozenDict from pants.util.logging import LogLevel @@ -56,10 +57,18 @@ def addresses_for_coordinates( class MutableTrieNode: + __slots__ = [ + "children", + "recursive", + "addresses", + "first_party", + ] # don't use a `dict` to store attrs + def __init__(self): self.children: dict[str, MutableTrieNode] = {} self.recursive: bool = False self.addresses: OrderedSet[Address] = OrderedSet() + self.first_party: bool = False def ensure_child(self, name: str) -> MutableTrieNode: if name in self.children: @@ -71,6 +80,14 @@ def ensure_child(self, name: str) -> MutableTrieNode: @frozen_after_init class FrozenTrieNode: + __slots__ = [ + "_is_frozen", + "_children", + "_recursive", + "_addresses", + "_first_party", + ] # don't use a `dict` to store attrs (speeds up attr access significantly) + def __init__(self, node: MutableTrieNode) -> None: children = {} for key, child in node.children.items(): @@ -78,6 +95,7 @@ def __init__(self, node: MutableTrieNode) -> None: self._children: FrozenDict[str, FrozenTrieNode] = FrozenDict(children) self._recursive: bool = node.recursive self._addresses: FrozenOrderedSet[Address] = FrozenOrderedSet(node.addresses) + self._first_party: bool = node.first_party def find_child(self, name: str) -> FrozenTrieNode | None: return self._children.get(name) @@ -86,6 +104,10 @@ def find_child(self, name: str) -> FrozenTrieNode | None: def recursive(self) -> bool: return self._recursive + @property + def first_party(self) -> bool: + return self._first_party + @property def addresses(self) -> FrozenOrderedSet[Address]: return self._addresses @@ -103,13 +125,17 @@ def __eq__(self, other: Any) -> bool: ) def __repr__(self): - return f"FrozenTrieNode(children={repr(self._children)}, recursive={self._recursive}, addresses={self._addresses})" + return f"FrozenTrieNode(children={repr(self._children)}, recursive={self._recursive}, addresses={self._addresses}, first_party={self._first_party})" class AllJvmArtifactTargets(Targets): pass +class AllJvmTypeProvidingTargets(Targets): + pass + + @rule(desc="Find all jvm_artifact targets in project", level=LogLevel.DEBUG) def find_all_jvm_artifact_targets(targets: AllTargets) -> AllJvmArtifactTargets: return AllJvmArtifactTargets( @@ -117,6 +143,15 @@ def find_all_jvm_artifact_targets(targets: AllTargets) -> AllJvmArtifactTargets: ) +@rule(desc="Find all targets with experimental_provides fields in project", level=LogLevel.DEBUG) +def find_all_jvm_provides_fields(targets: AllTargets) -> AllJvmTypeProvidingTargets: + return AllJvmTypeProvidingTargets( + tgt + for tgt in targets + if tgt.has_field(JvmProvidesTypesField) and tgt[JvmProvidesTypesField].value is not None + ) + + @dataclass(frozen=True) class ThirdPartyPackageToArtifactMapping: mapping_root: FrozenTrieNode @@ -126,6 +161,7 @@ class ThirdPartyPackageToArtifactMapping: async def find_available_third_party_artifacts( all_jvm_artifact_tgts: AllJvmArtifactTargets, ) -> AvailableThirdPartyArtifacts: + address_mapping: dict[UnversionedCoordinate, OrderedSet[Address]] = defaultdict(OrderedSet) package_mapping: dict[UnversionedCoordinate, OrderedSet[str]] = defaultdict(OrderedSet) for tgt in all_jvm_artifact_tgts: @@ -163,11 +199,15 @@ async def find_available_third_party_artifacts( async def compute_java_third_party_artifact_mapping( java_infer_subsystem: JavaInferSubsystem, available_artifacts: AvailableThirdPartyArtifacts, + all_jvm_type_providing_tgts: AllJvmTypeProvidingTargets, ) -> ThirdPartyPackageToArtifactMapping: """Implements the mapping logic from the `jvm_artifact` and `java-infer` help.""" def insert( - mapping: MutableTrieNode, package_pattern: str, addresses: Iterable[Address] + mapping: MutableTrieNode, + package_pattern: str, + addresses: Iterable[Address], + first_party: bool, ) -> None: imp_parts = package_pattern.split(".") recursive = False @@ -181,6 +221,7 @@ def insert( current_node = child_node current_node.addresses.update(addresses) + current_node.first_party = first_party current_node.recursive = recursive # Build a default mapping from coord to package. @@ -205,7 +246,12 @@ def insert( # Default to exposing the `group` name as a package. packages = (f"{coord.group}.**",) for package in packages: - insert(mapping, package, addresses) + insert(mapping, package, addresses, False) + + # Mark types that have strong first-party declarations as first-party + for tgt in all_jvm_type_providing_tgts: + for provides_type in tgt[JvmProvidesTypesField].value or []: + insert(mapping, provides_type, [], True) return ThirdPartyPackageToArtifactMapping(FrozenTrieNode(mapping)) @@ -231,6 +277,9 @@ def find_artifact_mapping( # If the length of the found nodes equals the number of parts of the package path, then there # is an exact match. if len(found_nodes) == len(imp_parts): + best_match = found_nodes[-1] + if best_match.first_party: + return FrozenOrderedSet() # The first-party symbol mapper should provide this dep return found_nodes[-1].addresses # Otherwise, check for the first found node (in reverse order) to match recursively, and use its coordinate. diff --git a/src/python/pants/jvm/dependency_inference/artifact_mapper_test.py b/src/python/pants/jvm/dependency_inference/artifact_mapper_test.py index 4275f28cbc3..39165f1cdae 100644 --- a/src/python/pants/jvm/dependency_inference/artifact_mapper_test.py +++ b/src/python/pants/jvm/dependency_inference/artifact_mapper_test.py @@ -17,13 +17,20 @@ FrozenTrieNode, ThirdPartyPackageToArtifactMapping, ) +from pants.jvm.dependency_inference.symbol_mapper import JvmFirstPartyPackageMappingException from pants.jvm.jdk_rules import rules as java_util_rules from pants.jvm.resolve.coursier_fetch import rules as coursier_fetch_rules from pants.jvm.resolve.coursier_setup import rules as coursier_setup_rules from pants.jvm.target_types import JvmArtifact from pants.jvm.testutil import maybe_skip_jdk_test from pants.jvm.util_rules import rules as util_rules -from pants.testutil.rule_runner import PYTHON_BOOTSTRAP_ENV, QueryRule, RuleRunner, logging +from pants.testutil.rule_runner import ( + PYTHON_BOOTSTRAP_ENV, + QueryRule, + RuleRunner, + engine_error, + logging, +) NAMED_RESOLVE_OPTIONS = '--jvm-resolves={"test": "coursier_resolve.lockfile"}' DEFAULT_RESOLVE_OPTION = "--jvm-default-resolve=test" @@ -300,3 +307,119 @@ def test_third_party_dep_inference_nonrecursive(rule_runner: RuleRunner) -> None assert rule_runner.request(Addresses, [DependenciesRequest(lib2[Dependencies])]) == Addresses( [Address("", target_name="joda-time_joda-time")] ) + + +@maybe_skip_jdk_test +def test_third_party_dep_inference_with_provides(rule_runner: RuleRunner) -> None: + rule_runner.set_options( + [ + "--java-infer-third-party-import-mapping={'org.joda.time.**':'joda-time:joda-time', 'org.joda.time.DateTime':'joda-time:joda-time-2'}" + ], + env_inherit=PYTHON_BOOTSTRAP_ENV, + ) + rule_runner.write_files( + { + "BUILD": dedent( + """\ + jvm_artifact( + name = "joda-time_joda-time", + group = "joda-time", + artifact = "joda-time", + version = "2.10.10", + ) + + java_sources( + name = 'lib', + experimental_provides_types = ['org.joda.time.MefripulousDateTime', ], + ) + """ + ), + "PrintDate.java": dedent( + """\ + package org.pantsbuild.example; + + import org.joda.time.DateTime; + import org.joda.time.MefripulousDateTime; + + public class PrintDate { + public static void main(String[] args) { + DateTime dt = new DateTime(); + System.out.println(dt.toString()); + new MefripulousDateTime().mefripulate(); + } + } + """ + ), + "MefripulousDateTime.java": dedent( + """\ + package org.joda.time; + + public class MefripulousDateTime { + public void mefripulate() { + DateTime dt = new LocalDateTime(); + System.out.println(dt.toString()); + } + } + """ + ), + } + ) + + lib1 = rule_runner.get_target( + Address("", target_name="lib", relative_file_path="PrintDate.java") + ) + assert rule_runner.request(Addresses, [DependenciesRequest(lib1[Dependencies])]) == Addresses( + [ + Address("", target_name="joda-time_joda-time"), + Address("", target_name="lib", relative_file_path="MefripulousDateTime.java"), + ] + ) + + +@maybe_skip_jdk_test +def test_third_party_dep_inference_with_incorrect_provides(rule_runner: RuleRunner) -> None: + rule_runner.set_options( + [ + "--java-infer-third-party-import-mapping={'org.joda.time.**':'joda-time:joda-time', 'org.joda.time.DateTime':'joda-time:joda-time-2'}" + ], + env_inherit=PYTHON_BOOTSTRAP_ENV, + ) + rule_runner.write_files( + { + "BUILD": dedent( + """\ + jvm_artifact( + name = "joda-time_joda-time", + group = "joda-time", + artifact = "joda-time", + version = "2.10.10", + ) + + java_sources( + name = 'lib', + experimental_provides_types = ['org.joda.time.DateTime', ], + ) + """ + ), + "PrintDate.java": dedent( + """\ + package org.pantsbuild.example; + + import org.joda.time.DateTime; + + public class PrintDate { + public static void main(String[] args) { + DateTime dt = new DateTime(); + System.out.println(dt.toString()); + } + } + """ + ), + } + ) + + lib1 = rule_runner.get_target( + Address("", target_name="lib", relative_file_path="PrintDate.java") + ) + with engine_error(JvmFirstPartyPackageMappingException): + rule_runner.request(Addresses, [DependenciesRequest(lib1[Dependencies])]) diff --git a/src/python/pants/jvm/dependency_inference/symbol_mapper.py b/src/python/pants/jvm/dependency_inference/symbol_mapper.py index 73372ce7923..31bea853f04 100644 --- a/src/python/pants/jvm/dependency_inference/symbol_mapper.py +++ b/src/python/pants/jvm/dependency_inference/symbol_mapper.py @@ -10,6 +10,8 @@ from pants.build_graph.address import Address from pants.engine.rules import Get, MultiGet, collect_rules, rule from pants.engine.unions import UnionMembership, union +from pants.jvm.dependency_inference.artifact_mapper import AllJvmTypeProvidingTargets +from pants.jvm.target_types import JvmProvidesTypesField from pants.util.logging import LogLevel logger = logging.getLogger(__name__) @@ -20,6 +22,10 @@ # ----------------------------------------------------------------------------------------------- +class JvmFirstPartyPackageMappingException(Exception): + pass + + class SymbolMap: """A mapping of JVM package names to owning addresses.""" @@ -78,6 +84,7 @@ class FirstPartySymbolMapping: @rule(level=LogLevel.DEBUG) async def merge_first_party_module_mappings( union_membership: UnionMembership, + targets_that_provide_types: AllJvmTypeProvidingTargets, ) -> FirstPartySymbolMapping: all_mappings = await MultiGet( Get( @@ -92,6 +99,25 @@ async def merge_first_party_module_mappings( for dep_map in all_mappings: merged_dep_map.merge(dep_map) + # `experimental_provides_types` ("`provides`") can be declared on a `java_sources` target, + # so each generated `java_source` target will have that `provides` annotation. All that matters + # here is that _one_ of the souce files amongst the set of sources actually provides that type. + + # Collect each address associated with a `provides` annotation and index by the provided type. + provided_types: dict[str, set[Address]] = defaultdict(set) + for tgt in targets_that_provide_types: + for provided_type in tgt[JvmProvidesTypesField].value or []: + provided_types[provided_type].add(tgt.address) + + # Check that at least one address declared by each `provides` value actually provides the type: + for provided_type, provided_addresses in provided_types.items(): + symbol_addresses = merged_dep_map.addresses_for_symbol(provided_type) + if not provided_addresses.intersection(symbol_addresses): + raise JvmFirstPartyPackageMappingException( + f"The target {next(iter(provided_addresses))} declares that it provides the JVM type " + f"`{provided_type}`, however, it does not appear to actually provide that type." + ) + return FirstPartySymbolMapping(merged_dep_map) diff --git a/src/python/pants/jvm/target_types.py b/src/python/pants/jvm/target_types.py index aed927ebf40..2948e7d44a2 100644 --- a/src/python/pants/jvm/target_types.py +++ b/src/python/pants/jvm/target_types.py @@ -64,6 +64,17 @@ class JvmArtifactPackagesField(StringSequenceField): ) +class JvmProvidesTypesField(StringSequenceField): + alias = "experimental_provides_types" + help = ( + "Signals that the specified types should be fulfilled by these source files during " + "dependency inference. This allows for specific types within packages that are otherwise " + "inferred as belonging to `jvm_artifact` targets to be unambiguously inferred as belonging " + "to this first-party source. If a given type is defined, at least one source file captured " + "by this target must actually provide that symbol." + ) + + class JvmArtifactFieldSet(FieldSet): group: JvmArtifactGroupField