Skip to content

Commit

Permalink
Merge pull request #2 from yummyml/fetch_all_entities
Browse files Browse the repository at this point in the history
add select_all, yummy provider and registry
  • Loading branch information
qooba authored Aug 27, 2022
2 parents 245280d + 58bbbf8 commit 6af3ec1
Show file tree
Hide file tree
Showing 11 changed files with 199 additions and 28 deletions.
10 changes: 5 additions & 5 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -24,19 +24,19 @@ publish-testpypi: ## Publish to testpipy
publish-pypi: ## Publish to pipy
twine upload --repository pypi dist/*

test-polars:
test-polars: ## Run tests for polars backend
FEAST_USAGE=False IS_TEST=True python3 ./tests/run.py --polars

test-dask:
test-dask: ## Run tests for dask backend
FEAST_USAGE=False IS_TEST=True python3 ./tests/run.py --dask

test-ray:
test-ray: ## Run tests for ray backend
FEAST_USAGE=False IS_TEST=True python3 ./tests/run.py --ray

test-nospark:
test-nospark:
FEAST_USAGE=False IS_TEST=True python3 ./tests/run.py --nospark

test-spark:
test-spark: ## Run tests for spark backend
PYSPARK_PYTHON=/opt/conda/bin/python3 PYSPARK_DRIVER_PYTHON=/opt/conda/bin/python3 FEAST_USAGE=False IS_TEST=True spark-submit \
--packages io.delta:delta-core_2.12:1.1.0 \
--conf "spark.sql.extensions=io.delta.sql.DeltaSparkSessionExtension" \
Expand Down
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
REQUIRES_PYTHON = ">=3.7.0"

INSTALL_REQUIRE = [
"feast>=0.18.0",
"feast~=0.22.1",
"polars>=0.13.18",
]

Expand Down Expand Up @@ -38,7 +38,7 @@

setup(
name=NAME,
version="0.0.3",
version="0.0.4",
author="Qooba",
long_description=open("README.md").read(),
long_description_content_type="text/markdown",
Expand Down
57 changes: 48 additions & 9 deletions tests/integration/e2e/test_e2e_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,14 @@
import pandas as pd
from feast import FeatureStore
from tests.generators import Generator, csv, parquet, start_date, end_date
from datetime import datetime
from datetime import datetime, timezone
from yummy import select_all

def e2e_basic(feature_store: FeatureStore, tmp_dir: TemporaryDirectory, backend: str):

def entity_df_all():
return select_all(datetime(2021, 10, 2, 1, 0, 0, tzinfo=timezone.utc))

def e2e_basic(feature_store: FeatureStore, tmp_dir: TemporaryDirectory, backend: str, entity_df: pd.DataFrame):
"""
This will test all backends with basic data stores (parquet and csv)
"""
Expand All @@ -19,7 +24,6 @@ def e2e_basic(feature_store: FeatureStore, tmp_dir: TemporaryDirectory, backend:
entity = Generator.entity()
feature_store.apply([entity, csv_fv, parquet_fv])

entity_df = Generator.entity_df()

feature_vector = feature_store.get_historical_features(
features=[
Expand All @@ -30,7 +34,7 @@ def e2e_basic(feature_store: FeatureStore, tmp_dir: TemporaryDirectory, backend:

feature_store.materialize(start_date=start_date, end_date=end_date)

fv = feature_vector[feature_vector.entity_id == 1].to_dict(orient="records")[0]
fv = feature_vector[feature_vector.entity_id.astype('int') == 1].to_dict(orient="records")[0]
csv_f0 = float(fv[f"{csv_fv_name}__f0"])
parquet_f0 = float(fv[f"{parquet_fv_name}__f0"])

Expand All @@ -54,22 +58,47 @@ def test_e2e_basic_polars(feature_store: FeatureStore, tmp_dir: TemporaryDirecto
"""
This will test polars backend with basic data stores (parquet and csv)
"""
e2e_basic(feature_store, tmp_dir, "polars")
entity_df = Generator.entity_df()
e2e_basic(feature_store, tmp_dir, "polars", entity_df)

@pytest.mark.polars
def test_e2e_basic_select_all_polars(feature_store: FeatureStore, tmp_dir: TemporaryDirectory):
"""
This will test polars backend with basic data stores (parquet and csv)
"""
e2e_basic(feature_store, tmp_dir, "polars", entity_df_all())


@pytest.mark.dask
def test_e2e_basic_dask(feature_store: FeatureStore, tmp_dir: TemporaryDirectory):
"""
This will test dask backend with basic data stores (parquet and csv)
"""
e2e_basic(feature_store, tmp_dir, "dask")
entity_df = Generator.entity_df()
e2e_basic(feature_store, tmp_dir, "dask", entity_df)

@pytest.mark.dask
def test_e2e_basic_select_all_dask(feature_store: FeatureStore, tmp_dir: TemporaryDirectory):
"""
This will test dask backend with basic data stores (parquet and csv)
"""
e2e_basic(feature_store, tmp_dir, "dask", entity_df_all())

@pytest.mark.ray
def test_e2e_basic_ray(feature_store: FeatureStore, tmp_dir: TemporaryDirectory):
"""
This will test ray backend with basic data stores (parquet and csv)
"""
e2e_basic(feature_store, tmp_dir, "ray")
entity_df = Generator.entity_df()
e2e_basic(feature_store, tmp_dir, "ray", entity_df)

@pytest.mark.ray
def test_e2e_basic_select_all_ray(feature_store: FeatureStore, tmp_dir: TemporaryDirectory):
"""
This will test ray backend with basic data stores (parquet and csv)
"""
e2e_basic(feature_store, tmp_dir, "ray", entity_df_all())



@pytest.mark.parametrize("backend", ["polars", "dask", "ray"])
Expand All @@ -78,7 +107,8 @@ def test_e2e_basic_nospark(feature_store: FeatureStore, tmp_dir: TemporaryDirect
"""
This will test all backends with basic data stores (parquet and csv)
"""
e2e_basic(feature_store, tmp_dir, backend)
entity_df = Generator.entity_df()
e2e_basic(feature_store, tmp_dir, backend, entity_df)


@pytest.mark.parametrize("backend", ["spark", "polars", "dask", "ray"])
Expand All @@ -87,4 +117,13 @@ def test_e2e_basic_spark(feature_store: FeatureStore, tmp_dir: TemporaryDirector
"""
This will test all backends with basic data stores (parquet and csv)
"""
e2e_basic(feature_store, tmp_dir, backend)
entity_df = Generator.entity_df()
e2e_basic(feature_store, tmp_dir, backend, entity_df)

@pytest.mark.parametrize("backend", ["spark", "polars", "dask", "ray"])
@pytest.mark.spark
def test_e2e_basic_select_all_spark(feature_store: FeatureStore, tmp_dir: TemporaryDirectory, backend: str):
"""
This will test all backends with basic data stores (parquet and csv)
"""
e2e_basic(feature_store, tmp_dir, backend, entity_df_all())
1 change: 0 additions & 1 deletion tests/integration/e2e/test_e2e_iceberg.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ def test_e2e_iceberg_only(feature_store: FeatureStore, tmp_dir: TemporaryDirecto
], entity_df=entity_df, full_feature_names=True
).to_df()

print(feature_vector)
assert(feature_vector[feature_vector.entity_id == 1][f"{iceberg_fv_name}__f0"] is not None)

@pytest.mark.parametrize("backend", ["spark"])
Expand Down
18 changes: 16 additions & 2 deletions yummy/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import os
from .backends.backend import YummyOfflineStore, YummyOfflineStoreConfig
from .backends.backend import YummyOfflineStore, YummyOfflineStoreConfig, select_all
from .sources.file import ParquetSource, CsvSource
from .sources.delta import DeltaSource
from .sources.iceberg import IcebergSource
from .providers.provider import YummyProvider
from .registries.registry import YummyRegistryStore

class DeprecationHelper(object):
def __init__(self, new_target, old_name):
Expand All @@ -27,4 +29,16 @@ def __getattr__(self, attr):
IcebergDataSource=DeprecationHelper(IcebergSource, "IcebergDataSource")

os.environ["FEAST_USAGE"]="False"
__all__ = ["YummyOfflineStore", "YummyOfflineStoreConfig", "ParquetSource", "CsvSource", "DeltaSource", "IcebergSource", "ParquetDataSource", "CsvDataSource", "DeltaDataSource", "IcebergDataSource"]
__all__ = ["YummyProvider",
"YummyRegistryStore",
"YummyOfflineStore",
"YummyOfflineStoreConfig",
"ParquetSource",
"CsvSource",
"DeltaSource",
"IcebergSource",
"ParquetDataSource",
"CsvDataSource",
"DeltaDataSource",
"IcebergDataSource",
"select_all"]
42 changes: 34 additions & 8 deletions yummy/backends/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,17 @@
from feast.usage import log_exceptions_and_usage
from feast.importer import import_class
from enum import Enum
import pandas as pd
from datetime import datetime

YUMMY_ALL = "@yummy*"

def select_all(event_timestamp: datetime):
"""
selects all entities during fetching historical features for specified event timestamp
"""
return pd.DataFrame.from_dict({YUMMY_ALL: ["*"], DEFAULT_ENTITY_DF_EVENT_TIMESTAMP_COL: [ event_timestamp ]})


class BackendType(str, Enum):
dask = "dask"
Expand Down Expand Up @@ -60,6 +71,18 @@ def backend_type(self) -> BackendType:
def retrival_job_type(self):
raise NotImplementedError("Retrival job type not defined")

@abstractmethod
def first_event_timestamp(
self,
entity_df: Union[pd.DataFrame, Any],
column_name: str
) -> datetime:
"""
Fetch first event timestamp
"""
...


@abstractmethod
def prepare_entity_df(
self,
Expand Down Expand Up @@ -184,17 +207,20 @@ def merge(
# tmp join keys needed for cross join with null join table view
tmp_join_keys = []
if not join_keys:
self.add_static_column(entity_df_with_features, "__tmp", 1)
self.add_static_column(df_to_join, "__tmp", 1)
entity_df_with_features=self.add_static_column(entity_df_with_features, "__tmp", 1)
df_to_join=self.add_static_column(df_to_join, "__tmp", 1)
tmp_join_keys = ["__tmp"]

# Get only data with requested entities
df_to_join = self.join(
entity_df_with_features,
df_to_join,
join_keys or tmp_join_keys,
feature_view,
)
if YUMMY_ALL in entity_df_with_features.columns:
df_to_join=self.add_static_column(df_to_join, DEFAULT_ENTITY_DF_EVENT_TIMESTAMP_COL, self.first_event_timestamp(entity_df_with_features, DEFAULT_ENTITY_DF_EVENT_TIMESTAMP_COL))
else:
df_to_join = self.join(
entity_df_with_features,
df_to_join,
join_keys or tmp_join_keys,
feature_view,
)

if tmp_join_keys:
df_to_join = self.drop(df_to_join, tmp_join_keys)
Expand Down
11 changes: 11 additions & 0 deletions yummy/backends/dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,17 @@ def prepare_entity_df(

return entity_df

def first_event_timestamp(
self,
entity_df: Union[pd.DataFrame, Any],
column_name: str
) -> datetime:
"""
Fetch first event timestamp
"""
return entity_df.compute()[column_name][0]


def normalize_timezone(
self,
entity_df_with_features: Union[pd.DataFrame, Any],
Expand Down
13 changes: 12 additions & 1 deletion yummy/backends/polars.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,17 @@ def backend_type(self) -> BackendType:
def retrival_job_type(self):
return PolarsRetrievalJob

def first_event_timestamp(
self,
entity_df: Union[pd.DataFrame, Any],
column_name: str
) -> datetime:
"""
Fetch first event timestamp
"""
return entity_df[column_name][0]


def prepare_entity_df(
self,
entity_df: Union[pd.DataFrame, Any],
Expand Down Expand Up @@ -159,7 +170,7 @@ def drop_duplicates(
df_to_join: Union[pd.DataFrame, Any],
subset: List[str],
) -> Union[pd.DataFrame, Any]:
return df_to_join.distinct(subset=subset, keep='last')
return df_to_join.unique(subset=subset, keep='last')

def drop(
self,
Expand Down
10 changes: 10 additions & 0 deletions yummy/backends/spark.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,16 @@ def retrival_job_type(self):
def spark_session(self) -> SparkSession:
return self._spark_session

def first_event_timestamp(
self,
entity_df: Union[pd.DataFrame, Any],
column_name: str
) -> datetime:
"""
Fetch first event timestamp
"""
return entity_df.toPandas()[column_name][0]

def prepare_entity_df(
self,
entity_df: Union[pd.DataFrame, Any],
Expand Down
18 changes: 18 additions & 0 deletions yummy/providers/provider.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from feast.infra.local import LocalProvider
from feast.repo_config import RepoConfig
from feast.infra.offline_stores.offline_utils import get_offline_store_from_config
from yummy import YummyOfflineStoreConfig

class YummyProvider(LocalProvider):

def __init__(self, config: RepoConfig):
super().__init__(config)

if hasattr(config, "backend"):
config.offline_store.backend=config.backend

if hasattr(config, "backend_config"):
config.offline_store.config=config.backend_config



43 changes: 43 additions & 0 deletions yummy/registries/registry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import uuid
import os
from datetime import datetime
from pathlib import Path

from feast.registry_store import RegistryStore
from feast.protos.feast.core.Registry_pb2 import Registry as RegistryProto
from feast.repo_config import RegistryConfig
from feast.usage import log_exceptions_and_usage

class YummyRegistryStore(RegistryStore):
def __init__(self, registry_config: RegistryConfig, repo_path: Path):
self._registry: RegistryStore = self._factory(registry_config, repo_path)

def _factory(self, registry_config: RegistryConfig, repo_path: Path):
path = registry_config.path

if path.startswith("s3://"):
from feast.infra.aws import S3RegistryStore
if hasattr(registry_config, "endpoint_url"):
os.environ["FEAST_S3_ENDPOINT_URL"]=registry_config.s3_endpoint_override
return S3RegistryStore(registry_config, repo_path)
elif path.startswith("gs://"):
from feast.infra.gcp import GCSRegistryStore
return GCSRegistryStore(registry_config, repo_path)
else:
from feast.infra.local import LocalRegistryStore
return LocalRegistryStore(registry_config, repo_path)

@log_exceptions_and_usage(registry="local")
def get_registry_proto(self):
return self._registry.get_registry_proto()

@log_exceptions_and_usage(registry="local")
def update_registry_proto(self, registry_proto: RegistryProto):
self._registry.update_registry_proto(registry_proto)

def teardown(self):
self._registry.teardown()

def _write_registry(self, registry_proto: RegistryProto):
self._registry._write_registry(registry_proto)

0 comments on commit 6af3ec1

Please sign in to comment.