Skip to content

Commit

Permalink
Optimize Target.has_fields by removing linear lookup in tuple and r…
Browse files Browse the repository at this point in the history
…edundant `Target.plugin_types` property.
  • Loading branch information
stuhood committed May 5, 2023
1 parent fa853ed commit 68ce5d4
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 27 deletions.
4 changes: 2 additions & 2 deletions src/python/pants/backend/docker/goals/package_image_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -954,13 +954,13 @@ def check_docker_proc(process: Process):
assert process.argv == (
"/dummy/docker",
"build",
"--pull=False",
"--secret",
"id=system-secret,src=/var/run/secrets/mysecret",
"--secret",
f"id=project-secret,src={rule_runner.build_root}/secrets/mysecret",
"--secret",
f"id=target-secret,src={rule_runner.build_root}/docker/test/mysecret",
"--pull=False",
"--tag",
"img1:latest",
"--file",
Expand Down Expand Up @@ -993,9 +993,9 @@ def check_docker_proc(process: Process):
assert process.argv == (
"/dummy/docker",
"build",
"--pull=False",
"--ssh",
"default",
"--pull=False",
"--tag",
"img1:latest",
"--file",
Expand Down
6 changes: 5 additions & 1 deletion src/python/pants/engine/internals/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,11 @@ def create_registrar_for_target(alias: str) -> tuple[str, Registrar]:
return alias, Registrar(
parse_state,
alias,
registered_target_types.aliases_to_types[alias].class_field_types(union_membership),
tuple(
registered_target_types.aliases_to_types[alias].class_field_types(
union_membership
)
),
)

type_aliases = dict(map(create_registrar_for_target, registered_target_types.aliases))
Expand Down
41 changes: 25 additions & 16 deletions src/python/pants/engine/target.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,15 @@
from enum import Enum
from pathlib import PurePath
from typing import (
AbstractSet,
Any,
Callable,
ClassVar,
Dict,
Generic,
Iterable,
Iterator,
KeysView,
Mapping,
Optional,
Sequence,
Expand Down Expand Up @@ -379,7 +381,6 @@ class Target:

# These get calculated in the constructor
address: Address
plugin_fields: tuple[type[Field], ...]
field_values: FrozenDict[type[Field], Field]
residence_dir: str
name_explicitly_set: bool
Expand Down Expand Up @@ -440,17 +441,13 @@ def __init__(
)
object.__setattr__(self, "name_explicitly_set", name_explicitly_set)
try:
object.__setattr__(
self,
"plugin_fields",
self._find_plugin_fields(union_membership or UnionMembership({})),
)
object.__setattr__(
self,
"field_values",
self._calculate_field_values(
unhydrated_values,
address,
union_membership,
ignore_unrecognized_fields=ignore_unrecognized_fields,
),
)
Expand All @@ -466,11 +463,14 @@ def _calculate_field_values(
self,
unhydrated_values: dict[str, Any],
address: Address,
# See `__init__`.
union_membership: UnionMembership | None,
*,
ignore_unrecognized_fields: bool,
) -> FrozenDict[type[Field], Field]:
all_field_types = self.class_field_types(union_membership)
field_values = {}
aliases_to_field_types = self._get_field_aliases_to_field_types(self.field_types)
aliases_to_field_types = self._get_field_aliases_to_field_types(all_field_types)

for alias, value in unhydrated_values.items():
if alias not in aliases_to_field_types:
Expand All @@ -494,7 +494,9 @@ def _calculate_field_values(
field_values[field_type] = field_type(value, address)

# For undefined fields, mark the raw value as missing.
for field_type in set(self.field_types) - set(field_values.keys()):
for field_type in all_field_types:
if field_type in field_values:
continue
field_values[field_type] = field_type(NO_VALUE, address)
return FrozenDict(
sorted(
Expand All @@ -506,7 +508,7 @@ def _calculate_field_values(
@final
@classmethod
def _get_field_aliases_to_field_types(
cls, field_types: tuple[type[Field], ...]
cls, field_types: Iterable[type[Field]]
) -> dict[str, type[Field]]:
aliases_to_field_types = {}
for field_type in field_types:
Expand All @@ -517,8 +519,8 @@ def _get_field_aliases_to_field_types(

@final
@property
def field_types(self) -> Tuple[Type[Field], ...]:
return (*self.core_fields, *self.plugin_fields)
def field_types(self) -> KeysView[Type[Field]]:
return self.field_values.keys()

@distinct_union_type_per_subclass
class PluginField:
Expand Down Expand Up @@ -555,6 +557,7 @@ def __eq__(self, other: Union[Target, Any]) -> bool:

@final
@classmethod
@memoized_method
def _find_plugin_fields(cls, union_membership: UnionMembership) -> tuple[type[Field], ...]:
result: set[type[Field]] = set()
classes = [cls]
Expand Down Expand Up @@ -594,7 +597,7 @@ def _maybe_get(self, field: Type[_F]) -> Optional[_F]:
if result is not None:
return cast(_F, result)
field_subclass = self._find_registered_field_subclass(
field, registered_fields=self.field_types
field, registered_fields=self.field_values.keys()
)
if field_subclass is not None:
return cast(_F, self.field_values[field_subclass])
Expand Down Expand Up @@ -648,7 +651,7 @@ def get(self, field: Type[_F], *, default_raw_value: Optional[Any] = None) -> _F
@final
@classmethod
def _has_fields(
cls, fields: Iterable[Type[Field]], *, registered_fields: Iterable[Type[Field]]
cls, fields: Iterable[Type[Field]], *, registered_fields: AbstractSet[Type[Field]]
) -> bool:
unrecognized_fields = [field for field in fields if field not in registered_fields]
if not unrecognized_fields:
Expand Down Expand Up @@ -679,17 +682,23 @@ def has_fields(self, fields: Iterable[Type[Field]]) -> bool:
custom subclass `CustomTags`, both `tgt.has_fields([Tags])` and
`python_tgt.has_fields([CustomTags])` will return True.
"""
return self._has_fields(fields, registered_fields=self.field_types)
return self._has_fields(fields, registered_fields=self.field_values.keys())

@final
@classmethod
def class_field_types(cls, union_membership: UnionMembership) -> Tuple[Type[Field], ...]:
@memoized_method
def class_field_types(
cls, union_membership: UnionMembership | None
) -> FrozenOrderedSet[Type[Field]]:
"""Return all registered Fields belonging to this target type.
You can also use the instance property `tgt.field_types` to avoid having to pass the
parameter UnionMembership.
"""
return (*cls.core_fields, *cls._find_plugin_fields(union_membership))
if union_membership is None:
return FrozenOrderedSet(cls.core_fields)
else:
return FrozenOrderedSet((*cls.core_fields, *cls._find_plugin_fields(union_membership)))

@final
@classmethod
Expand Down
15 changes: 7 additions & 8 deletions src/python/pants/engine/target_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,11 +215,11 @@ def test_has_fields() -> None:
empty_union_membership = UnionMembership({})
tgt = FortranTarget({}, Address("", target_name="lib"))

assert tgt.field_types == (FortranExtensions, FortranVersion)
assert FortranTarget.class_field_types(union_membership=empty_union_membership) == (
assert tgt.field_types == {FortranExtensions, FortranVersion}
assert set(FortranTarget.class_field_types(union_membership=empty_union_membership)) == {
FortranExtensions,
FortranVersion,
)
}

assert tgt.has_fields([]) is True
assert FortranTarget.class_has_fields([], union_membership=empty_union_membership) is True
Expand Down Expand Up @@ -268,16 +268,15 @@ class CustomField(BoolField):
tgt_values, Address("", target_name="lib"), union_membership=union_membership
)

assert tgt.field_types == (FortranExtensions, FortranVersion, CustomField)
assert tgt.field_types == {FortranExtensions, FortranVersion, CustomField}
assert tgt.core_fields == (FortranExtensions, FortranVersion)
assert tgt.plugin_fields == (CustomField,)
assert tgt.has_field(CustomField) is True

assert FortranTarget.class_field_types(union_membership=union_membership) == (
assert set(FortranTarget.class_field_types(union_membership=union_membership)) == {
FortranExtensions,
FortranVersion,
CustomField,
)
}
assert FortranTarget.class_has_field(CustomField, union_membership=union_membership) is True
assert (
FortranTarget.class_get_field(CustomField, union_membership=union_membership) is CustomField
Expand All @@ -296,7 +295,7 @@ class OtherTarget(Target):
core_fields = ()

other_tgt = OtherTarget({}, Address("", target_name="other"))
assert other_tgt.plugin_fields == ()
assert tuple(other_tgt.field_types) == ()
assert other_tgt.has_field(CustomField) is False


Expand Down

0 comments on commit 68ce5d4

Please sign in to comment.