From 4a737ac69899a5228ca42e457304068726906299 Mon Sep 17 00:00:00 2001 From: locnt241 <73770977+ElliotNguyen68@users.noreply.github.com> Date: Thu, 7 Mar 2024 00:20:46 +0700 Subject: [PATCH] feat: Add Entity df in format of a Spark Dataframe instead of just pd.DataFrame or string for SparkOfflineStore (#3988) * remove unused parameter when init sparksource Signed-off-by: tanlocnguyen * feat: add entity df to SparkOfflineStore when get_historical_features Signed-off-by: tanlocnguyen * fix: lint error Signed-off-by: tanlocnguyen --------- Signed-off-by: tanlocnguyen Co-authored-by: tanlocnguyen --- .../contrib/spark_offline_store/spark.py | 27 ++++++++++++------- 1 file changed, 17 insertions(+), 10 deletions(-) diff --git a/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/spark.py b/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/spark.py index c9591b7c3f..b1b1c04c7d 100644 --- a/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/spark.py +++ b/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/spark.py @@ -125,7 +125,7 @@ def get_historical_features( config: RepoConfig, feature_views: List[FeatureView], feature_refs: List[str], - entity_df: Union[pandas.DataFrame, str], + entity_df: Union[pandas.DataFrame, str, pyspark.sql.DataFrame], registry: Registry, project: str, full_feature_names: bool = False, @@ -473,15 +473,16 @@ def _get_entity_df_event_timestamp_range( entity_df_event_timestamp.min().to_pydatetime(), entity_df_event_timestamp.max().to_pydatetime(), ) - elif isinstance(entity_df, str): + elif isinstance(entity_df, str) or isinstance(entity_df, pyspark.sql.DataFrame): # If the entity_df is a string (SQL query), determine range # from table - df = spark_session.sql(entity_df).select(entity_df_event_timestamp_col) - - # Checks if executing entity sql resulted in any data - if df.rdd.isEmpty(): - raise EntitySQLEmptyResults(entity_df) - + if isinstance(entity_df, str): + df = spark_session.sql(entity_df).select(entity_df_event_timestamp_col) + # Checks if executing entity sql resulted in any data + if df.rdd.isEmpty(): + raise EntitySQLEmptyResults(entity_df) + else: + df = entity_df # TODO(kzhang132): need utc conversion here. entity_df_event_timestamp_range = ( @@ -499,8 +500,11 @@ def _get_entity_schema( ) -> Dict[str, np.dtype]: if isinstance(entity_df, pd.DataFrame): return dict(zip(entity_df.columns, entity_df.dtypes)) - elif isinstance(entity_df, str): - entity_spark_df = spark_session.sql(entity_df) + elif isinstance(entity_df, str) or isinstance(entity_df, pyspark.sql.DataFrame): + if isinstance(entity_df, str): + entity_spark_df = spark_session.sql(entity_df) + else: + entity_spark_df = entity_df return dict( zip( entity_spark_df.columns, @@ -526,6 +530,9 @@ def _upload_entity_df( elif isinstance(entity_df, str): spark_session.sql(entity_df).createOrReplaceTempView(table_name) return + elif isinstance(entity_df, pyspark.sql.DataFrame): + entity_df.createOrReplaceTempView(table_name) + return else: raise InvalidEntityType(type(entity_df))