diff --git a/sdk/python/feast/pyspark/historical_feature_retrieval_job.py b/sdk/python/feast/pyspark/historical_feature_retrieval_job.py index 9da3576862..d9afd43306 100644 --- a/sdk/python/feast/pyspark/historical_feature_retrieval_job.py +++ b/sdk/python/feast/pyspark/historical_feature_retrieval_job.py @@ -1,4 +1,5 @@ import json +from datetime import timedelta from typing import Any, Dict, List from pyspark import SparkFiles @@ -12,7 +13,7 @@ def as_of_join( feature_table: DataFrame, features: List[str], feature_prefix: str = "", - max_age: str = None, + max_age: int = None, ) -> DataFrame: """Perform an as of join between entity and feature table, given a maximum age tolerance. Join conditions: @@ -36,8 +37,8 @@ def as_of_join( feature_prefix (str): Feature column prefix for the result dataframe. Useful for cases where the entity dataframe contains one or more columns that share the same name as the features. - max_age (str): - Tolerance for the feature event timestamp recency. + max_age (int): + Tolerance for the feature event timestamp recency, in seconds. Returns: DataFrame: Join result. @@ -66,7 +67,7 @@ def as_of_join( | 1001|2020-09-02 00:00:00| 200| +------+-------------------+--------------+ - >>> df = as_of_join(entity, ["entity"], feature_table, ["feature"], max_age = "12 hour") + >>> df = as_of_join(entity, ["entity"], feature_table, ["feature"], max_age = 12 * 60 * 60) >>> df.show() +------+-------------------+-------+ |entity| event_timestamp|feature| @@ -96,7 +97,7 @@ def as_of_join( if max_age: join_cond = join_cond & ( selected_feature_table[feature_event_timestamp] - >= entity_with_id.event_timestamp - expr(f"INTERVAL {max_age}") + >= entity_with_id.event_timestamp - expr(f"INTERVAL {max_age} seconds") ) for key in entity_keys: @@ -317,7 +318,7 @@ def retrieve_historical_features(spark: SparkSession, conf: Dict) -> DataFrame: "table": "transactions", "features": ["daily_transactions"], "join": ["customer_id"], - "max_age": "2 day", + "max_age": 172800, }, { "table": "bookings", @@ -337,6 +338,9 @@ def retrieve_historical_features(spark: SparkSession, conf: Dict) -> DataFrame: `options` is optional. If present, the options will be used when reading / writing the input / output. + `max_age` is in seconds, and determines the lower bound of the timestamp of the retrieved feature. + If not specified, this would be unbounded. + If necessary, `col_mapping` can be provided to map the columns of the dataframes before performing the join operation. `col_mapping` is a dictionary where the key is the source column and the value is the mapped column. @@ -384,6 +388,18 @@ def map_column(df: DataFrame, col_mapping: Dict[str, str]): is_feature_table=True, ) + max_timestamp = mapped_entity_df.agg({"event_timestamp": "max"}).collect()[0][0] + min_timestamp = mapped_entity_df.agg({"event_timestamp": "min"}).collect()[0][0] + + for query in conf["queries"]: + max_age = query.get("max_age") + if max_age: + tables[query["table"]] = tables[query["table"]].filter( + col("event_timestamp").between( + min_timestamp - timedelta(seconds=max_age), max_timestamp + ) + ) + return join_entity_to_feature_tables(conf["queries"], mapped_entity_df, tables) diff --git a/sdk/python/tests/test_as_of_join.py b/sdk/python/tests/test_as_of_join.py index a72a44ec8f..21cd829f29 100644 --- a/sdk/python/tests/test_as_of_join.py +++ b/sdk/python/tests/test_as_of_join.py @@ -1,6 +1,10 @@ +import os import pathlib -from datetime import datetime +import shutil +import tempfile +from datetime import datetime, timedelta from os import path +from typing import Any, Dict, List import pytest from pyspark.sql import DataFrame, SparkSession @@ -35,6 +39,62 @@ def spark(pytestconfig): spark_session.stop() +@pytest.yield_fixture(scope="module") +def large_entity_csv_file(pytestconfig, spark): + start_datetime = datetime(year=2020, month=8, day=31) + nr_rows = 1000 + entity_data = [ + (1000 + i, start_datetime + timedelta(days=i)) for i in range(nr_rows) + ] + temp_dir = tempfile.mkdtemp() + file_path = os.path.join(temp_dir, "large_entity") + entity_schema = StructType( + [ + StructField("customer_id", IntegerType()), + StructField("event_timestamp", TimestampType()), + ] + ) + large_entity_df = spark.createDataFrame( + spark.sparkContext.parallelize(entity_data), entity_schema + ) + + large_entity_df.write.option("header", "true").csv(file_path) + yield file_path + shutil.rmtree(temp_dir) + + +@pytest.yield_fixture(scope="module") +def large_feature_csv_file(pytestconfig, spark): + start_datetime = datetime(year=2020, month=8, day=30) + nr_rows = 1000 + feature_data = [ + ( + 1000 + i, + start_datetime + timedelta(days=i), + start_datetime + timedelta(days=i + 1), + i * 10, + ) + for i in range(nr_rows) + ] + temp_dir = tempfile.mkdtemp() + file_path = os.path.join(temp_dir, "large_feature") + feature_schema = StructType( + [ + StructField("customer_id", IntegerType()), + StructField("event_timestamp", TimestampType()), + StructField("created_timestamp", TimestampType()), + StructField("total_bookings", IntegerType()), + ] + ) + large_feature_df = spark.createDataFrame( + spark.sparkContext.parallelize(feature_data), feature_schema + ) + + large_feature_df.write.option("header", "true").csv(file_path) + yield file_path + shutil.rmtree(temp_dir) + + @pytest.fixture def single_entity_schema(): return StructType( @@ -235,7 +295,7 @@ def test_join_with_max_age( feature_table_df, ["daily_transactions"], feature_prefix="transactions__", - max_age="1 day", + max_age=86400, ) expected_joined_schema = StructType( @@ -308,7 +368,7 @@ def test_join_with_composite_entity( feature_table_df, ["customer_rating", "driver_rating"], feature_prefix="ratings__", - max_age="1 day", + max_age=86400, ) expected_joined_schema = StructType( @@ -397,12 +457,12 @@ def test_multiple_join( customer_feature_schema: StructType, driver_feature_schema: StructType, ): - query_conf = [ + query_conf: List[Dict[str, Any]] = [ { "table": "transactions", "features": ["daily_transactions"], "join": ["customer_id"], - "max_age": "1 day", + "max_age": 86400, }, { "table": "bookings", @@ -524,7 +584,7 @@ def test_historical_feature_retrieval(spark): "table": "transactions", "features": ["daily_transactions"], "join": ["customer_id"], - "max_age": "1 day", + "max_age": 86400, }, { "table": "bookings", @@ -562,7 +622,7 @@ def test_historical_feature_retrieval(spark): def test_historical_feature_retrieval_with_mapping(spark): test_data_dir = path.join(pathlib.Path(__file__).parent.absolute(), "data") - batch_retrieval_conf = { + retrieval_conf = { "entity": { "format": "csv", "path": f"file://{path.join(test_data_dir, 'column_mapping_test_entity.csv')}", @@ -592,7 +652,7 @@ def test_historical_feature_retrieval_with_mapping(spark): ], } - joined_df = retrieve_historical_features(spark, batch_retrieval_conf) + joined_df = retrieve_historical_features(spark, retrieval_conf) expected_joined_schema = StructType( [ @@ -616,6 +676,53 @@ def test_historical_feature_retrieval_with_mapping(spark): assert_dataframe_equal(joined_df, expected_joined_df) +def test_large_historical_feature_retrieval( + spark, large_entity_csv_file, large_feature_csv_file +): + nr_rows = 1000 + start_datetime = datetime(year=2020, month=8, day=31) + expected_join_data = [ + (1000 + i, start_datetime + timedelta(days=i), i * 10) for i in range(nr_rows) + ] + expected_join_data_schema = StructType( + [ + StructField("customer_id", IntegerType()), + StructField("event_timestamp", TimestampType()), + StructField("feature__total_bookings", IntegerType()), + ] + ) + + expected_join_data_df = spark.createDataFrame( + spark.sparkContext.parallelize(expected_join_data), expected_join_data_schema + ) + + retrieval_conf = { + "entity": { + "format": "csv", + "path": f"file://{large_entity_csv_file}", + "options": {"inferSchema": "true", "header": "true"}, + }, + "tables": [ + { + "format": "csv", + "path": f"file://{large_feature_csv_file}", + "name": "feature", + "options": {"inferSchema": "true", "header": "true"}, + }, + ], + "queries": [ + { + "table": "feature", + "features": ["total_bookings"], + "join": ["customer_id"], + } + ], + } + + joined_df = retrieve_historical_features(spark, retrieval_conf) + assert_dataframe_equal(joined_df, expected_join_data_df) + + def test_schema_verification(spark): entity_schema = StructType( [