From 7855261b6a0f046db72ea49b6b8fb667aaa134b5 Mon Sep 17 00:00:00 2001 From: Beto Dealmeida Date: Fri, 30 Aug 2024 09:15:28 -0400 Subject: [PATCH] feat: allow certain exceptions to commit --- superset/utils/decorators.py | 9 +++ tests/unit_tests/utils/test_decorators.py | 72 +++++++++++++++++++++++ 2 files changed, 81 insertions(+) diff --git a/superset/utils/decorators.py b/superset/utils/decorators.py index 844a8f063c1b8..687b60b1911e8 100644 --- a/superset/utils/decorators.py +++ b/superset/utils/decorators.py @@ -238,6 +238,7 @@ def on_error( def transaction( # pylint: disable=redefined-outer-name on_error: Callable[..., Any] | None = on_error, + allowed: tuple[type[Exception], ...] = (), ) -> Callable[..., Any]: """ Perform a "unit of work". @@ -246,7 +247,12 @@ def transaction( # pylint: disable=redefined-outer-name proved rather complicated, likely due to many architectural facets, and thus has been left for a follow up exercise. + In certain cases it might be desirable to commit even though an exception was + raised. For OAuth2, for example, we use exceptions as a way to signal the client to + redirect to the login page. In this case, we re-raise the exception and commit. + :param on_error: Callback invoked when an exception is caught + :param allowed: Exception types to re-raise after committing :see: https://github.com/apache/superset/issues/25108 """ @@ -259,6 +265,9 @@ def wrapped(*args: Any, **kwargs: Any) -> Any: result = func(*args, **kwargs) db.session.commit() # pylint: disable=consider-using-transaction return result + except allowed: + db.session.commit() # pylint: disable=consider-using-transaction + raise except Exception as ex: db.session.rollback() # pylint: disable=consider-using-transaction diff --git a/tests/unit_tests/utils/test_decorators.py b/tests/unit_tests/utils/test_decorators.py index 0a622f4bce974..5f19d58c433c3 100644 --- a/tests/unit_tests/utils/test_decorators.py +++ b/tests/unit_tests/utils/test_decorators.py @@ -24,6 +24,7 @@ from unittest.mock import call, Mock, patch import pytest +from pytest_mock import MockerFixture from superset import app from superset.utils import decorators @@ -294,3 +295,74 @@ def func() -> None: decorated = decorators.suppress_logging("test-logger", logging.CRITICAL + 1)(func) decorated() assert len(handler.log_records) == 0 + + +def test_transaction_no_error(mocker: MockerFixture) -> None: + """ + Test the `transaction` decorator when the function works as expected. + """ + db = mocker.patch("superset.db") + + @decorators.transaction() + def f() -> int: + return 42 + + assert f() == 42 + db.session.commit.assert_called_once() + db.session.rollback.assert_not_called() + + +def test_transaction_with_on_error(mocker: MockerFixture) -> None: + """ + Test the `transaction` decorator when the function captures an exception. + """ + db = mocker.patch("superset.db") + + def on_error(ex: Exception) -> Exception: + return ex + + ex = ValueError("error") + + @decorators.transaction(on_error) + def f() -> None: + raise ex + + assert f() == ex + db.session.commit.assert_not_called() + db.session.rollback.assert_called_once() + + +def test_transaction_without_on_error(mocker: MockerFixture) -> None: + """ + Test the `transaction` decorator when the function raises an exception. + """ + db = mocker.patch("superset.db") + + @decorators.transaction() + def f() -> None: + raise ValueError("error") + + with pytest.raises(ValueError) as excinfo: + f() + assert str(excinfo.value) == "error" + db.session.commit.assert_not_called() + db.session.rollback.assert_called_once() + + +def test_transaction_with_allowed(mocker: MockerFixture) -> None: + """ + Test the `transaction` decorator with allowed exceptions. + + In this case the decorator will commit before re-raising the exception. + """ + db = mocker.patch("superset.db") + + @decorators.transaction(allowed=(ValueError,)) + def f() -> None: + raise ValueError("error") + + with pytest.raises(ValueError) as excinfo: + f() + assert str(excinfo.value) == "error" + db.session.commit.assert_called_once() + db.session.rollback.assert_not_called()