diff --git a/.coveragerc b/.coveragerc index d505ff27ef..59da024909 100644 --- a/.coveragerc +++ b/.coveragerc @@ -37,4 +37,4 @@ exclude_also = ignore_errors = True [html] -directory = coverage_html_report +directory = coverage_html_report \ No newline at end of file diff --git a/integration/test_event_exception_handlers.py b/integration/test_event_exception_handlers.py new file mode 100644 index 0000000000..904a25e0ab --- /dev/null +++ b/integration/test_event_exception_handlers.py @@ -0,0 +1,223 @@ +"""Integration tests for event exception handlers.""" +from __future__ import annotations + +from typing import Generator +from unittest.mock import AsyncMock + +import pytest + +import time + +from reflex.app import process +from reflex.event import Event +from reflex.state import StateManagerRedis + +from reflex.testing import AppHarness +from reflex.config import get_config + +from typing import Generator + +import pytest +from selenium.webdriver.common.by import By +from selenium.webdriver.remote.webdriver import WebDriver +from selenium.webdriver.support.ui import WebDriverWait +from selenium.webdriver.support import expected_conditions as EC + + +def TestApp(): + """A test app for event exception handler integration.""" + + import reflex as rx + + def frontend_exception_handler(message: str, stack: str): + print(f"[Fasfadgasdg] {message} {stack}") + + class TestAppConfig(rx.Config): + """Config for the TestApp app.""" + + class TestAppState(rx.State): + """State for the TestApp app.""" + + value: int + + def go(self, c: int): + """Increment the value c times and update each time. + + Args: + c: The number of times to increment. + + Yields: + After each increment. + """ + for _ in range(c): + self.value += 1 + yield + + app = rx.App(state=rx.State) + + @app.add_page + def index(): + return rx.vstack( + rx.button( + "induce_frontend_error", + on_click=rx.call_script("induce_frontend_error()"), + id="induce-frontend-error-btn", + ), + ) + + +@pytest.fixture(scope="module") +def test_app(tmp_path_factory) -> Generator[AppHarness, None, None]: + """Start TestApp app at tmp_path via AppHarness. + + Args: + tmp_path_factory: pytest tmp_path_factory fixture + + Yields: + running AppHarness instance + """ + with AppHarness.create( + root=tmp_path_factory.mktemp("test_app"), + app_source=TestApp, # type: ignore + ) as harness: + yield harness + + +@pytest.fixture +def driver(test_app: AppHarness) -> Generator[WebDriver, None, None]: + """Get an instance of the browser open to the test_app app. + + Args: + test_app: harness for TestApp app + + Yields: + WebDriver instance. + """ + assert test_app.app_instance is not None, "app is not running" + driver = test_app.frontend() + try: + yield driver + finally: + driver.quit() + + +def test_frontend_exception_handler_during_runtime( + driver: WebDriver, + capsys, +): + """Test calling frontend exception handler during runtime. + + We send an event containing a call to a non-existent function in the frontend. + This should trigger the default frontend exception handler. + + Args: + driver: WebDriver instance. + capsys: pytest fixture for capturing stdout and stderr. + """ + reset_button = WebDriverWait(driver, 20).until( + EC.element_to_be_clickable((By.ID, "induce-frontend-error-btn")) + ) + + # 1. Test the default frontend exception handler + + reset_button.click() + + # Wait for the error to be logged + time.sleep(2) + + captured_default_handler_output = capsys.readouterr() + assert ( + "[Reflex Frontend Exception]" in captured_default_handler_output.out + and "induce_frontend_error" in captured_default_handler_output.out + and "ReferenceError" in captured_default_handler_output.out + ) + + # 2. Test the custom frontend exception handler + + def custom_frontend_exception_handler(message: str, stack: str) -> None: + print(f"[Custom Frontend Exception] {message} {stack}") + + # Set the custom frontend exception handler + config = get_config() + config.frontend_exception_handler = custom_frontend_exception_handler + + reset_button.click() + + # Wait for the error to be logged + time.sleep(2) + + captured_custom_handler_output = capsys.readouterr() + assert ( + "[Custom Frontend Exception]" in captured_custom_handler_output.out + and "induce_frontend_error" in captured_custom_handler_output.out + and "ReferenceError" in captured_custom_handler_output.out + ) + + +@pytest.mark.asyncio +async def test_backend_exception_handler_during_runtime(mocker, capsys, test_app): + """Test calling backend exception handler during runtime. + + Args: + mocker: mocker object. + capsys: capsys fixture. + test_app: harness for CallScript app. + driver: WebDriver instance. + + """ + token = "mock_token" + + router_data = { + "pathname": "/", + "query": {}, + "token": token, + "sid": "mock_sid", + "headers": {}, + "ip": "127.0.0.1", + } + + app = test_app.app_instance + mocker.patch.object(app, "postprocess", AsyncMock()) + + payload = {"c": "5"} # should be an int + + # 1. Test the default backend exception handler + + event = Event( + token=token, name="test_app_state.go", payload=payload, router_data=router_data + ) + + async for _update in process(app, event, "mock_sid", {}, "127.0.0.1"): + pass + + captured_default_handler_output = capsys.readouterr() + assert ( + "[Reflex Backend Exception]" in captured_default_handler_output.out + and "'str' object cannot be interpreted as an integer" + in captured_default_handler_output.out + ) + + # 2. Test the custom backend exception handler + + def custom_backend_exception_handler(message: str, stack: str) -> None: + print(f"[Custom Backend Exception] {message} {stack}") + + config = get_config() + config.backend_exception_handler = custom_backend_exception_handler + + event = Event( + token=token, name="test_app_state.go", payload=payload, router_data=router_data + ) + + async for _update in process(app, event, "mock_sid", {}, "127.0.0.1"): + pass + + captured_custom_handler_output = capsys.readouterr() + assert ( + "[Custom Backend Exception]" in captured_custom_handler_output.out + and "'str' object cannot be interpreted as an integer" + in captured_custom_handler_output.out + ) + + if isinstance(app.state_manager, StateManagerRedis): + await app.state_manager.close() diff --git a/reflex/.templates/web/utils/state.js b/reflex/.templates/web/utils/state.js index 7ff8256c4a..dc467ef1b9 100644 --- a/reflex/.templates/web/utils/state.js +++ b/reflex/.templates/web/utils/state.js @@ -202,6 +202,9 @@ export const applyEvent = async (event, socket) => { } } catch (e) { console.log("_call_script", e); + if (window && window?.onerror) { + window.onerror(e.message, null, null, null, e) + } } return false; } @@ -574,6 +577,32 @@ export const useEventLoop = ( queueEvents(events, socket); }; + // Handle frontend errors and send them to the backend via websocket. + useEffect(() => { + + if (typeof window === 'undefined') { + return; + + } + + window.onerror = function (msg, url, lineNo, columnNo, error) { + addEvents([Event("state.handle_frontend_exception", { + message: error.message, + stack: error.stack, + })]) + return false; + } + + window.onunhandledrejection = function (event) { + addEvents([Event("state.handle_frontend_exception", { + message: event.reason.message, + stack: event.reason.stack, + })]) + return false; + } + + },[]) + const sentHydrate = useRef(false); // Avoid double-hydrate due to React strict-mode useEffect(() => { if (router.isReady && !sentHydrate.current) { diff --git a/reflex/config.py b/reflex/config.py index f10b38e965..e3b7222b90 100644 --- a/reflex/config.py +++ b/reflex/config.py @@ -5,8 +5,10 @@ import importlib import os import sys +import inspect +import functools import urllib.parse -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Callable try: # TODO The type checking guard can be removed once @@ -24,6 +26,11 @@ from reflex import constants from reflex.base import Base from reflex.utils import console +from reflex.utils.exceptions import ( + default_frontend_exception_handler, + default_backend_exception_handler, +) +from reflex.event import EventSpec class DBConfig(Base): @@ -207,6 +214,16 @@ class Config: # Attributes that were explicitly set by the user. _non_default_attributes: Set[str] = pydantic.PrivateAttr(set()) + # Frontend Error Handler Function + frontend_exception_handler: Optional[ + Callable[[str, str], None] + ] = default_frontend_exception_handler + + # Backend Error Handler Function + backend_exception_handler: Optional[ + Callable[[str, str], EventSpec | list[EventSpec] | None] + ] = default_backend_exception_handler + def __init__(self, *args, **kwargs): """Initialize the config values. @@ -229,6 +246,9 @@ def __init__(self, *args, **kwargs): self._non_default_attributes.update(kwargs) self._replace_defaults(**kwargs) + # Check the exception handlers + self._validate_exception_handlers() + @property def module(self) -> str: """Get the module name of the app. @@ -357,6 +377,86 @@ def _set_persistent(self, **kwargs): self._non_default_attributes.update(kwargs) self._replace_defaults(**kwargs) + def _validate_exception_handlers(self): + """Validate the custom event exception handlers for front- and backend. + + Raises: + ValueError: If the custom exception handlers are invalid. + + """ + FRONTEND_ARG_SPEC = { + "message": str, + "stack": str, + } + + BACKEND_ARG_SPEC = { + "message": str, + "stack": str, + } + + for handler_domain, handler_fn, handler_spec in zip( + ["frontend", "backend"], + [self.frontend_exception_handler, self.backend_exception_handler], + [ + FRONTEND_ARG_SPEC, + BACKEND_ARG_SPEC, + ], + ): + if hasattr(handler_fn, "__name__"): + _fn_name = handler_fn.__name__ + else: + _fn_name = handler_fn.__class__.__name__ + + if isinstance(handler_fn, functools.partial): + raise ValueError( + f"Provided custom {handler_domain} exception handler `{_fn_name}` is a partial function. Please provide a named function instead." + ) + + if not callable(handler_fn): + raise ValueError( + f"Provided custom {handler_domain} exception handler `{_fn_name}` is not a function." + ) + + # Allow named functions only as lambda functions cannot be introspected + if _fn_name == "": + raise ValueError( + f"Provided custom {handler_domain} exception handler `{_fn_name}` is a lambda function. Please use a named function instead." + ) + + # Check if the function has the necessary annotations and types + arg_annotations = inspect.get_annotations(handler_fn) + + for required_arg in handler_spec: + + if required_arg not in arg_annotations: + raise ValueError( + f"Provided custom {handler_domain} exception handler `{_fn_name}` does not take the required argument `{required_arg}`" + ) + + if arg_annotations[required_arg] != handler_spec[required_arg]: + raise ValueError( + f"Provided custom {handler_domain} exception handler `{_fn_name}` has the wrong type for {required_arg} argument." + f"Expected `{handler_spec[required_arg]}` but got `{arg_annotations[required_arg]}`" + ) + + # Check if the return type is valid for backend exception handler + if handler_domain == "backend": + sig = inspect.signature(self.backend_exception_handler) + return_type = sig.return_annotation + + valid = bool( + return_type == EventSpec + or return_type == Optional[EventSpec] + or return_type == list[EventSpec] + or return_type == inspect.Signature.empty + ) + + if not valid: + raise ValueError( + f"Provided custom {handler_domain} exception handler `{_fn_name}` has the wrong return type." + f"Expected `EventSpec | list[EventSpec] | None` but got `{return_type}`" + ) + def get_config(reload: bool = False) -> Config: """Get the app config. diff --git a/reflex/config.pyi b/reflex/config.pyi index 57ce1123de..5458bb4098 100644 --- a/reflex/config.pyi +++ b/reflex/config.pyi @@ -3,7 +3,8 @@ from reflex import constants as constants from reflex.base import Base as Base from reflex.utils import console as console -from typing import Any, Dict, List, Optional, overload +from reflex.event import EventSpec as EventSpec +from typing import Any, Dict, List, Optional, Callable, overload class DBConfig(Base): engine: str @@ -70,6 +71,10 @@ class Config(Base): cp_web_url: str username: Optional[str] gunicorn_worker_class: str + frontend_exception_handler: Optional[Callable[[str, str], None]] + backend_exception_handler: Optional[ + Callable[[str, str], EventSpec | list[EventSpec] | None] + ] def __init__( self, diff --git a/reflex/state.py b/reflex/state.py index 7545aed542..8c2969361a 100644 --- a/reflex/state.py +++ b/reflex/state.py @@ -59,6 +59,9 @@ from reflex.utils.serializers import SerializedType, serialize, serializer from reflex.vars import BaseVar, ComputedVar, Var, computed_var +from reflex.config import get_config + + if TYPE_CHECKING: from reflex.components.component import Component @@ -1532,12 +1535,16 @@ async def _process_event( yield state._as_state_update(handler, events, final=True) # If an error occurs, throw a window alert. - except Exception: - error = traceback.format_exc() - print(error) + except Exception as e: + stack_trace = traceback.format_exc() + + config = get_config() + + events = config.backend_exception_handler(message=str(e), stack=stack_trace) + yield state._as_state_update( handler, - window_alert("An error occurred. See logs for details."), + events, final=True, ) @@ -1806,6 +1813,21 @@ class State(BaseState): # The hydrated bool. is_hydrated: bool = False + def handle_frontend_exception(self, message: str, stack: str) -> None: + """Handle frontend exceptions. + + If a frontend exception handler is provided, it will be called. + Otherwise, the default frontend exception handler will be called. + + Args: + message: The message of the exception. + stack: The stack trace of the exception. + + """ + config = get_config() + + config.frontend_exception_handler(message=message, stack=stack) + class UpdateVarsInternalState(State): """Substate for handling internal state var updates.""" diff --git a/reflex/utils/exceptions.py b/reflex/utils/exceptions.py index e4a9a6e6c9..7c00f1ee2a 100644 --- a/reflex/utils/exceptions.py +++ b/reflex/utils/exceptions.py @@ -1,4 +1,7 @@ -"""Custom Exceptions.""" +"""Custom Exceptions and Exception Handlers.""" + +import reflex.utils.console as console +from reflex.event import EventSpec, window_alert class InvalidStylePropError(TypeError): @@ -19,3 +22,33 @@ class MatchTypeError(TypeError): """Raised when the return types of match cases are different.""" pass + + +def default_frontend_exception_handler(message: str, stack: str) -> None: + """Default frontend exception handler function. + + Args: + message: The error message. + stack: The stack trace. + + """ + console.error( + f"[Reflex Frontend Exception]\n - Message: {message}\n - Stack: {stack}\n" + ) + + +def default_backend_exception_handler(message: str, stack: str) -> EventSpec: + """Default backend exception handler function. + + Args: + message: The error message. + stack: The stack trace. + + Returns: + EventSpec: The window alert event. + """ + console.error( + f"[Reflex Backend Exception]\n - Message: {message}\n - Stack: {stack}\n" + ) + + return window_alert("An error occurred. See logs for details.") diff --git a/tests/test_config.py b/tests/test_config.py index 1ba2f548de..92fda23787 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -3,10 +3,11 @@ from typing import Any, Dict import pytest - +import functools import reflex as rx import reflex.config from reflex.constants import Endpoint +from contextlib import nullcontext as does_not_raise def test_requires_app_name(): @@ -219,3 +220,135 @@ def test_reflex_dir_env_var(monkeypatch, tmp_path): mp_ctx = multiprocessing.get_context(method="spawn") with mp_ctx.Pool(processes=1) as pool: assert pool.apply(reflex_dir_constant) == str(tmp_path) + + +def valid_custom_handler(message: str, stack: str, logger: str = "test"): + print("Custom Backend Exception") + print(stack) + + +def custom_exception_handler_with_wrong_argspec( + message: int, stack: str # Should be str +): + print("Custom Backend Exception") + print(stack) + + +class SomeHandler: + def handle(self, message: str, stack: str): + print("Custom Backend Exception") + print(stack) + + +custom_exception_handlers = { + "lambda": lambda message, stack: print("Custom Exception Handler", message, stack), + "wrong_argspec": custom_exception_handler_with_wrong_argspec, + "valid": valid_custom_handler, + "partial": functools.partial(valid_custom_handler, logger="test"), + "method": SomeHandler().handle, +} + + +@pytest.mark.parametrize( + "handler_fn, expected", + [ + pytest.param( + custom_exception_handlers["partial"], + pytest.raises(ValueError), + id="partial", + ), + pytest.param( + custom_exception_handlers["lambda"], + pytest.raises(ValueError), + id="lambda", + ), + pytest.param( + custom_exception_handlers["wrong_argspec"], + pytest.raises(ValueError), + id="wrong_argspec", + ), + pytest.param( + custom_exception_handlers["valid"], + does_not_raise(), + id="valid_handler", + ), + pytest.param( + custom_exception_handlers["method"], + does_not_raise(), + id="valid_class_method", + ), + ], +) +def test_frontend_exception_handler_validation(handler_fn, expected): + """Test that the custom frontend exception handler is properly validated. + + Args: + handler_fn: The handler function. + expected: The expected result. + + """ + with expected: + rx.Config(app_name="a", frontend_exception_handler=handler_fn) + + +def backend_exception_handler_with_wrong_return_type(message: str, stack: str) -> int: + """Custom backend exception handler with wrong return type. + + Args: + message: The error message. + stack: The stack trace. + + Returns: + int: The wrong return type. + """ + print("Custom Backend Exception") + print(stack) + + return 5 + + +@pytest.mark.parametrize( + "handler_fn, expected", + [ + pytest.param( + backend_exception_handler_with_wrong_return_type, + pytest.raises(ValueError), + id="wrong_return_type", + ), + pytest.param( + custom_exception_handlers["partial"], + pytest.raises(ValueError), + id="partial", + ), + pytest.param( + custom_exception_handlers["lambda"], + pytest.raises(ValueError), + id="lambda", + ), + pytest.param( + custom_exception_handlers["wrong_argspec"], + pytest.raises(ValueError), + id="wrong_argspec", + ), + pytest.param( + custom_exception_handlers["valid"], + does_not_raise(), + id="valid_handler", + ), + pytest.param( + custom_exception_handlers["method"], + does_not_raise(), + id="valid_class_method", + ), + ], +) +def test_backend_exception_handler_validation(handler_fn, expected): + """Test that the custom backend exception handler is properly validated. + + Args: + handler_fn: The handler function. + expected: The expected result. + + """ + with expected: + rx.Config(app_name="a", backend_exception_handler=handler_fn)