Skip to content

Commit

Permalink
feat: Add Snowflake Registry
Browse files Browse the repository at this point in the history
Signed-off-by: miles.adkins <miles.adkins@snowflake.com>
  • Loading branch information
sfc-gh-madkins committed Dec 22, 2022
1 parent 98a24a3 commit 0c0d468
Show file tree
Hide file tree
Showing 13 changed files with 1,382 additions and 113 deletions.
27 changes: 27 additions & 0 deletions sdk/python/feast/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,26 @@ def __init__(self, name, project=None):
super().__init__(f"On demand feature view {name} does not exist")


class RequestFeatureViewNotFoundException(FeastObjectNotFoundException):
def __init__(self, name, project=None):
if project:
super().__init__(
f"Request feature view {name} does not exist in project {project}"
)
else:
super().__init__(f"Request feature view {name} does not exist")


class StreamFeatureViewNotFoundException(FeastObjectNotFoundException):
def __init__(self, name, project=None):
if project:
super().__init__(
f"Stream feature view {name} does not exist in project {project}"
)
else:
super().__init__(f"Stream feature view {name} does not exist")


class RequestDataNotFoundInEntityDfException(FeastObjectNotFoundException):
def __init__(self, feature_name, feature_view_name):
super().__init__(
Expand Down Expand Up @@ -146,6 +166,13 @@ def __init__(self, feature_server_type: str):
)


class FeastRegistryTypeInvalidError(Exception):
def __init__(self, registry_type: str):
super().__init__(
f"Feature server type was set to {registry_type}, but this type is invalid"
)


