diff --git a/sdk/python/feast/errors.py b/sdk/python/feast/errors.py index d39009ae7a..fd5955fd98 100644 --- a/sdk/python/feast/errors.py +++ b/sdk/python/feast/errors.py @@ -57,7 +57,7 @@ def from_error_detail(detail: str) -> Optional["FeastError"]: module = importlib.import_module(module_name) class_reference = getattr(module, class_name) - instance = class_reference(message) + instance = class_reference.__new__(class_reference) setattr(instance, "__overridden_message__", message) return instance except Exception as e: @@ -451,6 +451,9 @@ class PushSourceNotFoundException(FeastError): def __init__(self, push_source_name: str): super().__init__(f"Unable to find push source '{push_source_name}'.") + def http_status_code(self) -> int: + return HttpStatusCode.HTTP_422_UNPROCESSABLE_ENTITY + class ReadOnlyRegistryException(FeastError): def __init__(self): diff --git a/sdk/python/feast/feature_server.py b/sdk/python/feast/feature_server.py index 7f24580b7a..4f8de1eef5 100644 --- a/sdk/python/feast/feature_server.py +++ b/sdk/python/feast/feature_server.py @@ -9,8 +9,9 @@ import pandas as pd import psutil from dateutil import parser -from fastapi import Depends, FastAPI, HTTPException, Request, Response, status +from fastapi import Depends, FastAPI, Request, Response, status from fastapi.logger import logger +from fastapi.responses import JSONResponse from google.protobuf.json_format import MessageToDict from prometheus_client import Gauge, start_http_server from pydantic import BaseModel @@ -19,7 +20,10 @@ from feast import proto_json, utils from feast.constants import DEFAULT_FEATURE_SERVER_REGISTRY_TTL from feast.data_source import PushMode -from feast.errors import FeatureViewNotFoundException, PushSourceNotFoundException +from feast.errors import ( + FeastError, + FeatureViewNotFoundException, +) from feast.permissions.action import WRITE, AuthzedAction from feast.permissions.security_manager import assert_permissions from feast.permissions.server.rest import inject_user_details @@ -101,147 +105,119 @@ async def lifespan(app: FastAPI): async def get_body(request: Request): return await request.body() - # TODO RBAC: complete the dependencies for the other endpoints @app.post( "/get-online-features", dependencies=[Depends(inject_user_details)], ) def get_online_features(body=Depends(get_body)): - try: - body = json.loads(body) - full_feature_names = body.get("full_feature_names", False) - entity_rows = body["entities"] - # Initialize parameters for FeatureStore.get_online_features(...) call - if "feature_service" in body: - feature_service = store.get_feature_service( - body["feature_service"], allow_cache=True + body = json.loads(body) + full_feature_names = body.get("full_feature_names", False) + entity_rows = body["entities"] + # Initialize parameters for FeatureStore.get_online_features(...) call + if "feature_service" in body: + feature_service = store.get_feature_service( + body["feature_service"], allow_cache=True + ) + assert_permissions( + resource=feature_service, actions=[AuthzedAction.READ_ONLINE] + ) + features = feature_service + else: + features = body["features"] + all_feature_views, all_on_demand_feature_views = ( + utils._get_feature_views_to_use( + store.registry, + store.project, + features, + allow_cache=True, + hide_dummy_entity=False, ) + ) + for feature_view in all_feature_views: assert_permissions( - resource=feature_service, actions=[AuthzedAction.READ_ONLINE] + resource=feature_view, actions=[AuthzedAction.READ_ONLINE] ) - features = feature_service - else: - features = body["features"] - all_feature_views, all_on_demand_feature_views = ( - utils._get_feature_views_to_use( - store.registry, - store.project, - features, - allow_cache=True, - hide_dummy_entity=False, - ) + for od_feature_view in all_on_demand_feature_views: + assert_permissions( + resource=od_feature_view, actions=[AuthzedAction.READ_ONLINE] ) - for feature_view in all_feature_views: - assert_permissions( - resource=feature_view, actions=[AuthzedAction.READ_ONLINE] - ) - for od_feature_view in all_on_demand_feature_views: - assert_permissions( - resource=od_feature_view, actions=[AuthzedAction.READ_ONLINE] - ) - - response_proto = store.get_online_features( - features=features, - entity_rows=entity_rows, - full_feature_names=full_feature_names, - ).proto - - # Convert the Protobuf object to JSON and return it - return MessageToDict( - response_proto, preserving_proto_field_name=True, float_precision=18 - ) - except Exception as e: - # Print the original exception on the server side - logger.exception(traceback.format_exc()) - # Raise HTTPException to return the error message to the client - raise HTTPException(status_code=500, detail=str(e)) + + response_proto = store.get_online_features( + features=features, + entity_rows=entity_rows, + full_feature_names=full_feature_names, + ).proto + + # Convert the Protobuf object to JSON and return it + return MessageToDict( + response_proto, preserving_proto_field_name=True, float_precision=18 + ) @app.post("/push", dependencies=[Depends(inject_user_details)]) def push(body=Depends(get_body)): - try: - request = PushFeaturesRequest(**json.loads(body)) - df = pd.DataFrame(request.df) - actions = [] - if request.to == "offline": - to = PushMode.OFFLINE - actions = [AuthzedAction.WRITE_OFFLINE] - elif request.to == "online": - to = PushMode.ONLINE - actions = [AuthzedAction.WRITE_ONLINE] - elif request.to == "online_and_offline": - to = PushMode.ONLINE_AND_OFFLINE - actions = WRITE - else: - raise ValueError( - f"{request.to} is not a supported push format. Please specify one of these ['online', 'offline', 'online_and_offline']." - ) - - from feast.data_source import PushSource + request = PushFeaturesRequest(**json.loads(body)) + df = pd.DataFrame(request.df) + actions = [] + if request.to == "offline": + to = PushMode.OFFLINE + actions = [AuthzedAction.WRITE_OFFLINE] + elif request.to == "online": + to = PushMode.ONLINE + actions = [AuthzedAction.WRITE_ONLINE] + elif request.to == "online_and_offline": + to = PushMode.ONLINE_AND_OFFLINE + actions = WRITE + else: + raise ValueError( + f"{request.to} is not a supported push format. Please specify one of these ['online', 'offline', 'online_and_offline']." + ) - all_fvs = store.list_feature_views( - allow_cache=request.allow_registry_cache - ) + store.list_stream_feature_views( - allow_cache=request.allow_registry_cache + from feast.data_source import PushSource + + all_fvs = store.list_feature_views( + allow_cache=request.allow_registry_cache + ) + store.list_stream_feature_views(allow_cache=request.allow_registry_cache) + fvs_with_push_sources = { + fv + for fv in all_fvs + if ( + fv.stream_source is not None + and isinstance(fv.stream_source, PushSource) + and fv.stream_source.name == request.push_source_name ) - fvs_with_push_sources = { - fv - for fv in all_fvs - if ( - fv.stream_source is not None - and isinstance(fv.stream_source, PushSource) - and fv.stream_source.name == request.push_source_name - ) - } + } - for feature_view in fvs_with_push_sources: - assert_permissions(resource=feature_view, actions=actions) + for feature_view in fvs_with_push_sources: + assert_permissions(resource=feature_view, actions=actions) - store.push( - push_source_name=request.push_source_name, - df=df, - allow_registry_cache=request.allow_registry_cache, - to=to, - ) - except PushSourceNotFoundException as e: - # Print the original exception on the server side - logger.exception(traceback.format_exc()) - # Raise HTTPException to return the error message to the client - raise HTTPException(status_code=422, detail=str(e)) - except Exception as e: - # Print the original exception on the server side - logger.exception(traceback.format_exc()) - # Raise HTTPException to return the error message to the client - raise HTTPException(status_code=500, detail=str(e)) + store.push( + push_source_name=request.push_source_name, + df=df, + allow_registry_cache=request.allow_registry_cache, + to=to, + ) @app.post("/write-to-online-store", dependencies=[Depends(inject_user_details)]) def write_to_online_store(body=Depends(get_body)): + request = WriteToFeatureStoreRequest(**json.loads(body)) + df = pd.DataFrame(request.df) + feature_view_name = request.feature_view_name + allow_registry_cache = request.allow_registry_cache try: - request = WriteToFeatureStoreRequest(**json.loads(body)) - df = pd.DataFrame(request.df) - feature_view_name = request.feature_view_name - allow_registry_cache = request.allow_registry_cache - try: - feature_view = store.get_stream_feature_view( - feature_view_name, allow_registry_cache=allow_registry_cache - ) - except FeatureViewNotFoundException: - feature_view = store.get_feature_view( - feature_view_name, allow_registry_cache=allow_registry_cache - ) - - assert_permissions( - resource=feature_view, actions=[AuthzedAction.WRITE_ONLINE] + feature_view = store.get_stream_feature_view( + feature_view_name, allow_registry_cache=allow_registry_cache ) - store.write_to_online_store( - feature_view_name=feature_view_name, - df=df, - allow_registry_cache=allow_registry_cache, + except FeatureViewNotFoundException: + feature_view = store.get_feature_view( + feature_view_name, allow_registry_cache=allow_registry_cache ) - except Exception as e: - # Print the original exception on the server side - logger.exception(traceback.format_exc()) - # Raise HTTPException to return the error message to the client - raise HTTPException(status_code=500, detail=str(e)) + + assert_permissions(resource=feature_view, actions=[AuthzedAction.WRITE_ONLINE]) + store.write_to_online_store( + feature_view_name=feature_view_name, + df=df, + allow_registry_cache=allow_registry_cache, + ) @app.get("/health") def health(): @@ -249,39 +225,43 @@ def health(): @app.post("/materialize", dependencies=[Depends(inject_user_details)]) def materialize(body=Depends(get_body)): - try: - request = MaterializeRequest(**json.loads(body)) - for feature_view in request.feature_views: - assert_permissions( - resource=feature_view, actions=[AuthzedAction.WRITE_ONLINE] - ) - store.materialize( - utils.make_tzaware(parser.parse(request.start_ts)), - utils.make_tzaware(parser.parse(request.end_ts)), - request.feature_views, + request = MaterializeRequest(**json.loads(body)) + for feature_view in request.feature_views: + assert_permissions( + resource=feature_view, actions=[AuthzedAction.WRITE_ONLINE] ) - except Exception as e: - # Print the original exception on the server side - logger.exception(traceback.format_exc()) - # Raise HTTPException to return the error message to the client - raise HTTPException(status_code=500, detail=str(e)) + store.materialize( + utils.make_tzaware(parser.parse(request.start_ts)), + utils.make_tzaware(parser.parse(request.end_ts)), + request.feature_views, + ) @app.post("/materialize-incremental", dependencies=[Depends(inject_user_details)]) def materialize_incremental(body=Depends(get_body)): - try: - request = MaterializeIncrementalRequest(**json.loads(body)) - for feature_view in request.feature_views: - assert_permissions( - resource=feature_view, actions=[AuthzedAction.WRITE_ONLINE] - ) - store.materialize_incremental( - utils.make_tzaware(parser.parse(request.end_ts)), request.feature_views + request = MaterializeIncrementalRequest(**json.loads(body)) + for feature_view in request.feature_views: + assert_permissions( + resource=feature_view, actions=[AuthzedAction.WRITE_ONLINE] + ) + store.materialize_incremental( + utils.make_tzaware(parser.parse(request.end_ts)), request.feature_views + ) + + @app.exception_handler(Exception) + async def rest_exception_handler(request: Request, exc: Exception): + # Print the original exception on the server side + logger.exception(traceback.format_exc()) + + if isinstance(exc, FeastError): + return JSONResponse( + status_code=exc.http_status_code(), + content=exc.to_error_detail(), + ) + else: + return JSONResponse( + status_code=500, + content=str(exc), ) - except Exception as e: - # Print the original exception on the server side - logger.exception(traceback.format_exc()) - # Raise HTTPException to return the error message to the client - raise HTTPException(status_code=500, detail=str(e)) return app diff --git a/sdk/python/feast/infra/online_stores/remote.py b/sdk/python/feast/infra/online_stores/remote.py index 93fbcaf771..5f65d8da8b 100644 --- a/sdk/python/feast/infra/online_stores/remote.py +++ b/sdk/python/feast/infra/online_stores/remote.py @@ -16,16 +16,15 @@ from datetime import datetime from typing import Any, Callable, Dict, List, Literal, Optional, Sequence, Tuple +import requests from pydantic import StrictStr from feast import Entity, FeatureView, RepoConfig from feast.infra.online_stores.online_store import OnlineStore -from feast.permissions.client.http_auth_requests_wrapper import ( - get_http_auth_requests_session, -) from feast.protos.feast.types.EntityKey_pb2 import EntityKey as EntityKeyProto from feast.protos.feast.types.Value_pb2 import Value as ValueProto from feast.repo_config import FeastConfigBaseModel +from feast.rest_error_handler import rest_error_handling_decorator from feast.type_map import python_values_to_proto_values from feast.value_type import ValueType @@ -72,9 +71,7 @@ def online_read( req_body = self._construct_online_read_api_json_request( entity_keys, table, requested_features ) - response = get_http_auth_requests_session(config.auth_config).post( - f"{config.online_store.path}/get-online-features", data=req_body - ) + response = get_remote_online_features(config=config, req_body=req_body) if response.status_code == 200: logger.debug("Able to retrieve the online features from feature server.") response_json = json.loads(response.text) @@ -167,3 +164,12 @@ def teardown( entities: Sequence[Entity], ): pass + + +@rest_error_handling_decorator +def get_remote_online_features( + session: requests.Session, config: RepoConfig, req_body: str +) -> requests.Response: + return session.post( + f"{config.online_store.path}/get-online-features", data=req_body + ) diff --git a/sdk/python/feast/rest_error_handler.py b/sdk/python/feast/rest_error_handler.py new file mode 100644 index 0000000000..fc802866f9 --- /dev/null +++ b/sdk/python/feast/rest_error_handler.py @@ -0,0 +1,57 @@ +import logging +from functools import wraps + +import requests + +from feast import RepoConfig +from feast.errors import FeastError +from feast.permissions.client.http_auth_requests_wrapper import ( + get_http_auth_requests_session, +) + +logger = logging.getLogger(__name__) + + +def rest_error_handling_decorator(func): + @wraps(func) + def wrapper(config: RepoConfig, *args, **kwargs): + assert isinstance(config, RepoConfig) + + # Get a Session object + with get_http_auth_requests_session(config.auth_config) as session: + # Define a wrapper for session methods + def method_wrapper(method_name): + original_method = getattr(session, method_name) + + @wraps(original_method) + def wrapped_method(*args, **kwargs): + logger.debug( + f"Calling {method_name} with args: {args}, kwargs: {kwargs}" + ) + response = original_method(*args, **kwargs) + logger.debug( + f"{method_name} response status code: {response.status_code}" + ) + + try: + response.raise_for_status() + except requests.RequestException: + logger.debug(f"response.json() = {response.json()}") + mapped_error = FeastError.from_error_detail(response.json()) + logger.debug(f"mapped_error = {str(mapped_error)}") + if mapped_error is not None: + raise mapped_error + return response + + return wrapped_method + + # Enhance session methods + session.get = method_wrapper("get") # type: ignore[method-assign] + session.post = method_wrapper("post") # type: ignore[method-assign] + session.put = method_wrapper("put") # type: ignore[method-assign] + session.delete = method_wrapper("delete") # type: ignore[method-assign] + + # Pass the enhanced session object to the decorated function + return func(session, config, *args, **kwargs) + + return wrapper diff --git a/sdk/python/tests/integration/online_store/test_python_feature_server.py b/sdk/python/tests/integration/online_store/test_python_feature_server.py index 1010e73178..d08e1104eb 100644 --- a/sdk/python/tests/integration/online_store/test_python_feature_server.py +++ b/sdk/python/tests/integration/online_store/test_python_feature_server.py @@ -4,6 +4,7 @@ import pytest from fastapi.testclient import TestClient +from feast.errors import PushSourceNotFoundException from feast.feast_object import FeastObject from feast.feature_server import get_app from feast.utils import _utc_now @@ -90,21 +91,24 @@ def test_push_source_does_not_exist(python_fs_client): initial_temp = _get_temperatures_from_feature_server( python_fs_client, location_ids=[1] )[0] - response = python_fs_client.post( - "/push", - data=json.dumps( - { - "push_source_name": "push_source_does_not_exist", - "df": { - "location_id": [1], - "temperature": [initial_temp * 100], - "event_timestamp": [str(_utc_now())], - "created": [str(_utc_now())], - }, - } - ), - ) - assert response.status_code == 422 + with pytest.raises( + PushSourceNotFoundException, + match="Unable to find push source 'push_source_does_not_exist'", + ): + python_fs_client.post( + "/push", + data=json.dumps( + { + "push_source_name": "push_source_does_not_exist", + "df": { + "location_id": [1], + "temperature": [initial_temp * 100], + "event_timestamp": [str(_utc_now())], + "created": [str(_utc_now())], + }, + } + ), + ) def _get_temperatures_from_feature_server(client, location_ids: List[int]): diff --git a/sdk/python/tests/unit/test_rest_error_decorator.py b/sdk/python/tests/unit/test_rest_error_decorator.py new file mode 100644 index 0000000000..147ae767bd --- /dev/null +++ b/sdk/python/tests/unit/test_rest_error_decorator.py @@ -0,0 +1,78 @@ +from unittest.mock import Mock, patch + +import assertpy +import pytest +import requests + +from feast import RepoConfig +from feast.errors import PermissionNotFoundException +from feast.infra.online_stores.remote import ( + RemoteOnlineStoreConfig, + get_remote_online_features, +) + + +@pytest.fixture +def feast_exception() -> PermissionNotFoundException: + return PermissionNotFoundException("dummy_name", "dummy_project") + + +@pytest.fixture +def none_feast_exception() -> RuntimeError: + return RuntimeError("dummy_name", "dummy_project") + + +@patch("feast.infra.online_stores.remote.requests.sessions.Session.post") +def test_rest_error_handling_with_feast_exception( + mock_post, environment, feast_exception +): + # Create a mock response object + mock_response = Mock() + mock_response.status_code = feast_exception.http_status_code() + mock_response.json.return_value = feast_exception.to_error_detail() + mock_response.raise_for_status.side_effect = requests.exceptions.HTTPError() + + # Configure the mock to return the mock response + mock_post.return_value = mock_response + + store = environment.feature_store + online_config = RemoteOnlineStoreConfig(type="remote", path="dummy") + + with pytest.raises( + PermissionNotFoundException, + match="Permission dummy_name does not exist in project dummy_project", + ): + get_remote_online_features( + config=RepoConfig( + project="test", online_store=online_config, registry=store.registry + ), + req_body="{test:test}", + ) + + +@patch("feast.infra.online_stores.remote.requests.sessions.Session.post") +def test_rest_error_handling_with_none_feast_exception( + mock_post, environment, none_feast_exception +): + # Create a mock response object + mock_response = Mock() + mock_response.status_code = 500 + mock_response.json.return_value = str(none_feast_exception) + mock_response.raise_for_status.side_effect = requests.exceptions.HTTPError() + + # Configure the mock to return the mock response + mock_post.return_value = mock_response + + store = environment.feature_store + online_config = RemoteOnlineStoreConfig(type="remote", path="dummy") + + response = get_remote_online_features( + config=RepoConfig( + project="test", online_store=online_config, registry=store.registry + ), + req_body="{test:test}", + ) + + assertpy.assert_that(response).is_not_none() + assertpy.assert_that(response.status_code).is_equal_to(500) + assertpy.assert_that(response.json()).is_equal_to("('dummy_name', 'dummy_project')")