Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support strong first-party declarations of provided types #13698

Merged
2 changes: 2 additions & 0 deletions src/python/pants/backend/java/dependency_inference/rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,8 @@ async def infer_java_dependencies_and_exports_via_source_analysis(
third_party_matches: FrozenOrderedSet[Address] = FrozenOrderedSet()
if java_infer_subsystem.third_party_imports:
third_party_matches = find_artifact_mapping(typ, third_party_artifact_mapping)
if "AbstractMatcher" in str(request.source.address):
logger.warning("%s", f"TPM: {third_party_matches}")
matches = first_party_matches.union(third_party_matches)
if not matches:
continue
Expand Down
10 changes: 9 additions & 1 deletion src/python/pants/backend/java/target_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,11 @@
generate_file_level_targets,
)
from pants.engine.unions import UnionMembership, UnionRule
from pants.jvm.target_types import JvmCompatibleResolveNamesField, JvmResolveName
from pants.jvm.target_types import (
JvmCompatibleResolveNamesField,
JvmProvidesTypesField,
JvmResolveName,
)


class JavaSourceField(SingleSourceField):
Expand Down Expand Up @@ -63,6 +67,7 @@ class JunitTestTarget(Target):
JavaTestSourceField,
Dependencies,
JvmCompatibleResolveNamesField,
JvmProvidesTypesField,
)
help = "A single Java test, run with JUnit."

Expand All @@ -78,6 +83,7 @@ class JunitTestsGeneratorTarget(Target):
JavaTestsGeneratorSourcesField,
Dependencies,
JvmCompatibleResolveNamesField,
JvmProvidesTypesField,
)
help = "Generate a `junit_test` target for each file in the `sources` field."

Expand Down Expand Up @@ -114,6 +120,7 @@ class JavaSourceTarget(Target):
Dependencies,
JavaSourceField,
JvmCompatibleResolveNamesField,
JvmProvidesTypesField,
)
help = "A single Java source file containing application or library code."

Expand All @@ -129,6 +136,7 @@ class JavaSourcesGeneratorTarget(Target):
Dependencies,
JavaSourcesGeneratorSourcesField,
JvmCompatibleResolveNamesField,
JvmProvidesTypesField,
)
help = "Generate a `java_source` target for each file in the `sources` field."

Expand Down
65 changes: 62 additions & 3 deletions src/python/pants/jvm/dependency_inference/artifact_mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
JvmArtifactArtifactField,
JvmArtifactGroupField,
JvmArtifactPackagesField,
JvmProvidesTypesField,
)
from pants.util.frozendict import FrozenDict
from pants.util.logging import LogLevel
Expand All @@ -38,6 +39,14 @@ def from_coord_str(cls, coord: str) -> UnversionedCoordinate:
return UnversionedCoordinate(group=coordinate_parts[0], artifact=coordinate_parts[1])


class FirstPartySourceProvided:
"""Marks when a first-party source has declared that it provides a given JVM symbol.

This resolves disambiguation cases in dependency inference when a first-party source provides
one type in a package that is otherwise fulfilled by third-party artifacts.
"""


@dataclass(frozen=True)
class AvailableThirdPartyArtifacts:
"""Maps JVM unversioned coordinates to target `Address`es and declared packages."""
Expand All @@ -56,10 +65,19 @@ def addresses_for_coordinates(


class MutableTrieNode:

__slots__ = [
tdyas marked this conversation as resolved.
Show resolved Hide resolved
"children",
"recursive",
"addresses",
"first_party",
] # don't use a `dict` to store attrs

def __init__(self):
self.children: dict[str, MutableTrieNode] = {}
self.recursive: bool = False
self.addresses: OrderedSet[Address] = OrderedSet()
self.first_party: bool = False
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we end up having the ! syntax for packages, we may want to redefine this field for more complicated exclusions, but for now, this will do.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed.


def ensure_child(self, name: str) -> MutableTrieNode:
if name in self.children:
Expand All @@ -71,13 +89,23 @@ def ensure_child(self, name: str) -> MutableTrieNode:

@frozen_after_init
class FrozenTrieNode:

chrisjrn marked this conversation as resolved.
Show resolved Hide resolved
__slots__ = [
"_is_frozen",
"_children",
"_recursive",
"_addresses",
"_first_party",
] # don't use a `dict` to store attrs (speeds up attr access significantly)

def __init__(self, node: MutableTrieNode) -> None:
children = {}
for key, child in node.children.items():
children[key] = FrozenTrieNode(child)
self._children: FrozenDict[str, FrozenTrieNode] = FrozenDict(children)
self._recursive: bool = node.recursive
self._addresses: FrozenOrderedSet[Address] = FrozenOrderedSet(node.addresses)
self._first_party: bool = node.first_party

def find_child(self, name: str) -> FrozenTrieNode | None:
return self._children.get(name)
Expand All @@ -86,6 +114,10 @@ def find_child(self, name: str) -> FrozenTrieNode | None:
def recursive(self) -> bool:
return self._recursive

@property
def first_party(self) -> bool:
return self._first_party

@property
def addresses(self) -> FrozenOrderedSet[Address]:
return self._addresses
Expand All @@ -103,20 +135,33 @@ def __eq__(self, other: Any) -> bool:
)

