Skip to content

Commit

Permalink
feat: Isolate input-dependent calculations in get_online_features (#…
Browse files Browse the repository at this point in the history
…4041)

refactor get online features

Signed-off-by: tokoko <togurg14@freeuni.edu.ge>
  • Loading branch information
tokoko authored Apr 19, 2024
1 parent d45fb4e commit 2a6edea
Showing 1 changed file with 63 additions and 34 deletions.
97 changes: 63 additions & 34 deletions sdk/python/feast/feature_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -1535,22 +1535,11 @@ def get_online_features(
native_entity_values=True,
)

def _get_online_features(
self,
features: Union[List[str], FeatureService],
entity_values: Mapping[
str, Union[Sequence[Any], Sequence[Value], RepeatedValue]
],
full_feature_names: bool = False,
native_entity_values: bool = True,
def _get_online_request_context(
self, features: Union[List[str], FeatureService], full_feature_names: bool
):
# Extract Sequence from RepeatedValue Protobuf.
entity_value_lists: Dict[str, Union[List[Any], List[Value]]] = {
k: list(v) if isinstance(v, Sequence) else list(v.val)
for k, v in entity_values.items()
}

_feature_refs = self._get_features(features, allow_cache=True)

(
requested_feature_views,
requested_on_demand_feature_views,
Expand All @@ -1564,19 +1553,6 @@ def _get_online_features(
join_keys_set,
) = self._get_entity_maps(requested_feature_views)

entity_proto_values: Dict[str, List[Value]]
if native_entity_values:
# Convert values to Protobuf once.
entity_proto_values = {
k: python_values_to_proto_values(
v, entity_type_map.get(k, ValueType.UNKNOWN)
)
for k, v in entity_value_lists.items()
}
else:
entity_proto_values = entity_value_lists

num_rows = _validate_entity_values(entity_proto_values)
_validate_feature_refs(_feature_refs, full_feature_names)
(
grouped_refs,
Expand All @@ -1588,7 +1564,6 @@ def _get_online_features(
)
set_usage_attribute("odfv", bool(grouped_odfv_refs))

# All requested features should be present in the result.
requested_result_row_names = {
feat_ref.replace(":", "__") for feat_ref in _feature_refs
}
Expand All @@ -1601,6 +1576,65 @@ def _get_online_features(

needed_request_data = self.get_needed_request_data(grouped_odfv_refs)

entityless_case = DUMMY_ENTITY_NAME in [
entity_name
for feature_view in feature_views
for entity_name in feature_view.entities
]

return (
_feature_refs,
requested_on_demand_feature_views,
entity_name_to_join_key_map,
entity_type_map,
join_keys_set,
grouped_refs,
requested_result_row_names,
needed_request_data,
entityless_case,
)

def _get_online_features(
self,
features: Union[List[str], FeatureService],
entity_values: Mapping[
str, Union[Sequence[Any], Sequence[Value], RepeatedValue]
],
full_feature_names: bool = False,
native_entity_values: bool = True,
):
(
_feature_refs,
requested_on_demand_feature_views,
entity_name_to_join_key_map,
entity_type_map,
join_keys_set,
grouped_refs,
requested_result_row_names,
needed_request_data,
entityless_case,
) = self._get_online_request_context(features, full_feature_names)

# Extract Sequence from RepeatedValue Protobuf.
entity_value_lists: Dict[str, Union[List[Any], List[Value]]] = {
k: list(v) if isinstance(v, Sequence) else list(v.val)
for k, v in entity_values.items()
}

entity_proto_values: Dict[str, List[Value]]
if native_entity_values:
# Convert values to Protobuf once.
entity_proto_values = {
k: python_values_to_proto_values(
v, entity_type_map.get(k, ValueType.UNKNOWN)
)
for k, v in entity_value_lists.items()
}
else:
entity_proto_values = entity_value_lists

num_rows = _validate_entity_values(entity_proto_values)

join_key_values: Dict[str, List[Value]] = {}
request_data_features: Dict[str, List[Value]] = {}
# Entity rows may be either entities or request data.
Expand Down Expand Up @@ -1640,11 +1674,6 @@ def _get_online_features(

# Add the Entityless case after populating result rows to avoid having to remove
# it later.
entityless_case = DUMMY_ENTITY_NAME in [
entity_name
for feature_view in feature_views
for entity_name in feature_view.entities
]
if entityless_case:
join_key_values[DUMMY_ENTITY_ID] = python_values_to_proto_values(
[DUMMY_ENTITY_VAL] * num_rows, DUMMY_ENTITY.value_type
Expand Down Expand Up @@ -1677,7 +1706,7 @@ def _get_online_features(
table,
)

if grouped_odfv_refs:
if requested_on_demand_feature_views:
self._augment_response_with_on_demand_transforms(
online_features_response,
_feature_refs,
Expand Down

0 comments on commit 2a6edea

Please sign in to comment.