Skip to content

Commit

Permalink
Generically compute dynamic defaults for Fields (#16206)
Browse files Browse the repository at this point in the history
As discussed on #16175, we don't currently consume the "dynamic" defaults of field values for the purposes of `parametrize`. That is at least partially because there is no generic way to do so: a `Field` has no way to declare a dynamic default currently, because `Field`s cannot declare a dependency `@rule_helper` to compute their value (...yet? see #12934 (comment)).

This change adds a mechanism for generically declaring the default value of a `Field`. This is definitely not the most ergonomic API: over the next few versions, many dynamic `Field` defaults will hopefully move to `__defaults__`. And #12934 (comment) will hopefully allow for significantly cleaning up those that remain. 

Fixes #16175.

[ci skip-rust]
[ci skip-build-wheels]
  • Loading branch information
stuhood authored Jul 18, 2022
1 parent 5ff9d52 commit 9df18f3
Show file tree
Hide file tree
Showing 13 changed files with 270 additions and 142 deletions.
19 changes: 1 addition & 18 deletions src/python/pants/backend/java/bsp/rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@
BSPBuildTargetsMetadataResult,
BSPCompileRequest,
BSPCompileResult,
BSPResolveFieldFactoryRequest,
BSPResolveFieldFactoryResult,
BSPResourcesRequest,
BSPResourcesResult,
)
Expand All @@ -28,7 +26,6 @@
from pants.jvm.bsp.resources import _jvm_bsp_resources
from pants.jvm.bsp.resources import rules as jvm_resources_rules
from pants.jvm.compile import ClasspathEntryRequestFactory
from pants.jvm.subsystems import JvmSubsystem
from pants.jvm.target_types import JvmResolveField

LANGUAGE_ID = "java"
Expand All @@ -50,28 +47,15 @@ class JavaMetadataFieldSet(FieldSet):
resolve: JvmResolveField


class JavaBSPResolveFieldFactoryRequest(BSPResolveFieldFactoryRequest):
resolve_prefix = "jvm"


class JavaBSPBuildTargetsMetadataRequest(BSPBuildTargetsMetadataRequest):
language_id = LANGUAGE_ID
can_merge_metadata_from = ()
field_set_type = JavaMetadataFieldSet

resolve_prefix = "jvm"
resolve_field = JvmResolveField


@rule
def bsp_resolve_field_factory(
request: JavaBSPResolveFieldFactoryRequest,
jvm: JvmSubsystem,
) -> BSPResolveFieldFactoryResult:
return BSPResolveFieldFactoryResult(
lambda target: target.get(JvmResolveField).normalized_value(jvm)
)