class FeastModuleImportError(Exception):
def __init__(self, module_name: str, class_name: str):
super().__init__(
Expand Down
3 changes: 3 additions & 0 deletions sdk/python/feast/feature_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@
from feast.infra.provider import Provider, RetrievalJob, get_provider
from feast.infra.registry.base_registry import BaseRegistry
from feast.infra.registry.registry import Registry
from feast.infra.registry.snowflake import SnowflakeRegistry
from feast.infra.registry.sql import SqlRegistry
from feast.on_demand_feature_view import OnDemandFeatureView
from feast.online_response import OnlineResponse
Expand Down Expand Up @@ -167,6 +168,8 @@ def __init__(
registry_config = self.config.get_registry_config()
if registry_config.registry_type == "sql":
self._registry = SqlRegistry(registry_config, None)
elif registry_config.registry_type == "snowflake.registry":
self._registry = SnowflakeRegistry(registry_config, None)
else:
r = Registry(registry_config, repo_path=self.repo_path)
r._initialize_registry(self.config.project)
Expand Down
81 changes: 39 additions & 42 deletions sdk/python/feast/infra/materialization/snowflake_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,27 @@ def url(self) -> Optional[str]:


class SnowflakeMaterializationEngine(BatchMaterializationEngine):
def __init__(
self,
*,
repo_config: RepoConfig,
offline_store: OfflineStore,
online_store: OnlineStore,
**kwargs,
):
assert (
repo_config.offline_store.type == "snowflake.offline"
), "To use SnowflakeMaterializationEngine, you must use Snowflake as an offline store."

self.snowflake_conn = get_snowflake_conn(repo_config.batch_engine)

super().__init__(
repo_config=repo_config,
offline_store=offline_store,
online_store=online_store,
**kwargs,
)

def update(
self,
project: str,
Expand All @@ -121,9 +142,9 @@ def update(
):
stage_context = f'"{self.repo_config.batch_engine.database}"."{self.repo_config.batch_engine.schema_}"'
stage_path = f'{stage_context}."feast_{project}"'
with get_snowflake_conn(self.repo_config.batch_engine) as conn:
with self.snowflake_conn.cursor() as cur:
query = f"SHOW STAGES IN {stage_context}"
cursor = execute_snowflake_statement(conn, query)
cursor = execute_snowflake_statement(cur, query)
stage_list = pd.DataFrame(
cursor.fetchall(),
columns=[column.name for column in cursor.description],
Expand All @@ -140,11 +161,11 @@ def update(
click.echo()

query = f"CREATE STAGE {stage_path}"
execute_snowflake_statement(conn, query)
execute_snowflake_statement(cur, query)

copy_path, zip_path = package_snowpark_zip(project)
query = f"PUT file://{zip_path} @{stage_path}"
execute_snowflake_statement(conn, query)
execute_snowflake_statement(cur, query)

shutil.rmtree(copy_path)

Expand All @@ -157,7 +178,7 @@ def update(
for command in sqlCommands:
command = command.replace("STAGE_HOLDER", f"{stage_path}")
query = command.replace("PROJECT_NAME", f"{project}")
execute_snowflake_statement(conn, query)
execute_snowflake_statement(cur, query)

return None

Expand All @@ -169,9 +190,9 @@ def teardown_infra(
):

stage_path = f'"{self.repo_config.batch_engine.database}"."{self.repo_config.batch_engine.schema_}"."feast_{project}"'
with get_snowflake_conn(self.repo_config.batch_engine) as conn:
with self.snowflake_conn.cursor() as cur:
query = f"DROP STAGE IF EXISTS {stage_path}"
execute_snowflake_statement(conn, query)
execute_snowflake_statement(cur, query)

# Execute snowflake python udf deletion functions
sql_function_file = f"{os.path.dirname(feast.__file__)}/infra/utils/snowflake/snowpark/snowflake_python_udfs_deletion.sql"
Expand All @@ -181,29 +202,10 @@ def teardown_infra(
sqlCommands = sqlFile.split(";")
for command in sqlCommands:
query = command.replace("PROJECT_NAME", f"{project}")
execute_snowflake_statement(conn, query)
execute_snowflake_statement(cur, query)

return None

def __init__(
self,
*,
repo_config: RepoConfig,
offline_store: OfflineStore,
online_store: OnlineStore,
**kwargs,
):
assert (
repo_config.offline_store.type == "snowflake.offline"
), "To use SnowflakeMaterializationEngine, you must use Snowflake as an offline store."

super().__init__(
repo_config=repo_config,
offline_store=offline_store,
online_store=online_store,
**kwargs,
)

def materialize(
self, registry, tasks: List[MaterializationTask]
) -> List[MaterializationJob]:
Expand Down Expand Up @@ -259,10 +261,11 @@ def _materialize_one(

# Lets check and see if we can skip this query, because the table hasnt changed
# since before the start date of this query
with get_snowflake_conn(self.repo_config.offline_store) as conn:
with self.snowflake_conn.cursor() as cur:
query = f"""SELECT SYSTEM$LAST_CHANGE_COMMIT_TIME('{feature_view.batch_source.get_table_query_string()}') AS last_commit_change_time"""
last_commit_change_time = (
conn.cursor().execute(query).fetchall()[0][0] / 1_000_000_000
execute_snowflake_statement(cur, query).fetchall()[0][0]
/ 1_000_000_000
)
if last_commit_change_time < start_date.astimezone(tz=utc).timestamp():
return SnowflakeMaterializationJob(
Expand All @@ -277,22 +280,19 @@ def _materialize_one(
)

fv_to_proto_sql = self.generate_snowflake_materialization_query(
self.repo_config,
fv_latest_mapped_values_sql,
feature_view,
project,
)

if self.repo_config.online_store.type == "snowflake.online":
self.materialize_to_snowflake_online_store(
self.repo_config,
fv_to_proto_sql,
feature_view,
project,
)
else:
self.materialize_to_external_online_store(
self.repo_config,
fv_to_proto_sql,
feature_view,
tqdm_builder,
Expand All @@ -308,7 +308,6 @@ def _materialize_one(

def generate_snowflake_materialization_query(
self,
repo_config: RepoConfig,
fv_latest_mapped_values_sql: str,
feature_view: Union[BatchFeatureView, FeatureView],
project: str,
Expand Down Expand Up @@ -353,7 +352,7 @@ def generate_snowflake_materialization_query(

features_str = ",\n".join(feature_sql_list)

if repo_config.online_store.type == "snowflake.online":
if self.repo_config.online_store.type == "snowflake.online":
serial_func = f"feast_{project}_serialize_entity_keys"
else:
serial_func = f"feast_{project}_entity_key_proto_to_string"
Expand All @@ -373,7 +372,6 @@ def generate_snowflake_materialization_query(

def materialize_to_snowflake_online_store(
self,
repo_config: RepoConfig,
materialization_sql: str,
feature_view: Union[BatchFeatureView, FeatureView],
project: str,
Expand All @@ -389,7 +387,7 @@ def materialize_to_snowflake_online_store(
else:
fv_created_str = None

online_path = get_snowflake_online_store_path(repo_config, feature_view)
online_path = get_snowflake_online_store_path(self.repo_config, feature_view)
online_table = (
f'{online_path}."[online-transient] {project}_{feature_view.name}"'
)
Expand Down Expand Up @@ -428,8 +426,8 @@ def materialize_to_snowflake_online_store(
)
"""

with get_snowflake_conn(repo_config.batch_engine) as conn:
query_id = execute_snowflake_statement(conn, query).sfqid
with self.snowflake_conn.cursor() as cur:
query_id = execute_snowflake_statement(cur, query).sfqid

click.echo(
f"Snowflake Query ID: {Style.BRIGHT + Fore.GREEN}{query_id}{Style.RESET_ALL}"
Expand All @@ -438,17 +436,16 @@ def materialize_to_snowflake_online_store(

def materialize_to_external_online_store(
self,
repo_config: RepoConfig,
materialization_sql: str,
feature_view: Union[StreamFeatureView, FeatureView],
tqdm_builder: Callable[[int], tqdm],
) -> None:

feature_names = [feature.name for feature in feature_view.features]

with get_snowflake_conn(repo_config.batch_engine) as conn:
with self.snowflake_conn.cursor() as cur:
query = materialization_sql
cursor = execute_snowflake_statement(conn, query)
cursor = execute_snowflake_statement(cur, query)
for i, df in enumerate(cursor.fetch_pandas_batches()):
click.echo(
f"Snowflake: Processing Materialization ResultSet Batch #{i+1}"
Expand Down Expand Up @@ -491,7 +488,7 @@ def materialize_to_external_online_store(

with tqdm_builder(len(rows_to_write)) as pbar:
self.online_store.online_write_batch(
repo_config,
self.repo_config,
feature_view,
rows_to_write,
lambda x: pbar.update(x),
Expand Down
Loading

0 comments on commit 0c0d468

Please sign in to comment.