From 60756cb4637a7961b6caffef3242e2886e77f78a Mon Sep 17 00:00:00 2001 From: Tornike Gurgenidze Date: Thu, 2 May 2024 18:41:26 +0400 Subject: [PATCH] fix: Pass native input values to `get_online_features` from feature server (#4117) * fix: Pass native input values to get_online_features from feature server Signed-off-by: tokoko * remove unnecessary type ignore hint Signed-off-by: tokoko --------- Signed-off-by: tokoko --- sdk/python/feast/feature_server.py | 26 ++++++++------------------ 1 file changed, 8 insertions(+), 18 deletions(-) diff --git a/sdk/python/feast/feature_server.py b/sdk/python/feast/feature_server.py index 4b0e50a06d..fda8745c2d 100644 --- a/sdk/python/feast/feature_server.py +++ b/sdk/python/feast/feature_server.py @@ -10,7 +10,7 @@ from fastapi import FastAPI, HTTPException, Request, Response, status from fastapi.logger import logger from fastapi.params import Depends -from google.protobuf.json_format import MessageToDict, Parse +from google.protobuf.json_format import MessageToDict from pydantic import BaseModel import feast @@ -18,7 +18,6 @@ from feast.constants import DEFAULT_FEATURE_SERVER_REGISTRY_TTL from feast.data_source import PushMode from feast.errors import PushSourceNotFoundException -from feast.protos.feast.serving.ServingService_pb2 import GetOnlineFeaturesRequest # TODO: deprecate this in favor of push features @@ -83,34 +82,25 @@ def shutdown_event(): @app.post("/get-online-features") def get_online_features(body=Depends(get_body)): try: - # Validate and parse the request data into GetOnlineFeaturesRequest Protobuf object - request_proto = GetOnlineFeaturesRequest() - Parse(body, request_proto) - + body = json.loads(body) # Initialize parameters for FeatureStore.get_online_features(...) call - if request_proto.HasField("feature_service"): + if "feature_service" in body: features = store.get_feature_service( - request_proto.feature_service, allow_cache=True + body["feature_service"], allow_cache=True ) else: - features = list(request_proto.features.val) - - full_feature_names = request_proto.full_feature_names + features = body["features"] - batch_sizes = [len(v.val) for v in request_proto.entities.values()] - num_entities = batch_sizes[0] - if any(batch_size != num_entities for batch_size in batch_sizes): - raise HTTPException(status_code=500, detail="Uneven number of columns") + full_feature_names = body.get("full_feature_names", False) response_proto = store._get_online_features( features=features, - entity_values=request_proto.entities, + entity_values=body["entities"], full_feature_names=full_feature_names, - native_entity_values=False, ).proto # Convert the Protobuf object to JSON and return it - return MessageToDict( # type: ignore + return MessageToDict( response_proto, preserving_proto_field_name=True, float_precision=18 ) except Exception as e: