diff --git a/sdk/python/feast/infra/offline_stores/bigquery.py b/sdk/python/feast/infra/offline_stores/bigquery.py index 96c4bd20dd..afb2a35b18 100644 --- a/sdk/python/feast/infra/offline_stores/bigquery.py +++ b/sdk/python/feast/infra/offline_stores/bigquery.py @@ -185,7 +185,7 @@ def full_feature_names(self) -> bool: def on_demand_feature_views(self) -> Optional[List[OnDemandFeatureView]]: return self._on_demand_feature_views - def to_df_internal(self) -> pd.DataFrame: + def _to_df_internal(self) -> pd.DataFrame: # TODO: Ideally only start this job when the user runs "get_historical_features", not when they run to_df() df = self.client.query(self.query).to_dataframe(create_bqstorage_client=True) return df @@ -234,7 +234,7 @@ def to_bigquery( print(f"Done writing to '{job_config.destination}'.") return str(job_config.destination) - def to_arrow(self) -> pyarrow.Table: + def _to_arrow_internal(self) -> pyarrow.Table: return self.client.query(self.query).to_arrow() diff --git a/sdk/python/feast/infra/offline_stores/file.py b/sdk/python/feast/infra/offline_stores/file.py index 327415d343..17c3ebfad6 100644 --- a/sdk/python/feast/infra/offline_stores/file.py +++ b/sdk/python/feast/infra/offline_stores/file.py @@ -51,12 +51,12 @@ def full_feature_names(self) -> bool: def on_demand_feature_views(self) -> Optional[List[OnDemandFeatureView]]: return self._on_demand_feature_views - def to_df_internal(self) -> pd.DataFrame: + def _to_df_internal(self) -> pd.DataFrame: # Only execute the evaluation function to build the final historical retrieval dataframe at the last moment. df = self.evaluation_function() return df - def to_arrow(self): + def _to_arrow_internal(self): # Only execute the evaluation function to build the final historical retrieval dataframe at the last moment. df = self.evaluation_function() return pyarrow.Table.from_pandas(df) diff --git a/sdk/python/feast/infra/offline_stores/offline_store.py b/sdk/python/feast/infra/offline_stores/offline_store.py index 16fc3b1fa0..4fa108d86b 100644 --- a/sdk/python/feast/infra/offline_stores/offline_store.py +++ b/sdk/python/feast/infra/offline_stores/offline_store.py @@ -40,7 +40,7 @@ def on_demand_feature_views(self) -> Optional[List[OnDemandFeatureView]]: def to_df(self) -> pd.DataFrame: """Return dataset as Pandas DataFrame synchronously including on demand transforms""" - features_df = self.to_df_internal() + features_df = self._to_df_internal() if self.on_demand_feature_views is None: return features_df @@ -51,16 +51,27 @@ def to_df(self) -> pd.DataFrame: return features_df @abstractmethod - def to_df_internal(self) -> pd.DataFrame: + def _to_df_internal(self) -> pd.DataFrame: """Return dataset as Pandas DataFrame synchronously""" pass - # TODO(adchia): implement ODFV for to_arrow method @abstractmethod - def to_arrow(self) -> pyarrow.Table: + def _to_arrow_internal(self) -> pyarrow.Table: """Return dataset as pyarrow Table synchronously""" pass + def to_arrow(self) -> pyarrow.Table: + """Return dataset as pyarrow Table synchronously""" + if self.on_demand_feature_views is None: + return self._to_arrow_internal() + + features_df = self._to_df_internal() + for odfv in self.on_demand_feature_views: + features_df = features_df.join( + odfv.get_transformed_features_df(self.full_feature_names, features_df) + ) + return pyarrow.Table.from_pandas(features_df) + class OfflineStore(ABC): """ diff --git a/sdk/python/feast/infra/offline_stores/redshift.py b/sdk/python/feast/infra/offline_stores/redshift.py index 05a47970be..211f23d3e9 100644 --- a/sdk/python/feast/infra/offline_stores/redshift.py +++ b/sdk/python/feast/infra/offline_stores/redshift.py @@ -234,7 +234,7 @@ def full_feature_names(self) -> bool: def on_demand_feature_views(self) -> Optional[List[OnDemandFeatureView]]: return self._on_demand_feature_views - def to_df_internal(self) -> pd.DataFrame: + def _to_df_internal(self) -> pd.DataFrame: with self._query_generator() as query: return aws_utils.unload_redshift_query_to_df( self._redshift_client, @@ -248,7 +248,7 @@ def to_df_internal(self) -> pd.DataFrame: self._drop_columns, ) - def to_arrow(self) -> pa.Table: + def _to_arrow_internal(self) -> pa.Table: with self._query_generator() as query: return aws_utils.unload_redshift_query_to_pa( self._redshift_client, diff --git a/sdk/python/feast/on_demand_feature_view.py b/sdk/python/feast/on_demand_feature_view.py index ae1ad2656c..f40e1ebbbc 100644 --- a/sdk/python/feast/on_demand_feature_view.py +++ b/sdk/python/feast/on_demand_feature_view.py @@ -72,10 +72,12 @@ def to_proto(self) -> OnDemandFeatureViewProto: inputs = {} for feature_ref, input in self.inputs.items(): if type(input) == FeatureView: - inputs[feature_ref] = OnDemandInput(feature_view=input.to_proto()) + fv = cast(FeatureView, input) + inputs[feature_ref] = OnDemandInput(feature_view=fv.to_proto()) else: + request_data_source = cast(RequestDataSource, input) inputs[feature_ref] = OnDemandInput( - request_data_source=input.to_proto() + request_data_source=request_data_source.to_proto() ) spec = OnDemandFeatureViewSpec( diff --git a/sdk/python/tests/integration/offline_store/test_universal_historical_retrieval.py b/sdk/python/tests/integration/offline_store/test_universal_historical_retrieval.py index 4a6dd42b65..da6a68b3ba 100644 --- a/sdk/python/tests/integration/offline_store/test_universal_historical_retrieval.py +++ b/sdk/python/tests/integration/offline_store/test_universal_historical_retrieval.py @@ -346,17 +346,9 @@ def test_historical_features(environment, universal_data_sources, full_feature_n event_timestamp, ) - # on demand features is only plumbed through to to_df for now. table_from_df_entities: pd.DataFrame = job_from_df.to_arrow().to_pandas() - actual_df_from_df_entities_for_table = actual_df_from_df_entities.drop( - columns=["conv_rate_plus_100", "conv_rate_plus_val_to_add"] - ) - assert "conv_rate_plus_100" not in table_from_df_entities.columns - assert "conv_rate_plus_val_to_add" not in table_from_df_entities.columns columns_expected_in_table = expected_df.columns.tolist() - columns_expected_in_table.remove("conv_rate_plus_100") - columns_expected_in_table.remove("conv_rate_plus_val_to_add") table_from_df_entities = ( table_from_df_entities[columns_expected_in_table] @@ -364,7 +356,7 @@ def test_historical_features(environment, universal_data_sources, full_feature_n .drop_duplicates() .reset_index(drop=True) ) - assert_frame_equal(actual_df_from_df_entities_for_table, table_from_df_entities) + assert_frame_equal(actual_df_from_df_entities, table_from_df_entities) # If request data is missing that's needed for on demand transform, throw an error with pytest.raises(RequestDataNotFoundInEntityDfException): diff --git a/sdk/python/tests/integration/registration/test_universal_types.py b/sdk/python/tests/integration/registration/test_universal_types.py index f6fd942ef9..dc175ab7d9 100644 --- a/sdk/python/tests/integration/registration/test_universal_types.py +++ b/sdk/python/tests/integration/registration/test_universal_types.py @@ -288,10 +288,12 @@ def assert_expected_arrow_types( ] if feature_is_list: if provider == "gcp": - assert ( - str(historical_features_arrow.schema.field_by_name("value").type) - == f"struct> not null>" - ) + assert str( + historical_features_arrow.schema.field_by_name("value").type + ) in [ + f"struct> not null>", + f"struct>>", + ] else: assert ( str(historical_features_arrow.schema.field_by_name("value").type)