Skip to content

Commit

Permalink
Respect specified ValueTypes for features during materialization (fea…
Browse files Browse the repository at this point in the history
…st-dev#1906)

* assert float feature is still float from online store

Signed-off-by: Jeff <jeffxl@apple.com>

* ensure float features retain float type from online store

Floats were converted to doubles when materialized to
the online store.  There is a broader bug trend around
type conversions and this particular conversion utility
function looks like it could use some cleanup. This
commit is a quick fix.

Signed-off-by: Jeff <jeffxl@apple.com>

* make fix more general

Signed-off-by: Achal Shah <achals@gmail.com>

* Use assertAlmostEquals

Signed-off-by: Achal Shah <achals@gmail.com>

* format

Signed-off-by: Achal Shah <achals@gmail.com>

* Support pandas timestamps correctly

Signed-off-by: Achal Shah <achals@gmail.com>

* Support pandas timestamps correctly

Signed-off-by: Achal Shah <achals@gmail.com>

* Correct import

Signed-off-by: Achal Shah <achals@gmail.com>

Co-authored-by: Achal Shah <achals@gmail.com>
  • Loading branch information
Agent007 and achals authored Sep 29, 2021
1 parent 6faf3a2 commit ce5a130
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 14 deletions.
10 changes: 8 additions & 2 deletions sdk/python/feast/type_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,13 @@ def _type_err(item, dtype):
ValueType, Tuple[str, Any, Optional[Set[Type]]]
] = {
ValueType.INT32: ("int32_val", lambda x: int(x), None),
ValueType.INT64: ("int64_val", lambda x: int(x), None),
ValueType.INT64: (
"int64_val",
lambda x: int(x.timestamp())
if isinstance(x, pd._libs.tslibs.timestamps.Timestamp)
else int(x),
None,
),
ValueType.FLOAT: ("float_val", lambda x: float(x), None),
ValueType.DOUBLE: ("double_val", lambda x: x, {float, np.float64}),
ValueType.STRING: ("string_val", lambda x: str(x), None),
Expand Down Expand Up @@ -317,7 +323,7 @@ def python_value_to_proto_value(
value: Any, feature_type: ValueType = ValueType.UNKNOWN
) -> ProtoValue:
value_type = feature_type
if value is not None:
if value is not None and feature_type == ValueType.UNKNOWN:
if isinstance(value, (list, np.ndarray)):
value_type = (
feature_type
Expand Down
10 changes: 8 additions & 2 deletions sdk/python/tests/integration/online_store/test_e2e_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def _assert_online_features(
):
"""Assert that features in online store are up to date with `max_date` date."""
# Read features back
result = store.get_online_features(
response = store.get_online_features(
features=[
"driver_hourly_stats:conv_rate",
"driver_hourly_stats:avg_daily_trips",
Expand All @@ -36,8 +36,14 @@ def _assert_online_features(
],
entity_rows=[{"driver_id": 1001}],
full_feature_names=True,
).to_dict()
)

# Float features should still be floats from the online store...
assert (
response.field_values[0].fields["driver_hourly_stats__conv_rate"].float_val > 0
)

result = response.to_dict()
assert len(result) == 5
assert "driver_hourly_stats__avg_daily_trips" in result
assert "driver_hourly_stats__conv_rate" in result
Expand Down
25 changes: 15 additions & 10 deletions sdk/python/tests/integration/online_store/test_universal_online.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,24 +110,27 @@ def test_online_retrieval(environment, universal_data_sources, full_feature_name

assert df_features["customer_id"] == online_features_dict["customer_id"][i]
assert df_features["driver_id"] == online_features_dict["driver_id"][i]
assert (
tc.assertAlmostEqual(
online_features_dict[
response_feature_name("conv_rate_plus_100", full_feature_names)
][i]
== df_features["conv_rate"] + 100
][i],
df_features["conv_rate"] + 100,
delta=0.0001,
)
assert (
tc.assertAlmostEqual(
online_features_dict[
response_feature_name("conv_rate_plus_val_to_add", full_feature_names)
][i]
== df_features["conv_rate"] + df_features["val_to_add"]
][i],
df_features["conv_rate"] + df_features["val_to_add"],
delta=0.0001,
)
for unprefixed_feature_ref in unprefixed_feature_refs:
tc.assertEqual(
tc.assertAlmostEqual(
df_features[unprefixed_feature_ref],
online_features_dict[
response_feature_name(unprefixed_feature_ref, full_feature_names)
][i],
delta=0.0001,
)

# Check what happens for missing values
Expand Down Expand Up @@ -254,13 +257,15 @@ def assert_feature_service_correctness(
+ 3
) # Add two for the driver id and the customer id entity keys and val_to_add request data

tc = unittest.TestCase()
for i, entity_row in enumerate(entity_rows):
df_features = get_latest_feature_values_from_dataframes(
drivers_df, customers_df, orders_df, global_df, entity_row
)
assert (
tc.assertAlmostEqual(
feature_service_online_features_dict[
response_feature_name("conv_rate_plus_100", full_feature_names)
][i]
== df_features["conv_rate"] + 100
][i],
df_features["conv_rate"] + 100,
delta=0.0001,
)

0 comments on commit ce5a130

Please sign in to comment.