diff --git a/sdk/python/feast/pyspark/historical_feature_retrieval_job.py b/sdk/python/feast/pyspark/historical_feature_retrieval_job.py index 26be6888f4..437a4828ea 100644 --- a/sdk/python/feast/pyspark/historical_feature_retrieval_job.py +++ b/sdk/python/feast/pyspark/historical_feature_retrieval_job.py @@ -8,6 +8,9 @@ from pyspark.sql import DataFrame, SparkSession, Window from pyspark.sql.functions import col, expr, monotonically_increasing_id, row_number +EVENT_TIMESTAMP_ALIAS = "event_timestamp" +CREATED_TIMESTAMP_ALIAS = "created_timestamp" + class Source(abc.ABC): """ @@ -287,8 +290,6 @@ def as_of_join( entity_event_timestamp_column: str, feature_table_df: DataFrame, feature_table: FeatureTable, - feature_event_timestamp_column: str, - feature_created_timestamp_column: str, ) -> DataFrame: """Perform an as of join between entity and feature table, given a maximum age tolerance. Join conditions: @@ -308,10 +309,6 @@ def as_of_join( feature_table_df (Dataframe): Spark dataframe representing the feature table. feature_table (FeatureTable): Feature table specification, which provide information on how the join should be performed, such as the entity primary keys and max age. - feature_event_timestamp_column (str): Column name in feature_table_df which represents - event timestamp. - feature_created_timestamp_column (str): Column name in feature_table_df which represents - when the feature is created. Returns: DataFrame: Join result, which contains all the original columns from entity_df, as well @@ -338,8 +335,7 @@ def as_of_join( None >>> feature_table_1.name 'table1' - >>> df = as_of_join(entity_df, "event_timestamp", feature_table_1_df, feature_table_1, - "event_timestamp", "created_timestamp") + >>> df = as_of_join(entity_df, "event_timestamp", feature_table_1_df, feature_table_1) >>> df.show() +------+-------------------+---------------+ |entity| event_timestamp|table1__feature| @@ -359,8 +355,7 @@ def as_of_join( 43200 >>> feature_table_2.name 'table2' - >>> df = as_of_join(entity_df, "event_timestamp", feature_table_2_df, feature_table_2, - "event_timestamp", "created_timestamp") + >>> df = as_of_join(entity_df, "event_timestamp", feature_table_2_df, feature_table_2) >>> df.show() +------+-------------------+---------------+ |entity| event_timestamp|table2__feature| @@ -372,10 +367,10 @@ def as_of_join( entity_with_id = entity_df.withColumn("_row_nr", monotonically_increasing_id()) feature_event_timestamp_column_with_prefix = ( - f"{feature_table.name}__{feature_event_timestamp_column}" + f"{feature_table.name}__{EVENT_TIMESTAMP_ALIAS}" ) feature_created_timestamp_column_with_prefix = ( - f"{feature_table.name}__{feature_created_timestamp_column}" + f"{feature_table.name}__{CREATED_TIMESTAMP_ALIAS}" ) projection = [ @@ -432,8 +427,6 @@ def join_entity_to_feature_tables( entity_event_timestamp_column: str, feature_table_dfs: List[DataFrame], feature_tables: List[FeatureTable], - feature_event_timestamp_columns: List[str], - feature_created_timestamp_columns: List[str], ) -> DataFrame: """Perform as of join between entity and multiple feature table. @@ -445,10 +438,6 @@ def join_entity_to_feature_tables( feature_table_dfs (List[Dataframe]): List of Spark dataframes representing the feature tables. feature_tables (List[FeatureTable]): List of feature table specification. The length and ordering of this argument must follow that of feature_table_dfs. - feature_event_timestamp_columns (List[str]): Column names which represent event timestamp for the - feature tables. The length and ordering of this argument must follow that of feature_table_dfs. - feature_created_timestamp_columns (str): Column names which represent when the feature is created. - The length and ordering of this argument must follow that of feature_table_dfs. Returns: DataFrame: Join result, which contains all the original columns from entity_df, as well @@ -496,8 +485,7 @@ def join_entity_to_feature_tables( tables, ) >>> joined_df = join_entity_to_feature_tables(entity_df, "event_timestamp", - [table1_df, table2_df], [table1, table2], - ["event_timestamp"] * 2, ["created_timestamp"] * 2) + [table1_df, table2_df], [table1, table2]) >>> joined_df.show() +------+-------------------+----------------+----------------+ @@ -508,24 +496,9 @@ def join_entity_to_feature_tables( """ joined_df = entity_df - for ( - feature_table_df, - feature_table, - feature_event_timestamp_column, - feature_created_timestamp_column, - ) in zip( - feature_table_dfs, - feature_tables, - feature_event_timestamp_columns, - feature_created_timestamp_columns, - ): + for (feature_table_df, feature_table,) in zip(feature_table_dfs, feature_tables): joined_df = as_of_join( - joined_df, - entity_event_timestamp_column, - feature_table_df, - feature_table, - feature_event_timestamp_column, - feature_created_timestamp_column, + joined_df, entity_event_timestamp_column, feature_table_df, feature_table, ) return joined_df @@ -597,6 +570,11 @@ def _read_and_verify_feature_table_df_from_source( mapped_source_df = _map_column(source_df, source.field_mapping) + if not source.created_timestamp_column: + raise SchemaError( + "Created timestamp column must not be none for feature table." + ) + column_selection = ( feature_table.feature_names + feature_table.entity_names @@ -628,9 +606,11 @@ def _read_and_verify_feature_table_df_from_source( ) return mapped_source_df.select( - feature_table.feature_names - + feature_table.entity_names - + [source.event_timestamp_column, source.created_timestamp_column] + [col(name) for name in feature_table.feature_names + feature_table.entity_names] + + [ + col(source.event_timestamp_column).alias(EVENT_TIMESTAMP_ALIAS), + col(source.created_timestamp_column).alias(CREATED_TIMESTAMP_ALIAS), + ] ) @@ -706,18 +686,6 @@ def retrieve_historical_features( for feature_table, source in zip(feature_tables, feature_tables_sources) ] - feature_event_timestamp_columns = [ - source.event_timestamp_column for source in feature_tables_sources - ] - feature_created_timestamp_columns: List[str] = [] - for source in feature_tables_sources: - if source.created_timestamp_column: - feature_created_timestamp_columns.append(source.created_timestamp_column) - else: - raise SchemaError( - "Created timestamp column must not be none for feature table." - ) - expected_entities = [] for feature_table in feature_tables: expected_entities.extend(feature_table.entities) @@ -747,8 +715,6 @@ def retrieve_historical_features( entity_source.event_timestamp_column, feature_table_dfs, feature_tables, - feature_event_timestamp_columns, - feature_created_timestamp_columns, ) diff --git a/sdk/python/tests/test_as_of_join.py b/sdk/python/tests/test_as_of_join.py index 23ce94e176..31cd150cbf 100644 --- a/sdk/python/tests/test_as_of_join.py +++ b/sdk/python/tests/test_as_of_join.py @@ -229,12 +229,7 @@ def test_join_without_max_age( ) joined_df = as_of_join( - entity_df, - "event_timestamp", - feature_table_df, - feature_table, - "event_timestamp", - "created_timestamp", + entity_df, "event_timestamp", feature_table_df, feature_table, ) expected_joined_schema = StructType( @@ -298,12 +293,7 @@ def test_join_with_max_age( ) joined_df = as_of_join( - entity_df, - "event_timestamp", - feature_table_df, - feature_table, - "event_timestamp", - "created_timestamp", + entity_df, "event_timestamp", feature_table_df, feature_table, ) expected_joined_schema = StructType( @@ -377,12 +367,7 @@ def test_join_with_composite_entity( ) joined_df = as_of_join( - entity_df, - "event_timestamp", - feature_table_df, - feature_table, - "event_timestamp", - "created_timestamp", + entity_df, "event_timestamp", feature_table_df, feature_table, ) expected_joined_schema = StructType( @@ -444,12 +429,7 @@ def test_select_subset_of_columns_as_entity_primary_keys( ) joined_df = as_of_join( - entity_df, - "event_timestamp", - feature_table_df, - feature_table, - "event_timestamp", - "created_timestamp", + entity_df, "event_timestamp", feature_table_df, feature_table, ) expected_joined_schema = StructType( @@ -552,8 +532,6 @@ def test_multiple_join( "event_timestamp", [customer_table_df, driver_table_df], [customer_table, driver_table], - ["event_timestamp"] * 2, - ["created_timestamp"] * 2, ) expected_joined_schema = StructType(