@rule
async def bsp_resolve_java_metadata(
_: JavaBSPBuildTargetsMetadataRequest,
Expand Down Expand Up @@ -169,7 +153,6 @@ def rules():
*jvm_compile_rules(),
*jvm_resources_rules(),
UnionRule(BSPLanguageSupport, JavaBSPLanguageSupport),
UnionRule(BSPResolveFieldFactoryRequest, JavaBSPResolveFieldFactoryRequest),
UnionRule(BSPBuildTargetsMetadataRequest, JavaBSPBuildTargetsMetadataRequest),
UnionRule(BSPHandlerMapping, JavacOptionsHandlerMapping),
UnionRule(BSPCompileRequest, JavaBSPCompileRequest),
Expand Down
6 changes: 5 additions & 1 deletion src/python/pants/backend/java/target_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
TargetFilesGenerator,
generate_multiple_sources_field_help_message,
)
from pants.jvm import target_types as jvm_target_types
from pants.jvm.target_types import (
JunitTestSourceField,
JunitTestTimeoutField,
Expand Down Expand Up @@ -137,4 +138,7 @@ class JavaSourcesGeneratorTarget(TargetFilesGenerator):


def rules():
return collect_rules()
return [
*collect_rules(),
*jvm_target_types.rules(),
]
20 changes: 20 additions & 0 deletions src/python/pants/backend/python/target_types_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@
from pants.engine.target import (
DependenciesRequest,
ExplicitlyProvidedDependencies,
FieldDefaultFactoryRequest,
FieldDefaultFactoryResult,
FieldSet,
GeneratedTargets,
GenerateTargetsRequest,
Expand Down Expand Up @@ -496,6 +498,23 @@ async def infer_python_distribution_dependencies(
)


# -----------------------------------------------------------------------------------------------
# Dynamic Field defaults
# -----------------------------------------------------------------------------------------------


class PythonResolveFieldDefaultFactoryRequest(FieldDefaultFactoryRequest):
field_type = PythonResolveField


@rule
def python_resolve_field_default_factory(
request: PythonResolveFieldDefaultFactoryRequest,
python_setup: PythonSetup,
) -> FieldDefaultFactoryResult:
return FieldDefaultFactoryResult(lambda f: f.normalized_value(python_setup))


# -----------------------------------------------------------------------------------------------
# Dependency validation
# -----------------------------------------------------------------------------------------------
Expand Down Expand Up @@ -561,6 +580,7 @@ def rules():
return (
*collect_rules(),
*import_rules(),
UnionRule(FieldDefaultFactoryRequest, PythonResolveFieldDefaultFactoryRequest),
UnionRule(TargetFilesGeneratorSettingsRequest, PythonFilesGeneratorSettingsRequest),
UnionRule(GenerateTargetsRequest, GenerateTargetsFromPexBinaries),
UnionRule(InferDependenciesRequest, InferPexBinaryEntryPointDependency),
Expand Down
20 changes: 3 additions & 17 deletions src/python/pants/backend/scala/bsp/rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,6 @@
BSPCompileResult,
BSPDependencyModulesRequest,
BSPDependencyModulesResult,
BSPResolveFieldFactoryRequest,
BSPResolveFieldFactoryResult,
BSPResourcesRequest,
BSPResourcesResult,
)
Expand Down Expand Up @@ -86,15 +84,14 @@ class ScalaMetadataFieldSet(FieldSet):
jdk: JvmJdkField


class ScalaBSPResolveFieldFactoryRequest(BSPResolveFieldFactoryRequest):
resolve_prefix = "jvm"


class ScalaBSPBuildTargetsMetadataRequest(BSPBuildTargetsMetadataRequest):
language_id = LANGUAGE_ID
can_merge_metadata_from = ("java",)
field_set_type = ScalaMetadataFieldSet

resolve_prefix = "jvm"
resolve_field = JvmResolveField


@dataclass(frozen=True)
class ThirdpartyModulesRequest:
Expand Down Expand Up @@ -178,16 +175,6 @@ async def _materialize_scala_runtime_jars(scala_version: str) -> Snapshot:
)


@rule
def bsp_resolve_field_factory(
request: ScalaBSPResolveFieldFactoryRequest,
jvm: JvmSubsystem,
) -> BSPResolveFieldFactoryResult:
return BSPResolveFieldFactoryResult(
lambda target: target.get(JvmResolveField).normalized_value(jvm)
)


@rule
async def bsp_resolve_scala_metadata(
request: ScalaBSPBuildTargetsMetadataRequest,
Expand Down Expand Up @@ -540,7 +527,6 @@ def rules():
*jvm_resources_rules(),
UnionRule(BSPLanguageSupport, ScalaBSPLanguageSupport),
UnionRule(BSPBuildTargetsMetadataRequest, ScalaBSPBuildTargetsMetadataRequest),
UnionRule(BSPResolveFieldFactoryRequest, ScalaBSPResolveFieldFactoryRequest),
UnionRule(BSPHandlerMapping, ScalacOptionsHandlerMapping),
UnionRule(BSPHandlerMapping, ScalaMainClassesHandlerMapping),
UnionRule(BSPHandlerMapping, ScalaTestClassesHandlerMapping),
Expand Down
2 changes: 2 additions & 0 deletions src/python/pants/backend/scala/target_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
generate_multiple_sources_field_help_message,
)
from pants.engine.unions import UnionRule
from pants.jvm import target_types as jvm_target_types
from pants.jvm.target_types import (
JunitTestSourceField,
JunitTestTimeoutField,
Expand Down Expand Up @@ -349,5 +350,6 @@ class ScalacPluginTarget(Target):
def rules():
return (
*collect_rules(),
*jvm_target_types.rules(),
UnionRule(TargetFilesGeneratorSettingsRequest, ScalaSettingsRequest),
)
56 changes: 15 additions & 41 deletions src/python/pants/bsp/util_rules/targets.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from typing import ClassVar, Generic, Sequence, Type, TypeVar

import toml
from typing_extensions import Protocol

from pants.base.build_root import BuildRoot
from pants.base.glob_match_error_behavior import GlobMatchErrorBehavior
Expand Down Expand Up @@ -48,11 +47,12 @@
from pants.engine.internals.selectors import Get, MultiGet
from pants.engine.rules import _uncacheable_rule, collect_rules, rule
from pants.engine.target import (
Field,
FieldDefaults,
FieldSet,
SourcesField,
SourcesPaths,
SourcesPathsRequest,
Target,
Targets,
)
from pants.engine.unions import UnionMembership, UnionRule, union
Expand All @@ -66,38 +66,6 @@
_FS = TypeVar("_FS", bound=FieldSet)


@union
@dataclass(frozen=True)
class BSPResolveFieldFactoryRequest(Generic[_FS]):
"""Requests an implementation of `BSPResolveFieldFactory` which can filter resolve fields.
TODO: This is to work around the fact that Field value defaulting cannot have arbitrary
subsystem requirements, and so `JvmResolveField` and `PythonResolveField` have methods
which compute the true value of the field given a subsytem argument. Consumers need to
be type aware, and `@rules` cannot have dynamic requirements.
See https://github.com/pantsbuild/pants/issues/12934 about potentially allowing unions
(including Field registrations) to have `@rule_helper` methods, which would allow the
computation of an AsyncFields to directly require a subsystem.
"""

resolve_prefix: ClassVar[str]


# TODO: Workaround for https://github.com/python/mypy/issues/5485, because we cannot directly use
# a Callable.
class _ResolveFieldFactory(Protocol):
def __call__(self, target: Target) -> str | None:
pass


@dataclass(frozen=True)
class BSPResolveFieldFactoryResult:
"""Computes the resolve field value for a Target, if applicable."""

resolve_field_value: _ResolveFieldFactory


@union
@dataclass(frozen=True)
class BSPBuildTargetsMetadataRequest(Generic[_FS]):
Expand All @@ -107,6 +75,9 @@ class BSPBuildTargetsMetadataRequest(Generic[_FS]):
can_merge_metadata_from: ClassVar[tuple[str, ...]]
field_set_type: ClassVar[Type[_FS]] # type: ignore[misc]

resolve_prefix: ClassVar[str]
resolve_field: ClassVar[type[Field]]

field_sets: tuple[_FS, ...]


Expand Down Expand Up @@ -261,6 +232,7 @@ async def resolve_bsp_build_target_identifier(
async def resolve_bsp_build_target_addresses(
bsp_target: BSPBuildTargetInternal,
union_membership: UnionMembership,
field_defaults: FieldDefaults,
) -> Targets:
# NB: Using `RawSpecs` directly rather than `RawSpecsWithoutFileOwners` results in a rule graph cycle.
targets = await Get(
Expand All @@ -279,17 +251,19 @@ async def resolve_bsp_build_target_addresses(
f"prefix like `$lang:$filter`, but the configured value: `{resolve_filter}` did not."
)

# TODO: See `BSPResolveFieldFactoryRequest` re: this awkwardness.
factories = await MultiGet(
Get(BSPResolveFieldFactoryResult, BSPResolveFieldFactoryRequest, request())
for request in union_membership.get(BSPResolveFieldFactoryRequest)
if request.resolve_prefix == resolve_prefix
)
resolve_fields = {
impl.resolve_field
for impl in union_membership.get(BSPBuildTargetsMetadataRequest)
if impl.resolve_prefix == resolve_prefix
}

return Targets(
t
for t in targets
if any((factory.resolve_field_value)(t) == resolve_value for factory in factories)
if any(
t.has_field(field) and field_defaults.value_or_default(t[field]) == resolve_value
for field in resolve_fields
)
)


Expand Down
2 changes: 2 additions & 0 deletions src/python/pants/bsp/util_rules/targets_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import pytest

from pants.backend.java import target_types
from pants.backend.java.bsp import rules as java_bsp_rules
from pants.backend.java.compile import javac
from pants.backend.java.target_types import JavaSourceTarget
Expand Down Expand Up @@ -31,6 +32,7 @@ def rule_runner() -> RuleRunner:
*jvm_tool.rules(),
*jvm_util_rules.rules(),
*jdk_rules.rules(),
*target_types.rules(),
QueryRule(BSPBuildTargets, ()),
QueryRule(Targets, [BuildTargetIdentifier]),
],
Expand Down
25 changes: 24 additions & 1 deletion src/python/pants/engine/internals/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,9 @@
DependenciesRequest,
ExplicitlyProvidedDependencies,
Field,
FieldDefaultFactoryRequest,
FieldDefaultFactoryResult,
FieldDefaults,
FieldSetsPerTarget,
FieldSetsPerTargetRequest,
FilteredTargets,
Expand Down Expand Up @@ -1058,6 +1061,7 @@ async def resolve_dependencies(
target_types_to_generate_requests: TargetTypesToGenerateTargetsRequests,
union_membership: UnionMembership,
subproject_roots: SubprojectRoots,
field_defaults: FieldDefaults,
) -> Addresses:
wrapped_tgt, explicitly_provided = await MultiGet(
Get(
Expand Down Expand Up @@ -1121,7 +1125,7 @@ async def resolve_dependencies(
)

explicitly_provided_includes = [
parametrizations.get_subset(address, tgt).address
parametrizations.get_subset(address, tgt, field_defaults).address
for address, parametrizations in zip(
explicitly_provided_includes, explicit_dependency_parametrizations
)
Expand Down Expand Up @@ -1247,6 +1251,25 @@ async def resolve_unparsed_address_inputs(
return Addresses(addresses)


# -----------------------------------------------------------------------------------------------
# Dynamic Field defaults
# -----------------------------------------------------------------------------------------------


@rule
async def field_defaults(union_membership: UnionMembership) -> FieldDefaults:
requests = list(union_membership.get(FieldDefaultFactoryRequest))
factories = await MultiGet(
Get(FieldDefaultFactoryResult, FieldDefaultFactoryRequest, impl()) for impl in requests
)
return FieldDefaults(
FrozenDict(
(request.field_type, factory.default_factory)
for request, factory in zip(requests, factories)
)
)


# -----------------------------------------------------------------------------------------------
# Find applicable field sets
# -----------------------------------------------------------------------------------------------
Expand Down
Loading

0 comments on commit 9df18f3

Please sign in to comment.