diff --git a/sdk/python/feast/infra/offline_stores/bigquery.py b/sdk/python/feast/infra/offline_stores/bigquery.py index f095caef9b..259a3af7d9 100644 --- a/sdk/python/feast/infra/offline_stores/bigquery.py +++ b/sdk/python/feast/infra/offline_stores/bigquery.py @@ -4,6 +4,7 @@ from datetime import date, datetime, timedelta from pathlib import Path from typing import ( + Any, Callable, ContextManager, Dict, @@ -303,6 +304,60 @@ def write_logged_features( job_config=job_config, ) + @staticmethod + def offline_write_batch( + config: RepoConfig, + feature_view: FeatureView, + table: pyarrow.Table, + progress: Optional[Callable[[int], Any]], + ): + if not feature_view.batch_source: + raise ValueError( + "feature view does not have a batch source to persist offline data" + ) + if not isinstance(config.offline_store, BigQueryOfflineStoreConfig): + raise ValueError( + f"offline store config is of type {type(config.offline_store)} when bigquery type required" + ) + if not isinstance(feature_view.batch_source, BigQuerySource): + raise ValueError( + f"feature view batch source is {type(feature_view.batch_source)} not bigquery source" + ) + + pa_schema, column_names = offline_utils.get_pyarrow_schema_from_batch_source( + config, feature_view.batch_source + ) + if column_names != table.column_names: + raise ValueError( + f"The input pyarrow table has schema {pa_schema} with the incorrect columns {column_names}. " + f"The columns are expected to be (in this order): {column_names}." + ) + + if table.schema != pa_schema: + table = table.cast(pa_schema) + + client = _get_bigquery_client( + project=config.offline_store.project_id, + location=config.offline_store.location, + ) + + job_config = bigquery.LoadJobConfig( + source_format=bigquery.SourceFormat.PARQUET, + schema=arrow_schema_to_bq_schema(pa_schema), + write_disposition="WRITE_APPEND", # Default but included for clarity + ) + + with tempfile.TemporaryFile() as parquet_temp_file: + pyarrow.parquet.write_table(table=table, where=parquet_temp_file) + + parquet_temp_file.seek(0) + + client.load_table_from_file( + file_obj=parquet_temp_file, + destination=feature_view.batch_source.table, + job_config=job_config, + ) + class BigQueryRetrievalJob(RetrievalJob): def __init__( diff --git a/sdk/python/feast/infra/offline_stores/file.py b/sdk/python/feast/infra/offline_stores/file.py index 194c233f53..75968146de 100644 --- a/sdk/python/feast/infra/offline_stores/file.py +++ b/sdk/python/feast/infra/offline_stores/file.py @@ -27,6 +27,7 @@ ) from feast.infra.offline_stores.offline_utils import ( DEFAULT_ENTITY_DF_EVENT_TIMESTAMP_COL, + get_pyarrow_schema_from_batch_source, ) from feast.infra.provider import ( _get_requested_feature_views_to_features_dict, @@ -408,7 +409,7 @@ def write_logged_features( def offline_write_batch( config: RepoConfig, feature_view: FeatureView, - data: pyarrow.Table, + table: pyarrow.Table, progress: Optional[Callable[[int], Any]], ): if not feature_view.batch_source: @@ -423,20 +424,27 @@ def offline_write_batch( raise ValueError( f"feature view batch source is {type(feature_view.batch_source)} not file source" ) + + pa_schema, column_names = get_pyarrow_schema_from_batch_source( + config, feature_view.batch_source + ) + if column_names != table.column_names: + raise ValueError( + f"The input pyarrow table has schema {pa_schema} with the incorrect columns {column_names}. " + f"The columns are expected to be (in this order): {column_names}." + ) + file_options = feature_view.batch_source.file_options filesystem, path = FileSource.create_filesystem_and_path( file_options.uri, file_options.s3_endpoint_override ) - prev_table = pyarrow.parquet.read_table(path, memory_map=True) - if prev_table.column_names != data.column_names: - raise ValueError( - f"Input dataframe has incorrect schema or wrong order, expected columns are: {prev_table.column_names}" - ) - if data.schema != prev_table.schema: - data = data.cast(prev_table.schema) - new_table = pyarrow.concat_tables([data, prev_table]) - writer = pyarrow.parquet.ParquetWriter(path, data.schema, filesystem=filesystem) + if table.schema != prev_table.schema: + table = table.cast(prev_table.schema) + new_table = pyarrow.concat_tables([table, prev_table]) + writer = pyarrow.parquet.ParquetWriter( + path, table.schema, filesystem=filesystem + ) writer.write_table(new_table) writer.close() diff --git a/sdk/python/feast/infra/offline_stores/offline_store.py b/sdk/python/feast/infra/offline_stores/offline_store.py index cd807764ba..439911fe2a 100644 --- a/sdk/python/feast/infra/offline_stores/offline_store.py +++ b/sdk/python/feast/infra/offline_stores/offline_store.py @@ -275,7 +275,7 @@ def write_logged_features( def offline_write_batch( config: RepoConfig, feature_view: FeatureView, - data: pyarrow.Table, + table: pyarrow.Table, progress: Optional[Callable[[int], Any]], ): """ @@ -286,8 +286,8 @@ def offline_write_batch( Args: config: Repo configuration object - table: FeatureView to write the data to. - data: pyarrow table containing feature data and timestamp column for historical feature retrieval + feature_view: FeatureView to write the data to. + table: pyarrow table containing feature data and timestamp column for historical feature retrieval progress: Optional function to be called once every mini-batch of rows is written to the online store. Can be used to display progress. """ diff --git a/sdk/python/feast/infra/offline_stores/offline_utils.py b/sdk/python/feast/infra/offline_stores/offline_utils.py index 893180f19f..abe8d4e4e5 100644 --- a/sdk/python/feast/infra/offline_stores/offline_utils.py +++ b/sdk/python/feast/infra/offline_stores/offline_utils.py @@ -5,9 +5,11 @@ import numpy as np import pandas as pd +import pyarrow as pa from jinja2 import BaseLoader, Environment from pandas import Timestamp +from feast.data_source import DataSource from feast.errors import ( EntityTimestampInferenceException, FeastEntityDFMissingColumnsError, @@ -17,6 +19,8 @@ from feast.infra.offline_stores.offline_store import OfflineStore from feast.infra.provider import _get_requested_feature_views_to_features_dict from feast.registry import BaseRegistry +from feast.repo_config import RepoConfig +from feast.type_map import feast_value_type_to_pa from feast.utils import to_naive_utc DEFAULT_ENTITY_DF_EVENT_TIMESTAMP_COL = "event_timestamp" @@ -217,3 +221,25 @@ def get_offline_store_from_config(offline_store_config: Any) -> OfflineStore: class_name = qualified_name.replace("Config", "") offline_store_class = import_class(module_name, class_name, "OfflineStore") return offline_store_class() + + +def get_pyarrow_schema_from_batch_source( + config: RepoConfig, batch_source: DataSource +) -> Tuple[pa.Schema, List[str]]: + """Returns the pyarrow schema and column names for the given batch source.""" + column_names_and_types = batch_source.get_table_column_names_and_types(config) + + pa_schema = [] + column_names = [] + for column_name, column_type in column_names_and_types: + pa_schema.append( + ( + column_name, + feast_value_type_to_pa( + batch_source.source_datatype_to_feast_value_type()(column_type) + ), + ) + ) + column_names.append(column_name) + + return pa.schema(pa_schema), column_names diff --git a/sdk/python/feast/infra/offline_stores/redshift.py b/sdk/python/feast/infra/offline_stores/redshift.py index 943bac502c..8667989268 100644 --- a/sdk/python/feast/infra/offline_stores/redshift.py +++ b/sdk/python/feast/infra/offline_stores/redshift.py @@ -42,7 +42,6 @@ from feast.registry import BaseRegistry from feast.repo_config import FeastConfigBaseModel, RepoConfig from feast.saved_dataset import SavedDatasetStorage -from feast.type_map import feast_value_type_to_pa, redshift_to_feast_value_type from feast.usage import log_exceptions_and_usage @@ -318,33 +317,23 @@ def offline_write_batch( raise ValueError( f"feature view batch source is {type(feature_view.batch_source)} not redshift source" ) - redshift_options = feature_view.batch_source.redshift_options - redshift_client = aws_utils.get_redshift_data_client( - config.offline_store.region - ) - column_name_to_type = feature_view.batch_source.get_table_column_names_and_types( - config + pa_schema, column_names = offline_utils.get_pyarrow_schema_from_batch_source( + config, feature_view.batch_source ) - pa_schema_list = [] - column_names = [] - for column_name, redshift_type in column_name_to_type: - pa_schema_list.append( - ( - column_name, - feast_value_type_to_pa(redshift_to_feast_value_type(redshift_type)), - ) - ) - column_names.append(column_name) - pa_schema = pa.schema(pa_schema_list) if column_names != table.column_names: raise ValueError( - f"Input dataframe has incorrect schema or wrong order, expected columns are: {column_names}" + f"The input pyarrow table has schema {pa_schema} with the incorrect columns {column_names}. " + f"The columns are expected to be (in this order): {column_names}." ) if table.schema != pa_schema: table = table.cast(pa_schema) + redshift_options = feature_view.batch_source.redshift_options + redshift_client = aws_utils.get_redshift_data_client( + config.offline_store.region + ) s3_resource = aws_utils.get_s3_resource(config.offline_store.region) aws_utils.upload_arrow_table_to_redshift( diff --git a/sdk/python/feast/infra/offline_stores/snowflake.py b/sdk/python/feast/infra/offline_stores/snowflake.py index 73c785eecf..ec06d8dce1 100644 --- a/sdk/python/feast/infra/offline_stores/snowflake.py +++ b/sdk/python/feast/infra/offline_stores/snowflake.py @@ -3,6 +3,7 @@ from datetime import datetime from pathlib import Path from typing import ( + Any, Callable, ContextManager, Dict, @@ -306,6 +307,47 @@ def write_logged_features( auto_create_table=True, ) + @staticmethod + def offline_write_batch( + config: RepoConfig, + feature_view: FeatureView, + table: pyarrow.Table, + progress: Optional[Callable[[int], Any]], + ): + if not feature_view.batch_source: + raise ValueError( + "feature view does not have a batch source to persist offline data" + ) + if not isinstance(config.offline_store, SnowflakeOfflineStoreConfig): + raise ValueError( + f"offline store config is of type {type(config.offline_store)} when snowflake type required" + ) + if not isinstance(feature_view.batch_source, SnowflakeSource): + raise ValueError( + f"feature view batch source is {type(feature_view.batch_source)} not snowflake source" + ) + + pa_schema, column_names = offline_utils.get_pyarrow_schema_from_batch_source( + config, feature_view.batch_source + ) + if column_names != table.column_names: + raise ValueError( + f"The input pyarrow table has schema {pa_schema} with the incorrect columns {column_names}. " + f"The columns are expected to be (in this order): {column_names}." + ) + + if table.schema != pa_schema: + table = table.cast(pa_schema) + + snowflake_conn = get_snowflake_conn(config.offline_store) + + write_pandas( + snowflake_conn, + table.to_pandas(), + table_name=feature_view.batch_source.table, + auto_create_table=True, + ) + class SnowflakeRetrievalJob(RetrievalJob): def __init__( diff --git a/sdk/python/tests/integration/feature_repos/repo_configuration.py b/sdk/python/tests/integration/feature_repos/repo_configuration.py index f4d5defcad..6f40d3171b 100644 --- a/sdk/python/tests/integration/feature_repos/repo_configuration.py +++ b/sdk/python/tests/integration/feature_repos/repo_configuration.py @@ -76,7 +76,7 @@ OFFLINE_STORE_TO_PROVIDER_CONFIG: Dict[str, DataSourceCreator] = { "file": ("local", FileDataSourceCreator), - "gcp": ("gcp", BigQueryDataSourceCreator), + "bigquery": ("gcp", BigQueryDataSourceCreator), "redshift": ("aws", RedshiftDataSourceCreator), "snowflake": ("aws", RedshiftDataSourceCreator), } diff --git a/sdk/python/tests/integration/offline_store/test_offline_write.py b/sdk/python/tests/integration/offline_store/test_offline_write.py index 997299c11b..30ead98389 100644 --- a/sdk/python/tests/integration/offline_store/test_offline_write.py +++ b/sdk/python/tests/integration/offline_store/test_offline_write.py @@ -109,7 +109,7 @@ def test_writing_incorrect_schema_fails(environment, universal_data_sources): @pytest.mark.integration -@pytest.mark.universal_offline_stores(only=["file", "redshift"]) +@pytest.mark.universal_offline_stores @pytest.mark.universal_online_stores(only=["sqlite"]) def test_writing_consecutively_to_offline_store(environment, universal_data_sources): store = environment.feature_store diff --git a/sdk/python/tests/integration/offline_store/test_push_offline_retrieval.py b/sdk/python/tests/integration/offline_store/test_push_offline_retrieval.py index b2f91f442e..5cea8a36ef 100644 --- a/sdk/python/tests/integration/offline_store/test_push_offline_retrieval.py +++ b/sdk/python/tests/integration/offline_store/test_push_offline_retrieval.py @@ -16,7 +16,7 @@ @pytest.mark.integration -@pytest.mark.universal_offline_stores(only=["file", "redshift"]) +@pytest.mark.universal_offline_stores @pytest.mark.universal_online_stores(only=["sqlite"]) def test_push_features_and_read_from_offline_store(environment, universal_data_sources): store = environment.feature_store