Skip to content

Commit

Permalink
feat: Added SnowflakeConnection caching (#3531)
Browse files Browse the repository at this point in the history
Signed-off-by: Miles Adkins <miles.adkins@snowflake.com>
  • Loading branch information
sfc-gh-madkins authored Mar 24, 2023
1 parent 03924a2 commit f9f8df2
Show file tree
Hide file tree
Showing 7 changed files with 122 additions and 102 deletions.
15 changes: 8 additions & 7 deletions sdk/python/feast/infra/materialization/snowflake_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,10 @@
from feast.infra.online_stores.online_store import OnlineStore
from feast.infra.registry.base_registry import BaseRegistry
from feast.infra.utils.snowflake.snowflake_utils import (
GetSnowflakeConnection,
_run_snowflake_field_mapping,
assert_snowflake_feature_names,
execute_snowflake_statement,
get_snowflake_conn,
get_snowflake_online_store_path,
package_snowpark_zip,
)
Expand Down Expand Up @@ -121,7 +121,7 @@ def update(
):
stage_context = f'"{self.repo_config.batch_engine.database}"."{self.repo_config.batch_engine.schema_}"'
stage_path = f'{stage_context}."feast_{project}"'
with get_snowflake_conn(self.repo_config.batch_engine) as conn:
with GetSnowflakeConnection(self.repo_config.batch_engine) as conn:
query = f"SHOW STAGES IN {stage_context}"
cursor = execute_snowflake_statement(conn, query)
stage_list = pd.DataFrame(
Expand Down Expand Up @@ -173,7 +173,7 @@ def teardown_infra(
):

stage_path = f'"{self.repo_config.batch_engine.database}"."{self.repo_config.batch_engine.schema_}"."feast_{project}"'
with get_snowflake_conn(self.repo_config.batch_engine) as conn:
with GetSnowflakeConnection(self.repo_config.batch_engine) as conn:
query = f"DROP STAGE IF EXISTS {stage_path}"
execute_snowflake_statement(conn, query)

Expand Down Expand Up @@ -263,10 +263,11 @@ def _materialize_one(

# Lets check and see if we can skip this query, because the table hasnt changed
# since before the start date of this query
with get_snowflake_conn(self.repo_config.offline_store) as conn:
with GetSnowflakeConnection(self.repo_config.offline_store) as conn:
query = f"""SELECT SYSTEM$LAST_CHANGE_COMMIT_TIME('{feature_view.batch_source.get_table_query_string()}') AS last_commit_change_time"""
last_commit_change_time = (
conn.cursor().execute(query).fetchall()[0][0] / 1_000_000_000
execute_snowflake_statement(conn, query).fetchall()[0][0]
/ 1_000_000_000
)
if last_commit_change_time < start_date.astimezone(tz=utc).timestamp():
return SnowflakeMaterializationJob(
Expand Down Expand Up @@ -432,7 +433,7 @@ def materialize_to_snowflake_online_store(
)
"""

with get_snowflake_conn(repo_config.batch_engine) as conn:
with GetSnowflakeConnection(repo_config.batch_engine) as conn:
query_id = execute_snowflake_statement(conn, query).sfqid

click.echo(
Expand All @@ -450,7 +451,7 @@ def materialize_to_external_online_store(

feature_names = [feature.name for feature in feature_view.features]

with get_snowflake_conn(repo_config.batch_engine) as conn:
with GetSnowflakeConnection(repo_config.batch_engine) as conn:
query = materialization_sql
cursor = execute_snowflake_statement(conn, query)
for i, df in enumerate(cursor.fetch_pandas_batches()):
Expand Down
24 changes: 14 additions & 10 deletions sdk/python/feast/infra/offline_stores/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@
)
from feast.infra.registry.base_registry import BaseRegistry
from feast.infra.utils.snowflake.snowflake_utils import (
GetSnowflakeConnection,
execute_snowflake_statement,
get_snowflake_conn,
write_pandas,
write_parquet,
)
Expand All @@ -74,13 +74,13 @@ class SnowflakeOfflineStoreConfig(FeastConfigBaseModel):
"""Offline store config for Snowflake"""

type: Literal["snowflake.offline"] = "snowflake.offline"
""" Offline store type selector"""
""" Offline store type selector """

config_path: Optional[str] = os.path.expanduser("~/.snowsql/config")
""" Snowflake config path -- absolute path required (Cant use ~)"""

account: Optional[str] = None
""" Snowflake deployment identifier -- drop .snowflakecomputing.com"""
""" Snowflake deployment identifier -- drop .snowflakecomputing.com """

user: Optional[str] = None
""" Snowflake user name """
Expand All @@ -89,7 +89,7 @@ class SnowflakeOfflineStoreConfig(FeastConfigBaseModel):
""" Snowflake password """

role: Optional[str] = None
""" Snowflake role name"""
""" Snowflake role name """

warehouse: Optional[str] = None
""" Snowflake warehouse name """
Expand Down Expand Up @@ -155,7 +155,8 @@ def pull_latest_from_table_or_query(
if data_source.snowflake_options.warehouse:
config.offline_store.warehouse = data_source.snowflake_options.warehouse

snowflake_conn = get_snowflake_conn(config.offline_store)
with GetSnowflakeConnection(config.offline_store) as conn:
snowflake_conn = conn

start_date = start_date.astimezone(tz=utc)
end_date = end_date.astimezone(tz=utc)
Expand Down Expand Up @@ -208,7 +209,8 @@ def pull_all_from_table_or_query(
if data_source.snowflake_options.warehouse:
config.offline_store.warehouse = data_source.snowflake_options.warehouse

snowflake_conn = get_snowflake_conn(config.offline_store)
with GetSnowflakeConnection(config.offline_store) as conn:
snowflake_conn = conn

start_date = start_date.astimezone(tz=utc)
end_date = end_date.astimezone(tz=utc)
Expand Down Expand Up @@ -241,7 +243,8 @@ def get_historical_features(
for fv in feature_views:
assert isinstance(fv.batch_source, SnowflakeSource)

snowflake_conn = get_snowflake_conn(config.offline_store)
with GetSnowflakeConnection(config.offline_store) as conn:
snowflake_conn = conn

entity_schema = _get_entity_schema(entity_df, snowflake_conn, config)

Expand Down Expand Up @@ -319,7 +322,8 @@ def write_logged_features(
):
assert isinstance(logging_config.destination, SnowflakeLoggingDestination)

snowflake_conn = get_snowflake_conn(config.offline_store)
with GetSnowflakeConnection(config.offline_store) as conn:
snowflake_conn = conn

if isinstance(data, Path):
write_parquet(
Expand Down Expand Up @@ -359,7 +363,8 @@ def offline_write_batch(
if table.schema != pa_schema:
table = table.cast(pa_schema)

snowflake_conn = get_snowflake_conn(config.offline_store)
with GetSnowflakeConnection(config.offline_store) as conn:
snowflake_conn = conn

write_pandas(
snowflake_conn,
Expand Down Expand Up @@ -427,7 +432,6 @@ def _to_arrow_internal(self, timeout: Optional[int] = None) -> pyarrow.Table:
).fetch_arrow_all()

if pa_table:

return pa_table
else:
empty_result = execute_snowflake_statement(self.snowflake_conn, query)
Expand Down
6 changes: 3 additions & 3 deletions sdk/python/feast/infra/offline_stores/snowflake_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,13 +213,13 @@ def get_table_column_names_and_types(
"""
from feast.infra.offline_stores.snowflake import SnowflakeOfflineStoreConfig
from feast.infra.utils.snowflake.snowflake_utils import (
GetSnowflakeConnection,
execute_snowflake_statement,
get_snowflake_conn,
)

assert isinstance(config.offline_store, SnowflakeOfflineStoreConfig)

with get_snowflake_conn(config.offline_store) as conn:
with GetSnowflakeConnection(config.offline_store) as conn:
query = f"SELECT * FROM {self.get_table_query_string()} LIMIT 5"
cursor = execute_snowflake_statement(conn, query)

Expand Down Expand Up @@ -250,7 +250,7 @@ def get_table_column_names_and_types(
else:
column = row["column_name"]

with get_snowflake_conn(config.offline_store) as conn:
with GetSnowflakeConnection(config.offline_store) as conn:
query = f'SELECT MAX("{column}") AS "{column}" FROM {self.get_table_query_string()}'
result = execute_snowflake_statement(
conn, query
Expand Down
16 changes: 8 additions & 8 deletions sdk/python/feast/infra/online_stores/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
from feast.infra.key_encoding_utils import serialize_entity_key
from feast.infra.online_stores.online_store import OnlineStore
from feast.infra.utils.snowflake.snowflake_utils import (
GetSnowflakeConnection,
execute_snowflake_statement,
get_snowflake_conn,
get_snowflake_online_store_path,
write_pandas_binary,
)
Expand All @@ -29,13 +29,13 @@ class SnowflakeOnlineStoreConfig(FeastConfigBaseModel):
"""Online store config for Snowflake"""

type: Literal["snowflake.online"] = "snowflake.online"
""" Online store type selector"""
""" Online store type selector """

config_path: Optional[str] = os.path.expanduser("~/.snowsql/config")
""" Snowflake config path -- absolute path required (Can't use ~)"""

account: Optional[str] = None
""" Snowflake deployment identifier -- drop .snowflakecomputing.com"""
""" Snowflake deployment identifier -- drop .snowflakecomputing.com """

user: Optional[str] = None
""" Snowflake user name """
Expand All @@ -44,7 +44,7 @@ class SnowflakeOnlineStoreConfig(FeastConfigBaseModel):
""" Snowflake password """

role: Optional[str] = None
""" Snowflake role name"""
""" Snowflake role name """

warehouse: Optional[str] = None
""" Snowflake warehouse name """
Expand Down Expand Up @@ -114,7 +114,7 @@ def online_write_batch(

# This combines both the data upload plus the overwrite in the same transaction
online_path = get_snowflake_online_store_path(config, table)
with get_snowflake_conn(config.online_store, autocommit=False) as conn:
with GetSnowflakeConnection(config.online_store, autocommit=False) as conn:
write_pandas_binary(
conn,
agg_df,
Expand Down Expand Up @@ -178,7 +178,7 @@ def online_read(
)

online_path = get_snowflake_online_store_path(config, table)
with get_snowflake_conn(config.online_store) as conn:
with GetSnowflakeConnection(config.online_store) as conn:
query = f"""
SELECT
"entity_key", "feature_name", "value", "event_ts"
Expand Down Expand Up @@ -220,7 +220,7 @@ def update(
):
assert isinstance(config.online_store, SnowflakeOnlineStoreConfig)

with get_snowflake_conn(config.online_store) as conn:
with GetSnowflakeConnection(config.online_store) as conn:
for table in tables_to_keep:
online_path = get_snowflake_online_store_path(config, table)
query = f"""
Expand Down Expand Up @@ -248,7 +248,7 @@ def teardown(
):
assert isinstance(config.online_store, SnowflakeOnlineStoreConfig)

with get_snowflake_conn(config.online_store) as conn:
with GetSnowflakeConnection(config.online_store) as conn:
for table in tables:
online_path = get_snowflake_online_store_path(config, table)
query = f'DROP TABLE IF EXISTS {online_path}."[online-transient] {config.project}_{table.name}"'
Expand Down
28 changes: 14 additions & 14 deletions sdk/python/feast/infra/registry/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@
from feast.infra.registry import proto_registry_utils
from feast.infra.registry.base_registry import BaseRegistry
from feast.infra.utils.snowflake.snowflake_utils import (
GetSnowflakeConnection,
execute_snowflake_statement,
get_snowflake_conn,
)
from feast.on_demand_feature_view import OnDemandFeatureView
from feast.project_metadata import ProjectMetadata
Expand Down Expand Up @@ -121,7 +121,7 @@ def __init__(
f'"{self.registry_config.database}"."{self.registry_config.schema_}"'
)

with get_snowflake_conn(self.registry_config) as conn:
with GetSnowflakeConnection(self.registry_config) as conn:
sql_function_file = f"{os.path.dirname(feast.__file__)}/infra/utils/snowflake/registry/snowflake_table_creation.sql"
with open(sql_function_file, "r") as file:
sqlFile = file.read()
Expand Down Expand Up @@ -177,7 +177,7 @@ def _refresh_cached_registry_if_necessary(self):
self.refresh()

def teardown(self):
with get_snowflake_conn(self.registry_config) as conn:
with GetSnowflakeConnection(self.registry_config) as conn:
sql_function_file = f"{os.path.dirname(feast.__file__)}/infra/utils/snowflake/registry/snowflake_table_deletion.sql"
with open(sql_function_file, "r") as file:
sqlFile = file.read()
Expand Down Expand Up @@ -284,7 +284,7 @@ def _apply_object(
if hasattr(obj, "last_updated_timestamp"):
obj.last_updated_timestamp = update_datetime

with get_snowflake_conn(self.registry_config) as conn:
with GetSnowflakeConnection(self.registry_config) as conn:
query = f"""
SELECT
project_id
Expand Down Expand Up @@ -405,7 +405,7 @@ def _delete_object(
id_field_name: str,
not_found_exception: Optional[Callable],
):
with get_snowflake_conn(self.registry_config) as conn:
with GetSnowflakeConnection(self.registry_config) as conn:
query = f"""
DELETE FROM {self.registry_path}."{table}"
WHERE
Expand Down Expand Up @@ -616,7 +616,7 @@ def _get_object(
not_found_exception: Optional[Callable],
):
self._maybe_init_project_metadata(project)
with get_snowflake_conn(self.registry_config) as conn:
with GetSnowflakeConnection(self.registry_config) as conn:
query = f"""
SELECT
{proto_field_name}
Expand Down Expand Up @@ -776,7 +776,7 @@ def _list_objects(
proto_field_name: str,
):
self._maybe_init_project_metadata(project)
with get_snowflake_conn(self.registry_config) as conn:
with GetSnowflakeConnection(self.registry_config) as conn:
query = f"""
SELECT
{proto_field_name}
Expand Down Expand Up @@ -839,7 +839,7 @@ def list_project_metadata(
return proto_registry_utils.list_project_metadata(
self.cached_registry_proto, project
)
with get_snowflake_conn(self.registry_config) as conn:
with GetSnowflakeConnection(self.registry_config) as conn:
query = f"""
SELECT
metadata_key,
Expand Down Expand Up @@ -869,7 +869,7 @@ def apply_user_metadata(
):
fv_table_str = self._infer_fv_table(feature_view)
fv_column_name = fv_table_str[:-1].lower()
with get_snowflake_conn(self.registry_config) as conn:
with GetSnowflakeConnection(self.registry_config) as conn:
query = f"""
SELECT
project_id
Expand Down Expand Up @@ -905,7 +905,7 @@ def get_user_metadata(
) -> Optional[bytes]:
fv_table_str = self._infer_fv_table(feature_view)
fv_column_name = fv_table_str[:-1].lower()
with get_snowflake_conn(self.registry_config) as conn:
with GetSnowflakeConnection(self.registry_config) as conn:
query = f"""
SELECT
user_metadata
Expand Down Expand Up @@ -971,7 +971,7 @@ def _get_all_projects(self) -> Set[str]:
"STREAM_FEATURE_VIEWS",
]

with get_snowflake_conn(self.registry_config) as conn:
with GetSnowflakeConnection(self.registry_config) as conn:
for table in base_tables:
query = (
f'SELECT DISTINCT project_id FROM {self.registry_path}."{table}"'
Expand All @@ -984,7 +984,7 @@ def _get_all_projects(self) -> Set[str]:
return projects

def _get_last_updated_metadata(self, project: str):
with get_snowflake_conn(self.registry_config) as conn:
with GetSnowflakeConnection(self.registry_config) as conn:
query = f"""
SELECT
metadata_value
Expand Down Expand Up @@ -1029,7 +1029,7 @@ def _infer_fv_table(self, feature_view) -> str:
return table

def _maybe_init_project_metadata(self, project):
with get_snowflake_conn(self.registry_config) as conn:
with GetSnowflakeConnection(self.registry_config) as conn:
query = f"""
SELECT
metadata_value
Expand All @@ -1056,7 +1056,7 @@ def _maybe_init_project_metadata(self, project):
usage.set_current_project_uuid(new_project_uuid)

def _set_last_updated_metadata(self, last_updated: datetime, project: str):
with get_snowflake_conn(self.registry_config) as conn:
with GetSnowflakeConnection(self.registry_config) as conn:
query = f"""
SELECT
project_id
Expand Down
Loading

0 comments on commit f9f8df2

Please sign in to comment.