diff --git a/sdk/python/feast/infra/offline_stores/snowflake.py b/sdk/python/feast/infra/offline_stores/snowflake.py index 4f11b1ac42..dd13ffc96c 100644 --- a/sdk/python/feast/infra/offline_stores/snowflake.py +++ b/sdk/python/feast/infra/offline_stores/snowflake.py @@ -615,12 +615,17 @@ def to_remote_storage(self) -> List[str]: HEADER = TRUE """ cursor = execute_snowflake_statement(self.snowflake_conn, query) + # s3gov schema is used by Snowflake in AWS govcloud regions + # remove gov portion from schema and pass it to online store upload + native_export_path = self.export_path.replace("s3gov://", "s3://") + return self._get_file_names_from_copy_into(cursor, native_export_path) + def _get_file_names_from_copy_into(self, cursor, native_export_path) -> List[str]: file_name_column_index = [ idx for idx, rm in enumerate(cursor.description) if rm.name == "FILE_NAME" ][0] return [ - f"{self.export_path}/{row[file_name_column_index]}" + f"{native_export_path}/{row[file_name_column_index]}" for row in cursor.fetchall() ] diff --git a/sdk/python/tests/unit/infra/offline_stores/test_snowflake.py b/sdk/python/tests/unit/infra/offline_stores/test_snowflake.py new file mode 100644 index 0000000000..afc3ae97ae --- /dev/null +++ b/sdk/python/tests/unit/infra/offline_stores/test_snowflake.py @@ -0,0 +1,57 @@ +import re +from unittest.mock import ANY, MagicMock, patch + +import pytest + +from feast.infra.offline_stores.snowflake import ( + SnowflakeOfflineStoreConfig, + SnowflakeRetrievalJob, +) +from feast.infra.online_stores.sqlite import SqliteOnlineStoreConfig +from feast.repo_config import RepoConfig + + +@pytest.fixture(params=["s3", "s3gov"]) +def retrieval_job(request): + offline_store_config = SnowflakeOfflineStoreConfig( + type="snowflake.offline", + account="snow", + user="snow", + password="snow", + role="snow", + warehouse="snow", + database="FEAST", + schema="OFFLINE", + storage_integration_name="FEAST_S3", + blob_export_location=f"{request.param}://feast-snowflake-offload/export", + ) + retrieval_job = SnowflakeRetrievalJob( + query="SELECT * FROM snowflake", + snowflake_conn=MagicMock(), + config=RepoConfig( + registry="s3://ml-test/repo/registry.db", + project="test", + provider="snowflake.offline", + online_store=SqliteOnlineStoreConfig(type="sqlite"), + offline_store=offline_store_config, + ), + full_feature_names=True, + on_demand_feature_views=[], + ) + return retrieval_job + + +def test_to_remote_storage(retrieval_job): + stored_files = ["just a path", "maybe another"] + with patch.object( + retrieval_job, "to_snowflake", return_value=None + ) as mock_to_snowflake, patch.object( + retrieval_job, "_get_file_names_from_copy_into", return_value=stored_files + ) as mock_get_file_names_from_copy: + assert ( + retrieval_job.to_remote_storage() == stored_files + ), "should return the list of files" + mock_to_snowflake.assert_called_once() + mock_get_file_names_from_copy.assert_called_once_with(ANY, ANY) + native_path = mock_get_file_names_from_copy.call_args[0][1] + assert re.match("^s3://.*", native_path), "path should be s3://*"