Skip to content

Commit

Permalink
fix: Fix file offline store logic for feature views without ttl (#2971)
Browse files Browse the repository at this point in the history
* Add new test for historical retrieval with feature views with no ttl

Signed-off-by: Felix Wang <wangfelix98@gmail.com>

* Fix no ttl logic

Signed-off-by: Felix Wang <wangfelix98@gmail.com>
  • Loading branch information
felixwang9817 authored Jul 26, 2022
1 parent 3ce5139 commit 26f6b69
Show file tree
Hide file tree
Showing 2 changed files with 104 additions and 7 deletions.
8 changes: 8 additions & 0 deletions sdk/python/feast/infra/offline_stores/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -635,6 +635,14 @@ def _filter_ttl(
)
]

df_to_join = df_to_join.persist()
else:
df_to_join = df_to_join[
# do not drop entity rows if one of the sources returns NaNs
df_to_join[timestamp_field].isna()
| (df_to_join[timestamp_field] <= df_to_join[entity_df_event_timestamp_col])
]

df_to_join = df_to_join.persist()

return df_to_join
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -115,60 +115,70 @@ def get_expected_training_df(
entity_df.to_dict("records"), event_timestamp
)

# Set sufficiently large ttl that it effectively functions as infinite for the calculations below.
default_ttl = timedelta(weeks=52)

# Manually do point-in-time join of driver, customer, and order records against
# the entity df
for entity_row in entity_rows:
customer_record = find_asof_record(
customer_records,
ts_key=customer_fv.batch_source.timestamp_field,
ts_start=entity_row[event_timestamp] - customer_fv.ttl,
ts_start=entity_row[event_timestamp]
- get_feature_view_ttl(customer_fv, default_ttl),
ts_end=entity_row[event_timestamp],
filter_keys=["customer_id"],
filter_values=[entity_row["customer_id"]],
)
driver_record = find_asof_record(
driver_records,
ts_key=driver_fv.batch_source.timestamp_field,
ts_start=entity_row[event_timestamp] - driver_fv.ttl,
ts_start=entity_row[event_timestamp]
- get_feature_view_ttl(driver_fv, default_ttl),
ts_end=entity_row[event_timestamp],
filter_keys=["driver_id"],
filter_values=[entity_row["driver_id"]],
)
order_record = find_asof_record(
order_records,
ts_key=customer_fv.batch_source.timestamp_field,
ts_start=entity_row[event_timestamp] - order_fv.ttl,
ts_start=entity_row[event_timestamp]
- get_feature_view_ttl(order_fv, default_ttl),
ts_end=entity_row[event_timestamp],
filter_keys=["customer_id", "driver_id"],
filter_values=[entity_row["customer_id"], entity_row["driver_id"]],
)
origin_record = find_asof_record(
location_records,
ts_key=location_fv.batch_source.timestamp_field,
ts_start=order_record[event_timestamp] - location_fv.ttl,
ts_start=order_record[event_timestamp]
- get_feature_view_ttl(location_fv, default_ttl),
ts_end=order_record[event_timestamp],
filter_keys=["location_id"],
filter_values=[order_record["origin_id"]],
)
destination_record = find_asof_record(
location_records,
ts_key=location_fv.batch_source.timestamp_field,
ts_start=order_record[event_timestamp] - location_fv.ttl,
ts_start=order_record[event_timestamp]
- get_feature_view_ttl(location_fv, default_ttl),
ts_end=order_record[event_timestamp],
filter_keys=["location_id"],
filter_values=[order_record["destination_id"]],
)
global_record = find_asof_record(
global_records,
ts_key=global_fv.batch_source.timestamp_field,
ts_start=order_record[event_timestamp] - global_fv.ttl,
ts_start=order_record[event_timestamp]
- get_feature_view_ttl(global_fv, default_ttl),
ts_end=order_record[event_timestamp],
)

field_mapping_record = find_asof_record(
field_mapping_records,
ts_key=field_mapping_fv.batch_source.timestamp_field,
ts_start=order_record[event_timestamp] - field_mapping_fv.ttl,
ts_start=order_record[event_timestamp]
- get_feature_view_ttl(field_mapping_fv, default_ttl),
ts_end=order_record[event_timestamp],
)

Expand Down Expand Up @@ -666,6 +676,78 @@ def test_historical_features_persisting(
)


@pytest.mark.integration
@pytest.mark.universal_offline_stores
@pytest.mark.parametrize("full_feature_names", [True, False], ids=lambda v: str(v))
def test_historical_features_with_no_ttl(
environment, universal_data_sources, full_feature_names
):
store = environment.feature_store

(entities, datasets, data_sources) = universal_data_sources
feature_views = construct_universal_feature_views(data_sources)

# Remove ttls.
feature_views.customer.ttl = timedelta(seconds=0)
feature_views.order.ttl = timedelta(seconds=0)
feature_views.global_fv.ttl = timedelta(seconds=0)
feature_views.field_mapping.ttl = timedelta(seconds=0)

store.apply([driver(), customer(), location(), *feature_views.values()])

entity_df = datasets.entity_df.drop(
columns=["order_id", "origin_id", "destination_id"]
)

job = store.get_historical_features(
entity_df=entity_df,
features=[
"customer_profile:current_balance",
"customer_profile:avg_passenger_count",
"customer_profile:lifetime_trip_count",
"order:order_is_success",
"global_stats:num_rides",
"global_stats:avg_ride_length",
"field_mapping:feature_name",
],
full_feature_names=full_feature_names,
)

event_timestamp = DEFAULT_ENTITY_DF_EVENT_TIMESTAMP_COL
expected_df = get_expected_training_df(
datasets.customer_df,
feature_views.customer,
datasets.driver_df,
feature_views.driver,
datasets.orders_df,
feature_views.order,
datasets.location_df,
feature_views.location,
datasets.global_df,
feature_views.global_fv,
datasets.field_mapping_df,
feature_views.field_mapping,
entity_df,
event_timestamp,
full_feature_names,
).drop(
columns=[
response_feature_name("conv_rate_plus_100", full_feature_names),
response_feature_name("conv_rate_plus_100_rounded", full_feature_names),
response_feature_name("avg_daily_trips", full_feature_names),
response_feature_name("conv_rate", full_feature_names),
"origin__temperature",
"destination__temperature",
]
)

assert_frame_equal(
expected_df,
job.to_df(),
keys=[event_timestamp, "driver_id", "customer_id"],
)


@pytest.mark.integration
@pytest.mark.universal_offline_stores
def test_historical_features_from_bigquery_sources_containing_backfills(environment):
Expand Down Expand Up @@ -781,6 +863,13 @@ def response_feature_name(feature: str, full_feature_names: bool) -> str:
return feature


def get_feature_view_ttl(
feature_view: FeatureView, default_ttl: timedelta
) -> timedelta:
"""Returns the ttl of a feature view if it is non-zero. Otherwise returns the specified default."""
return feature_view.ttl if feature_view.ttl else default_ttl


def assert_feature_service_correctness(
store, feature_service, full_feature_names, entity_df, expected_df, event_timestamp
):
Expand Down

0 comments on commit 26f6b69

Please sign in to comment.