Skip to content

Commit

Permalink
Fix native types multiple entities retrieval (#977)
Browse files Browse the repository at this point in the history
  • Loading branch information
terryyylim authored Sep 4, 2020
1 parent 35a9afc commit 0a8bf95
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 7 deletions.
6 changes: 3 additions & 3 deletions sdk/python/feast/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -993,6 +993,7 @@ def _infer_online_entity_rows(
entity_type_map = dict()

for entity in entity_rows_dicts:
fields = {}
for key, value in entity.items():
# Allow for feast.types.Value
if isinstance(value, Value):
Expand All @@ -1009,9 +1010,8 @@ def _infer_online_entity_rows(
f"Input entity {key} has mixed types, {current_dtype} and {entity_type_map[key]}. That is not allowed. "
)
proto_value = _python_value_to_proto_value(current_dtype, value)
entity_row_list.append(
GetOnlineFeaturesRequest.EntityRow(fields={key: proto_value})
)
fields[key] = proto_value
entity_row_list.append(GetOnlineFeaturesRequest.EntityRow(fields=fields))
return entity_row_list


Expand Down
64 changes: 60 additions & 4 deletions tests/e2e/redis/basic-ingest-redis-serving.py
Original file line number Diff line number Diff line change
Expand Up @@ -709,6 +709,62 @@ def try_get_features():

@pytest.mark.timeout(600)
@pytest.mark.run(order=18)
def test_basic_ingest_retrieval_multi_entities(client):
# Set to another project to test ingestion based on current project context
client.set_project(PROJECT_NAME + "_NS1")
merchant_fs = FeatureSet(
name="merchant_fs",
features=[Feature(name="merchant_sales", dtype=ValueType.FLOAT)],
entities=[
Entity("driver_id", ValueType.INT64),
Entity("merchant_id", ValueType.INT64),
],
max_age=Duration(seconds=3600),
)
client.apply(merchant_fs)

N_ROWS = 2
time_offset = datetime.utcnow().replace(tzinfo=pytz.utc)
merchant_df = pd.DataFrame(
{
"datetime": [time_offset] * N_ROWS,
"driver_id": [i for i in range(N_ROWS)],
"merchant_id": [i for i in range(N_ROWS)],
"merchant_sales": [float(i) + 0.5 for i in range(N_ROWS)],
}
)
client.ingest("merchant_fs", merchant_df, timeout=600)

online_request_entity = [
{"driver_id": 0, "merchant_id": 0},
{"driver_id": 1, "merchant_id": 1},
]
online_request_features = ["merchant_sales"]

def try_get_features():
response = client.get_online_features(
entity_rows=online_request_entity, feature_refs=online_request_features
)
is_ok = check_online_response("merchant_sales", merchant_df, response)
return response, is_ok

online_features_actual = wait_retry_backoff(
retry_fn=try_get_features,
timeout_secs=90,
timeout_msg="Timed out trying to get online feature values",
)

online_features_expected = {
"driver_id": [0, 1],
"merchant_id": [0, 1],
"merchant_sales": [0.5, 1.5],
}

assert online_features_actual.to_dict() == online_features_expected


@pytest.mark.timeout(600)
@pytest.mark.run(order=19)
def test_basic_retrieve_feature_row_missing_fields(client, cust_trans_df):
feature_refs = ["daily_transactions", "total_transactions", "null_values"]

Expand Down Expand Up @@ -756,7 +812,7 @@ def try_get_features():


@pytest.mark.timeout(600)
@pytest.mark.run(order=19)
@pytest.mark.run(order=20)
def test_basic_retrieve_feature_row_extra_fields(client, cust_trans_df):
feature_refs = ["daily_transactions", "total_transactions"]
# apply cust_trans_fs and ingest dataframe
Expand Down Expand Up @@ -851,7 +907,7 @@ def all_types_dataframe():


@pytest.mark.timeout(45)
@pytest.mark.run(order=20)
@pytest.mark.run(order=21)
def test_all_types_register_feature_set_success(client):
client.set_project(PROJECT_NAME)

Expand Down Expand Up @@ -897,7 +953,7 @@ def test_all_types_register_feature_set_success(client):


@pytest.mark.timeout(300)
@pytest.mark.run(order=21)
@pytest.mark.run(order=22)
def test_all_types_ingest_success(client, all_types_dataframe):
# Get all_types feature set
all_types_fs = client.get_feature_set(name="all_types")
Expand All @@ -907,7 +963,7 @@ def test_all_types_ingest_success(client, all_types_dataframe):


@pytest.mark.timeout(90)
@pytest.mark.run(order=22)
@pytest.mark.run(order=23)
def test_all_types_retrieve_online_success(client, all_types_dataframe):
# Poll serving for feature values until the correct values are returned_float_list
feature_refs = [
Expand Down

0 comments on commit 0a8bf95

Please sign in to comment.