Skip to content

Commit

Permalink
fix: Implement apply_materialization and infra methods in sql registry (
Browse files Browse the repository at this point in the history
#2775)

Signed-off-by: Achal Shah <achals@gmail.com>
  • Loading branch information
achals authored Jun 9, 2022
1 parent 846ff4a commit 4ed107c
Showing 1 changed file with 68 additions and 31 deletions.
99 changes: 68 additions & 31 deletions sdk/python/feast/infra/registry_stores/sql.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from datetime import datetime
from pathlib import Path
from threading import Lock
from typing import Any, List, Optional
from typing import Any, List, Optional, Union

from sqlalchemy import ( # type: ignore
BigInteger,
Expand Down Expand Up @@ -39,6 +39,7 @@
FeatureService as FeatureServiceProto,
)
from feast.protos.feast.core.FeatureView_pb2 import FeatureView as FeatureViewProto
from feast.protos.feast.core.InfraObject_pb2 import Infra as InfraProto
from feast.protos.feast.core.OnDemandFeatureView_pb2 import (
OnDemandFeatureView as OnDemandFeatureViewProto,
)
Expand Down Expand Up @@ -138,6 +139,14 @@
Column("validation_reference_proto", LargeBinary, nullable=False),
)

managed_infra = Table(
"managed_infra",
metadata,
Column("infra_name", String(50), primary_key=True),
Column("last_updated_timestamp", BigInteger, nullable=False),
Column("infra_proto", LargeBinary, nullable=False),
)


class SqlRegistry(BaseRegistry):
def __init__(
Expand Down Expand Up @@ -168,6 +177,7 @@ def teardown(self):
conn.execute(stmt)

def refresh(self):
# This method is a no-op since we're always reading the latest values from the db.
pass

def get_stream_feature_view(
Expand Down Expand Up @@ -353,16 +363,7 @@ def apply_data_source(
def apply_feature_view(
self, feature_view: BaseFeatureView, project: str, commit: bool = True
):
if isinstance(feature_view, StreamFeatureView):
fv_table = stream_feature_views
elif isinstance(feature_view, FeatureView):
fv_table = feature_views
elif isinstance(feature_view, OnDemandFeatureView):
fv_table = on_demand_feature_views
elif isinstance(feature_view, RequestFeatureView):
fv_table = request_feature_views
else:
raise ValueError(f"Unexpected feature view type: {type(feature_view)}")
fv_table = self._infer_fv_table(feature_view)

return self._apply_object(
fv_table, "feature_view_name", feature_view, "feature_view_proto"
Expand Down Expand Up @@ -457,7 +458,25 @@ def apply_materialization(
end_date: datetime,
commit: bool = True,
):
pass
table = self._infer_fv_table(feature_view)
python_class, proto_class = self._infer_fv_classes(feature_view)

if python_class in {RequestFeatureView, OnDemandFeatureView}:
raise ValueError(
f"Cannot apply materialization for feature {feature_view.name} of type {python_class}"
)
fv: Union[FeatureView, StreamFeatureView] = self._get_object(
table,
feature_view.name,
project,
proto_class,
python_class,
"feature_view_name",
"feature_view_proto",
FeatureViewNotFoundException,
)
fv.materialization_intervals.append((start_date, end_date))
self._apply_object(table, "feature_view_name", fv, "feature_view_proto")

def delete_validation_reference(self, name: str, project: str, commit: bool = True):
self._delete_object(
Expand All @@ -469,27 +488,29 @@ def delete_validation_reference(self, name: str, project: str, commit: bool = Tr
)

def update_infra(self, infra: Infra, project: str, commit: bool = True):
pass
self._apply_object(
managed_infra, "infra_name", infra, "infra_proto", name="infra_obj"
)

def get_infra(self, project: str, allow_cache: bool = False) -> Infra:
return Infra()
return self._get_object(
managed_infra,
"infra_obj",
project,
InfraProto,
Infra,
"infra_name",
"infra_proto",
None,
)

def apply_user_metadata(
self,
project: str,
feature_view: BaseFeatureView,
metadata_bytes: Optional[bytes],
):
if isinstance(feature_view, StreamFeatureView):
table = stream_feature_views
elif isinstance(feature_view, FeatureView):
table = feature_views
elif isinstance(feature_view, OnDemandFeatureView):
table = on_demand_feature_views
elif isinstance(feature_view, RequestFeatureView):
table = request_feature_views
else:
raise ValueError(f"Unexpected feature view type: {type(feature_view)}")
table = self._infer_fv_table(feature_view)

name = feature_view.name
with self.engine.connect() as conn:
Expand All @@ -511,9 +532,7 @@ def apply_user_metadata(
else:
raise FeatureViewNotFoundException(feature_view.name, project=project)

def get_user_metadata(
self, project: str, feature_view: BaseFeatureView
) -> Optional[bytes]:
def _infer_fv_table(self, feature_view):
if isinstance(feature_view, StreamFeatureView):
table = stream_feature_views
elif isinstance(feature_view, FeatureView):
Expand All @@ -524,6 +543,25 @@ def get_user_metadata(
table = request_feature_views
else:
raise ValueError(f"Unexpected feature view type: {type(feature_view)}")
return table

def _infer_fv_classes(self, feature_view):
if isinstance(feature_view, StreamFeatureView):
python_class, proto_class = StreamFeatureView, StreamFeatureViewProto
elif isinstance(feature_view, FeatureView):
python_class, proto_class = FeatureView, FeatureViewProto
elif isinstance(feature_view, OnDemandFeatureView):
python_class, proto_class = OnDemandFeatureView, OnDemandFeatureViewProto
elif isinstance(feature_view, RequestFeatureView):
python_class, proto_class = RequestFeatureView, RequestFeatureViewProto
else:
raise ValueError(f"Unexpected feature view type: {type(feature_view)}")
return python_class, proto_class

def get_user_metadata(
self, project: str, feature_view: BaseFeatureView
) -> Optional[bytes]:
table = self._infer_fv_table(feature_view)

name = feature_view.name
with self.engine.connect() as conn:
Expand Down Expand Up @@ -556,12 +594,11 @@ def proto(self) -> RegistryProto:
return r

def commit(self):
# This method is a no-op since we're always writing values eagerly to the db.
pass

def _apply_object(
self, table, id_field_name, obj, proto_field_name,
):
name = obj.name
def _apply_object(self, table, id_field_name, obj, proto_field_name, name=None):
name = name or obj.name
with self.engine.connect() as conn:
stmt = select(table).where(getattr(table.c, id_field_name) == name)
row = conn.execute(stmt).first()
Expand Down

0 comments on commit 4ed107c

Please sign in to comment.