diff --git a/sdk/python/feast/feature_store.py b/sdk/python/feast/feature_store.py index e27a786083..421d17f999 100644 --- a/sdk/python/feast/feature_store.py +++ b/sdk/python/feast/feature_store.py @@ -71,6 +71,10 @@ def __init__( ), ) + @property + def project(self) -> str: + return self.config.project + def _get_provider(self) -> Provider: return get_provider(self.config) @@ -108,17 +112,25 @@ def apply(self, objects: List[Union[FeatureView, Entity]]): # TODO: Add locking # TODO: Optimize by only making a single call (read/write) - # TODO: Add infra update operation (currently we are just writing to registry) registry = self._get_registry() + + views_to_update = [] for ob in objects: if isinstance(ob, FeatureView): registry.apply_feature_view(ob, project=self.config.project) + views_to_update.append(ob) elif isinstance(ob, Entity): registry.apply_entity(ob, project=self.config.project) else: raise ValueError( f"Unknown object type ({type(ob)}) provided as part of apply() call" ) + self._get_provider().update_infra( + project=self.config.project, + tables_to_delete=[], + tables_to_keep=views_to_update, + partial=True, + ) def get_historical_features( self, entity_df: Union[pd.DataFrame, str], feature_refs: List[str], diff --git a/sdk/python/feast/infra/gcp.py b/sdk/python/feast/infra/gcp.py index 79b1df0806..28566d3bb2 100644 --- a/sdk/python/feast/infra/gcp.py +++ b/sdk/python/feast/infra/gcp.py @@ -1,5 +1,5 @@ from datetime import datetime -from typing import Dict, List, Optional, Tuple, Union +from typing import Dict, List, Optional, Sequence, Tuple, Union import mmh3 from pytz import utc @@ -59,15 +59,16 @@ def _initialize_client(self): from google.cloud import datastore if self._gcp_project_id is not None: - return datastore.Client(self.project_id) + return datastore.Client(self._gcp_project_id) else: return datastore.Client() def update_infra( self, project: str, - tables_to_delete: List[Union[FeatureTable, FeatureView]], - tables_to_keep: List[Union[FeatureTable, FeatureView]], + tables_to_delete: Sequence[Union[FeatureTable, FeatureView]], + tables_to_keep: Sequence[Union[FeatureTable, FeatureView]], + partial: bool, ): from google.cloud import datastore @@ -89,7 +90,7 @@ def update_infra( client.delete(key) def teardown_infra( - self, project: str, tables: List[Union[FeatureTable, FeatureView]] + self, project: str, tables: Sequence[Union[FeatureTable, FeatureView]] ) -> None: client = self._initialize_client() diff --git a/sdk/python/feast/infra/local_sqlite.py b/sdk/python/feast/infra/local_sqlite.py index 762add2fce..eda37bdeac 100644 --- a/sdk/python/feast/infra/local_sqlite.py +++ b/sdk/python/feast/infra/local_sqlite.py @@ -1,7 +1,7 @@ import os import sqlite3 from datetime import datetime -from typing import Dict, List, Optional, Tuple, Union +from typing import Dict, List, Optional, Sequence, Tuple, Union from feast import FeatureTable, FeatureView from feast.infra.key_encoding_utils import serialize_entity_key @@ -29,8 +29,9 @@ def _get_conn(self): def update_infra( self, project: str, - tables_to_delete: List[Union[FeatureTable, FeatureView]], - tables_to_keep: List[Union[FeatureTable, FeatureView]], + tables_to_delete: Sequence[Union[FeatureTable, FeatureView]], + tables_to_keep: Sequence[Union[FeatureTable, FeatureView]], + partial: bool, ): conn = self._get_conn() for table in tables_to_keep: @@ -45,7 +46,7 @@ def update_infra( conn.execute(f"DROP TABLE IF EXISTS {_table_id(project, table)}") def teardown_infra( - self, project: str, tables: List[Union[FeatureTable, FeatureView]] + self, project: str, tables: Sequence[Union[FeatureTable, FeatureView]] ) -> None: os.unlink(self._db_path) diff --git a/sdk/python/feast/infra/provider.py b/sdk/python/feast/infra/provider.py index 719d04a0c7..824619bc9f 100644 --- a/sdk/python/feast/infra/provider.py +++ b/sdk/python/feast/infra/provider.py @@ -1,6 +1,6 @@ import abc from datetime import datetime -from typing import Dict, List, Optional, Tuple, Union +from typing import Dict, List, Optional, Sequence, Tuple, Union from feast import FeatureTable, FeatureView from feast.protos.feast.types.EntityKey_pb2 import EntityKey as EntityKeyProto @@ -13,8 +13,9 @@ class Provider(abc.ABC): def update_infra( self, project: str, - tables_to_delete: List[Union[FeatureTable, FeatureView]], - tables_to_keep: List[Union[FeatureTable, FeatureView]], + tables_to_delete: Sequence[Union[FeatureTable, FeatureView]], + tables_to_keep: Sequence[Union[FeatureTable, FeatureView]], + partial: bool, ): """ Reconcile cloud resources with the objects declared in the feature repo. @@ -24,12 +25,14 @@ def update_infra( clean up the corresponding cloud resources. tables_to_keep: Tables that are still in the feature repo. Depending on implementation, provider may or may not need to update the corresponding resources. + partial: if true, then tables_to_delete and tables_to_keep are *not* exhaustive lists. + There may be other tables that are not touched by this update. """ ... @abc.abstractmethod def teardown_infra( - self, project: str, tables: List[Union[FeatureTable, FeatureView]] + self, project: str, tables: Sequence[Union[FeatureTable, FeatureView]] ): """ Tear down all cloud resources for a repo. diff --git a/sdk/python/feast/repo_operations.py b/sdk/python/feast/repo_operations.py index bd584d74c9..235ed79109 100644 --- a/sdk/python/feast/repo_operations.py +++ b/sdk/python/feast/repo_operations.py @@ -103,7 +103,10 @@ def apply_total(repo_config: RepoConfig, repo_path: Path): all_to_keep.extend(repo.feature_views) infra_provider.update_infra( - project, tables_to_delete=all_to_delete, tables_to_keep=all_to_keep + project, + tables_to_delete=all_to_delete, + tables_to_keep=all_to_keep, + partial=False, ) print("Done!") diff --git a/sdk/python/tests/cli/online_read_write_test.py b/sdk/python/tests/cli/online_read_write_test.py index 9ec32f5115..edde2dc743 100644 --- a/sdk/python/tests/cli/online_read_write_test.py +++ b/sdk/python/tests/cli/online_read_write_test.py @@ -1,19 +1,17 @@ from datetime import datetime, timedelta -from pathlib import Path from feast.feature_store import FeatureStore from feast.protos.feast.types.EntityKey_pb2 import EntityKey as EntityKeyProto from feast.protos.feast.types.Value_pb2 import Value as ValueProto -def basic_rw_test(repo_path: Path, project_name: str) -> None: +def basic_rw_test(store: FeatureStore, view_name: str) -> None: """ This is a provider-independent test suite for reading and writing from the online store, to be used by provider-specific tests. """ - store = FeatureStore(repo_path=repo_path, config=None) registry = store._get_registry() - table = registry.get_feature_view(project=project_name, name="driver_locations") + table = registry.get_feature_view(project=store.project, name=view_name) provider = store._get_provider() @@ -26,7 +24,7 @@ def _driver_rw_test(event_ts, created_ts, write, expect_read): write_lat, write_lon = write expect_lat, expect_lon = expect_read provider.online_write_batch( - project=project_name, + project=store.project, table=table, data=[ ( @@ -42,7 +40,7 @@ def _driver_rw_test(event_ts, created_ts, write, expect_read): ) read_rows = provider.online_read( - project=project_name, table=table, entity_keys=[entity_key] + project=store.project, table=table, entity_keys=[entity_key] ) assert len(read_rows) == 1 _, val = read_rows[0] diff --git a/sdk/python/tests/cli/test_cli_local.py b/sdk/python/tests/cli/test_cli_local.py index bee80647a9..f9f5eb38c0 100644 --- a/sdk/python/tests/cli/test_cli_local.py +++ b/sdk/python/tests/cli/test_cli_local.py @@ -2,6 +2,7 @@ from pathlib import Path from textwrap import dedent +from feast.feature_store import FeatureStore from tests.cli.online_read_write_test import basic_rw_test from tests.cli.utils import CliRunner @@ -42,7 +43,10 @@ def test_basic(self) -> None: result = runner.run(["apply", str(repo_path)], cwd=repo_path) assert result.returncode == 0 - basic_rw_test(repo_path, "foo") + basic_rw_test( + FeatureStore(repo_path=str(repo_path), config=None), + view_name="driver_locations", + ) result = runner.run(["teardown", str(repo_path)], cwd=repo_path) assert result.returncode == 0 diff --git a/sdk/python/tests/cli/test_datastore.py b/sdk/python/tests/cli/test_datastore.py index 11e73f843b..239f4d12d0 100644 --- a/sdk/python/tests/cli/test_datastore.py +++ b/sdk/python/tests/cli/test_datastore.py @@ -6,6 +6,7 @@ import pytest +from feast.feature_store import FeatureStore from tests.cli.online_read_write_test import basic_rw_test from tests.cli.utils import CliRunner @@ -48,7 +49,10 @@ def test_basic(self) -> None: result = runner.run(["apply", str(repo_path)], cwd=repo_path) assert result.returncode == 0 - basic_rw_test(repo_path, project_name=self._project_id) + basic_rw_test( + FeatureStore(repo_path=str(repo_path), config=None), + view_name="driver_locations", + ) result = runner.run(["teardown", str(repo_path)], cwd=repo_path) assert result.returncode == 0 diff --git a/sdk/python/tests/cli/test_partial_apply.py b/sdk/python/tests/cli/test_partial_apply.py new file mode 100644 index 0000000000..21f4aa1187 --- /dev/null +++ b/sdk/python/tests/cli/test_partial_apply.py @@ -0,0 +1,41 @@ +from google.protobuf.duration_pb2 import Duration + +from feast import BigQuerySource, Feature, FeatureView, ValueType +from tests.cli.online_read_write_test import basic_rw_test +from tests.cli.utils import CliRunner + + +class TestOnlineRetrieval: + def test_basic(self) -> None: + """ + Add another table to existing repo using partial apply API. Make sure both the table + applied via CLI apply and the new table are passing RW test. + """ + + runner = CliRunner() + with runner.local_repo("example_feature_repo_1.py") as store: + + driver_locations_source = BigQuerySource( + table_ref="rh_prod.ride_hailing_co.drivers", + event_timestamp_column="event_timestamp", + created_timestamp_column="created_timestamp", + ) + + driver_locations_100 = FeatureView( + name="driver_locations_100", + entities=["driver"], + ttl=Duration(seconds=86400 * 1), + features=[ + Feature(name="lat", dtype=ValueType.FLOAT), + Feature(name="lon", dtype=ValueType.STRING), + Feature(name="name", dtype=ValueType.STRING), + ], + online=True, + input=driver_locations_source, + tags={}, + ) + + store.apply([driver_locations_100]) + + basic_rw_test(store, view_name="driver_locations") + basic_rw_test(store, view_name="driver_locations_100")