Skip to content

Commit

Permalink
Rename get_pyarrow_schema
Browse files Browse the repository at this point in the history
Signed-off-by: Felix Wang <wangfelix98@gmail.com>
  • Loading branch information
felixwang9817 committed Jun 22, 2022
1 parent 6315d6d commit 1420aa0
Show file tree
Hide file tree
Showing 5 changed files with 19 additions and 14 deletions.
4 changes: 3 additions & 1 deletion sdk/python/feast/infra/offline_stores/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,9 @@ def offline_write_batch(
f"feature view batch source is {type(feature_view.batch_source)} not bigquery source"
)

pa_schema, column_names = offline_utils.get_pyarrow_schema(config, feature_view)
pa_schema, column_names = offline_utils.get_pyarrow_schema_from_batch_source(
config, feature_view.batch_source
)
if column_names != table.column_names:
raise ValueError(
f"The input pyarrow table has schema {pa_schema} with the incorrect columns {column_names}. "
Expand Down
6 changes: 4 additions & 2 deletions sdk/python/feast/infra/offline_stores/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
)
from feast.infra.offline_stores.offline_utils import (
DEFAULT_ENTITY_DF_EVENT_TIMESTAMP_COL,
get_pyarrow_schema,
get_pyarrow_schema_from_batch_source,
)
from feast.infra.provider import (
_get_requested_feature_views_to_features_dict,
Expand Down Expand Up @@ -425,7 +425,9 @@ def offline_write_batch(
f"feature view batch source is {type(feature_view.batch_source)} not file source"
)

pa_schema, column_names = get_pyarrow_schema(config, feature_view)
pa_schema, column_names = get_pyarrow_schema_from_batch_source(
config, feature_view.batch_source
)
if column_names != table.column_names:
raise ValueError(
f"The input pyarrow table has schema {pa_schema} with the incorrect columns {column_names}. "
Expand Down
15 changes: 6 additions & 9 deletions sdk/python/feast/infra/offline_stores/offline_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from jinja2 import BaseLoader, Environment
from pandas import Timestamp

from feast.data_source import DataSource
from feast.errors import (
EntityTimestampInferenceException,
FeastEntityDFMissingColumnsError,
Expand Down Expand Up @@ -222,13 +223,11 @@ def get_offline_store_from_config(offline_store_config: Any) -> OfflineStore:
return offline_store_class()


def get_pyarrow_schema(
config: RepoConfig, feature_view: FeatureView
def get_pyarrow_schema_from_batch_source(
config: RepoConfig, batch_source: DataSource
) -> Tuple[pa.Schema, List[str]]:
"""Returns the pyarrow schema and column names for the specified feature view's batch source."""
column_names_and_types = feature_view.batch_source.get_table_column_names_and_types(
config
)
"""Returns the pyarrow schema and column names for the given batch source."""
column_names_and_types = batch_source.get_table_column_names_and_types(config)

pa_schema = []
column_names = []
Expand All @@ -237,9 +236,7 @@ def get_pyarrow_schema(
(
column_name,
feast_value_type_to_pa(
feature_view.batch_source.source_datatype_to_feast_value_type()(
column_type
)
batch_source.source_datatype_to_feast_value_type()(column_type)
),
)
)
Expand Down
4 changes: 3 additions & 1 deletion sdk/python/feast/infra/offline_stores/redshift.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,9 @@ def offline_write_batch(
f"feature view batch source is {type(feature_view.batch_source)} not redshift source"
)

pa_schema, column_names = offline_utils.get_pyarrow_schema(config, feature_view)
pa_schema, column_names = offline_utils.get_pyarrow_schema_from_batch_source(
config, feature_view.batch_source
)
if column_names != table.column_names:
raise ValueError(
f"The input pyarrow table has schema {pa_schema} with the incorrect columns {column_names}. "
Expand Down
4 changes: 3 additions & 1 deletion sdk/python/feast/infra/offline_stores/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,9 @@ def offline_write_batch(
f"feature view batch source is {type(feature_view.batch_source)} not snowflake source"
)

pa_schema, column_names = offline_utils.get_pyarrow_schema(config, feature_view)
pa_schema, column_names = offline_utils.get_pyarrow_schema_from_batch_source(
config, feature_view.batch_source
)
if column_names != table.column_names:
raise ValueError(
f"The input pyarrow table has schema {pa_schema} with the incorrect columns {column_names}. "
Expand Down

0 comments on commit 1420aa0

Please sign in to comment.