diff --git a/Makefile b/Makefile index 1c421e8..8e884fa 100644 --- a/Makefile +++ b/Makefile @@ -24,19 +24,19 @@ publish-testpypi: ## Publish to testpipy publish-pypi: ## Publish to pipy twine upload --repository pypi dist/* -test-polars: +test-polars: ## Run tests for polars backend FEAST_USAGE=False IS_TEST=True python3 ./tests/run.py --polars -test-dask: +test-dask: ## Run tests for dask backend FEAST_USAGE=False IS_TEST=True python3 ./tests/run.py --dask -test-ray: +test-ray: ## Run tests for ray backend FEAST_USAGE=False IS_TEST=True python3 ./tests/run.py --ray -test-nospark: +test-nospark: FEAST_USAGE=False IS_TEST=True python3 ./tests/run.py --nospark -test-spark: +test-spark: ## Run tests for spark backend PYSPARK_PYTHON=/opt/conda/bin/python3 PYSPARK_DRIVER_PYTHON=/opt/conda/bin/python3 FEAST_USAGE=False IS_TEST=True spark-submit \ --packages io.delta:delta-core_2.12:1.1.0 \ --conf "spark.sql.extensions=io.delta.sql.DeltaSparkSessionExtension" \ diff --git a/setup.py b/setup.py index aa4fff2..9b92be1 100644 --- a/setup.py +++ b/setup.py @@ -4,7 +4,7 @@ REQUIRES_PYTHON = ">=3.7.0" INSTALL_REQUIRE = [ - "feast>=0.18.0", + "feast~=0.22.1", "polars>=0.13.18", ] @@ -38,7 +38,7 @@ setup( name=NAME, - version="0.0.3", + version="0.0.4", author="Qooba", long_description=open("README.md").read(), long_description_content_type="text/markdown", diff --git a/tests/integration/e2e/test_e2e_basic.py b/tests/integration/e2e/test_e2e_basic.py index 80bd4ee..2b89ec3 100644 --- a/tests/integration/e2e/test_e2e_basic.py +++ b/tests/integration/e2e/test_e2e_basic.py @@ -5,9 +5,14 @@ import pandas as pd from feast import FeatureStore from tests.generators import Generator, csv, parquet, start_date, end_date -from datetime import datetime +from datetime import datetime, timezone +from yummy import select_all -def e2e_basic(feature_store: FeatureStore, tmp_dir: TemporaryDirectory, backend: str): + +def entity_df_all(): + return select_all(datetime(2021, 10, 2, 1, 0, 0, tzinfo=timezone.utc)) + +def e2e_basic(feature_store: FeatureStore, tmp_dir: TemporaryDirectory, backend: str, entity_df: pd.DataFrame): """ This will test all backends with basic data stores (parquet and csv) """ @@ -19,7 +24,6 @@ def e2e_basic(feature_store: FeatureStore, tmp_dir: TemporaryDirectory, backend: entity = Generator.entity() feature_store.apply([entity, csv_fv, parquet_fv]) - entity_df = Generator.entity_df() feature_vector = feature_store.get_historical_features( features=[ @@ -30,7 +34,7 @@ def e2e_basic(feature_store: FeatureStore, tmp_dir: TemporaryDirectory, backend: feature_store.materialize(start_date=start_date, end_date=end_date) - fv = feature_vector[feature_vector.entity_id == 1].to_dict(orient="records")[0] + fv = feature_vector[feature_vector.entity_id.astype('int') == 1].to_dict(orient="records")[0] csv_f0 = float(fv[f"{csv_fv_name}__f0"]) parquet_f0 = float(fv[f"{parquet_fv_name}__f0"]) @@ -54,22 +58,47 @@ def test_e2e_basic_polars(feature_store: FeatureStore, tmp_dir: TemporaryDirecto """ This will test polars backend with basic data stores (parquet and csv) """ - e2e_basic(feature_store, tmp_dir, "polars") + entity_df = Generator.entity_df() + e2e_basic(feature_store, tmp_dir, "polars", entity_df) + +@pytest.mark.polars +def test_e2e_basic_select_all_polars(feature_store: FeatureStore, tmp_dir: TemporaryDirectory): + """ + This will test polars backend with basic data stores (parquet and csv) + """ + e2e_basic(feature_store, tmp_dir, "polars", entity_df_all()) + @pytest.mark.dask def test_e2e_basic_dask(feature_store: FeatureStore, tmp_dir: TemporaryDirectory): """ This will test dask backend with basic data stores (parquet and csv) """ - e2e_basic(feature_store, tmp_dir, "dask") + entity_df = Generator.entity_df() + e2e_basic(feature_store, tmp_dir, "dask", entity_df) +@pytest.mark.dask +def test_e2e_basic_select_all_dask(feature_store: FeatureStore, tmp_dir: TemporaryDirectory): + """ + This will test dask backend with basic data stores (parquet and csv) + """ + e2e_basic(feature_store, tmp_dir, "dask", entity_df_all()) @pytest.mark.ray def test_e2e_basic_ray(feature_store: FeatureStore, tmp_dir: TemporaryDirectory): """ This will test ray backend with basic data stores (parquet and csv) """ - e2e_basic(feature_store, tmp_dir, "ray") + entity_df = Generator.entity_df() + e2e_basic(feature_store, tmp_dir, "ray", entity_df) + +@pytest.mark.ray +def test_e2e_basic_select_all_ray(feature_store: FeatureStore, tmp_dir: TemporaryDirectory): + """ + This will test ray backend with basic data stores (parquet and csv) + """ + e2e_basic(feature_store, tmp_dir, "ray", entity_df_all()) + @pytest.mark.parametrize("backend", ["polars", "dask", "ray"]) @@ -78,7 +107,8 @@ def test_e2e_basic_nospark(feature_store: FeatureStore, tmp_dir: TemporaryDirect """ This will test all backends with basic data stores (parquet and csv) """ - e2e_basic(feature_store, tmp_dir, backend) + entity_df = Generator.entity_df() + e2e_basic(feature_store, tmp_dir, backend, entity_df) @pytest.mark.parametrize("backend", ["spark", "polars", "dask", "ray"]) @@ -87,4 +117,13 @@ def test_e2e_basic_spark(feature_store: FeatureStore, tmp_dir: TemporaryDirector """ This will test all backends with basic data stores (parquet and csv) """ - e2e_basic(feature_store, tmp_dir, backend) + entity_df = Generator.entity_df() + e2e_basic(feature_store, tmp_dir, backend, entity_df) + +@pytest.mark.parametrize("backend", ["spark", "polars", "dask", "ray"]) +@pytest.mark.spark +def test_e2e_basic_select_all_spark(feature_store: FeatureStore, tmp_dir: TemporaryDirectory, backend: str): + """ + This will test all backends with basic data stores (parquet and csv) + """ + e2e_basic(feature_store, tmp_dir, backend, entity_df_all()) diff --git a/tests/integration/e2e/test_e2e_iceberg.py b/tests/integration/e2e/test_e2e_iceberg.py index 26a52f8..b3bfa7f 100644 --- a/tests/integration/e2e/test_e2e_iceberg.py +++ b/tests/integration/e2e/test_e2e_iceberg.py @@ -36,7 +36,6 @@ def test_e2e_iceberg_only(feature_store: FeatureStore, tmp_dir: TemporaryDirecto ], entity_df=entity_df, full_feature_names=True ).to_df() - print(feature_vector) assert(feature_vector[feature_vector.entity_id == 1][f"{iceberg_fv_name}__f0"] is not None) @pytest.mark.parametrize("backend", ["spark"]) diff --git a/yummy/__init__.py b/yummy/__init__.py index a2a1d30..b3800b6 100644 --- a/yummy/__init__.py +++ b/yummy/__init__.py @@ -1,8 +1,10 @@ import os -from .backends.backend import YummyOfflineStore, YummyOfflineStoreConfig +from .backends.backend import YummyOfflineStore, YummyOfflineStoreConfig, select_all from .sources.file import ParquetSource, CsvSource from .sources.delta import DeltaSource from .sources.iceberg import IcebergSource +from .providers.provider import YummyProvider +from .registries.registry import YummyRegistryStore class DeprecationHelper(object): def __init__(self, new_target, old_name): @@ -27,4 +29,16 @@ def __getattr__(self, attr): IcebergDataSource=DeprecationHelper(IcebergSource, "IcebergDataSource") os.environ["FEAST_USAGE"]="False" -__all__ = ["YummyOfflineStore", "YummyOfflineStoreConfig", "ParquetSource", "CsvSource", "DeltaSource", "IcebergSource", "ParquetDataSource", "CsvDataSource", "DeltaDataSource", "IcebergDataSource"] +__all__ = ["YummyProvider", + "YummyRegistryStore", + "YummyOfflineStore", + "YummyOfflineStoreConfig", + "ParquetSource", + "CsvSource", + "DeltaSource", + "IcebergSource", + "ParquetDataSource", + "CsvDataSource", + "DeltaDataSource", + "IcebergDataSource", + "select_all"] diff --git a/yummy/backends/backend.py b/yummy/backends/backend.py index f8a660a..b3117b9 100644 --- a/yummy/backends/backend.py +++ b/yummy/backends/backend.py @@ -28,6 +28,17 @@ from feast.usage import log_exceptions_and_usage from feast.importer import import_class from enum import Enum +import pandas as pd +from datetime import datetime + +YUMMY_ALL = "@yummy*" + +def select_all(event_timestamp: datetime): + """ + selects all entities during fetching historical features for specified event timestamp + """ + return pd.DataFrame.from_dict({YUMMY_ALL: ["*"], DEFAULT_ENTITY_DF_EVENT_TIMESTAMP_COL: [ event_timestamp ]}) + class BackendType(str, Enum): dask = "dask" @@ -60,6 +71,18 @@ def backend_type(self) -> BackendType: def retrival_job_type(self): raise NotImplementedError("Retrival job type not defined") + @abstractmethod + def first_event_timestamp( + self, + entity_df: Union[pd.DataFrame, Any], + column_name: str + ) -> datetime: + """ + Fetch first event timestamp + """ + ... + + @abstractmethod def prepare_entity_df( self, @@ -184,17 +207,20 @@ def merge( # tmp join keys needed for cross join with null join table view tmp_join_keys = [] if not join_keys: - self.add_static_column(entity_df_with_features, "__tmp", 1) - self.add_static_column(df_to_join, "__tmp", 1) + entity_df_with_features=self.add_static_column(entity_df_with_features, "__tmp", 1) + df_to_join=self.add_static_column(df_to_join, "__tmp", 1) tmp_join_keys = ["__tmp"] # Get only data with requested entities - df_to_join = self.join( - entity_df_with_features, - df_to_join, - join_keys or tmp_join_keys, - feature_view, - ) + if YUMMY_ALL in entity_df_with_features.columns: + df_to_join=self.add_static_column(df_to_join, DEFAULT_ENTITY_DF_EVENT_TIMESTAMP_COL, self.first_event_timestamp(entity_df_with_features, DEFAULT_ENTITY_DF_EVENT_TIMESTAMP_COL)) + else: + df_to_join = self.join( + entity_df_with_features, + df_to_join, + join_keys or tmp_join_keys, + feature_view, + ) if tmp_join_keys: df_to_join = self.drop(df_to_join, tmp_join_keys) diff --git a/yummy/backends/dask.py b/yummy/backends/dask.py index 8fab349..076a37c 100644 --- a/yummy/backends/dask.py +++ b/yummy/backends/dask.py @@ -61,6 +61,17 @@ def prepare_entity_df( return entity_df + def first_event_timestamp( + self, + entity_df: Union[pd.DataFrame, Any], + column_name: str + ) -> datetime: + """ + Fetch first event timestamp + """ + return entity_df.compute()[column_name][0] + + def normalize_timezone( self, entity_df_with_features: Union[pd.DataFrame, Any], diff --git a/yummy/backends/polars.py b/yummy/backends/polars.py index 8e7c852..a5cbc8f 100644 --- a/yummy/backends/polars.py +++ b/yummy/backends/polars.py @@ -37,6 +37,17 @@ def backend_type(self) -> BackendType: def retrival_job_type(self): return PolarsRetrievalJob + def first_event_timestamp( + self, + entity_df: Union[pd.DataFrame, Any], + column_name: str + ) -> datetime: + """ + Fetch first event timestamp + """ + return entity_df[column_name][0] + + def prepare_entity_df( self, entity_df: Union[pd.DataFrame, Any], @@ -159,7 +170,7 @@ def drop_duplicates( df_to_join: Union[pd.DataFrame, Any], subset: List[str], ) -> Union[pd.DataFrame, Any]: - return df_to_join.distinct(subset=subset, keep='last') + return df_to_join.unique(subset=subset, keep='last') def drop( self, diff --git a/yummy/backends/spark.py b/yummy/backends/spark.py index 42755b4..815371f 100644 --- a/yummy/backends/spark.py +++ b/yummy/backends/spark.py @@ -45,6 +45,16 @@ def retrival_job_type(self): def spark_session(self) -> SparkSession: return self._spark_session + def first_event_timestamp( + self, + entity_df: Union[pd.DataFrame, Any], + column_name: str + ) -> datetime: + """ + Fetch first event timestamp + """ + return entity_df.toPandas()[column_name][0] + def prepare_entity_df( self, entity_df: Union[pd.DataFrame, Any], diff --git a/yummy/providers/provider.py b/yummy/providers/provider.py new file mode 100644 index 0000000..6923e9e --- /dev/null +++ b/yummy/providers/provider.py @@ -0,0 +1,18 @@ +from feast.infra.local import LocalProvider +from feast.repo_config import RepoConfig +from feast.infra.offline_stores.offline_utils import get_offline_store_from_config +from yummy import YummyOfflineStoreConfig + +class YummyProvider(LocalProvider): + + def __init__(self, config: RepoConfig): + super().__init__(config) + + if hasattr(config, "backend"): + config.offline_store.backend=config.backend + + if hasattr(config, "backend_config"): + config.offline_store.config=config.backend_config + + + diff --git a/yummy/registries/registry.py b/yummy/registries/registry.py new file mode 100644 index 0000000..36ebca0 --- /dev/null +++ b/yummy/registries/registry.py @@ -0,0 +1,43 @@ +import uuid +import os +from datetime import datetime +from pathlib import Path + +from feast.registry_store import RegistryStore +from feast.protos.feast.core.Registry_pb2 import Registry as RegistryProto +from feast.repo_config import RegistryConfig +from feast.usage import log_exceptions_and_usage + +class YummyRegistryStore(RegistryStore): + def __init__(self, registry_config: RegistryConfig, repo_path: Path): + self._registry: RegistryStore = self._factory(registry_config, repo_path) + + def _factory(self, registry_config: RegistryConfig, repo_path: Path): + path = registry_config.path + + if path.startswith("s3://"): + from feast.infra.aws import S3RegistryStore + if hasattr(registry_config, "endpoint_url"): + os.environ["FEAST_S3_ENDPOINT_URL"]=registry_config.s3_endpoint_override + return S3RegistryStore(registry_config, repo_path) + elif path.startswith("gs://"): + from feast.infra.gcp import GCSRegistryStore + return GCSRegistryStore(registry_config, repo_path) + else: + from feast.infra.local import LocalRegistryStore + return LocalRegistryStore(registry_config, repo_path) + + @log_exceptions_and_usage(registry="local") + def get_registry_proto(self): + return self._registry.get_registry_proto() + + @log_exceptions_and_usage(registry="local") + def update_registry_proto(self, registry_proto: RegistryProto): + self._registry.update_registry_proto(registry_proto) + + def teardown(self): + self._registry.teardown() + + def _write_registry(self, registry_proto: RegistryProto): + self._registry._write_registry(registry_proto) +