Skip to content

Commit

Permalink
feat: Update snowflake offline store job output formats -- added arrow (
Browse files Browse the repository at this point in the history
#3589)

Signed-off-by: Miles Adkins <miles.adkins@snowflake.com>
  • Loading branch information
sfc-gh-madkins authored Apr 21, 2023
1 parent 58ce148 commit be3e349
Show file tree
Hide file tree
Showing 3 changed files with 84 additions and 55 deletions.
6 changes: 3 additions & 3 deletions docs/reference/offline-stores/overview.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,11 @@ Below is a matrix indicating which `RetrievalJob`s support what functionality.
| --------------------------------- | --- | --- | --- | --- | --- | --- | --- |
| export to dataframe | yes | yes | yes | yes | yes | yes | yes |
| export to arrow table | yes | yes | yes | yes | yes | yes | yes |
| export to arrow batches | no | no | no | yes | no | no | no |
| export to SQL | no | yes | no | yes | yes | no | yes |
| export to arrow batches | no | no | yes | yes | no | no | no |
| export to SQL | no | yes | yes | yes | yes | no | yes |
| export to data lake (S3, GCS, etc.) | no | no | yes | no | yes | no | no |
| export to data warehouse | no | yes | yes | yes | yes | no | no |
| export as Spark dataframe | no | no | no | no | no | yes | no |
| export as Spark dataframe | no | no | yes | no | no | yes | no |
| local execution of Python-based on-demand transforms | yes | yes | yes | yes | yes | no | yes |
| remote execution of Python-based on-demand transforms | no | no | no | no | no | no | no |
| persist results in the offline store | yes | yes | yes | yes | yes | yes | no |
Expand Down
4 changes: 2 additions & 2 deletions docs/reference/offline-stores/snowflake.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,11 +53,11 @@ Below is a matrix indicating which functionality is supported by `SnowflakeRetri
| ----------------------------------------------------- | --------- |
| export to dataframe | yes |
| export to arrow table | yes |
| export to arrow batches | no |
| export to arrow batches | yes |
| export to SQL | yes |
| export to data lake (S3, GCS, etc.) | yes |
| export to data warehouse | yes |
| export as Spark dataframe | no |
| export as Spark dataframe | yes |
| local execution of Python-based on-demand transforms | yes |
| remote execution of Python-based on-demand transforms | no |
| persist results in the offline store | yes |
Expand Down
129 changes: 79 additions & 50 deletions sdk/python/feast/infra/offline_stores/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,52 +436,85 @@ def on_demand_feature_views(self) -> List[OnDemandFeatureView]:
return self._on_demand_feature_views

def _to_df_internal(self, timeout: Optional[int] = None) -> pd.DataFrame:
with self._query_generator() as query:

df = execute_snowflake_statement(
self.snowflake_conn, query
).fetch_pandas_all()
df = execute_snowflake_statement(
self.snowflake_conn, self.to_sql()
).fetch_pandas_all()

return df

def _to_arrow_internal(self, timeout: Optional[int] = None) -> pyarrow.Table:
with self._query_generator() as query:
pa_table = execute_snowflake_statement(
self.snowflake_conn, self.to_sql()
).fetch_arrow_all()

pa_table = execute_snowflake_statement(
self.snowflake_conn, query
).fetch_arrow_all()
if pa_table:
return pa_table
else:
empty_result = execute_snowflake_statement(
self.snowflake_conn, self.to_sql()
)

if pa_table:
return pa_table
else:
empty_result = execute_snowflake_statement(self.snowflake_conn, query)
return pyarrow.Table.from_pandas(
pd.DataFrame(columns=[md.name for md in empty_result.description])
)

return pyarrow.Table.from_pandas(
pd.DataFrame(columns=[md.name for md in empty_result.description])
)
def to_sql(self) -> str:
"""
Returns the SQL query that will be executed in Snowflake to build the historical feature table.
"""
with self._query_generator() as query:
return query

def to_snowflake(self, table_name: str, temporary=False) -> None:
def to_snowflake(
self, table_name: str, allow_overwrite: bool = False, temporary: bool = False
) -> None:
"""Save dataset as a new Snowflake table"""
if self.on_demand_feature_views:
transformed_df = self.to_df()

if allow_overwrite:
query = f'DROP TABLE IF EXISTS "{table_name}"'
execute_snowflake_statement(self.snowflake_conn, query)

write_pandas(
self.snowflake_conn, transformed_df, table_name, auto_create_table=True
self.snowflake_conn,
transformed_df,
table_name,
auto_create_table=True,
create_temp_table=temporary,
)

return None
else:
query = f'CREATE {"OR REPLACE" if allow_overwrite else ""} {"TEMPORARY" if temporary else ""} TABLE {"IF NOT EXISTS" if not allow_overwrite else ""} "{table_name}" AS ({self.to_sql()});\n'
execute_snowflake_statement(self.snowflake_conn, query)

with self._query_generator() as query:
query = f'CREATE {"TEMPORARY" if temporary else ""} TABLE IF NOT EXISTS "{table_name}" AS ({query});\n'
return None

execute_snowflake_statement(self.snowflake_conn, query)
def to_arrow_batches(self) -> Iterator[pyarrow.Table]:

def to_sql(self) -> str:
"""
Returns the SQL query that will be executed in Snowflake to build the historical feature table.
"""
with self._query_generator() as query:
return query
table_name = "temp_arrow_batches_" + uuid.uuid4().hex

self.to_snowflake(table_name=table_name, allow_overwrite=True, temporary=True)

query = f'SELECT * FROM "{table_name}"'
arrow_batches = execute_snowflake_statement(
self.snowflake_conn, query
).fetch_arrow_batches()

return arrow_batches

def to_pandas_batches(self) -> Iterator[pd.DataFrame]:

table_name = "temp_pandas_batches_" + uuid.uuid4().hex

self.to_snowflake(table_name=table_name, allow_overwrite=True, temporary=True)

query = f'SELECT * FROM "{table_name}"'
arrow_batches = execute_snowflake_statement(
self.snowflake_conn, query
).fetch_pandas_batches()

return arrow_batches

def to_spark_df(self, spark_session: "SparkSession") -> "DataFrame":
"""
Expand All @@ -502,37 +535,33 @@ def to_spark_df(self, spark_session: "SparkSession") -> "DataFrame":
raise FeastExtrasDependencyImportError("spark", str(e))

if isinstance(spark_session, SparkSession):
with self._query_generator() as query:

arrow_batches = execute_snowflake_statement(
self.snowflake_conn, query
).fetch_arrow_batches()

if arrow_batches:
spark_df = reduce(
DataFrame.unionAll,
[
spark_session.createDataFrame(batch.to_pandas())
for batch in arrow_batches
],
)

return spark_df

else:
raise EntitySQLEmptyResults(query)

arrow_batches = self.to_arrow_batches()

if arrow_batches:
spark_df = reduce(
DataFrame.unionAll,
[
spark_session.createDataFrame(batch.to_pandas())
for batch in arrow_batches
],
)
return spark_df
else:
raise EntitySQLEmptyResults(self.to_sql())
else:
raise InvalidSparkSessionException(spark_session)

def persist(
self,
storage: SavedDatasetStorage,
allow_overwrite: Optional[bool] = False,
allow_overwrite: bool = False,
timeout: Optional[int] = None,
):
assert isinstance(storage, SavedDatasetSnowflakeStorage)
self.to_snowflake(table_name=storage.snowflake_options.table)

self.to_snowflake(
table_name=storage.snowflake_options.table, allow_overwrite=allow_overwrite
)

@property
def metadata(self) -> Optional[RetrievalMetadata]:
Expand Down

0 comments on commit be3e349

Please sign in to comment.