Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Update udf tests and add base functions to streaming fcos and fix some nonetype errors #2776

Merged
merged 3 commits into from
Jun 10, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion sdk/python/feast/data_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -503,7 +503,7 @@ def __hash__(self):
@staticmethod
def from_proto(data_source: DataSourceProto):
watermark = None
if data_source.kafka_options.HasField("watermark"):
if data_source.kafka_options.watermark:
watermark = (
timedelta(days=0)
if data_source.kafka_options.watermark.ToNanoseconds() == 0
Expand Down
99 changes: 78 additions & 21 deletions sdk/python/feast/stream_feature_view.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import copy
import functools
import warnings
from datetime import timedelta
Expand All @@ -9,7 +10,7 @@

from feast import utils
from feast.aggregation import Aggregation
from feast.data_source import DataSource, KafkaSource
from feast.data_source import DataSource, KafkaSource, PushSource
from feast.entity import Entity
from feast.feature_view import FeatureView
from feast.field import Field
Expand Down Expand Up @@ -39,6 +40,26 @@ class StreamFeatureView(FeatureView):
"""
NOTE: Stream Feature Views are not yet fully implemented and exist to allow users to register their stream sources and
schemas with Feast.

Attributes:
name: str. The unique name of the stream feature view.
entities: Union[List[Entity], List[str]]. List of entities or entity join keys.
ttl: timedelta. The amount of time this group of features lives. A ttl of 0 indicates that
this group of features lives forever. Note that large ttl's or a ttl of 0
can result in extremely computationally intensive queries.
tags: Dict[str, str]. A dictionary of key-value pairs to store arbitrary metadata.
online: bool. Defines whether this stream feature view is used in online feature retrieval.
description: str. A human-readable description.
owner: The owner of the on demand feature view, typically the email of the primary
maintainer.
schema: List[Field] The schema of the feature view, including feature, timestamp, and entity
columns. If not specified, can be inferred from the underlying data source.
source: DataSource. The stream source of data where this group of features
is stored.
aggregations (optional): List[Aggregation]. List of aggregations registered with the stream feature view.
mode(optional): str. The mode of execution.
timestamp_field (optional): Must be specified if aggregations are specified. Defines the timestamp column on which to aggregate windows.
udf (optional): MethodType The user defined transformation function. This transformation function should have all of the corresponding imports imported within the function.
"""

def __init__(
Expand All @@ -54,18 +75,19 @@ def __init__(
schema: Optional[List[Field]] = None,
source: Optional[DataSource] = None,
aggregations: Optional[List[Aggregation]] = None,
mode: Optional[str] = "spark", # Mode of ingestion/transformation
timestamp_field: Optional[str] = "", # Timestamp for aggregation
mode: Optional[str] = "spark",
timestamp_field: Optional[str] = "",
udf: Optional[MethodType] = None,
):
warnings.warn(
"Stream Feature Views are experimental features in alpha development. "
"Some functionality may still be unstable so functionality can change in the future.",
RuntimeWarning,
)

if source is None:
raise ValueError("Stream Feature views need a source specified")
# source uses the batch_source of the kafkasource in feature_view
raise ValueError("Stream Feature views need a source to be specified")

if (
type(source).__name__ not in SUPPORTED_STREAM_SOURCES
and source.to_proto().type != DataSourceProto.SourceType.CUSTOM_SOURCE
Expand All @@ -74,18 +96,26 @@ def __init__(
f"Stream feature views need a stream source, expected one of {SUPPORTED_STREAM_SOURCES} "
f"or CUSTOM_SOURCE, got {type(source).__name__}: {source.name} instead "
)

if aggregations and not timestamp_field:
raise ValueError(
"aggregations must have a timestamp field associated with them to perform the aggregations"
)

self.aggregations = aggregations or []
self.mode = mode
self.timestamp_field = timestamp_field
self.mode = mode or ""
self.timestamp_field = timestamp_field or ""
self.udf = udf
_batch_source = None
if isinstance(source, KafkaSource):
if isinstance(source, KafkaSource) or isinstance(source, PushSource):
_batch_source = source.batch_source if source.batch_source else None

_ttl = ttl
if not _ttl:
_ttl = timedelta(days=0)
super().__init__(
name=name,
entities=entities,
ttl=ttl,
ttl=_ttl,
batch_source=_batch_source,
stream_source=source,
tags=tags,
Expand All @@ -102,7 +132,10 @@ def __eq__(self, other):

if not super().__eq__(other):
return False

if not self.udf:
return not other.udf
if not other.udf:
return False
if (
self.mode != other.mode
or self.timestamp_field != other.timestamp_field
Expand All @@ -113,13 +146,14 @@ def __eq__(self, other):

return True

def __hash__(self):
def __hash__(self) -> int:
return super().__hash__()

def to_proto(self):
meta = StreamFeatureViewMetaProto(materialization_intervals=[])
if self.created_timestamp:
meta.created_timestamp.FromDatetime(self.created_timestamp)

if self.last_updated_timestamp:
meta.last_updated_timestamp.FromDatetime(self.last_updated_timestamp)

Expand All @@ -134,6 +168,7 @@ def to_proto(self):
ttl_duration = Duration()
ttl_duration.FromTimedelta(self.ttl)

batch_source_proto = None
if self.batch_source:
batch_source_proto = self.batch_source.to_proto()
batch_source_proto.data_source_class_type = f"{self.batch_source.__class__.__module__}.{self.batch_source.__class__.__name__}"
Expand All @@ -143,23 +178,24 @@ def to_proto(self):
stream_source_proto = self.stream_source.to_proto()
stream_source_proto.data_source_class_type = f"{self.stream_source.__class__.__module__}.{self.stream_source.__class__.__name__}"

udf_proto = None
if self.udf:
udf_proto = UserDefinedFunctionProto(
name=self.udf.__name__, body=dill.dumps(self.udf, recurse=True),
)
spec = StreamFeatureViewSpecProto(
name=self.name,
entities=self.entities,
entity_columns=[field.to_proto() for field in self.entity_columns],
features=[field.to_proto() for field in self.schema],
user_defined_function=UserDefinedFunctionProto(
name=self.udf.__name__, body=dill.dumps(self.udf, recurse=True),
)
if self.udf
else None,
user_defined_function=udf_proto,
description=self.description,
tags=self.tags,
owner=self.owner,
ttl=(ttl_duration if ttl_duration is not None else None),
ttl=ttl_duration,
online=self.online,
batch_source=batch_source_proto or None,
stream_source=stream_source_proto,
stream_source=stream_source_proto or None,
timestamp_field=self.timestamp_field,
aggregations=[agg.to_proto() for agg in self.aggregations],
mode=self.mode,
Expand Down Expand Up @@ -239,6 +275,25 @@ def from_proto(cls, sfv_proto):

return sfv_feature_view

def __copy__(self):
fv = StreamFeatureView(
name=self.name,
schema=self.schema,
entities=self.entities,
ttl=self.ttl,
tags=self.tags,
online=self.online,
description=self.description,
owner=self.owner,
aggregations=self.aggregations,
mode=self.mode,
timestamp_field=self.timestamp_field,
sources=self.sources,
udf=self.udf,
)
fv.projection = copy.copy(self.projection)
return fv


def stream_feature_view(
*,
Expand All @@ -251,11 +306,13 @@ def stream_feature_view(
schema: Optional[List[Field]] = None,
source: Optional[DataSource] = None,
aggregations: Optional[List[Aggregation]] = None,
mode: Optional[str] = "spark", # Mode of ingestion/transformation
timestamp_field: Optional[str] = "", # Timestamp for aggregation
mode: Optional[str] = "spark",
timestamp_field: Optional[str] = "",
):
"""
Creates an StreamFeatureView object with the given user function as udf.
Please make sure that the udf contains all non-built in imports within the function to ensure that the execution
of a deserialized function does not miss imports.
"""

def mainify(obj):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,3 +70,71 @@ def simple_sfv(df):
assert features["test_key"] == [1001]
assert "dummy_field" in features
assert features["dummy_field"] == [None]


@pytest.mark.integration
def test_stream_feature_view_udf(environment) -> None:
"""
Test apply of StreamFeatureView udfs are serialized correctly and usable.
"""
fs = environment.feature_store

# Create Feature Views
entity = Entity(name="driver_entity", join_keys=["test_key"])

stream_source = KafkaSource(
name="kafka",
timestamp_field="event_timestamp",
bootstrap_servers="",
message_format=AvroFormat(""),
topic="topic",
batch_source=FileSource(path="test_path", timestamp_field="event_timestamp"),
watermark=timedelta(days=1),
)

@stream_feature_view(
entities=[entity],
ttl=timedelta(days=30),
owner="test@example.com",
online=True,
schema=[Field(name="dummy_field", dtype=Float32)],
description="desc",
aggregations=[
Aggregation(
column="dummy_field", function="max", time_window=timedelta(days=1),
),
Aggregation(
column="dummy_field2", function="count", time_window=timedelta(days=24),
),
],
timestamp_field="event_timestamp",
mode="spark",
source=stream_source,
tags={},
)
def pandas_view(pandas_df):
import pandas as pd

assert type(pandas_df) == pd.DataFrame
df = pandas_df.transform(lambda x: x + 10, axis=1)
df.insert(2, "C", [20.2, 230.0, 34.0], True)
return df

import pandas as pd

df = pd.DataFrame({"A": [1, 2, 3], "B": [10, 20, 30]})

fs.apply([entity, pandas_view])
stream_feature_views = fs.list_stream_feature_views()
assert len(stream_feature_views) == 1
assert stream_feature_views[0].name == "pandas_view"
assert stream_feature_views[0] == pandas_view

sfv = stream_feature_views[0]

new_df = sfv.udf(df)

expected_df = pd.DataFrame(
{"A": [11, 12, 13], "B": [20, 30, 40], "C": [20.2, 230.0, 34.0]}
)
assert new_df.equals(expected_df)
74 changes: 73 additions & 1 deletion sdk/python/tests/unit/test_feature_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from feast.entity import Entity
from feast.field import Field
from feast.infra.offline_stores.file_source import FileSource
from feast.stream_feature_view import StreamFeatureView
from feast.stream_feature_view import StreamFeatureView, stream_feature_view
from feast.types import Float32


Expand Down Expand Up @@ -129,3 +129,75 @@ def test_stream_feature_view_serialization():

new_sfv = StreamFeatureView.from_proto(sfv_proto=sfv_proto)
assert new_sfv == sfv


def test_stream_feature_view_udfs():
entity = Entity(name="driver_entity", join_keys=["test_key"])
stream_source = KafkaSource(
name="kafka",
timestamp_field="event_timestamp",
bootstrap_servers="",
message_format=AvroFormat(""),
topic="topic",
batch_source=FileSource(path="some path"),
)

@stream_feature_view(
entities=[entity],
ttl=timedelta(days=30),
owner="test@example.com",
online=True,
schema=[Field(name="dummy_field", dtype=Float32)],
description="desc",
aggregations=[
Aggregation(
column="dummy_field", function="max", time_window=timedelta(days=1),
)
],
timestamp_field="event_timestamp",
source=stream_source,
)
def pandas_udf(pandas_df):
import pandas as pd

assert type(pandas_df) == pd.DataFrame
df = pandas_df.transform(lambda x: x + 10, axis=1)
return df

import pandas as pd

df = pd.DataFrame({"A": [1, 2, 3], "B": [10, 20, 30]})
sfv = pandas_udf
sfv_proto = sfv.to_proto()
new_sfv = StreamFeatureView.from_proto(sfv_proto)
new_df = new_sfv.udf(df)

expected_df = pd.DataFrame({"A": [11, 12, 13], "B": [20, 30, 40]})

assert new_df.equals(expected_df)


def test_stream_feature_view_initialization_with_optional_fields_omitted():
entity = Entity(name="driver_entity", join_keys=["test_key"])
stream_source = KafkaSource(
name="kafka",
timestamp_field="event_timestamp",
bootstrap_servers="",
message_format=AvroFormat(""),
topic="topic",
batch_source=FileSource(path="some path"),
)

sfv = StreamFeatureView(
name="test kafka stream feature view",
entities=[entity],
schema=[],
description="desc",
timestamp_field="event_timestamp",
source=stream_source,
tags={},
)
sfv_proto = sfv.to_proto()

new_sfv = StreamFeatureView.from_proto(sfv_proto=sfv_proto)
assert new_sfv == sfv