Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Isolate input-dependent calculations in get_online_features #4041

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 63 additions & 34 deletions sdk/python/feast/feature_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -1535,22 +1535,11 @@ def get_online_features(
native_entity_values=True,
)

def _get_online_features(
self,
features: Union[List[str], FeatureService],
entity_values: Mapping[
str, Union[Sequence[Any], Sequence[Value], RepeatedValue]
],
full_feature_names: bool = False,
native_entity_values: bool = True,
def _get_online_request_context(
self, features: Union[List[str], FeatureService], full_feature_names: bool
):
# Extract Sequence from RepeatedValue Protobuf.
entity_value_lists: Dict[str, Union[List[Any], List[Value]]] = {
k: list(v) if isinstance(v, Sequence) else list(v.val)
for k, v in entity_values.items()
}

_feature_refs = self._get_features(features, allow_cache=True)

(
requested_feature_views,
requested_on_demand_feature_views,
Expand All @@ -1564,19 +1553,6 @@ def _get_online_features(
join_keys_set,
) = self._get_entity_maps(requested_feature_views)

entity_proto_values: Dict[str, List[Value]]
if native_entity_values:
# Convert values to Protobuf once.
entity_proto_values = {
k: python_values_to_proto_values(
v, entity_type_map.get(k, ValueType.UNKNOWN)
)
for k, v in entity_value_lists.items()
}
else:
entity_proto_values = entity_value_lists

num_rows = _validate_entity_values(entity_proto_values)
_validate_feature_refs(_feature_refs, full_feature_names)
(
grouped_refs,
Expand All @@ -1588,7 +1564,6 @@ def _get_online_features(
)
set_usage_attribute("odfv", bool(grouped_odfv_refs))

# All requested features should be present in the result.
requested_result_row_names = {
feat_ref.replace(":", "__") for feat_ref in _feature_refs
}
Expand All @@ -1601,6 +1576,65 @@ def _get_online_features(

needed_request_data = self.get_needed_request_data(grouped_odfv_refs)

entityless_case = DUMMY_ENTITY_NAME in [
entity_name
for feature_view in feature_views
for entity_name in feature_view.entities
]

return (
_feature_refs,
requested_on_demand_feature_views,
entity_name_to_join_key_map,
entity_type_map,
join_keys_set,
grouped_refs,
requested_result_row_names,
needed_request_data,
entityless_case,
)

def _get_online_features(
self,
features: Union[List[str], FeatureService],
entity_values: Mapping[
str, Union[Sequence[Any], Sequence[Value], RepeatedValue]
],
full_feature_names: bool = False,
native_entity_values: bool = True,
):
(
_feature_refs,
requested_on_demand_feature_views,
entity_name_to_join_key_map,
entity_type_map,
join_keys_set,
grouped_refs,
requested_result_row_names,
needed_request_data,
entityless_case,
) = self._get_online_request_context(features, full_feature_names)

# Extract Sequence from RepeatedValue Protobuf.
entity_value_lists: Dict[str, Union[List[Any], List[Value]]] = {
k: list(v) if isinstance(v, Sequence) else list(v.val)
for k, v in entity_values.items()
}

entity_proto_values: Dict[str, List[Value]]
if native_entity_values:
# Convert values to Protobuf once.
entity_proto_values = {
k: python_values_to_proto_values(
v, entity_type_map.get(k, ValueType.UNKNOWN)
)
for k, v in entity_value_lists.items()
}
else:
entity_proto_values = entity_value_lists

num_rows = _validate_entity_values(entity_proto_values)

join_key_values: Dict[str, List[Value]] = {}
request_data_features: Dict[str, List[Value]] = {}
# Entity rows may be either entities or request data.
Expand Down Expand Up @@ -1640,11 +1674,6 @@ def _get_online_features(

# Add the Entityless case after populating result rows to avoid having to remove
# it later.
entityless_case = DUMMY_ENTITY_NAME in [
entity_name
for feature_view in feature_views
for entity_name in feature_view.entities
]
if entityless_case:
join_key_values[DUMMY_ENTITY_ID] = python_values_to_proto_values(
[DUMMY_ENTITY_VAL] * num_rows, DUMMY_ENTITY.value_type
Expand Down Expand Up @@ -1677,7 +1706,7 @@ def _get_online_features(
table,
)

if grouped_odfv_refs:
if requested_on_demand_feature_views:
self._augment_response_with_on_demand_transforms(
online_features_response,
_feature_refs,
Expand Down
Loading