def __repr__(self):
return f"FrozenTrieNode(children={repr(self._children)}, recursive={self._recursive}, addresses={self._addresses})"
return f"FrozenTrieNode(children={repr(self._children)}, recursive={self._recursive}, addresses={self._addresses}, first_party={self._first_party})"


class AllJvmArtifactTargets(Targets):
pass


class AllJvmTypeProvidingTargets(Targets):
pass


@rule(desc="Find all jvm_artifact targets in project", level=LogLevel.DEBUG)
def find_all_jvm_artifact_targets(targets: AllTargets) -> AllJvmArtifactTargets:
return AllJvmArtifactTargets(
tgt for tgt in targets if tgt.has_fields((JvmArtifactGroupField, JvmArtifactArtifactField))
)


@rule(desc="Find all targets with experimental_provides fields in project", level=LogLevel.DEBUG)
def find_all_jvm_provides_fields(targets: AllTargets) -> AllJvmTypeProvidingTargets:
return AllJvmTypeProvidingTargets(
tgt
for tgt in targets
if tgt.has_fields((JvmProvidesTypesField,)) and tgt[JvmProvidesTypesField].value is not None
chrisjrn marked this conversation as resolved.
Show resolved Hide resolved
)


@dataclass(frozen=True)
class ThirdPartyPackageToArtifactMapping:
mapping_root: FrozenTrieNode
Expand All @@ -126,6 +171,7 @@ class ThirdPartyPackageToArtifactMapping:
async def find_available_third_party_artifacts(
all_jvm_artifact_tgts: AllJvmArtifactTargets,
) -> AvailableThirdPartyArtifacts:

address_mapping: dict[UnversionedCoordinate, OrderedSet[Address]] = defaultdict(OrderedSet)
package_mapping: dict[UnversionedCoordinate, OrderedSet[str]] = defaultdict(OrderedSet)
for tgt in all_jvm_artifact_tgts:
Expand Down Expand Up @@ -163,11 +209,15 @@ async def find_available_third_party_artifacts(
async def compute_java_third_party_artifact_mapping(
java_infer_subsystem: JavaInferSubsystem,
available_artifacts: AvailableThirdPartyArtifacts,
all_jvm_type_providing_tgts: AllJvmTypeProvidingTargets,
) -> ThirdPartyPackageToArtifactMapping:
"""Implements the mapping logic from the `jvm_artifact` and `java-infer` help."""

def insert(
mapping: MutableTrieNode, package_pattern: str, addresses: Iterable[Address]
mapping: MutableTrieNode,
package_pattern: str,
addresses: Iterable[Address],
first_party: bool,
) -> None:
imp_parts = package_pattern.split(".")
recursive = False
Expand All @@ -181,6 +231,7 @@ def insert(
current_node = child_node

current_node.addresses.update(addresses)
current_node.first_party = first_party
current_node.recursive = recursive

# Build a default mapping from coord to package.
Expand All @@ -205,7 +256,12 @@ def insert(
# Default to exposing the `group` name as a package.
packages = (f"{coord.group}.**",)
for package in packages:
insert(mapping, package, addresses)
insert(mapping, package, addresses, False)

# Mark types that have strong first-party declarations as first-party
for tgt in all_jvm_type_providing_tgts:
for provides_types in tgt[JvmProvidesTypesField].value or []:
chrisjrn marked this conversation as resolved.
Show resolved Hide resolved
insert(mapping, provides_types, [], True)

return ThirdPartyPackageToArtifactMapping(FrozenTrieNode(mapping))

Expand All @@ -231,6 +287,9 @@ def find_artifact_mapping(
# If the length of the found nodes equals the number of parts of the package path, then there
# is an exact match.
if len(found_nodes) == len(imp_parts):
best_match = found_nodes[-1]
if best_match.first_party:
return FrozenOrderedSet() # The first-party symbol mapper should provide this dep
return found_nodes[-1].addresses

# Otherwise, check for the first found node (in reverse order) to match recursively, and use its coordinate.
Expand Down
5 changes: 5 additions & 0 deletions src/python/pants/jvm/target_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,11 @@ class JvmArtifactPackagesField(StringSequenceField):
)


class JvmProvidesTypesField(StringSequenceField):
alias = "experimental_provides_types"
help = "TODO: Add help for this."
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Still needs to be filled in.



class JvmArtifactFieldSet(FieldSet):

group: JvmArtifactGroupField
Expand Down
8 changes: 8 additions & 0 deletions testprojects/src/jvm/org/pantsbuild/example/app/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,11 @@ deploy_jar(
scala_sources(
compatible_resolves=["exampleapp"],
)

jvm_artifact(
name="com.google.truth_truth",
group="com.google.truth",
artifact="truth",
version="0.45",
# scope = "test",
)
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
package org.pantsbuild.example.app;

import org.pantsbuild.example.lib.ExampleLib
import com.google.common.truth.BadAndDangerous

class ExampleApp {
def main(args: Array[String]): Unit = {
println(ExampleLib.hello())
println(BadAndDangerous.hello())
com.google.common.truth.Truth.assertThat(new Object())
}
}
Loading