From d71be5304069c2a2fbac691cce1d74fc1b462d30 Mon Sep 17 00:00:00 2001 From: Felix Wang Date: Thu, 16 Jun 2022 16:28:04 -0700 Subject: [PATCH] Add test for SFV online retrieval Signed-off-by: Felix Wang --- .../online_store/test_universal_online.py | 80 +++++++++++++++++++ 1 file changed, 80 insertions(+) diff --git a/sdk/python/tests/integration/online_store/test_universal_online.py b/sdk/python/tests/integration/online_store/test_universal_online.py index d05045e295..b01448e7cc 100644 --- a/sdk/python/tests/integration/online_store/test_universal_online.py +++ b/sdk/python/tests/integration/online_store/test_universal_online.py @@ -441,6 +441,82 @@ def test_online_retrieval_with_event_timestamps( ) +@pytest.mark.integration +@pytest.mark.universal_online_stores +@pytest.mark.goserver +@pytest.mark.parametrize("full_feature_names", [True, False], ids=lambda v: str(v)) +def test_stream_feature_view_online_retrieval( + environment, universal_data_sources, feature_server_endpoint, full_feature_names +): + """ + Tests materialization and online retrieval for stream feature views. + + This test is separate from test_online_retrieval since combining feature views and + stream feature views into a single test resulted in test flakiness. This is tech + debt that should be resolved soon. + """ + # Set up feature store. + fs = environment.feature_store + entities, datasets, data_sources = universal_data_sources + feature_views = construct_universal_feature_views(data_sources) + pushable_feature_view = feature_views.pushed_locations + fs.apply([location(), pushable_feature_view]) + + # Materialize. + fs.materialize( + environment.start_date - timedelta(days=1), + environment.end_date + timedelta(days=1), + ) + + # Get online features by randomly sampling 10 entities that exist in the batch source. + sample_locations = datasets.location_df.sample(10)["location_id"] + entity_rows = [ + {"location_id": sample_location} for sample_location in sample_locations + ] + + feature_refs = [ + "pushable_location_stats:temperature", + ] + unprefixed_feature_refs = [f.rsplit(":", 1)[-1] for f in feature_refs if ":" in f] + + online_features_dict = get_online_features_dict( + environment=environment, + endpoint=feature_server_endpoint, + features=feature_refs, + entity_rows=entity_rows, + full_feature_names=full_feature_names, + ) + + # Check that the response has the expected set of keys. + keys = set(online_features_dict.keys()) + expected_keys = set( + f.replace(":", "__") if full_feature_names else f.split(":")[-1] + for f in feature_refs + ) | {"location_id"} + assert ( + keys == expected_keys + ), f"Response keys are different from expected: {keys - expected_keys} (extra) and {expected_keys - keys} (missing)" + + # Check that the feature values match. + tc = unittest.TestCase() + for i, entity_row in enumerate(entity_rows): + df_features = get_latest_feature_values_from_location_df( + entity_row, datasets.location_df + ) + + assert df_features["location_id"] == online_features_dict["location_id"][i] + for unprefixed_feature_ref in unprefixed_feature_refs: + tc.assertAlmostEqual( + df_features[unprefixed_feature_ref], + online_features_dict[ + response_feature_name( + unprefixed_feature_ref, feature_refs, full_feature_names + ) + ][i], + delta=0.0001, + ) + + @pytest.mark.integration @pytest.mark.universal_online_stores @pytest.mark.goserver @@ -859,6 +935,10 @@ def get_latest_feature_values_for_location_df(entity_row, origin_df, destination } +def get_latest_feature_values_from_location_df(entity_row, location_df): + return get_latest_row(entity_row, location_df, "location_id", "location_id") + + def assert_feature_service_correctness( environment, endpoint,