Skip to content

Commit

Permalink
jvm: teach dependency inference about multiple resolves (#14491)
Browse files Browse the repository at this point in the history
As described in #14293, dependency inference does not currently work properly for first-party sources when multiple resolves are configured. This PR fixes the issue for JVM (Java/Scala) first-party sources.

The solution is similar to what was already done for third-party dependency inference: make the resolve part of the key in the symbol map. With this change, dependency inference looks up symbols by both name and resolve.
  • Loading branch information
Tom Dyas authored Feb 15, 2022
1 parent c671f02 commit 05ec375
Show file tree
Hide file tree
Showing 6 changed files with 113 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ async def infer_java_dependencies_and_exports_via_source_analysis(
dependencies: OrderedSet[Address] = OrderedSet()
exports: OrderedSet[Address] = OrderedSet()
for typ in types:
first_party_matches = first_party_dep_map.symbols.addresses_for_symbol(typ)
first_party_matches = first_party_dep_map.symbols.addresses_for_symbol(typ, resolve=resolve)
third_party_matches = (
third_party_artifact_mapping.addresses_for_symbol(typ, resolve)
if java_infer_subsystem.third_party_imports
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
from pants.engine.unions import UnionRule
from pants.jvm.dependency_inference import symbol_mapper
from pants.jvm.dependency_inference.symbol_mapper import FirstPartyMappingRequest, SymbolMap
from pants.jvm.subsystems import JvmSubsystem
from pants.jvm.target_types import JvmResolveField
from pants.util.logging import LogLevel

logger = logging.getLogger(__name__)
Expand All @@ -33,18 +35,23 @@ class FirstPartyJavaTargetsMappingRequest(FirstPartyMappingRequest):

@rule(desc="Map all first party Java targets to their packages", level=LogLevel.DEBUG)
async def map_first_party_java_targets_to_symbols(
_: FirstPartyJavaTargetsMappingRequest, java_targets: AllJavaTargets
_: FirstPartyJavaTargetsMappingRequest,
java_targets: AllJavaTargets,
jvm: JvmSubsystem,
) -> SymbolMap:
source_analysis = await MultiGet(
Get(JavaSourceDependencyAnalysis, SourceFilesRequest([target[JavaSourceField]]))
for target in java_targets
)
address_and_analysis = zip([t.address for t in java_targets], source_analysis)
address_and_analysis = zip(
[(tgt.address, tgt[JvmResolveField].normalized_value(jvm)) for tgt in java_targets],
source_analysis,
)

dep_map = SymbolMap()
for address, analysis in address_and_analysis:
for (address, resolve), analysis in address_and_analysis:
for top_level_type in analysis.top_level_types:
dep_map.add_symbol(top_level_type, address=address)
dep_map.add_symbol(top_level_type, address=address, resolve=resolve)

return dep_map

Expand Down
4 changes: 3 additions & 1 deletion src/python/pants/backend/scala/dependency_inference/rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,9 @@ async def infer_scala_dependencies_via_source_analysis(

dependencies: OrderedSet[Address] = OrderedSet()
for symbol in symbols:
first_party_matches = first_party_symbol_map.symbols.addresses_for_symbol(symbol)
first_party_matches = first_party_symbol_map.symbols.addresses_for_symbol(
symbol, resolve=resolve
)
third_party_matches = third_party_artifact_mapping.addresses_for_symbol(symbol, resolve)
matches = first_party_matches.union(third_party_matches)
if not matches:
Expand Down
70 changes: 70 additions & 0 deletions src/python/pants/backend/scala/dependency_inference/rules_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,3 +233,73 @@ def main(args: Array[String]): Unit = {
InferredDependencies, [InferScalaSourceDependencies(tgt[ScalaSourceField])]
)
assert deps == InferredDependencies([Address("bar", relative_file_path="B.scala")])


@maybe_skip_jdk_test
def test_multi_resolve_dependency_inference(rule_runner: RuleRunner) -> None:
rule_runner.write_files(
{
"lib/BUILD": dedent(
"""\
scala_sources(name="lib_2_13", resolve="scala-2.13")
scala_sources(name="lib_2_12", resolve="scala-2.12")
"""
),
"lib/Library.scala": dedent(
"""\
package org.pantsbuild.lib
object Library {
def grok(): Unit = {
println("Hello world!")
}
}
"""
),
"user/BUILD": dedent(
"""\
scala_sources(name="user_2_13", resolve="scala-2.13")
scala_sources(name="user_2_12", resolve="scala-2.12")
"""
),
"user/Main.scala": dedent(
"""\
package org.pantsbuild.user
import org.pantsbuild.lib.Library
object Main {
def main(args: Array[String]): Unit = {
Library.grok()
}
}
"""
),
}
)
rule_runner.set_options(
[
'--jvm-resolves={"scala-2.13":"3rdparty/jvm/scala-2.13.lock", "scala-2.12":"3rdparty/jvm/scala-2.12.lock"}'
],
env_inherit=PYTHON_BOOTSTRAP_ENV,
)

tgt = rule_runner.get_target(
Address("user", relative_file_path="Main.scala", target_name="user_2_13")
)
deps = rule_runner.request(
InferredDependencies, [InferScalaSourceDependencies(tgt[ScalaSourceField])]
)
assert deps == InferredDependencies(
[Address("lib", relative_file_path="Library.scala", target_name="lib_2_13")]
)

tgt = rule_runner.get_target(
Address("user", relative_file_path="Main.scala", target_name="user_2_12")
)
deps = rule_runner.request(
InferredDependencies, [InferScalaSourceDependencies(tgt[ScalaSourceField])]
)
assert deps == InferredDependencies(
[Address("lib", relative_file_path="Library.scala", target_name="lib_2_12")]
)
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
from pants.engine.unions import UnionRule
from pants.jvm.dependency_inference import symbol_mapper
from pants.jvm.dependency_inference.symbol_mapper import FirstPartyMappingRequest, SymbolMap
from pants.jvm.subsystems import JvmSubsystem
from pants.jvm.target_types import JvmResolveField
from pants.util.logging import LogLevel


Expand All @@ -33,19 +35,23 @@ def find_all_scala_targets(targets: AllTargets) -> AllScalaTargets:
async def map_first_party_scala_targets_to_symbols(
_: FirstPartyScalaTargetsMappingRequest,
scala_targets: AllScalaTargets,
jvm: JvmSubsystem,
) -> SymbolMap:
source_analysis = await MultiGet(
Get(ScalaSourceDependencyAnalysis, SourceFilesRequest([target[ScalaSourceField]]))
for target in scala_targets
)
address_and_analysis = zip([t.address for t in scala_targets], source_analysis)
address_and_analysis = zip(
[(tgt.address, tgt[JvmResolveField].normalized_value(jvm)) for tgt in scala_targets],
source_analysis,
)

symbol_map = SymbolMap()
for address, analysis in address_and_analysis:
for (address, resolve), analysis in address_and_analysis:
for symbol in analysis.provided_symbols:
symbol_map.add_symbol(symbol, address)
symbol_map.add_symbol(symbol, address, resolve=resolve)
for symbol in analysis.provided_symbols_encoded:
symbol_map.add_symbol(symbol, address)
symbol_map.add_symbol(symbol, address, resolve=resolve)

return symbol_map

Expand Down
31 changes: 18 additions & 13 deletions src/python/pants/jvm/dependency_inference/symbol_mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
from pants.engine.rules import Get, MultiGet, collect_rules, rule
from pants.engine.unions import UnionMembership, union
from pants.jvm.dependency_inference.artifact_mapper import AllJvmTypeProvidingTargets
from pants.jvm.target_types import JvmProvidesTypesField
from pants.jvm.subsystems import JvmSubsystem
from pants.jvm.target_types import JvmProvidesTypesField, JvmResolveField
from pants.util.logging import LogLevel

logger = logging.getLogger(__name__)
Expand All @@ -31,28 +32,30 @@ class SymbolMap:
"""A mapping of JVM package names to owning addresses."""

def __init__(self):
self._symbol_map: dict[str, set[Address]] = defaultdict(set)
self._symbol_map: dict[tuple[str, str], set[Address]] = defaultdict(set)

def add_symbol(self, symbol: str, address: Address):
def add_symbol(self, symbol: str, address: Address, *, resolve: str):
"""Declare a single Address as a provider of a symbol."""
self._symbol_map[symbol].add(address)
self._symbol_map[(resolve, symbol)].add(address)

def addresses_for_symbol(self, symbol: str) -> frozenset[Address]:
def addresses_for_symbol(self, symbol: str, *, resolve: str) -> frozenset[Address]:
"""Returns the set of addresses that provide the passed symbol.
:param symbol: a fully-qualified JVM symbol (e.g. `foo.bar.Thing`).
:param resolve: name of resolve name in which to check for symbol.
"""
return frozenset(self._symbol_map[symbol])
return frozenset(self._symbol_map[(resolve, symbol)])

def merge(self, other: SymbolMap) -> None:
"""Merge 'other' into this dependency map."""
for symbol, addresses in other._symbol_map.items():
self._symbol_map[symbol] |= addresses
for (resolve, symbol), addresses in other._symbol_map.items():
self._symbol_map[(resolve, symbol)] |= addresses

def to_json_dict(self):
return {
"symbol_map": {
sym: [str(addr) for addr in addrs] for sym, addrs in self._symbol_map.items()
f"{resolve}/{sym}": [str(addr) for addr in addrs]
for (resolve, sym), addrs in self._symbol_map.items()
},
}

Expand Down Expand Up @@ -82,6 +85,7 @@ class FirstPartySymbolMapping:
async def merge_first_party_module_mappings(
union_membership: UnionMembership,
targets_that_provide_types: AllJvmTypeProvidingTargets,
jvm: JvmSubsystem,
) -> FirstPartySymbolMapping:
all_mappings = await MultiGet(
Get(
Expand All @@ -101,14 +105,15 @@ async def merge_first_party_module_mappings(
# here is that _one_ of the souce files amongst the set of sources actually provides that type.

# Collect each address associated with a `provides` annotation and index by the provided type.
provided_types: dict[str, set[Address]] = defaultdict(set)
provided_types: dict[tuple[str, str], set[Address]] = defaultdict(set)
for tgt in targets_that_provide_types:
resolve = tgt[JvmResolveField].normalized_value(jvm)
for provided_type in tgt[JvmProvidesTypesField].value or []:
provided_types[provided_type].add(tgt.address)
provided_types[(resolve, provided_type)].add(tgt.address)

# Check that at least one address declared by each `provides` value actually provides the type:
for provided_type, provided_addresses in provided_types.items():
symbol_addresses = merged_dep_map.addresses_for_symbol(provided_type)
for (resolve, provided_type), provided_addresses in provided_types.items():
symbol_addresses = merged_dep_map.addresses_for_symbol(provided_type, resolve=resolve)
if not provided_addresses.intersection(symbol_addresses):
raise JvmFirstPartyPackageMappingException(
f"The target {next(iter(provided_addresses))} declares that it provides the JVM type "
Expand Down

0 comments on commit 05ec375

Please sign in to comment.