diff --git a/sdk/python/feast/feature_store.py b/sdk/python/feast/feature_store.py index 25fb037f87..9c2ea8a276 100644 --- a/sdk/python/feast/feature_store.py +++ b/sdk/python/feast/feature_store.py @@ -1423,6 +1423,8 @@ def _write_to_offline_store( feature_view = self.get_feature_view( feature_view_name, allow_registry_cache=allow_registry_cache ) + df.reset_index(drop=True) + table = pa.Table.from_pandas(df) provider = self._get_provider() provider.ingest_df_to_offline_store(feature_view, table) diff --git a/sdk/python/feast/infra/offline_stores/redshift.py b/sdk/python/feast/infra/offline_stores/redshift.py index a5483e8140..943bac502c 100644 --- a/sdk/python/feast/infra/offline_stores/redshift.py +++ b/sdk/python/feast/infra/offline_stores/redshift.py @@ -3,6 +3,7 @@ from datetime import datetime from pathlib import Path from typing import ( + Any, Callable, ContextManager, Dict, @@ -41,6 +42,7 @@ from feast.registry import BaseRegistry from feast.repo_config import FeastConfigBaseModel, RepoConfig from feast.saved_dataset import SavedDatasetStorage +from feast.type_map import feast_value_type_to_pa, redshift_to_feast_value_type from feast.usage import log_exceptions_and_usage @@ -297,6 +299,69 @@ def write_logged_features( fail_if_exists=False, ) + @staticmethod + def offline_write_batch( + config: RepoConfig, + feature_view: FeatureView, + table: pyarrow.Table, + progress: Optional[Callable[[int], Any]], + ): + if not feature_view.batch_source: + raise ValueError( + "feature view does not have a batch source to persist offline data" + ) + if not isinstance(config.offline_store, RedshiftOfflineStoreConfig): + raise ValueError( + f"offline store config is of type {type(config.offline_store)} when redshift type required" + ) + if not isinstance(feature_view.batch_source, RedshiftSource): + raise ValueError( + f"feature view batch source is {type(feature_view.batch_source)} not redshift source" + ) + redshift_options = feature_view.batch_source.redshift_options + redshift_client = aws_utils.get_redshift_data_client( + config.offline_store.region + ) + + column_name_to_type = feature_view.batch_source.get_table_column_names_and_types( + config + ) + pa_schema_list = [] + column_names = [] + for column_name, redshift_type in column_name_to_type: + pa_schema_list.append( + ( + column_name, + feast_value_type_to_pa(redshift_to_feast_value_type(redshift_type)), + ) + ) + column_names.append(column_name) + pa_schema = pa.schema(pa_schema_list) + if column_names != table.column_names: + raise ValueError( + f"Input dataframe has incorrect schema or wrong order, expected columns are: {column_names}" + ) + + if table.schema != pa_schema: + table = table.cast(pa_schema) + + s3_resource = aws_utils.get_s3_resource(config.offline_store.region) + + aws_utils.upload_arrow_table_to_redshift( + table=table, + redshift_data_client=redshift_client, + cluster_id=config.offline_store.cluster_id, + database=redshift_options.database + or config.offline_store.database, # Users can define database in the source if needed but it's not required. + user=config.offline_store.user, + s3_resource=s3_resource, + s3_path=f"{config.offline_store.s3_staging_location}/push/{uuid.uuid4()}.parquet", + iam_role=config.offline_store.iam_role, + table_name=redshift_options.table, + schema=pa_schema, + fail_if_exists=False, + ) + class RedshiftRetrievalJob(RetrievalJob): def __init__( diff --git a/sdk/python/feast/infra/passthrough_provider.py b/sdk/python/feast/infra/passthrough_provider.py index e023afe782..9d18e6b249 100644 --- a/sdk/python/feast/infra/passthrough_provider.py +++ b/sdk/python/feast/infra/passthrough_provider.py @@ -103,14 +103,14 @@ def online_write_batch( def offline_write_batch( self, config: RepoConfig, - table: FeatureView, + feature_view: FeatureView, data: pa.Table, progress: Optional[Callable[[int], Any]], ) -> None: set_usage_attribute("provider", self.__class__.__name__) if self.offline_store: - self.offline_store.offline_write_batch(config, table, data, progress) + self.offline_store.offline_write_batch(config, feature_view, data, progress) @log_exceptions_and_usage(sampler=RatioSampler(ratio=0.001)) def online_read( diff --git a/sdk/python/feast/infra/utils/aws_utils.py b/sdk/python/feast/infra/utils/aws_utils.py index bb75160a87..7badda9846 100644 --- a/sdk/python/feast/infra/utils/aws_utils.py +++ b/sdk/python/feast/infra/utils/aws_utils.py @@ -235,6 +235,15 @@ def upload_df_to_redshift( ) +def delete_redshift_table( + redshift_data_client, cluster_id: str, database: str, user: str, table_name: str, +): + drop_query = f"DROP {table_name} IF EXISTS" + execute_redshift_statement( + redshift_data_client, cluster_id, database, user, drop_query, + ) + + def upload_arrow_table_to_redshift( table: Union[pyarrow.Table, Path], redshift_data_client, @@ -320,7 +329,7 @@ def upload_arrow_table_to_redshift( cluster_id, database, user, - f"{create_query}; {copy_query}", + f"{create_query}; {copy_query};", ) finally: # Clean up S3 temporary data @@ -371,6 +380,53 @@ def temporarily_upload_df_to_redshift( ) +@contextlib.contextmanager +def temporarily_upload_arrow_table_to_redshift( + table: Union[pyarrow.Table, Path], + redshift_data_client, + cluster_id: str, + database: str, + user: str, + s3_resource, + iam_role: str, + s3_path: str, + table_name: str, + schema: Optional[pyarrow.Schema] = None, + fail_if_exists: bool = True, +) -> Iterator[None]: + """Uploads a Arrow Table to Redshift as a new table with cleanup logic. + + This is essentially the same as upload_arrow_table_to_redshift (check out its docstring for full details), + but unlike it this method is a generator and should be used with `with` block. For example: + + >>> with temporarily_upload_arrow_table_to_redshift(...): # doctest: +SKIP + >>> # Use `table_name` table in Redshift here + >>> # `table_name` will not exist at this point, since it's cleaned up by the `with` block + + """ + # Upload the dataframe to Redshift + upload_arrow_table_to_redshift( + table, + redshift_data_client, + cluster_id, + database, + user, + s3_resource, + s3_path, + iam_role, + table_name, + schema, + fail_if_exists, + ) + + yield + + # Clean up the uploaded Redshift table + execute_redshift_statement( + redshift_data_client, cluster_id, database, user, f"DROP TABLE {table_name}", + ) + + def download_s3_directory(s3_resource, bucket: str, key: str, local_dir: str): """Download the S3 directory to a local disk""" bucket_obj = s3_resource.Bucket(bucket) diff --git a/sdk/python/tests/conftest.py b/sdk/python/tests/conftest.py index 671acb3b92..bf69a85fa3 100644 --- a/sdk/python/tests/conftest.py +++ b/sdk/python/tests/conftest.py @@ -33,6 +33,7 @@ from tests.integration.feature_repos.repo_configuration import ( AVAILABLE_OFFLINE_STORES, AVAILABLE_ONLINE_STORES, + OFFLINE_STORE_TO_PROVIDER_CONFIG, Environment, TestData, construct_test_environment, @@ -196,16 +197,24 @@ def pytest_generate_tests(metafunc: pytest.Metafunc): """ if "environment" in metafunc.fixturenames: markers = {m.name: m for m in metafunc.definition.own_markers} - + offline_stores = None if "universal_offline_stores" in markers: - offline_stores = AVAILABLE_OFFLINE_STORES + # Offline stores can be explicitly requested + if "only" in markers["universal_offline_stores"].kwargs: + offline_stores = [ + OFFLINE_STORE_TO_PROVIDER_CONFIG.get(store_name) + for store_name in markers["universal_offline_stores"].kwargs["only"] + if store_name in OFFLINE_STORE_TO_PROVIDER_CONFIG + ] + else: + offline_stores = AVAILABLE_OFFLINE_STORES else: # default offline store for testing online store dimension offline_stores = [("local", FileDataSourceCreator)] online_stores = None if "universal_online_stores" in markers: - # Online stores are explicitly requested + # Online stores can be explicitly requested if "only" in markers["universal_online_stores"].kwargs: online_stores = [ AVAILABLE_ONLINE_STORES.get(store_name) @@ -240,40 +249,44 @@ def pytest_generate_tests(metafunc: pytest.Metafunc): extra_dimensions.append({"go_feature_retrieval": True}) configs = [] - for provider, offline_store_creator in offline_stores: - for online_store, online_store_creator in online_stores: - for dim in extra_dimensions: - config = { - "provider": provider, - "offline_store_creator": offline_store_creator, - "online_store": online_store, - "online_store_creator": online_store_creator, - **dim, - } - # temporary Go works only with redis - if config.get("go_feature_retrieval") and ( - not isinstance(online_store, dict) - or online_store["type"] != "redis" - ): - continue - - # aws lambda works only with dynamo - if ( - config.get("python_feature_server") - and config.get("provider") == "aws" - and ( + if offline_stores: + for provider, offline_store_creator in offline_stores: + for online_store, online_store_creator in online_stores: + for dim in extra_dimensions: + config = { + "provider": provider, + "offline_store_creator": offline_store_creator, + "online_store": online_store, + "online_store_creator": online_store_creator, + **dim, + } + # temporary Go works only with redis + if config.get("go_feature_retrieval") and ( not isinstance(online_store, dict) - or online_store["type"] != "dynamodb" - ) - ): - continue - - c = IntegrationTestRepoConfig(**config) - - if c not in _config_cache: - _config_cache[c] = c - - configs.append(_config_cache[c]) + or online_store["type"] != "redis" + ): + continue + + # aws lambda works only with dynamo + if ( + config.get("python_feature_server") + and config.get("provider") == "aws" + and ( + not isinstance(online_store, dict) + or online_store["type"] != "dynamodb" + ) + ): + continue + + c = IntegrationTestRepoConfig(**config) + + if c not in _config_cache: + _config_cache[c] = c + + configs.append(_config_cache[c]) + else: + # No offline stores requested -> setting the default or first available + offline_stores = [("local", FileDataSourceCreator)] metafunc.parametrize( "environment", configs, indirect=True, ids=[str(c) for c in configs] diff --git a/sdk/python/tests/integration/feature_repos/repo_configuration.py b/sdk/python/tests/integration/feature_repos/repo_configuration.py index 5a48115dbe..f4d5defcad 100644 --- a/sdk/python/tests/integration/feature_repos/repo_configuration.py +++ b/sdk/python/tests/integration/feature_repos/repo_configuration.py @@ -74,6 +74,13 @@ "connection_string": "127.0.0.1:6001,127.0.0.1:6002,127.0.0.1:6003", } +OFFLINE_STORE_TO_PROVIDER_CONFIG: Dict[str, DataSourceCreator] = { + "file": ("local", FileDataSourceCreator), + "gcp": ("gcp", BigQueryDataSourceCreator), + "redshift": ("aws", RedshiftDataSourceCreator), + "snowflake": ("aws", RedshiftDataSourceCreator), +} + AVAILABLE_OFFLINE_STORES: List[Tuple[str, Type[DataSourceCreator]]] = [ ("local", FileDataSourceCreator), ] diff --git a/sdk/python/tests/integration/offline_store/test_offline_write.py b/sdk/python/tests/integration/offline_store/test_offline_write.py index 41f6ea89fa..5e7a242513 100644 --- a/sdk/python/tests/integration/offline_store/test_offline_write.py +++ b/sdk/python/tests/integration/offline_store/test_offline_write.py @@ -11,8 +11,9 @@ @pytest.mark.integration -@pytest.mark.universal_online_stores -def test_writing_incorrect_order_fails(environment, universal_data_sources): +@pytest.mark.universal_offline_stores(only=["file", "redshift"]) +@pytest.mark.universal_online_stores(only=["sqlite"]) +def test_writing_columns_in_incorrect_order_fails(environment, universal_data_sources): # TODO(kevjumba) handle incorrect order later, for now schema must be in the order that the filesource is in store = environment.feature_store _, _, data_sources = universal_data_sources @@ -59,7 +60,8 @@ def test_writing_incorrect_order_fails(environment, universal_data_sources): @pytest.mark.integration -@pytest.mark.universal_online_stores +@pytest.mark.universal_offline_stores(only=["file", "redshift"]) +@pytest.mark.universal_online_stores(only=["sqlite"]) def test_writing_incorrect_schema_fails(environment, universal_data_sources): # TODO(kevjumba) handle incorrect order later, for now schema must be in the order that the filesource is in store = environment.feature_store @@ -107,7 +109,8 @@ def test_writing_incorrect_schema_fails(environment, universal_data_sources): @pytest.mark.integration -@pytest.mark.universal_online_stores +@pytest.mark.universal_offline_stores(only=["file", "redshift"]) +@pytest.mark.universal_online_stores(only=["sqlite"]) def test_writing_consecutively_to_offline_store(environment, universal_data_sources): store = environment.feature_store _, _, data_sources = universal_data_sources