Skip to content

Commit

Permalink
Optimize Target and FieldSet operations (Cherry-pick of #18917) (#…
Browse files Browse the repository at this point in the history
…18949)

Recent profiles (from #18911, and reproduced locally) highlighted:
* `FieldSet.create` was repeatedly calling the unmemoized
`_get_field_set_fields`, which used the surprisingly expensive `from
typing import get_type_hints`.
    * This change moves type hint extraction into a memoized property.
* `Target.has_fields` was doing an N^2 lookup of fields in a computed
collection of registered fields.
* Moved to using the `field_values` `dict` to represent the set of
present fields, and removed the `plugin_fields` property (which is
always incorporated into the `field_values` at construction time
anyway).

----

Performance for `./pants --no-pantsd dependencies ::` in
`pantsbuild/pants` improves by 9% (or 15%, if the time for rule graph
solving is ignored).
  • Loading branch information
stuhood authored May 9, 2023
1 parent d858b9d commit 56a60be
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 42 deletions.
4 changes: 2 additions & 2 deletions src/python/pants/backend/docker/goals/package_image_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -706,13 +706,13 @@ def check_docker_proc(process: Process):
assert process.argv == (
"/dummy/docker",
"build",
"--pull=False",
"--secret",
"id=system-secret,src=/var/run/secrets/mysecret",
"--secret",
f"id=project-secret,src={rule_runner.build_root}/secrets/mysecret",
"--secret",
f"id=target-secret,src={rule_runner.build_root}/docker/test/mysecret",
"--pull=False",
"--tag",
"img1:latest",
"--file",
Expand Down Expand Up @@ -745,9 +745,9 @@ def check_docker_proc(process: Process):
assert process.argv == (
"/dummy/docker",
"build",
"--pull=False",
"--ssh",
"default",
"--pull=False",
"--tag",
"img1:latest",
"--file",
Expand Down
10 changes: 2 additions & 8 deletions src/python/pants/core/util_rules/partitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,7 @@
from pants.engine.engine_aware import EngineAwareParameter
from pants.engine.internals.selectors import Get, MultiGet
from pants.engine.rules import collect_rules, rule
from pants.engine.target import (
FieldSet,
SourcesField,
SourcesPaths,
SourcesPathsRequest,
_get_field_set_fields,
)
from pants.engine.target import FieldSet, SourcesField, SourcesPaths, SourcesPathsRequest
from pants.util.memo import memoized
from pants.util.meta import frozen_after_init, runtime_ignore_subscripts

Expand Down Expand Up @@ -319,7 +313,7 @@ def _get_sources_field_name(field_set_type: type[FieldSet]) -> str:
"""

sources_field_name = None
for fieldname, fieldtype in _get_field_set_fields(field_set_type).items():
for fieldname, fieldtype in field_set_type.fields.items():
if issubclass(fieldtype, SourcesField):
if sources_field_name is None:
sources_field_name = fieldname
Expand Down
65 changes: 41 additions & 24 deletions src/python/pants/engine/target.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,14 @@
from enum import Enum
from pathlib import PurePath
from typing import (
AbstractSet,
Any,
ClassVar,
Dict,
Generic,
Iterable,
Iterator,
KeysView,
Mapping,
Optional,
Sequence,
Expand Down Expand Up @@ -57,7 +59,7 @@
from pants.util.dirutil import fast_relpath
from pants.util.docutil import bin_name, doc_url
from pants.util.frozendict import FrozenDict
from pants.util.memo import memoized_method, memoized_property
from pants.util.memo import memoized_classproperty, memoized_method, memoized_property
from pants.util.meta import frozen_after_init
from pants.util.ordered_set import FrozenOrderedSet
from pants.util.strutil import bullet_list, pluralize, softwrap
Expand Down Expand Up @@ -363,7 +365,6 @@ class Target:

# These get calculated in the constructor
address: Address
plugin_fields: tuple[type[Field], ...]
field_values: FrozenDict[type[Field], Field]
residence_dir: str
name_explicitly_set: bool
Expand Down Expand Up @@ -412,11 +413,13 @@ def __init__(
)

self.address = address
self.plugin_fields = self._find_plugin_fields(union_membership or UnionMembership({}))
self.residence_dir = residence_dir if residence_dir is not None else address.spec_path
self.name_explicitly_set = name_explicitly_set
self.field_values = self._calculate_field_values(
unhydrated_values, address, ignore_unrecognized_fields=ignore_unrecognized_fields
unhydrated_values,
address,
union_membership,
ignore_unrecognized_fields=ignore_unrecognized_fields,
)
self.validate()

Expand All @@ -425,11 +428,14 @@ def _calculate_field_values(
self,
unhydrated_values: dict[str, Any],
address: Address,
# See `__init__`.
union_membership: UnionMembership | None,
*,
ignore_unrecognized_fields: bool,
) -> FrozenDict[type[Field], Field]:
all_field_types = self.class_field_types(union_membership)
field_values = {}
aliases_to_field_types = self._get_field_aliases_to_field_types(self.field_types)
aliases_to_field_types = self._get_field_aliases_to_field_types(all_field_types)

for alias, value in unhydrated_values.items():
if alias not in aliases_to_field_types:
Expand All @@ -453,7 +459,9 @@ def _calculate_field_values(
field_values[field_type] = field_type(value, address)

# For undefined fields, mark the raw value as None.
for field_type in set(self.field_types) - set(field_values.keys()):
for field_type in all_field_types:
if field_type in field_values:
continue
field_values[field_type] = field_type(None, address)
return FrozenDict(
sorted(
Expand All @@ -465,7 +473,7 @@ def _calculate_field_values(
@final
@classmethod
def _get_field_aliases_to_field_types(
cls, field_types: tuple[type[Field], ...]
cls, field_types: Iterable[type[Field]]
) -> dict[str, type[Field]]:
aliases_to_field_types = {}
for field_type in field_types:
Expand All @@ -476,8 +484,8 @@ def _get_field_aliases_to_field_types(

@final
@property
def field_types(self) -> Tuple[Type[Field], ...]:
return (*self.core_fields, *self.plugin_fields)
def field_types(self) -> KeysView[Type[Field]]:
return self.field_values.keys()

@distinct_union_type_per_subclass
class PluginField:
Expand Down Expand Up @@ -513,6 +521,7 @@ def __eq__(self, other: Union[Target, Any]) -> bool:

@final
@classmethod
@memoized_method
def _find_plugin_fields(cls, union_membership: UnionMembership) -> tuple[type[Field], ...]:
result: set[type[Field]] = set()
classes = [cls]
Expand Down Expand Up @@ -552,7 +561,7 @@ def _maybe_get(self, field: Type[_F]) -> Optional[_F]:
if result is not None:
return cast(_F, result)
field_subclass = self._find_registered_field_subclass(
field, registered_fields=self.field_types
field, registered_fields=self.field_values.keys()
)
if field_subclass is not None:
return cast(_F, self.field_values[field_subclass])
Expand Down Expand Up @@ -606,7 +615,7 @@ def get(self, field: Type[_F], *, default_raw_value: Optional[Any] = None) -> _F
@final
@classmethod
def _has_fields(
cls, fields: Iterable[Type[Field]], *, registered_fields: Iterable[Type[Field]]
cls, fields: Iterable[Type[Field]], *, registered_fields: AbstractSet[Type[Field]]
) -> bool:
unrecognized_fields = [field for field in fields if field not in registered_fields]
if not unrecognized_fields:
Expand Down Expand Up @@ -637,17 +646,23 @@ def has_fields(self, fields: Iterable[Type[Field]]) -> bool:
custom subclass `CustomTags`, both `tgt.has_fields([Tags])` and
`python_tgt.has_fields([CustomTags])` will return True.
"""
return self._has_fields(fields, registered_fields=self.field_types)
return self._has_fields(fields, registered_fields=self.field_values.keys())

@final
@classmethod
def class_field_types(cls, union_membership: UnionMembership) -> Tuple[Type[Field], ...]:
@memoized_method
def class_field_types(
cls, union_membership: UnionMembership | None
) -> FrozenOrderedSet[Type[Field]]:
"""Return all registered Fields belonging to this target type.
You can also use the instance property `tgt.field_types` to avoid having to pass the
parameter UnionMembership.
"""
return (*cls.core_fields, *cls._find_plugin_fields(union_membership))
if union_membership is None:
return FrozenOrderedSet(cls.core_fields)
else:
return FrozenOrderedSet((*cls.core_fields, *cls._find_plugin_fields(union_membership)))

@final
@classmethod
Expand Down Expand Up @@ -1304,23 +1319,14 @@ def gen_tgt(address: Address, full_fp: str, generated_target_fields: dict[str, A
# -----------------------------------------------------------------------------------------------
# FieldSet
# -----------------------------------------------------------------------------------------------
def _get_field_set_fields(field_set: Type[FieldSet]) -> Dict[str, Type[Field]]:
return {
name: field_type
for name, field_type in get_type_hints(field_set).items()
if isinstance(field_type, type) and issubclass(field_type, Field)
}


def _get_field_set_fields_from_target(
field_set: Type[FieldSet], target: Target
) -> Dict[str, Field]:
all_expected_fields = _get_field_set_fields(field_set)
return {
dataclass_field_name: (
target[field_cls] if field_cls in field_set.required_fields else target.get(field_cls)
)
for dataclass_field_name, field_cls in all_expected_fields.items()
for dataclass_field_name, field_cls in field_set.fields.items()
}


Expand Down Expand Up @@ -1406,6 +1412,17 @@ def applicable_target_types(
def create(cls: Type[_FS], tgt: Target) -> _FS:
return cls(address=tgt.address, **_get_field_set_fields_from_target(cls, tgt))

@final
@memoized_classproperty
def fields(cls) -> FrozenDict[str, Type[Field]]:
return FrozenDict(
(
(name, field_type)
for name, field_type in get_type_hints(cls).items()
if isinstance(field_type, type) and issubclass(field_type, Field)
)
)

def debug_hint(self) -> str:
return self.address.spec

Expand Down
15 changes: 7 additions & 8 deletions src/python/pants/engine/target_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,11 +216,11 @@ def test_has_fields() -> None:
empty_union_membership = UnionMembership({})
tgt = FortranTarget({}, Address("", target_name="lib"))

assert tgt.field_types == (FortranExtensions, FortranVersion)
assert FortranTarget.class_field_types(union_membership=empty_union_membership) == (
assert tgt.field_types == {FortranExtensions, FortranVersion}
assert set(FortranTarget.class_field_types(union_membership=empty_union_membership)) == {
FortranExtensions,
FortranVersion,
)
}

assert tgt.has_fields([]) is True
assert FortranTarget.class_has_fields([], union_membership=empty_union_membership) is True
Expand Down Expand Up @@ -269,16 +269,15 @@ class CustomField(BoolField):
tgt_values, Address("", target_name="lib"), union_membership=union_membership
)

assert tgt.field_types == (FortranExtensions, FortranVersion, CustomField)
assert tgt.field_types == {FortranExtensions, FortranVersion, CustomField}
assert tgt.core_fields == (FortranExtensions, FortranVersion)
assert tgt.plugin_fields == (CustomField,)
assert tgt.has_field(CustomField) is True

assert FortranTarget.class_field_types(union_membership=union_membership) == (
assert set(FortranTarget.class_field_types(union_membership=union_membership)) == {
FortranExtensions,
FortranVersion,
CustomField,
)
}
assert FortranTarget.class_has_field(CustomField, union_membership=union_membership) is True
assert (
FortranTarget.class_get_field(CustomField, union_membership=union_membership) is CustomField
Expand All @@ -297,7 +296,7 @@ class OtherTarget(Target):
core_fields = ()

other_tgt = OtherTarget({}, Address("", target_name="other"))
assert other_tgt.plugin_fields == ()
assert tuple(other_tgt.field_types) == ()
assert other_tgt.has_field(CustomField) is False


Expand Down

0 comments on commit 56a60be

Please sign in to comment.