diff --git a/Makefile b/Makefile index 0d1bc034c2..3ee9b1468c 100644 --- a/Makefile +++ b/Makefile @@ -61,6 +61,9 @@ test-python: test-python-integration: FEAST_USAGE=False IS_TEST=True python -m pytest -n 8 --integration sdk/python/tests +test-python-universal-local: + FEAST_USAGE=False IS_TEST=True FEAST_IS_LOCAL_TEST=True python -m pytest -n 8 --integration --universal sdk/python/tests + test-python-universal: FEAST_USAGE=False IS_TEST=True python -m pytest -n 8 --integration --universal sdk/python/tests diff --git a/sdk/python/feast/feature_store.py b/sdk/python/feast/feature_store.py index 50213e80dc..67302edcca 100644 --- a/sdk/python/feast/feature_store.py +++ b/sdk/python/feast/feature_store.py @@ -880,9 +880,9 @@ def get_online_features( """ _feature_refs = self._get_features(features, feature_refs) ( - all_feature_views, - all_request_feature_views, - all_on_demand_feature_views, + requested_feature_views, + requested_request_feature_views, + requested_on_demand_feature_views, ) = self._get_feature_views_to_use( features=features, allow_cache=True, hide_dummy_entity=False ) @@ -895,9 +895,9 @@ def get_online_features( _, ) = _group_feature_refs( _feature_refs, - all_feature_views, - all_request_feature_views, - all_on_demand_feature_views, + requested_feature_views, + requested_request_feature_views, + requested_on_demand_feature_views, ) if len(grouped_odfv_refs) > 0: log_event(UsageEvent.GET_ONLINE_FEATURES_WITH_ODFV) @@ -913,10 +913,10 @@ def get_online_features( provider = self._get_provider() entities = self._list_entities(allow_cache=True, hide_dummy_entity=False) - entity_name_to_join_key_map = {} + entity_name_to_join_key_map: Dict[str, str] = {} for entity in entities: entity_name_to_join_key_map[entity.name] = entity.join_key - for feature_view in all_feature_views: + for feature_view in requested_feature_views: for entity_name in feature_view.entities: entity = self._registry.get_entity( entity_name, self.project, allow_cache=True @@ -976,17 +976,6 @@ def get_online_features( # Also create entity values to append to the result result_rows.append(_entity_row_to_field_values(entity_row_proto)) - # Add more feature values to the existing result rows for the request data features - for feature_name, feature_values in request_data_features.items(): - for row_idx, feature_value in enumerate(feature_values): - result_row = result_rows[row_idx] - result_row.fields[feature_name].CopyFrom( - python_value_to_proto_value(feature_value) - ) - result_row.statuses[ - feature_name - ] = GetOnlineFeaturesResponse.FieldStatus.PRESENT - for table, requested_features in grouped_refs: table_join_keys = [ entity_name_to_join_key_map[entity_name] @@ -1002,17 +991,85 @@ def get_online_features( union_of_entity_keys, ) + requested_result_row_names = self._get_requested_result_fields( + result_rows, needed_request_fv_features + ) + self._populate_odfv_dependencies( + entity_name_to_join_key_map, + full_feature_names, + grouped_odfv_refs, + provider, + request_data_features, + result_rows, + union_of_entity_keys, + ) + initial_response = OnlineResponse( GetOnlineFeaturesResponse(field_values=result_rows) ) return self._augment_response_with_on_demand_transforms( _feature_refs, - all_on_demand_feature_views, + requested_result_row_names, + requested_on_demand_feature_views, full_feature_names, initial_response, result_rows, ) + def _get_requested_result_fields( + self, + result_rows: List[GetOnlineFeaturesResponse.FieldValues], + needed_request_fv_features: Set[str], + ): + # Get requested feature values so we can drop odfv dependencies that aren't requested + requested_result_row_names: Set[str] = set() + for result_row in result_rows: + for feature_name in result_row.fields.keys(): + requested_result_row_names.add(feature_name) + # Request feature view values are also request data features that should be in the + # final output + requested_result_row_names.update(needed_request_fv_features) + return requested_result_row_names + + def _populate_odfv_dependencies( + self, + entity_name_to_join_key_map: Dict[str, str], + full_feature_names: bool, + grouped_odfv_refs: List[Tuple[OnDemandFeatureView, List[str]]], + provider: Provider, + request_data_features: Dict[str, List[Any]], + result_rows: List[GetOnlineFeaturesResponse.FieldValues], + union_of_entity_keys: List[EntityKeyProto], + ): + # Add more feature values to the existing result rows for the request data features + for feature_name, feature_values in request_data_features.items(): + for row_idx, feature_value in enumerate(feature_values): + result_row = result_rows[row_idx] + result_row.fields[feature_name].CopyFrom( + python_value_to_proto_value(feature_value) + ) + result_row.statuses[ + feature_name + ] = GetOnlineFeaturesResponse.FieldStatus.PRESENT + + # Add data if odfv requests specific feature views as dependencies + if len(grouped_odfv_refs) > 0: + for odfv, _ in grouped_odfv_refs: + for fv in odfv.input_feature_views.values(): + table_join_keys = [ + entity_name_to_join_key_map[entity_name] + for entity_name in fv.entities + ] + self._populate_result_rows_from_feature_view( + table_join_keys, + full_feature_names, + provider, + [feature.name for feature in fv.features], + result_rows, + fv, + union_of_entity_keys, + ) + def get_needed_request_data( self, grouped_odfv_refs: List[Tuple[OnDemandFeatureView, List[str]]], @@ -1097,27 +1154,10 @@ def _populate_result_rows_from_feature_view( feature_ref ] = GetOnlineFeaturesResponse.FieldStatus.PRESENT - def _get_needed_request_data_features( - self, - grouped_odfv_refs: List[Tuple[OnDemandFeatureView, List[str]]], - grouped_request_fv_refs: List[Tuple[RequestFeatureView, List[str]]], - ) -> Set[str]: - needed_request_data_features = set() - for odfv_to_feature_names in grouped_odfv_refs: - odfv, requested_feature_names = odfv_to_feature_names - odfv_request_data_schema = odfv.get_request_data_schema() - for feature_name in odfv_request_data_schema.keys(): - needed_request_data_features.add(feature_name) - for request_fv_to_feature_names in grouped_request_fv_refs: - request_fv, requested_feature_names = request_fv_to_feature_names - for fv in request_fv.features: - needed_request_data_features.add(fv.name) - return needed_request_data_features - - # TODO(adchia): remove request data, which isn't part of the feature_refs def _augment_response_with_on_demand_transforms( self, feature_refs: List[str], + requested_result_row_names: Set[str], odfvs: List[OnDemandFeatureView], full_feature_names: bool, initial_response: OnlineResponse, @@ -1137,6 +1177,7 @@ def _augment_response_with_on_demand_transforms( odfv_feature_refs[view_name].append(feature_name) # Apply on demand transformations + odfv_result_names = set() for odfv_name, _feature_refs in odfv_feature_refs.items(): odfv = all_on_demand_feature_views[odfv_name] transformed_features_df = odfv.get_transformed_features_df( @@ -1155,6 +1196,7 @@ def _augment_response_with_on_demand_transforms( if full_feature_names else transformed_feature ) + odfv_result_names.add(transformed_feature_name) proto_value = python_value_to_proto_value( transformed_features_df[transformed_feature].values[row_idx] ) @@ -1162,6 +1204,19 @@ def _augment_response_with_on_demand_transforms( result_row.statuses[ transformed_feature_name ] = GetOnlineFeaturesResponse.FieldStatus.PRESENT + + # Drop values that aren't needed + unneeded_features = [ + val + for val in result_rows[0].fields + if val not in requested_result_row_names and val not in odfv_result_names + ] + for row_idx in range(len(result_rows)): + result_row = result_rows[row_idx] + for unneeded_feature in unneeded_features: + result_row.fields.pop(unneeded_feature) + result_row.statuses.pop(unneeded_feature) + return OnlineResponse(GetOnlineFeaturesResponse(field_values=result_rows)) def _get_feature_views_to_use( diff --git a/sdk/python/tests/integration/feature_repos/repo_configuration.py b/sdk/python/tests/integration/feature_repos/repo_configuration.py index 203d1df474..67b330be49 100644 --- a/sdk/python/tests/integration/feature_repos/repo_configuration.py +++ b/sdk/python/tests/integration/feature_repos/repo_configuration.py @@ -49,30 +49,35 @@ DEFAULT_FULL_REPO_CONFIGS: List[IntegrationTestRepoConfig] = [ # Local configurations IntegrationTestRepoConfig(), +] +if os.getenv("FEAST_IS_LOCAL_TEST", "False") != "True": IntegrationTestRepoConfig(online_store=REDIS_CONFIG), # GCP configurations - IntegrationTestRepoConfig( - provider="gcp", - offline_store_creator=BigQueryDataSourceCreator, - online_store="datastore", - ), - IntegrationTestRepoConfig( - provider="gcp", - offline_store_creator=BigQueryDataSourceCreator, - online_store=REDIS_CONFIG, - ), - # AWS configurations - IntegrationTestRepoConfig( - provider="aws", - offline_store_creator=RedshiftDataSourceCreator, - online_store=DYNAMO_CONFIG, - ), - IntegrationTestRepoConfig( - provider="aws", - offline_store_creator=RedshiftDataSourceCreator, - online_store=REDIS_CONFIG, - ), -] + DEFAULT_FULL_REPO_CONFIGS.extend( + [ + IntegrationTestRepoConfig( + provider="gcp", + offline_store_creator=BigQueryDataSourceCreator, + online_store="datastore", + ), + IntegrationTestRepoConfig( + provider="gcp", + offline_store_creator=BigQueryDataSourceCreator, + online_store=REDIS_CONFIG, + ), + # AWS configurations + IntegrationTestRepoConfig( + provider="aws", + offline_store_creator=RedshiftDataSourceCreator, + online_store=DYNAMO_CONFIG, + ), + IntegrationTestRepoConfig( + provider="aws", + offline_store_creator=RedshiftDataSourceCreator, + online_store=REDIS_CONFIG, + ), + ] + ) full_repo_configs_module = os.environ.get(FULL_REPO_CONFIGS_MODULE_ENV_NAME) if full_repo_configs_module is not None: try: diff --git a/sdk/python/tests/integration/online_store/test_universal_online.py b/sdk/python/tests/integration/online_store/test_universal_online.py index 9557972e5e..c90021f9ce 100644 --- a/sdk/python/tests/integration/online_store/test_universal_online.py +++ b/sdk/python/tests/integration/online_store/test_universal_online.py @@ -1,5 +1,6 @@ import datetime import itertools +import os import unittest from datetime import timedelta @@ -29,6 +30,8 @@ # TODO: make this work with all universal (all online store types) @pytest.mark.integration def test_write_to_online_store_event_check(local_redis_environment): + if os.getenv("FEAST_IS_LOCAL_TEST", "False") == "True": + return fs = local_redis_environment.feature_store # write same data points 3 with different timestamps @@ -274,11 +277,20 @@ def test_online_retrieval(environment, universal_data_sources, full_feature_name ) assert online_features is not None + # Test that the on demand feature views compute properly even if the dependent conv_rate + # feature isn't requested. + online_features_no_conv_rate = fs.get_online_features( + features=[ref for ref in feature_refs if ref != "driver_stats:conv_rate"], + entity_rows=entity_rows, + full_feature_names=full_feature_names, + ) + assert online_features_no_conv_rate is not None + online_features_dict = online_features.to_dict() keys = online_features_dict.keys() assert ( - len(keys) == len(feature_refs) + 3 - ) # Add three for the driver id and the customer id entity keys + val_to_add request data. + len(keys) == len(feature_refs) + 2 + ) # Add two for the driver id and the customer id entity keys for feature in feature_refs: # full_feature_names does not apply to request feature views if full_feature_names and feature != "driver_age:driver_age": @@ -526,8 +538,8 @@ def assert_feature_service_correctness( for projection in feature_service.feature_view_projections ] ) - + 3 - ) # Add two for the driver id and the customer id entity keys and val_to_add request data + + 2 + ) # Add two for the driver id and the customer id entity keys tc = unittest.TestCase() for i, entity_row in enumerate(entity_rows):