diff --git a/protos/feast/core/DataSource.proto b/protos/feast/core/DataSource.proto index 8fe84274a1..07b2f978b8 100644 --- a/protos/feast/core/DataSource.proto +++ b/protos/feast/core/DataSource.proto @@ -161,6 +161,9 @@ message DataSource { // Snowflake schema name string database = 4; + + // Snowflake warehouse name + string warehouse = 5; } // Defines configuration for custom third-party data sources. diff --git a/sdk/python/feast/infra/offline_stores/snowflake.py b/sdk/python/feast/infra/offline_stores/snowflake.py index cc346251a8..968055fcee 100644 --- a/sdk/python/feast/infra/offline_stores/snowflake.py +++ b/sdk/python/feast/infra/offline_stores/snowflake.py @@ -128,6 +128,9 @@ def pull_latest_from_table_or_query( + '"' ) + if data_source.snowflake_options.warehouse: + config.offline_store.warehouse = data_source.snowflake_options.warehouse + snowflake_conn = get_snowflake_conn(config.offline_store) query = f""" @@ -173,6 +176,9 @@ def pull_all_from_table_or_query( + '"' ) + if data_source.snowflake_options.warehouse: + config.offline_store.warehouse = data_source.snowflake_options.warehouse + snowflake_conn = get_snowflake_conn(config.offline_store) start_date = start_date.astimezone(tz=utc) diff --git a/sdk/python/feast/infra/offline_stores/snowflake_source.py b/sdk/python/feast/infra/offline_stores/snowflake_source.py index 40868ef64d..f094d2b329 100644 --- a/sdk/python/feast/infra/offline_stores/snowflake_source.py +++ b/sdk/python/feast/infra/offline_stores/snowflake_source.py @@ -16,6 +16,7 @@ class SnowflakeSource(DataSource): def __init__( self, database: Optional[str] = None, + warehouse: Optional[str] = None, schema: Optional[str] = None, table: Optional[str] = None, query: Optional[str] = None, @@ -33,6 +34,7 @@ def __init__( Args: database (optional): Snowflake database where the features are stored. + warehouse (optional): Snowflake warehouse where the database is stored. schema (optional): Snowflake schema in which the table is located. table (optional): Snowflake table where the features are stored. event_timestamp_column (optional): Event timestamp column used for point in @@ -55,7 +57,11 @@ def __init__( _schema = "PUBLIC" if (database and table and not schema) else schema self.snowflake_options = SnowflakeOptions( - database=database, schema=_schema, table=table, query=query + database=database, + schema=_schema, + table=table, + query=query, + warehouse=warehouse, ) # If no name, use the table as the default name @@ -99,6 +105,7 @@ def from_proto(data_source: DataSourceProto): database=data_source.snowflake_options.database, schema=data_source.snowflake_options.schema, table=data_source.snowflake_options.table, + warehouse=data_source.snowflake_options.warehouse, event_timestamp_column=data_source.event_timestamp_column, created_timestamp_column=data_source.created_timestamp_column, date_partition_column=data_source.date_partition_column, @@ -124,6 +131,7 @@ def __eq__(self, other): and self.snowflake_options.schema == other.snowflake_options.schema and self.snowflake_options.table == other.snowflake_options.table and self.snowflake_options.query == other.snowflake_options.query + and self.snowflake_options.warehouse == other.snowflake_options.warehouse and self.event_timestamp_column == other.event_timestamp_column and self.created_timestamp_column == other.created_timestamp_column and self.field_mapping == other.field_mapping @@ -152,6 +160,11 @@ def query(self): """Returns the snowflake options of this snowflake source.""" return self.snowflake_options.query + @property + def warehouse(self): + """Returns the warehouse of this snowflake source.""" + return self.snowflake_options.warehouse + def to_proto(self) -> DataSourceProto: """ Converts a SnowflakeSource object to its protobuf representation. @@ -239,11 +252,13 @@ def __init__( schema: Optional[str], table: Optional[str], query: Optional[str], + warehouse: Optional[str], ): self._database = database self._schema = schema self._table = table self._query = query + self._warehouse = warehouse @property def query(self): @@ -285,6 +300,16 @@ def table(self, table): """Sets the table ref of this snowflake table.""" self._table = table + @property + def warehouse(self): + """Returns the warehouse name of this snowflake table.""" + return self._warehouse + + @warehouse.setter + def warehouse(self, warehouse): + """Sets the warehouse name of this snowflake table.""" + self._warehouse = warehouse + @classmethod def from_proto(cls, snowflake_options_proto: DataSourceProto.SnowflakeOptions): """ @@ -301,6 +326,7 @@ def from_proto(cls, snowflake_options_proto: DataSourceProto.SnowflakeOptions): schema=snowflake_options_proto.schema, table=snowflake_options_proto.table, query=snowflake_options_proto.query, + warehouse=snowflake_options_proto.warehouse, ) return snowflake_options @@ -317,6 +343,7 @@ def to_proto(self) -> DataSourceProto.SnowflakeOptions: schema=self.schema, table=self.table, query=self.query, + warehouse=self.warehouse, ) return snowflake_options_proto @@ -329,7 +356,7 @@ class SavedDatasetSnowflakeStorage(SavedDatasetStorage): def __init__(self, table_ref: str): self.snowflake_options = SnowflakeOptions( - database=None, schema=None, table=table_ref, query=None + database=None, schema=None, table=table_ref, query=None, warehouse=None ) @staticmethod diff --git a/sdk/python/feast/templates/snowflake/bootstrap.py b/sdk/python/feast/templates/snowflake/bootstrap.py index 3712651a5d..194ba08c08 100644 --- a/sdk/python/feast/templates/snowflake/bootstrap.py +++ b/sdk/python/feast/templates/snowflake/bootstrap.py @@ -68,7 +68,7 @@ def bootstrap(): repo_path = pathlib.Path(__file__).parent.absolute() config_file = repo_path / "feature_store.yaml" - + driver_file = repo_path / "driver_repo.py" replace_str_in_file( config_file, "SNOWFLAKE_DEPLOYMENT_URL", snowflake_deployment_url ) @@ -78,6 +78,8 @@ def bootstrap(): replace_str_in_file(config_file, "SNOWFLAKE_WAREHOUSE", snowflake_warehouse) replace_str_in_file(config_file, "SNOWFLAKE_DATABASE", snowflake_database) + replace_str_in_file(driver_file, "SNOWFLAKE_WAREHOUSE", snowflake_warehouse) + def replace_str_in_file(file_path, match_str, sub_str): with open(file_path, "r") as f: diff --git a/sdk/python/feast/templates/snowflake/driver_repo.py b/sdk/python/feast/templates/snowflake/driver_repo.py index a63c6cb503..0ecdad7f05 100644 --- a/sdk/python/feast/templates/snowflake/driver_repo.py +++ b/sdk/python/feast/templates/snowflake/driver_repo.py @@ -24,6 +24,7 @@ # The Snowflake table where features can be found database=yaml.safe_load(open("feature_store.yaml"))["offline_store"]["database"], table=f"{project_name}_feast_driver_hourly_stats", + warehouse="SNOWFLAKE_WAREHOUSE", # The event timestamp is used for point-in-time joins and for ensuring only # features within the TTL are returned event_timestamp_column="event_timestamp", diff --git a/sdk/python/tests/integration/feature_repos/universal/data_sources/snowflake.py b/sdk/python/tests/integration/feature_repos/universal/data_sources/snowflake.py index f76656f5b7..05cdea82f0 100644 --- a/sdk/python/tests/integration/feature_repos/universal/data_sources/snowflake.py +++ b/sdk/python/tests/integration/feature_repos/universal/data_sources/snowflake.py @@ -57,6 +57,7 @@ def create_data_source( created_timestamp_column=created_timestamp_column, date_partition_column="", field_mapping=field_mapping or {"ts_1": "ts"}, + warehouse=self.offline_store_config.warehouse, ) def create_saved_dataset_destination(self) -> SavedDatasetSnowflakeStorage: