From 9df18f349788adb1198ac3a3f30ad6c92be205a7 Mon Sep 17 00:00:00 2001 From: Stu Hood Date: Mon, 18 Jul 2022 16:58:27 -0700 Subject: [PATCH] Generically compute dynamic defaults for `Field`s (#16206) As discussed on #16175, we don't currently consume the "dynamic" defaults of field values for the purposes of `parametrize`. That is at least partially because there is no generic way to do so: a `Field` has no way to declare a dynamic default currently, because `Field`s cannot declare a dependency `@rule_helper` to compute their value (...yet? see https://github.com/pantsbuild/pants/issues/12934#issuecomment-1111608974). This change adds a mechanism for generically declaring the default value of a `Field`. This is definitely not the most ergonomic API: over the next few versions, many dynamic `Field` defaults will hopefully move to `__defaults__`. And https://github.com/pantsbuild/pants/issues/12934#issuecomment-1111608974 will hopefully allow for significantly cleaning up those that remain. Fixes #16175. [ci skip-rust] [ci skip-build-wheels] --- src/python/pants/backend/java/bsp/rules.py | 19 +-- src/python/pants/backend/java/target_types.py | 6 +- .../backend/python/target_types_rules.py | 20 +++ src/python/pants/backend/scala/bsp/rules.py | 20 +-- .../pants/backend/scala/target_types.py | 2 + src/python/pants/bsp/util_rules/targets.py | 56 +++------ .../pants/bsp/util_rules/targets_test.py | 2 + src/python/pants/engine/internals/graph.py | 25 +++- .../pants/engine/internals/graph_test.py | 26 +++- .../pants/engine/internals/parametrize.py | 28 +++-- .../engine/internals/parametrize_test.py | 114 ++++++++++-------- src/python/pants/engine/target.py | 66 +++++++++- src/python/pants/jvm/target_types.py | 28 +++++ 13 files changed, 270 insertions(+), 142 deletions(-) diff --git a/src/python/pants/backend/java/bsp/rules.py b/src/python/pants/backend/java/bsp/rules.py index 0c8d1d20c00..859e697157c 100644 --- a/src/python/pants/backend/java/bsp/rules.py +++ b/src/python/pants/backend/java/bsp/rules.py @@ -14,8 +14,6 @@ BSPBuildTargetsMetadataResult, BSPCompileRequest, BSPCompileResult, - BSPResolveFieldFactoryRequest, - BSPResolveFieldFactoryResult, BSPResourcesRequest, BSPResourcesResult, ) @@ -28,7 +26,6 @@ from pants.jvm.bsp.resources import _jvm_bsp_resources from pants.jvm.bsp.resources import rules as jvm_resources_rules from pants.jvm.compile import ClasspathEntryRequestFactory -from pants.jvm.subsystems import JvmSubsystem from pants.jvm.target_types import JvmResolveField LANGUAGE_ID = "java" @@ -50,28 +47,15 @@ class JavaMetadataFieldSet(FieldSet): resolve: JvmResolveField -class JavaBSPResolveFieldFactoryRequest(BSPResolveFieldFactoryRequest): - resolve_prefix = "jvm" - - class JavaBSPBuildTargetsMetadataRequest(BSPBuildTargetsMetadataRequest): language_id = LANGUAGE_ID can_merge_metadata_from = () field_set_type = JavaMetadataFieldSet + resolve_prefix = "jvm" resolve_field = JvmResolveField -@rule -def bsp_resolve_field_factory( - request: JavaBSPResolveFieldFactoryRequest, - jvm: JvmSubsystem, -) -> BSPResolveFieldFactoryResult: - return BSPResolveFieldFactoryResult( - lambda target: target.get(JvmResolveField).normalized_value(jvm) - ) - - @rule async def bsp_resolve_java_metadata( _: JavaBSPBuildTargetsMetadataRequest, @@ -169,7 +153,6 @@ def rules(): *jvm_compile_rules(), *jvm_resources_rules(), UnionRule(BSPLanguageSupport, JavaBSPLanguageSupport), - UnionRule(BSPResolveFieldFactoryRequest, JavaBSPResolveFieldFactoryRequest), UnionRule(BSPBuildTargetsMetadataRequest, JavaBSPBuildTargetsMetadataRequest), UnionRule(BSPHandlerMapping, JavacOptionsHandlerMapping), UnionRule(BSPCompileRequest, JavaBSPCompileRequest), diff --git a/src/python/pants/backend/java/target_types.py b/src/python/pants/backend/java/target_types.py index 5d6e29491c0..9ee0d3d93e2 100644 --- a/src/python/pants/backend/java/target_types.py +++ b/src/python/pants/backend/java/target_types.py @@ -15,6 +15,7 @@ TargetFilesGenerator, generate_multiple_sources_field_help_message, ) +from pants.jvm import target_types as jvm_target_types from pants.jvm.target_types import ( JunitTestSourceField, JunitTestTimeoutField, @@ -137,4 +138,7 @@ class JavaSourcesGeneratorTarget(TargetFilesGenerator): def rules(): - return collect_rules() + return [ + *collect_rules(), + *jvm_target_types.rules(), + ] diff --git a/src/python/pants/backend/python/target_types_rules.py b/src/python/pants/backend/python/target_types_rules.py index 5d84084585d..0ac597d98be 100644 --- a/src/python/pants/backend/python/target_types_rules.py +++ b/src/python/pants/backend/python/target_types_rules.py @@ -49,6 +49,8 @@ from pants.engine.target import ( DependenciesRequest, ExplicitlyProvidedDependencies, + FieldDefaultFactoryRequest, + FieldDefaultFactoryResult, FieldSet, GeneratedTargets, GenerateTargetsRequest, @@ -496,6 +498,23 @@ async def infer_python_distribution_dependencies( ) +# ----------------------------------------------------------------------------------------------- +# Dynamic Field defaults +# ----------------------------------------------------------------------------------------------- + + +class PythonResolveFieldDefaultFactoryRequest(FieldDefaultFactoryRequest): + field_type = PythonResolveField + + +@rule +def python_resolve_field_default_factory( + request: PythonResolveFieldDefaultFactoryRequest, + python_setup: PythonSetup, +) -> FieldDefaultFactoryResult: + return FieldDefaultFactoryResult(lambda f: f.normalized_value(python_setup)) + + # ----------------------------------------------------------------------------------------------- # Dependency validation # ----------------------------------------------------------------------------------------------- @@ -561,6 +580,7 @@ def rules(): return ( *collect_rules(), *import_rules(), + UnionRule(FieldDefaultFactoryRequest, PythonResolveFieldDefaultFactoryRequest), UnionRule(TargetFilesGeneratorSettingsRequest, PythonFilesGeneratorSettingsRequest), UnionRule(GenerateTargetsRequest, GenerateTargetsFromPexBinaries), UnionRule(InferDependenciesRequest, InferPexBinaryEntryPointDependency), diff --git a/src/python/pants/backend/scala/bsp/rules.py b/src/python/pants/backend/scala/bsp/rules.py index 41d7417af6f..ca880194188 100644 --- a/src/python/pants/backend/scala/bsp/rules.py +++ b/src/python/pants/backend/scala/bsp/rules.py @@ -33,8 +33,6 @@ BSPCompileResult, BSPDependencyModulesRequest, BSPDependencyModulesResult, - BSPResolveFieldFactoryRequest, - BSPResolveFieldFactoryResult, BSPResourcesRequest, BSPResourcesResult, ) @@ -86,15 +84,14 @@ class ScalaMetadataFieldSet(FieldSet): jdk: JvmJdkField -class ScalaBSPResolveFieldFactoryRequest(BSPResolveFieldFactoryRequest): - resolve_prefix = "jvm" - - class ScalaBSPBuildTargetsMetadataRequest(BSPBuildTargetsMetadataRequest): language_id = LANGUAGE_ID can_merge_metadata_from = ("java",) field_set_type = ScalaMetadataFieldSet + resolve_prefix = "jvm" + resolve_field = JvmResolveField + @dataclass(frozen=True) class ThirdpartyModulesRequest: @@ -178,16 +175,6 @@ async def _materialize_scala_runtime_jars(scala_version: str) -> Snapshot: ) -@rule -def bsp_resolve_field_factory( - request: ScalaBSPResolveFieldFactoryRequest, - jvm: JvmSubsystem, -) -> BSPResolveFieldFactoryResult: - return BSPResolveFieldFactoryResult( - lambda target: target.get(JvmResolveField).normalized_value(jvm) - ) - - @rule async def bsp_resolve_scala_metadata( request: ScalaBSPBuildTargetsMetadataRequest, @@ -540,7 +527,6 @@ def rules(): *jvm_resources_rules(), UnionRule(BSPLanguageSupport, ScalaBSPLanguageSupport), UnionRule(BSPBuildTargetsMetadataRequest, ScalaBSPBuildTargetsMetadataRequest), - UnionRule(BSPResolveFieldFactoryRequest, ScalaBSPResolveFieldFactoryRequest), UnionRule(BSPHandlerMapping, ScalacOptionsHandlerMapping), UnionRule(BSPHandlerMapping, ScalaMainClassesHandlerMapping), UnionRule(BSPHandlerMapping, ScalaTestClassesHandlerMapping), diff --git a/src/python/pants/backend/scala/target_types.py b/src/python/pants/backend/scala/target_types.py index 41de1405a83..473bf6a0688 100644 --- a/src/python/pants/backend/scala/target_types.py +++ b/src/python/pants/backend/scala/target_types.py @@ -26,6 +26,7 @@ generate_multiple_sources_field_help_message, ) from pants.engine.unions import UnionRule +from pants.jvm import target_types as jvm_target_types from pants.jvm.target_types import ( JunitTestSourceField, JunitTestTimeoutField, @@ -349,5 +350,6 @@ class ScalacPluginTarget(Target): def rules(): return ( *collect_rules(), + *jvm_target_types.rules(), UnionRule(TargetFilesGeneratorSettingsRequest, ScalaSettingsRequest), ) diff --git a/src/python/pants/bsp/util_rules/targets.py b/src/python/pants/bsp/util_rules/targets.py index e1dbf6749b1..e88ef307c0e 100644 --- a/src/python/pants/bsp/util_rules/targets.py +++ b/src/python/pants/bsp/util_rules/targets.py @@ -10,7 +10,6 @@ from typing import ClassVar, Generic, Sequence, Type, TypeVar import toml -from typing_extensions import Protocol from pants.base.build_root import BuildRoot from pants.base.glob_match_error_behavior import GlobMatchErrorBehavior @@ -48,11 +47,12 @@ from pants.engine.internals.selectors import Get, MultiGet from pants.engine.rules import _uncacheable_rule, collect_rules, rule from pants.engine.target import ( + Field, + FieldDefaults, FieldSet, SourcesField, SourcesPaths, SourcesPathsRequest, - Target, Targets, ) from pants.engine.unions import UnionMembership, UnionRule, union @@ -66,38 +66,6 @@ _FS = TypeVar("_FS", bound=FieldSet) -@union -@dataclass(frozen=True) -class BSPResolveFieldFactoryRequest(Generic[_FS]): - """Requests an implementation of `BSPResolveFieldFactory` which can filter resolve fields. - - TODO: This is to work around the fact that Field value defaulting cannot have arbitrary - subsystem requirements, and so `JvmResolveField` and `PythonResolveField` have methods - which compute the true value of the field given a subsytem argument. Consumers need to - be type aware, and `@rules` cannot have dynamic requirements. - - See https://github.com/pantsbuild/pants/issues/12934 about potentially allowing unions - (including Field registrations) to have `@rule_helper` methods, which would allow the - computation of an AsyncFields to directly require a subsystem. - """ - - resolve_prefix: ClassVar[str] - - -# TODO: Workaround for https://github.com/python/mypy/issues/5485, because we cannot directly use -# a Callable. -class _ResolveFieldFactory(Protocol): - def __call__(self, target: Target) -> str | None: - pass - - -@dataclass(frozen=True) -class BSPResolveFieldFactoryResult: - """Computes the resolve field value for a Target, if applicable.""" - - resolve_field_value: _ResolveFieldFactory - - @union @dataclass(frozen=True) class BSPBuildTargetsMetadataRequest(Generic[_FS]): @@ -107,6 +75,9 @@ class BSPBuildTargetsMetadataRequest(Generic[_FS]): can_merge_metadata_from: ClassVar[tuple[str, ...]] field_set_type: ClassVar[Type[_FS]] # type: ignore[misc] + resolve_prefix: ClassVar[str] + resolve_field: ClassVar[type[Field]] + field_sets: tuple[_FS, ...] @@ -261,6 +232,7 @@ async def resolve_bsp_build_target_identifier( async def resolve_bsp_build_target_addresses( bsp_target: BSPBuildTargetInternal, union_membership: UnionMembership, + field_defaults: FieldDefaults, ) -> Targets: # NB: Using `RawSpecs` directly rather than `RawSpecsWithoutFileOwners` results in a rule graph cycle. targets = await Get( @@ -279,17 +251,19 @@ async def resolve_bsp_build_target_addresses( f"prefix like `$lang:$filter`, but the configured value: `{resolve_filter}` did not." ) - # TODO: See `BSPResolveFieldFactoryRequest` re: this awkwardness. - factories = await MultiGet( - Get(BSPResolveFieldFactoryResult, BSPResolveFieldFactoryRequest, request()) - for request in union_membership.get(BSPResolveFieldFactoryRequest) - if request.resolve_prefix == resolve_prefix - ) + resolve_fields = { + impl.resolve_field + for impl in union_membership.get(BSPBuildTargetsMetadataRequest) + if impl.resolve_prefix == resolve_prefix + } return Targets( t for t in targets - if any((factory.resolve_field_value)(t) == resolve_value for factory in factories) + if any( + t.has_field(field) and field_defaults.value_or_default(t[field]) == resolve_value + for field in resolve_fields + ) ) diff --git a/src/python/pants/bsp/util_rules/targets_test.py b/src/python/pants/bsp/util_rules/targets_test.py index c8399e3b6fd..5c86d52971e 100644 --- a/src/python/pants/bsp/util_rules/targets_test.py +++ b/src/python/pants/bsp/util_rules/targets_test.py @@ -4,6 +4,7 @@ import pytest +from pants.backend.java import target_types from pants.backend.java.bsp import rules as java_bsp_rules from pants.backend.java.compile import javac from pants.backend.java.target_types import JavaSourceTarget @@ -31,6 +32,7 @@ def rule_runner() -> RuleRunner: *jvm_tool.rules(), *jvm_util_rules.rules(), *jdk_rules.rules(), + *target_types.rules(), QueryRule(BSPBuildTargets, ()), QueryRule(Targets, [BuildTargetIdentifier]), ], diff --git a/src/python/pants/engine/internals/graph.py b/src/python/pants/engine/internals/graph.py index 91aca32c8fc..f074f7a99cf 100644 --- a/src/python/pants/engine/internals/graph.py +++ b/src/python/pants/engine/internals/graph.py @@ -45,6 +45,9 @@ DependenciesRequest, ExplicitlyProvidedDependencies, Field, + FieldDefaultFactoryRequest, + FieldDefaultFactoryResult, + FieldDefaults, FieldSetsPerTarget, FieldSetsPerTargetRequest, FilteredTargets, @@ -1058,6 +1061,7 @@ async def resolve_dependencies( target_types_to_generate_requests: TargetTypesToGenerateTargetsRequests, union_membership: UnionMembership, subproject_roots: SubprojectRoots, + field_defaults: FieldDefaults, ) -> Addresses: wrapped_tgt, explicitly_provided = await MultiGet( Get( @@ -1121,7 +1125,7 @@ async def resolve_dependencies( ) explicitly_provided_includes = [ - parametrizations.get_subset(address, tgt).address + parametrizations.get_subset(address, tgt, field_defaults).address for address, parametrizations in zip( explicitly_provided_includes, explicit_dependency_parametrizations ) @@ -1247,6 +1251,25 @@ async def resolve_unparsed_address_inputs( return Addresses(addresses) +# ----------------------------------------------------------------------------------------------- +# Dynamic Field defaults +# ----------------------------------------------------------------------------------------------- + + +@rule +async def field_defaults(union_membership: UnionMembership) -> FieldDefaults: + requests = list(union_membership.get(FieldDefaultFactoryRequest)) + factories = await MultiGet( + Get(FieldDefaultFactoryResult, FieldDefaultFactoryRequest, impl()) for impl in requests + ) + return FieldDefaults( + FrozenDict( + (request.field_type, factory.default_factory) + for request, factory in zip(requests, factories) + ) + ) + + # ----------------------------------------------------------------------------------------------- # Find applicable field sets # ----------------------------------------------------------------------------------------------- diff --git a/src/python/pants/engine/internals/graph_test.py b/src/python/pants/engine/internals/graph_test.py index 65cf6b8298d..08ac37c447b 100644 --- a/src/python/pants/engine/internals/graph_test.py +++ b/src/python/pants/engine/internals/graph_test.py @@ -40,6 +40,8 @@ DependenciesRequest, ExplicitlyProvidedDependencies, Field, + FieldDefaultFactoryRequest, + FieldDefaultFactoryResult, FieldSet, GeneratedSources, GenerateSourcesRequest, @@ -88,6 +90,21 @@ class SpecialCasedDeps2(SpecialCasedDependencies): class ResolveField(StringField, AsyncFieldMixin): alias = "resolve" + default = None + + +_DEFAULT_RESOLVE = "default_test_resolve" + + +class ResolveFieldDefaultFactoryRequest(FieldDefaultFactoryRequest): + field_type = ResolveField + + +@rule +def resolve_field_default_factory( + request: ResolveFieldDefaultFactoryRequest, +) -> FieldDefaultFactoryResult: + return FieldDefaultFactoryResult(lambda f: f.value or _DEFAULT_RESOLVE) class MockMultipleSourcesField(MultipleSourcesField): @@ -829,6 +846,8 @@ def generated_targets_rule_runner() -> RuleRunner: QueryRule(Addresses, [Specs]), QueryRule(_DependencyMapping, [_DependencyMappingRequest]), QueryRule(_TargetParametrizations, [_TargetParametrizationsRequest]), + UnionRule(FieldDefaultFactoryRequest, ResolveFieldDefaultFactoryRequest), + resolve_field_default_factory, ], target_types=[MockTargetGenerator, MockGeneratedTarget], objects={"parametrize": Parametrize}, @@ -1111,12 +1130,11 @@ def test_parametrize_partial_atom_to_atom(generated_targets_rule_runner: RuleRun """\ generated( name='t1', - resolve=parametrize('a', 'b'), + resolve=parametrize('default_test_resolve', 'b'), source='f1.ext', ) generated( name='t2', - resolve='a', source='f2.ext', dependencies=[':t1'], ) @@ -1124,9 +1142,9 @@ def test_parametrize_partial_atom_to_atom(generated_targets_rule_runner: RuleRun ), ["f1.ext", "f2.ext"], expected_dependencies={ - "demo:t1@resolve=a": set(), + "demo:t1@resolve=default_test_resolve": set(), "demo:t1@resolve=b": set(), - "demo:t2": {"demo:t1@resolve=a"}, + "demo:t2": {"demo:t1@resolve=default_test_resolve"}, }, ) diff --git a/src/python/pants/engine/internals/parametrize.py b/src/python/pants/engine/internals/parametrize.py index 130df727cb8..efe04fad56c 100644 --- a/src/python/pants/engine/internals/parametrize.py +++ b/src/python/pants/engine/internals/parametrize.py @@ -12,7 +12,7 @@ from pants.engine.addresses import Address from pants.engine.collection import Collection from pants.engine.engine_aware import EngineAwareParameter -from pants.engine.target import Field, Target +from pants.engine.target import Field, FieldDefaults, Target from pants.util.frozendict import FrozenDict from pants.util.meta import frozen_after_init from pants.util.strutil import bullet_list, softwrap @@ -224,7 +224,9 @@ def get_all_superset_targets(self, address: Address) -> Iterator[Address]: if address.is_parametrized_subset_of(parametrized_tgt.address): yield parametrized_tgt.address - def get_subset(self, address: Address, consumer: Target) -> Target: + def get_subset( + self, address: Address, consumer: Target, field_defaults: FieldDefaults + ) -> Target: """Find the Target with the given Address, or with fields matching the given consumer.""" # Check for exact matches. instance = self.get(address) @@ -238,9 +240,9 @@ def remaining_fields_match(candidate: Target) -> bool: } return all( _concrete_fields_are_equivalent( + field_defaults, consumer=consumer, - candidate_field_type=field_type, - candidate_field_value=field.value, + candidate_field=field, ) for field_type, field in candidate.field_values.items() if field_type.alias in unspecified_param_field_names @@ -301,11 +303,17 @@ def _bare_address_error(self, address) -> ValueError: def _concrete_fields_are_equivalent( - *, consumer: Target, candidate_field_value: Any, candidate_field_type: type[Field] + field_defaults: FieldDefaults, *, consumer: Target, candidate_field: Field ) -> bool: - # TODO(#16175): Does not account for the computed default values of Fields. + candidate_field_type = type(candidate_field) + candidate_field_value = field_defaults.value_or_default(candidate_field) + if consumer.has_field(candidate_field_type): - return cast(bool, consumer[candidate_field_type].value == candidate_field_value) + return cast( + bool, + field_defaults.value_or_default(consumer[candidate_field_type]) + == candidate_field_value, + ) # Else, see if the consumer has a field that is a superclass of `candidate_field_value`, to # handle https://github.com/pantsbuild/pants/issues/16190. This is only safe because we are # confident that both `candidate_field_type` and the fields from `consumer` are _concrete_, @@ -314,10 +322,12 @@ def _concrete_fields_are_equivalent( ( consumer_field for consumer_field in consumer.field_types - if issubclass(candidate_field_type, consumer_field) + if isinstance(candidate_field, consumer_field) ), None, ) if superclass is None: return False - return cast(bool, consumer[superclass].value == candidate_field_value) + return cast( + bool, field_defaults.value_or_default(consumer[superclass]) == candidate_field_value + ) diff --git a/src/python/pants/engine/internals/parametrize_test.py b/src/python/pants/engine/internals/parametrize_test.py index 4c2459d8df2..a4da6d2776e 100644 --- a/src/python/pants/engine/internals/parametrize_test.py +++ b/src/python/pants/engine/internals/parametrize_test.py @@ -15,7 +15,7 @@ _TargetParametrization, _TargetParametrizations, ) -from pants.engine.target import Field, Target +from pants.engine.target import Field, FieldDefaults, Target from pants.util.frozendict import FrozenDict @@ -144,10 +144,12 @@ def assert_gets(addr: Address, expected: set[Address]) -> None: def test_concrete_fields_are_equivalent() -> None: class ParentField(Field): alias = "parent" + default = None help = "foo" class ChildField(ParentField): alias = "child" + default = None help = "foo" class UnrelatedField(Field): @@ -164,66 +166,78 @@ class ChildTarget(Target): help = "foo" core_fields = (ChildField,) + # Validate literal value matches. + empty_defaults = FieldDefaults(FrozenDict()) + unused_addr = Address("unused") parent_tgt = ParentTarget({"parent": "val"}, Address("parent")) - assert ( - _concrete_fields_are_equivalent( - consumer=parent_tgt, candidate_field_type=ParentField, candidate_field_value="val" - ) - is True + child_tgt = ChildTarget({"child": "val"}, Address("child")) + + assert _concrete_fields_are_equivalent( + empty_defaults, consumer=parent_tgt, candidate_field=ParentField("val", unused_addr) ) - assert ( - _concrete_fields_are_equivalent( - consumer=parent_tgt, candidate_field_type=ParentField, candidate_field_value="different" - ) - is False + assert not _concrete_fields_are_equivalent( + empty_defaults, + consumer=parent_tgt, + candidate_field=ParentField("different", unused_addr), ) - assert ( - _concrete_fields_are_equivalent( - consumer=parent_tgt, candidate_field_type=ChildField, candidate_field_value="val" - ) - is True + assert _concrete_fields_are_equivalent( + empty_defaults, consumer=parent_tgt, candidate_field=ChildField("val", unused_addr) ) - assert ( - _concrete_fields_are_equivalent( - consumer=parent_tgt, candidate_field_type=ChildField, candidate_field_value="different" - ) - is False + assert not _concrete_fields_are_equivalent( + empty_defaults, + consumer=parent_tgt, + candidate_field=ChildField("different", unused_addr), ) - assert ( - _concrete_fields_are_equivalent( - consumer=parent_tgt, candidate_field_type=UnrelatedField, candidate_field_value="val" - ) - is False + assert not _concrete_fields_are_equivalent( + empty_defaults, consumer=parent_tgt, candidate_field=UnrelatedField("val", unused_addr) ) - child_tgt = ChildTarget({"child": "val"}, Address("child")) - assert ( - _concrete_fields_are_equivalent( - consumer=child_tgt, candidate_field_type=ParentField, candidate_field_value="val" - ) - is True + assert _concrete_fields_are_equivalent( + empty_defaults, consumer=child_tgt, candidate_field=ParentField("val", unused_addr) ) - assert ( - _concrete_fields_are_equivalent( - consumer=child_tgt, candidate_field_type=ParentField, candidate_field_value="different" - ) - is False + assert not _concrete_fields_are_equivalent( + empty_defaults, + consumer=child_tgt, + candidate_field=ParentField("different", unused_addr), ) - assert ( - _concrete_fields_are_equivalent( - consumer=child_tgt, candidate_field_type=ChildField, candidate_field_value="val" - ) - is True + assert _concrete_fields_are_equivalent( + empty_defaults, consumer=child_tgt, candidate_field=ChildField("val", unused_addr) + ) + assert not _concrete_fields_are_equivalent( + empty_defaults, consumer=child_tgt, candidate_field=ChildField("different", unused_addr) ) - assert ( - _concrete_fields_are_equivalent( - consumer=child_tgt, candidate_field_type=ChildField, candidate_field_value="different" + assert not _concrete_fields_are_equivalent( + empty_defaults, consumer=child_tgt, candidate_field=UnrelatedField("val", unused_addr) + ) + + # Validate field defaulting. + parent_field_defaults = FieldDefaults( + FrozenDict( + { + ParentField: lambda f: f.value or "val", + } ) - is False ) - assert ( - _concrete_fields_are_equivalent( - consumer=child_tgt, candidate_field_type=UnrelatedField, candidate_field_value="val" + child_field_defaults = FieldDefaults( + FrozenDict( + { + ChildField: lambda f: f.value or "val", + } ) - is False + ) + assert _concrete_fields_are_equivalent( + parent_field_defaults, consumer=child_tgt, candidate_field=ParentField(None, unused_addr) + ) + assert _concrete_fields_are_equivalent( + parent_field_defaults, + consumer=ParentTarget({}, Address("parent")), + candidate_field=ChildField("val", unused_addr), + ) + assert _concrete_fields_are_equivalent( + child_field_defaults, consumer=parent_tgt, candidate_field=ChildField(None, unused_addr) + ) + assert _concrete_fields_are_equivalent( + child_field_defaults, + consumer=ChildTarget({}, Address("child")), + candidate_field=ParentField("val", unused_addr), ) diff --git a/src/python/pants/engine/target.py b/src/python/pants/engine/target.py index 2d595f5ae34..ee499b91a8b 100644 --- a/src/python/pants/engine/target.py +++ b/src/python/pants/engine/target.py @@ -36,7 +36,7 @@ get_type_hints, ) -from typing_extensions import final +from typing_extensions import Protocol, final from pants.base.deprecated import warn_or_error from pants.engine.addresses import Address, Addresses, UnparsedAddressInputs, assert_single_address @@ -264,10 +264,74 @@ def __eq__(self, other: Union[Any, AsyncFieldMixin]) -> bool: ) +@union +@dataclass(frozen=True) +class FieldDefaultFactoryRequest: + """Registers a dynamic default for a Field. + + See `FieldDefaults`. + """ + + field_type: ClassVar[type[Field]] + + +# TODO: Workaround for https://github.com/python/mypy/issues/5485, because we cannot directly use +# a Callable. +class FieldDefaultFactory(Protocol): + def __call__(self, field: Field) -> Any: + pass + + +@dataclass(frozen=True) +class FieldDefaultFactoryResult: + """A wrapper for a function which computes the default value of a Field.""" + + default_factory: FieldDefaultFactory + + +@dataclass(frozen=True) +class FieldDefaults: + """Generic Field default values. To install a default, see `FieldDefaultFactoryRequest`. + + TODO: This is to work around the fact that Field value defaulting cannot have arbitrary + subsystem requirements, and so e.g. `JvmResolveField` and `PythonResolveField` have methods + which compute the true value of the field given a subsytem argument. Consumers need to + be type aware, and `@rules` cannot have dynamic requirements. + + Additionally, `__defaults__` should mean that computed default Field values should become + more rare: i.e. `JvmResolveField` and `PythonResolveField` could potentially move to + hardcoded default values which users override with `__defaults__` if they'd like to change + the default resolve names. + + See https://github.com/pantsbuild/pants/issues/12934 about potentially allowing unions + (including Field registrations) to have `@rule_helper` methods, which would allow the + computation of an AsyncField to directly require a subsystem. + """ + + _factories: FrozenDict[type[Field], FieldDefaultFactory] + + @memoized_method + def factory(self, field_type: type[Field]) -> FieldDefaultFactory: + """Looks up a Field default factory in a subclass-aware way.""" + factory = self._factories.get(field_type, None) + if factory is not None: + return factory + + for ft, factory in self._factories.items(): + if issubclass(field_type, ft): + return factory + + return lambda f: f.value + + def value_or_default(self, field: Field) -> Any: + return (self.factory(type(field)))(field) + + # ----------------------------------------------------------------------------------------------- # Core Target abstractions # ----------------------------------------------------------------------------------------------- + # NB: This TypeVar is what allows `Target.get()` to properly work with MyPy so that MyPy knows # the precise Field returned. _F = TypeVar("_F", bound=Field) diff --git a/src/python/pants/jvm/target_types.py b/src/python/pants/jvm/target_types.py index d36e7c7b664..d0c98cba2c6 100644 --- a/src/python/pants/jvm/target_types.py +++ b/src/python/pants/jvm/target_types.py @@ -11,10 +11,13 @@ from pants.core.goals.run import RestartableField from pants.core.goals.test import TestTimeoutField from pants.engine.addresses import Address +from pants.engine.rules import collect_rules, rule from pants.engine.target import ( COMMON_TARGET_FIELDS, AsyncFieldMixin, Dependencies, + FieldDefaultFactoryRequest, + FieldDefaultFactoryResult, FieldSet, InvalidFieldException, InvalidTargetException, @@ -25,6 +28,7 @@ StringSequenceField, Target, ) +from pants.engine.unions import UnionRule from pants.jvm.subsystems import JvmSubsystem from pants.util.docutil import git_url from pants.util.strutil import softwrap @@ -391,3 +395,27 @@ class JvmWarTarget(Target): deploys in Java Servlet containers. """ ) + + +# ----------------------------------------------------------------------------------------------- +# Dynamic Field defaults +# -----------------------------------------------------------------------------------------------# + + +class JvmResolveFieldDefaultFactoryRequest(FieldDefaultFactoryRequest): + field_type = JvmResolveField + + +@rule +def jvm_resolve_field_default_factory( + request: JvmResolveFieldDefaultFactoryRequest, + jvm: JvmSubsystem, +) -> FieldDefaultFactoryResult: + return FieldDefaultFactoryResult(lambda f: f.normalized_value(jvm)) + + +def rules(): + return [ + *collect_rules(), + UnionRule(FieldDefaultFactoryRequest, JvmResolveFieldDefaultFactoryRequest), + ]