Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize Target and FieldSet operations (Cherry-pick of #18917) #18949

Merged
merged 1 commit into from
May 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/python/pants/backend/docker/goals/package_image_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -706,13 +706,13 @@ def check_docker_proc(process: Process):
assert process.argv == (
"/dummy/docker",
"build",
"--pull=False",
"--secret",
"id=system-secret,src=/var/run/secrets/mysecret",
"--secret",
f"id=project-secret,src={rule_runner.build_root}/secrets/mysecret",
"--secret",
f"id=target-secret,src={rule_runner.build_root}/docker/test/mysecret",
"--pull=False",
"--tag",
"img1:latest",
"--file",
Expand Down Expand Up @@ -745,9 +745,9 @@ def check_docker_proc(process: Process):
assert process.argv == (
"/dummy/docker",
"build",
"--pull=False",
"--ssh",
"default",
"--pull=False",
"--tag",
"img1:latest",
"--file",
Expand Down
10 changes: 2 additions & 8 deletions src/python/pants/core/util_rules/partitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,7 @@
from pants.engine.engine_aware import EngineAwareParameter
from pants.engine.internals.selectors import Get, MultiGet
from pants.engine.rules import collect_rules, rule
from pants.engine.target import (
FieldSet,
SourcesField,
SourcesPaths,
SourcesPathsRequest,
_get_field_set_fields,
)
from pants.engine.target import FieldSet, SourcesField, SourcesPaths, SourcesPathsRequest
from pants.util.memo import memoized
from pants.util.meta import frozen_after_init, runtime_ignore_subscripts

Expand Down Expand Up @@ -319,7 +313,7 @@ def _get_sources_field_name(field_set_type: type[FieldSet]) -> str:
"""

sources_field_name = None
for fieldname, fieldtype in _get_field_set_fields(field_set_type).items():
for fieldname, fieldtype in field_set_type.fields.items():
if issubclass(fieldtype, SourcesField):
if sources_field_name is None:
sources_field_name = fieldname
Expand Down
65 changes: 41 additions & 24 deletions src/python/pants/engine/target.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,14 @@
from enum import Enum
from pathlib import PurePath
from typing import (
AbstractSet,
Any,
ClassVar,
Dict,
Generic,
Iterable,
Iterator,
KeysView,
Mapping,
Optional,
Sequence,
Expand Down Expand Up @@ -57,7 +59,7 @@
from pants.util.dirutil import fast_relpath
from pants.util.docutil import bin_name, doc_url
from pants.util.frozendict import FrozenDict
from pants.util.memo import memoized_method, memoized_property
from pants.util.memo import memoized_classproperty, memoized_method, memoized_property
from pants.util.meta import frozen_after_init
from pants.util.ordered_set import FrozenOrderedSet
from pants.util.strutil import bullet_list, pluralize, softwrap
Expand Down Expand Up @@ -363,7 +365,6 @@ class Target:

# These get calculated in the constructor
address: Address
plugin_fields: tuple[type[Field], ...]
field_values: FrozenDict[type[Field], Field]
residence_dir: str
name_explicitly_set: bool
Expand Down Expand Up @@ -412,11 +413,13 @@ def __init__(
)

self.address = address
self.plugin_fields = self._find_plugin_fields(union_membership or UnionMembership({}))
self.residence_dir = residence_dir if residence_dir is not None else address.spec_path
self.name_explicitly_set = name_explicitly_set
self.field_values = self._calculate_field_values(
unhydrated_values, address, ignore_unrecognized_fields=ignore_unrecognized_fields
unhydrated_values,
address,
union_membership,
ignore_unrecognized_fields=ignore_unrecognized_fields,
)
self.validate()

Expand All @@ -425,11 +428,14 @@ def _calculate_field_values(
self,
unhydrated_values: dict[str, Any],
address: Address,
# See `__init__`.
union_membership: UnionMembership | None,
*,
ignore_unrecognized_fields: bool,
) -> FrozenDict[type[Field], Field]:
all_field_types = self.class_field_types(union_membership)
field_values = {}
aliases_to_field_types = self._get_field_aliases_to_field_types(self.field_types)
aliases_to_field_types = self._get_field_aliases_to_field_types(all_field_types)

for alias, value in unhydrated_values.items():
if alias not in aliases_to_field_types:
Expand All @@ -453,7 +459,9 @@ def _calculate_field_values(
field_values[field_type] = field_type(value, address)

# For undefined fields, mark the raw value as None.
for field_type in set(self.field_types) - set(field_values.keys()):
for field_type in all_field_types:
if field_type in field_values:
continue
field_values[field_type] = field_type(None, address)
return FrozenDict(
sorted(
Expand All @@ -465,7 +473,7 @@ def _calculate_field_values(
@final
@classmethod
def _get_field_aliases_to_field_types(
cls, field_types: tuple[type[Field], ...]
cls, field_types: Iterable[type[Field]]
) -> dict[str, type[Field]]:
aliases_to_field_types = {}
for field_type in field_types:
Expand All @@ -476,8 +484,8 @@ def _get_field_aliases_to_field_types(

@final
@property
def field_types(self) -> Tuple[Type[Field], ...]:
return (*self.core_fields, *self.plugin_fields)
def field_types(self) -> KeysView[Type[Field]]:
return self.field_values.keys()

@distinct_union_type_per_subclass
class PluginField:
Expand Down Expand Up @@ -513,6 +521,7 @@ def __eq__(self, other: Union[Target, Any]) -> bool:

@final
@classmethod
@memoized_method
def _find_plugin_fields(cls, union_membership: UnionMembership) -> tuple[type[Field], ...]:
result: set[type[Field]] = set()
classes = [cls]
Expand Down Expand Up @@ -552,7 +561,7 @@ def _maybe_get(self, field: Type[_F]) -> Optional[_F]:
if result is not None:
return cast(_F, result)
field_subclass = self._find_registered_field_subclass(
field, registered_fields=self.field_types
field, registered_fields=self.field_values.keys()
)
if field_subclass is not None:
return cast(_F, self.field_values[field_subclass])
Expand Down Expand Up @@ -606,7 +615,7 @@ def get(self, field: Type[_F], *, default_raw_value: Optional[Any] = None) -> _F
@final
@classmethod
def _has_fields(
cls, fields: Iterable[Type[Field]], *, registered_fields: Iterable[Type[Field]]
cls, fields: Iterable[Type[Field]], *, registered_fields: AbstractSet[Type[Field]]
) -> bool:
unrecognized_fields = [field for field in fields if field not in registered_fields]
if not unrecognized_fields:
Expand Down Expand Up @@ -637,17 +646,23 @@ def has_fields(self, fields: Iterable[Type[Field]]) -> bool:
custom subclass `CustomTags`, both `tgt.has_fields([Tags])` and
`python_tgt.has_fields([CustomTags])` will return True.
"""
return self._has_fields(fields, registered_fields=self.field_types)
return self._has_fields(fields, registered_fields=self.field_values.keys())

@final
@classmethod
def class_field_types(cls, union_membership: UnionMembership) -> Tuple[Type[Field], ...]:
@memoized_method
def class_field_types(
cls, union_membership: UnionMembership | None
) -> FrozenOrderedSet[Type[Field]]:
"""Return all registered Fields belonging to this target type.

You can also use the instance property `tgt.field_types` to avoid having to pass the
parameter UnionMembership.
"""
return (*cls.core_fields, *cls._find_plugin_fields(union_membership))
if union_membership is None:
return FrozenOrderedSet(cls.core_fields)
else:
return FrozenOrderedSet((*cls.core_fields, *cls._find_plugin_fields(union_membership)))

@final
@classmethod
Expand Down Expand Up @@ -1304,23 +1319,14 @@ def gen_tgt(address: Address, full_fp: str, generated_target_fields: dict[str, A
# -----------------------------------------------------------------------------------------------
# FieldSet
# -----------------------------------------------------------------------------------------------
def _get_field_set_fields(field_set: Type[FieldSet]) -> Dict[str, Type[Field]]:
return {
name: field_type
for name, field_type in get_type_hints(field_set).items()
if isinstance(field_type, type) and issubclass(field_type, Field)
}


def _get_field_set_fields_from_target(
field_set: Type[FieldSet], target: Target
) -> Dict[str, Field]:
all_expected_fields = _get_field_set_fields(field_set)
return {
dataclass_field_name: (
target[field_cls] if field_cls in field_set.required_fields else target.get(field_cls)
)
for dataclass_field_name, field_cls in all_expected_fields.items()
for dataclass_field_name, field_cls in field_set.fields.items()
}


Expand Down Expand Up @@ -1406,6 +1412,17 @@ def applicable_target_types(
def create(cls: Type[_FS], tgt: Target) -> _FS:
return cls(address=tgt.address, **_get_field_set_fields_from_target(cls, tgt))

@final
@memoized_classproperty
def fields(cls) -> FrozenDict[str, Type[Field]]:
return FrozenDict(
(
(name, field_type)
for name, field_type in get_type_hints(cls).items()
if isinstance(field_type, type) and issubclass(field_type, Field)
)
)

def debug_hint(self) -> str:
return self.address.spec

Expand Down
15 changes: 7 additions & 8 deletions src/python/pants/engine/target_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,11 +216,11 @@ def test_has_fields() -> None:
empty_union_membership = UnionMembership({})
tgt = FortranTarget({}, Address("", target_name="lib"))

assert tgt.field_types == (FortranExtensions, FortranVersion)
assert FortranTarget.class_field_types(union_membership=empty_union_membership) == (
assert tgt.field_types == {FortranExtensions, FortranVersion}
assert set(FortranTarget.class_field_types(union_membership=empty_union_membership)) == {
FortranExtensions,
FortranVersion,
)
}

assert tgt.has_fields([]) is True
assert FortranTarget.class_has_fields([], union_membership=empty_union_membership) is True
Expand Down Expand Up @@ -269,16 +269,15 @@ class CustomField(BoolField):
tgt_values, Address("", target_name="lib"), union_membership=union_membership
)

assert tgt.field_types == (FortranExtensions, FortranVersion, CustomField)
assert tgt.field_types == {FortranExtensions, FortranVersion, CustomField}
assert tgt.core_fields == (FortranExtensions, FortranVersion)
assert tgt.plugin_fields == (CustomField,)
assert tgt.has_field(CustomField) is True

assert FortranTarget.class_field_types(union_membership=union_membership) == (
assert set(FortranTarget.class_field_types(union_membership=union_membership)) == {
FortranExtensions,
FortranVersion,
CustomField,
)
}
assert FortranTarget.class_has_field(CustomField, union_membership=union_membership) is True
assert (
FortranTarget.class_get_field(CustomField, union_membership=union_membership) is CustomField
Expand All @@ -297,7 +296,7 @@ class OtherTarget(Target):
core_fields = ()

other_tgt = OtherTarget({}, Address("", target_name="other"))
assert other_tgt.plugin_fields == ()
assert tuple(other_tgt.field_types) == ()
assert other_tgt.has_field(CustomField) is False


Expand Down