Skip to content

Commit

Permalink
Optimize Target and FieldSet operations (cherry-pick #18917) (#18944
Browse files Browse the repository at this point in the history
)

Recent profiles (from #18911, and reproduced locally) highlighted:
* `FieldSet.create` was repeatedly calling the unmemoized
`_get_field_set_fields`, which used the surprisingly expensive `from
typing import get_type_hints`.
    * This change moves type hint extraction into a memoized property.
* `Target.has_fields` was doing an N^2 lookup of fields in a computed
collection of registered fields.
* Moved to using the `field_values` `dict` to represent the set of
present fields, and removed the `plugin_fields` property (which is
always incorporated into the `field_values` at construction time
anyway).

----

Performance for `./pants --no-pantsd dependencies ::` in
`pantsbuild/pants` improves by 9% (or 15%, if the time for rule graph
solving is ignored).

---------

Co-authored-by: Stu Hood <stuhood@gmail.com>
  • Loading branch information
benjyw and stuhood authored May 9, 2023
1 parent 8f77283 commit bf02313
Show file tree
Hide file tree
Showing 5 changed files with 57 additions and 45 deletions.
4 changes: 2 additions & 2 deletions src/python/pants/backend/docker/goals/package_image_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -954,13 +954,13 @@ def check_docker_proc(process: Process):
assert process.argv == (
"/dummy/docker",
"build",
"--pull=False",
"--secret",
"id=system-secret,src=/var/run/secrets/mysecret",
"--secret",
f"id=project-secret,src={rule_runner.build_root}/secrets/mysecret",
"--secret",
f"id=target-secret,src={rule_runner.build_root}/docker/test/mysecret",
"--pull=False",
"--tag",
"img1:latest",
"--file",
Expand Down Expand Up @@ -993,9 +993,9 @@ def check_docker_proc(process: Process):
assert process.argv == (
"/dummy/docker",
"build",
"--pull=False",
"--ssh",
"default",
"--pull=False",
"--tag",
"img1:latest",
"--file",
Expand Down
10 changes: 2 additions & 8 deletions src/python/pants/core/util_rules/partitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,7 @@
from pants.engine.engine_aware import EngineAwareParameter
from pants.engine.internals.selectors import Get, MultiGet
from pants.engine.rules import collect_rules, rule
from pants.engine.target import (
FieldSet,
SourcesField,
SourcesPaths,
SourcesPathsRequest,
_get_field_set_fields,
)
from pants.engine.target import FieldSet, SourcesField, SourcesPaths, SourcesPathsRequest
from pants.util.memo import memoized
from pants.util.meta import runtime_ignore_subscripts

Expand Down Expand Up @@ -317,7 +311,7 @@ def _get_sources_field_name(field_set_type: type[FieldSet]) -> str:
"""

sources_field_name = None
for fieldname, fieldtype in _get_field_set_fields(field_set_type).items():
for fieldname, fieldtype in field_set_type.fields.items():
if issubclass(fieldtype, SourcesField):
if sources_field_name is None:
sources_field_name = fieldname
Expand Down
6 changes: 5 additions & 1 deletion src/python/pants/engine/internals/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,7 +323,11 @@ def create_registrar_for_target(alias: str) -> tuple[str, Registrar]:
return alias, Registrar(
parse_state,
alias,
registered_target_types.aliases_to_types[alias].class_field_types(union_membership),
tuple(
registered_target_types.aliases_to_types[alias].class_field_types(
union_membership
)
),
)

type_aliases = dict(map(create_registrar_for_target, registered_target_types.aliases))
Expand Down
67 changes: 41 additions & 26 deletions src/python/pants/engine/target.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,15 @@
from enum import Enum
from pathlib import PurePath
from typing import (
AbstractSet,
Any,
Callable,
ClassVar,
Dict,
Generic,
Iterable,
Iterator,
KeysView,
Mapping,
Optional,
Sequence,
Expand Down Expand Up @@ -62,7 +64,7 @@
from pants.util.dirutil import fast_relpath
from pants.util.docutil import bin_name, doc_url
from pants.util.frozendict import FrozenDict
from pants.util.memo import memoized_method, memoized_property
from pants.util.memo import memoized_classproperty, memoized_method, memoized_property
from pants.util.ordered_set import FrozenOrderedSet
from pants.util.strutil import bullet_list, help_text, pluralize, softwrap

Expand Down Expand Up @@ -379,7 +381,6 @@ class Target:

# These get calculated in the constructor
address: Address
plugin_fields: tuple[type[Field], ...]
field_values: FrozenDict[type[Field], Field]
residence_dir: str
name_explicitly_set: bool
Expand Down Expand Up @@ -428,9 +429,6 @@ def __init__(
)

object.__setattr__(self, "address", address)
object.__setattr__(
self, "plugin_fields", self._find_plugin_fields(union_membership or UnionMembership({}))
)
object.__setattr__(
self, "residence_dir", residence_dir if residence_dir is not None else address.spec_path
)
Expand All @@ -439,7 +437,10 @@ def __init__(
self,
"field_values",
self._calculate_field_values(
unhydrated_values, address, ignore_unrecognized_fields=ignore_unrecognized_fields
unhydrated_values,
address,
union_membership,
ignore_unrecognized_fields=ignore_unrecognized_fields,
),
)

Expand All @@ -450,11 +451,14 @@ def _calculate_field_values(
self,
unhydrated_values: dict[str, Any],
address: Address,
# See `__init__`.
union_membership: UnionMembership | None,
*,
ignore_unrecognized_fields: bool,
) -> FrozenDict[type[Field], Field]:
all_field_types = self.class_field_types(union_membership)
field_values = {}
aliases_to_field_types = self._get_field_aliases_to_field_types(self.field_types)
aliases_to_field_types = self._get_field_aliases_to_field_types(all_field_types)

for alias, value in unhydrated_values.items():
if alias not in aliases_to_field_types:
Expand All @@ -478,7 +482,9 @@ def _calculate_field_values(
field_values[field_type] = field_type(value, address)

# For undefined fields, mark the raw value as missing.
for field_type in set(self.field_types) - set(field_values.keys()):
for field_type in all_field_types:
if field_type in field_values:
continue
field_values[field_type] = field_type(NO_VALUE, address)
return FrozenDict(
sorted(
Expand All @@ -490,7 +496,7 @@ def _calculate_field_values(
@final
@classmethod
def _get_field_aliases_to_field_types(
cls, field_types: tuple[type[Field], ...]
cls, field_types: Iterable[type[Field]]
) -> dict[str, type[Field]]:
aliases_to_field_types = {}
for field_type in field_types:
Expand All @@ -501,8 +507,8 @@ def _get_field_aliases_to_field_types(

@final
@property
def field_types(self) -> Tuple[Type[Field], ...]:
return (*self.core_fields, *self.plugin_fields)
def field_types(self) -> KeysView[Type[Field]]:
return self.field_values.keys()

@distinct_union_type_per_subclass
class PluginField:
Expand Down Expand Up @@ -538,6 +544,7 @@ def __eq__(self, other: Union[Target, Any]) -> bool:

@final
@classmethod
@memoized_method
def _find_plugin_fields(cls, union_membership: UnionMembership) -> tuple[type[Field], ...]:
result: set[type[Field]] = set()
classes = [cls]
Expand Down Expand Up @@ -577,7 +584,7 @@ def _maybe_get(self, field: Type[_F]) -> Optional[_F]:
if result is not None:
return cast(_F, result)
field_subclass = self._find_registered_field_subclass(
field, registered_fields=self.field_types
field, registered_fields=self.field_values.keys()
)
if field_subclass is not None:
return cast(_F, self.field_values[field_subclass])
Expand Down Expand Up @@ -631,7 +638,7 @@ def get(self, field: Type[_F], *, default_raw_value: Optional[Any] = None) -> _F
@final
@classmethod
def _has_fields(
cls, fields: Iterable[Type[Field]], *, registered_fields: Iterable[Type[Field]]
cls, fields: Iterable[Type[Field]], *, registered_fields: AbstractSet[Type[Field]]
) -> bool:
unrecognized_fields = [field for field in fields if field not in registered_fields]
if not unrecognized_fields:
Expand Down Expand Up @@ -662,17 +669,23 @@ def has_fields(self, fields: Iterable[Type[Field]]) -> bool:
custom subclass `CustomTags`, both `tgt.has_fields([Tags])` and
`python_tgt.has_fields([CustomTags])` will return True.
"""
return self._has_fields(fields, registered_fields=self.field_types)
return self._has_fields(fields, registered_fields=self.field_values.keys())

@final
@classmethod
def class_field_types(cls, union_membership: UnionMembership) -> Tuple[Type[Field], ...]:
@memoized_method
def class_field_types(
cls, union_membership: UnionMembership | None
) -> FrozenOrderedSet[Type[Field]]:
"""Return all registered Fields belonging to this target type.
You can also use the instance property `tgt.field_types` to avoid having to pass the
parameter UnionMembership.
"""
return (*cls.core_fields, *cls._find_plugin_fields(union_membership))
if union_membership is None:
return FrozenOrderedSet(cls.core_fields)
else:
return FrozenOrderedSet((*cls.core_fields, *cls._find_plugin_fields(union_membership)))

@final
@classmethod
Expand Down Expand Up @@ -1326,23 +1339,14 @@ def gen_tgt(address: Address, full_fp: str, generated_target_fields: dict[str, A
# -----------------------------------------------------------------------------------------------
# FieldSet
# -----------------------------------------------------------------------------------------------
def _get_field_set_fields(field_set: Type[FieldSet]) -> Dict[str, Type[Field]]:
return {
name: field_type
for name, field_type in get_type_hints(field_set).items()
if isinstance(field_type, type) and issubclass(field_type, Field)
}


def _get_field_set_fields_from_target(
field_set: Type[FieldSet], target: Target
) -> Dict[str, Field]:
all_expected_fields = _get_field_set_fields(field_set)
return {
dataclass_field_name: (
target[field_cls] if field_cls in field_set.required_fields else target.get(field_cls)
)
for dataclass_field_name, field_cls in all_expected_fields.items()
for dataclass_field_name, field_cls in field_set.fields.items()
}


Expand Down Expand Up @@ -1428,6 +1432,17 @@ def applicable_target_types(
def create(cls: Type[_FS], tgt: Target) -> _FS:
return cls(address=tgt.address, **_get_field_set_fields_from_target(cls, tgt))

@final
@memoized_classproperty
def fields(cls) -> FrozenDict[str, Type[Field]]:
return FrozenDict(
(
(name, field_type)
for name, field_type in get_type_hints(cls).items()
if isinstance(field_type, type) and issubclass(field_type, Field)
)
)

def debug_hint(self) -> str:
return self.address.spec

Expand Down
15 changes: 7 additions & 8 deletions src/python/pants/engine/target_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,11 +216,11 @@ def test_has_fields() -> None:
empty_union_membership = UnionMembership({})
tgt = FortranTarget({}, Address("", target_name="lib"))

assert tgt.field_types == (FortranExtensions, FortranVersion)
assert FortranTarget.class_field_types(union_membership=empty_union_membership) == (
assert tgt.field_types == {FortranExtensions, FortranVersion}
assert set(FortranTarget.class_field_types(union_membership=empty_union_membership)) == {
FortranExtensions,
FortranVersion,
)
}

assert tgt.has_fields([]) is True
assert FortranTarget.class_has_fields([], union_membership=empty_union_membership) is True
Expand Down Expand Up @@ -269,16 +269,15 @@ class CustomField(BoolField):
tgt_values, Address("", target_name="lib"), union_membership=union_membership
)

assert tgt.field_types == (FortranExtensions, FortranVersion, CustomField)
assert tgt.field_types == {FortranExtensions, FortranVersion, CustomField}
assert tgt.core_fields == (FortranExtensions, FortranVersion)
assert tgt.plugin_fields == (CustomField,)
assert tgt.has_field(CustomField) is True

assert FortranTarget.class_field_types(union_membership=union_membership) == (
assert set(FortranTarget.class_field_types(union_membership=union_membership)) == {
FortranExtensions,
FortranVersion,
CustomField,
)
}
assert FortranTarget.class_has_field(CustomField, union_membership=union_membership) is True
assert (
FortranTarget.class_get_field(CustomField, union_membership=union_membership) is CustomField
Expand All @@ -297,7 +296,7 @@ class OtherTarget(Target):
core_fields = ()

other_tgt = OtherTarget({}, Address("", target_name="other"))
assert other_tgt.plugin_fields == ()
assert tuple(other_tgt.field_types) == ()
assert other_tgt.has_field(CustomField) is False


Expand Down

0 comments on commit bf02313

Please sign in to comment.