Skip to content

Commit

Permalink
Optimize memory usage during materialization
Browse files Browse the repository at this point in the history
Signed-off-by: Judah Rand <17158624+judahrand@users.noreply.github.com>
  • Loading branch information
judahrand committed Nov 19, 2021
1 parent 91b37e7 commit 1d31495
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 7 deletions.
16 changes: 11 additions & 5 deletions sdk/python/feast/infra/passthrough_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
from feast.repo_config import RepoConfig
from feast.usage import RatioSampler, log_exceptions_and_usage, set_usage_attribute

DEFAULT_BATCH_SIZE = 10_000


class PassthroughProvider(Provider):
"""
Expand Down Expand Up @@ -145,12 +147,16 @@ def materialize_single_feature_view(
table = _run_field_mapping(table, feature_view.batch_source.field_mapping)

join_keys = [entity.join_key for entity in entities]
rows_to_write = _convert_arrow_to_proto(table, feature_view, join_keys)

with tqdm_builder(len(rows_to_write)) as pbar:
self.online_write_batch(
self.repo_config, feature_view, rows_to_write, lambda x: pbar.update(x)
)
with tqdm_builder(table.num_rows) as pbar:
for batch in table.to_batches(DEFAULT_BATCH_SIZE):
rows_to_write = _convert_arrow_to_proto(batch, feature_view, join_keys)
self.online_write_batch(
self.repo_config,
feature_view,
rows_to_write,
lambda x: pbar.update(x),
)

def get_historical_features(
self,
Expand Down
6 changes: 4 additions & 2 deletions sdk/python/feast/infra/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,9 @@ def _run_field_mapping(


def _convert_arrow_to_proto(
table: pyarrow.Table, feature_view: FeatureView, join_keys: List[str],
table: Union[pyarrow.Table, pyarrow.RecordBatch],
feature_view: FeatureView,
join_keys: List[str],
) -> List[Tuple[EntityKeyProto, Dict[str, ValueProto], datetime, Optional[datetime]]]:
rows_to_write = []

Expand All @@ -305,7 +307,7 @@ def _coerce_datetime(ts):
else:
return ts

column_names_idx = {k: i for i, k in enumerate(table.column_names)}
column_names_idx = {field.name: i for i, field in enumerate(table.schema)}
for row in zip(*table.to_pydict().values()):
entity_key = EntityKeyProto()
for join_key in join_keys:
Expand Down

0 comments on commit 1d31495

Please sign in to comment.