Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Implement apply_materialization and infra methods in sql registry #2775

Merged
merged 1 commit into from
Jun 9, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 68 additions & 31 deletions sdk/python/feast/infra/registry_stores/sql.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from datetime import datetime
from pathlib import Path
from threading import Lock
from typing import Any, List, Optional
from typing import Any, List, Optional, Union

from sqlalchemy import ( # type: ignore
BigInteger,
Expand Down Expand Up @@ -39,6 +39,7 @@
FeatureService as FeatureServiceProto,
)
from feast.protos.feast.core.FeatureView_pb2 import FeatureView as FeatureViewProto
from feast.protos.feast.core.InfraObject_pb2 import Infra as InfraProto
from feast.protos.feast.core.OnDemandFeatureView_pb2 import (
OnDemandFeatureView as OnDemandFeatureViewProto,
)
Expand Down Expand Up @@ -138,6 +139,14 @@
Column("validation_reference_proto", LargeBinary, nullable=False),
)

managed_infra = Table(
"managed_infra",
metadata,
Column("infra_name", String(50), primary_key=True),
Column("last_updated_timestamp", BigInteger, nullable=False),
Column("infra_proto", LargeBinary, nullable=False),
)


class SqlRegistry(BaseRegistry):
def __init__(
Expand Down Expand Up @@ -168,6 +177,7 @@ def teardown(self):
conn.execute(stmt)

def refresh(self):
# This method is a no-op since we're always reading the latest values from the db.
pass

def get_stream_feature_view(
Expand Down Expand Up @@ -353,16 +363,7 @@ def apply_data_source(
def apply_feature_view(
self, feature_view: BaseFeatureView, project: str, commit: bool = True
):
if isinstance(feature_view, StreamFeatureView):
fv_table = stream_feature_views
elif isinstance(feature_view, FeatureView):
fv_table = feature_views
elif isinstance(feature_view, OnDemandFeatureView):
fv_table = on_demand_feature_views
elif isinstance(feature_view, RequestFeatureView):
fv_table = request_feature_views
else:
raise ValueError(f"Unexpected feature view type: {type(feature_view)}")
fv_table = self._infer_fv_table(feature_view)

return self._apply_object(
fv_table, "feature_view_name", feature_view, "feature_view_proto"
Expand Down Expand Up @@ -457,7 +458,25 @@ def apply_materialization(
end_date: datetime,
commit: bool = True,
):
pass
table = self._infer_fv_table(feature_view)
python_class, proto_class = self._infer_fv_classes(feature_view)

if python_class in {RequestFeatureView, OnDemandFeatureView}:
raise ValueError(
f"Cannot apply materialization for feature {feature_view.name} of type {python_class}"
)
fv: Union[FeatureView, StreamFeatureView] = self._get_object(
table,
feature_view.name,
project,
proto_class,
python_class,
"feature_view_name",
"feature_view_proto",
FeatureViewNotFoundException,
)
fv.materialization_intervals.append((start_date, end_date))
self._apply_object(table, "feature_view_name", fv, "feature_view_proto")

def delete_validation_reference(self, name: str, project: str, commit: bool = True):
self._delete_object(
Expand All @@ -469,27 +488,29 @@ def delete_validation_reference(self, name: str, project: str, commit: bool = Tr
)

def update_infra(self, infra: Infra, project: str, commit: bool = True):
pass
self._apply_object(
managed_infra, "infra_name", infra, "infra_proto", name="infra_obj"
)

def get_infra(self, project: str, allow_cache: bool = False) -> Infra:
return Infra()
return self._get_object(
managed_infra,
"infra_obj",
project,
InfraProto,
Infra,
"infra_name",
"infra_proto",
None,
)

def apply_user_metadata(
self,
project: str,
feature_view: BaseFeatureView,
metadata_bytes: Optional[bytes],
):
if isinstance(feature_view, StreamFeatureView):
table = stream_feature_views
elif isinstance(feature_view, FeatureView):
table = feature_views
elif isinstance(feature_view, OnDemandFeatureView):
table = on_demand_feature_views
elif isinstance(feature_view, RequestFeatureView):
table = request_feature_views
else:
raise ValueError(f"Unexpected feature view type: {type(feature_view)}")
table = self._infer_fv_table(feature_view)

name = feature_view.name
with self.engine.connect() as conn:
Expand All @@ -511,9 +532,7 @@ def apply_user_metadata(
else:
raise FeatureViewNotFoundException(feature_view.name, project=project)

def get_user_metadata(
self, project: str, feature_view: BaseFeatureView
) -> Optional[bytes]:
def _infer_fv_table(self, feature_view):
if isinstance(feature_view, StreamFeatureView):
table = stream_feature_views
elif isinstance(feature_view, FeatureView):
Expand All @@ -524,6 +543,25 @@ def get_user_metadata(
table = request_feature_views
else:
raise ValueError(f"Unexpected feature view type: {type(feature_view)}")
return table

def _infer_fv_classes(self, feature_view):
if isinstance(feature_view, StreamFeatureView):
python_class, proto_class = StreamFeatureView, StreamFeatureViewProto
elif isinstance(feature_view, FeatureView):
python_class, proto_class = FeatureView, FeatureViewProto
elif isinstance(feature_view, OnDemandFeatureView):
python_class, proto_class = OnDemandFeatureView, OnDemandFeatureViewProto
elif isinstance(feature_view, RequestFeatureView):
python_class, proto_class = RequestFeatureView, RequestFeatureViewProto
else:
raise ValueError(f"Unexpected feature view type: {type(feature_view)}")
return python_class, proto_class

def get_user_metadata(
self, project: str, feature_view: BaseFeatureView
) -> Optional[bytes]:
table = self._infer_fv_table(feature_view)

name = feature_view.name
with self.engine.connect() as conn:
Expand Down Expand Up @@ -556,12 +594,11 @@ def proto(self) -> RegistryProto:
return r

def commit(self):
# This method is a no-op since we're always writing values eagerly to the db.
pass

def _apply_object(
self, table, id_field_name, obj, proto_field_name,
):
name = obj.name
def _apply_object(self, table, id_field_name, obj, proto_field_name, name=None):
name = name or obj.name
with self.engine.connect() as conn:
stmt = select(table).where(getattr(table.c, id_field_name) == name)
row = conn.execute(stmt).first()
Expand Down