Skip to content

Commit

Permalink
extractor: Make embedding type hashable (#513)
Browse files Browse the repository at this point in the history
  • Loading branch information
aditya-nambiar authored Jul 23, 2024
1 parent e98d14d commit aa41684
Show file tree
Hide file tree
Showing 6 changed files with 69 additions and 6 deletions.
62 changes: 61 additions & 1 deletion fennel/client_tests/test_featureset.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import sys
import unittest
from datetime import datetime, timedelta, timezone
from typing import Optional, Dict, List
Expand All @@ -18,7 +19,7 @@
expectations,
expect_column_values_to_be_between,
)
from fennel.testing import mock
from fennel.testing import mock, log

################################################################################
# Feature Single Extractor Unit Tests
Expand Down Expand Up @@ -833,3 +834,62 @@ class FS1:
14,
pd.NA,
]


@pytest.mark.skipif(
sys.version_info[:2] <= (3, 10),
reason="Optional embedding not supported in python 3.9/10",
)
@mock
def test_embedding_features(client):
@dataset(version=4, index=True)
class ImageEmbeddings:
image_id: int = field(key=True)
embedding: Embedding[2]
ts: datetime

@featureset
class ImageFeature:
image_id: int
embedding: Optional[Embedding[2]] = F(ImageEmbeddings.embedding)

@featureset
class ImageFeatureWithDefault:
image_id: int = F(ImageFeature.image_id)
embedding2: Embedding[2] = F(
ImageEmbeddings.embedding, default=[11.0, 13.2]
)

client.commit(
datasets=[ImageEmbeddings],
featuresets=[ImageFeature, ImageFeatureWithDefault],
message="committing image embeddings",
)

log(
ImageEmbeddings,
df=pd.DataFrame(
{
"image_id": [1, 2, 3],
"embedding": [[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]],
"ts": [datetime.now(), datetime.now(), datetime.now()],
}
),
)

feature_df = client.query(
outputs=[ImageFeature, ImageFeatureWithDefault],
inputs=[ImageFeature.image_id],
input_dataframe=pd.DataFrame({"ImageFeature.image_id": [1, 2, 4]}),
)
assert feature_df.shape == (3, 4)
assert feature_df["ImageFeature.embedding"].tolist() == [
[1.0, 2.0],
[2.0, 3.0],
pd.NA,
]
assert feature_df["ImageFeatureWithDefault.embedding2"].tolist() == [
[1.0, 2.0],
[2.0, 3.0],
[11.0, 13.2],
]
2 changes: 1 addition & 1 deletion fennel/dtypes/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def struct(cls):
# ---------------------------------------------------------------------


@dataclass
@dataclass(frozen=True)
class _Embedding:
dim: int

Expand Down
6 changes: 5 additions & 1 deletion fennel/featuresets/featureset.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,6 @@ def _create_extractor(extractor_func: Callable, version: int):
f"Series[feature, str] or dataframe[<list of features defined in this featureset>], "
f"found {type(return_annotation)}"
)

setattr(
extractor_func,
EXTRACTOR_ATTR,
Expand Down Expand Up @@ -615,6 +614,11 @@ def _get_generated_extractors(
extractor.set_inputs_from_featureset(self, feature)
extractor.featureset = self._name
extractor.outputs = [feature]
# If extractor already exists, throw an error
if extractor.name in [e.name for e in output]:
raise ValueError(
f"An auto-generated extractor `{extractor.name}` already exists in the featureset `{self._name}`."
)
output.append(extractor)
return output

Expand Down
2 changes: 1 addition & 1 deletion fennel/internal_lib/schema/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def parse_json(annotation, json) -> Any:
f"Union must be of the form `Union[type, None]`, "
f"got `{annotation}`"
)
if json is None or pd.isna(json):
if json is None or (not isinstance(json, list) and pd.isna(json)):
return None
return parse_json(args[0], json)
if origin is list:
Expand Down
1 change: 0 additions & 1 deletion fennel/testing/query_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,7 +386,6 @@ def _compute_lookup_extractor(
results = results[extractor.derived_extractor_info.field.name]
default_value = extractor.derived_extractor_info.default
proto_dtype = get_datatype(extractor.derived_extractor_info.field.dtype)

# Custom operations on default value for datetime and decimal type
if proto_dtype.HasField("decimal_type"):
if pd.notna(default_value) and not isinstance(
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "fennel-ai"
version = "1.4.4"
version = "1.4.5"
description = "The modern realtime feature engineering platform"
authors = ["Fennel AI <developers@fennel.ai>"]
packages = [{ include = "fennel" }]
Expand Down

0 comments on commit aa41684

Please sign in to comment.