From ed7f256f15bce1cb30e0bd1f3e4a9810518a0474 Mon Sep 17 00:00:00 2001 From: tokoko Date: Tue, 26 Mar 2024 05:24:24 +0000 Subject: [PATCH] refactor get online features Signed-off-by: tokoko --- sdk/python/feast/feature_store.py | 98 ++++++++++++++++++++----------- 1 file changed, 64 insertions(+), 34 deletions(-) diff --git a/sdk/python/feast/feature_store.py b/sdk/python/feast/feature_store.py index 9ac2c14527..652a7270c7 100644 --- a/sdk/python/feast/feature_store.py +++ b/sdk/python/feast/feature_store.py @@ -1532,22 +1532,11 @@ def get_online_features( native_entity_values=True, ) - def _get_online_features( - self, - features: Union[List[str], FeatureService], - entity_values: Mapping[ - str, Union[Sequence[Any], Sequence[Value], RepeatedValue] - ], - full_feature_names: bool = False, - native_entity_values: bool = True, + def _get_online_request_context( + self, features: Union[List[str], FeatureService], full_feature_names: bool ): - # Extract Sequence from RepeatedValue Protobuf. - entity_value_lists: Dict[str, Union[List[Any], List[Value]]] = { - k: list(v) if isinstance(v, Sequence) else list(v.val) - for k, v in entity_values.items() - } - _feature_refs = self._get_features(features, allow_cache=True) + ( requested_feature_views, requested_on_demand_feature_views, @@ -1561,20 +1550,8 @@ def _get_online_features( join_keys_set, ) = self._get_entity_maps(requested_feature_views) - entity_proto_values: Dict[str, List[Value]] - if native_entity_values: - # Convert values to Protobuf once. - entity_proto_values = { - k: python_values_to_proto_values( - v, entity_type_map.get(k, ValueType.UNKNOWN) - ) - for k, v in entity_value_lists.items() - } - else: - entity_proto_values = entity_value_lists - - num_rows = _validate_entity_values(entity_proto_values) _validate_feature_refs(_feature_refs, full_feature_names) + (grouped_refs, grouped_odfv_refs,) = _group_feature_refs( _feature_refs, requested_feature_views, @@ -1582,7 +1559,6 @@ def _get_online_features( ) set_usage_attribute("odfv", bool(grouped_odfv_refs)) - # All requested features should be present in the result. requested_result_row_names = { feat_ref.replace(":", "__") for feat_ref in _feature_refs } @@ -1595,6 +1571,65 @@ def _get_online_features( needed_request_data = self.get_needed_request_data(grouped_odfv_refs) + entityless_case = DUMMY_ENTITY_NAME in [ + entity_name + for feature_view in feature_views + for entity_name in feature_view.entities + ] + + return ( + _feature_refs, + requested_on_demand_feature_views, + entity_name_to_join_key_map, + entity_type_map, + join_keys_set, + grouped_refs, + requested_result_row_names, + needed_request_data, + entityless_case, + ) + + def _get_online_features( + self, + features: Union[List[str], FeatureService], + entity_values: Mapping[ + str, Union[Sequence[Any], Sequence[Value], RepeatedValue] + ], + full_feature_names: bool = False, + native_entity_values: bool = True, + ): + ( + _feature_refs, + requested_on_demand_feature_views, + entity_name_to_join_key_map, + entity_type_map, + join_keys_set, + grouped_refs, + requested_result_row_names, + needed_request_data, + entityless_case, + ) = self._get_online_request_context(features, full_feature_names) + + # Extract Sequence from RepeatedValue Protobuf. + entity_value_lists: Dict[str, Union[List[Any], List[Value]]] = { + k: list(v) if isinstance(v, Sequence) else list(v.val) + for k, v in entity_values.items() + } + + entity_proto_values: Dict[str, List[Value]] + if native_entity_values: + # Convert values to Protobuf once. + entity_proto_values = { + k: python_values_to_proto_values( + v, entity_type_map.get(k, ValueType.UNKNOWN) + ) + for k, v in entity_value_lists.items() + } + else: + entity_proto_values = entity_value_lists + + num_rows = _validate_entity_values(entity_proto_values) + join_key_values: Dict[str, List[Value]] = {} request_data_features: Dict[str, List[Value]] = {} # Entity rows may be either entities or request data. @@ -1634,11 +1669,6 @@ def _get_online_features( # Add the Entityless case after populating result rows to avoid having to remove # it later. - entityless_case = DUMMY_ENTITY_NAME in [ - entity_name - for feature_view in feature_views - for entity_name in feature_view.entities - ] if entityless_case: join_key_values[DUMMY_ENTITY_ID] = python_values_to_proto_values( [DUMMY_ENTITY_VAL] * num_rows, DUMMY_ENTITY.value_type @@ -1671,7 +1701,7 @@ def _get_online_features( table, ) - if grouped_odfv_refs: + if requested_on_demand_feature_views: self._augment_response_with_on_demand_transforms( online_features_response, _feature_refs,