diff --git a/docs/source/api.rst b/docs/source/api.rst index 34ba14e7..ca23b4fb 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -1902,6 +1902,8 @@ Client-side errors * :class:`neo4j.exceptions.ResultError` + * :class:`neo4j.exceptions.ResultFailedError` + * :class:`neo4j.exceptions.ResultConsumedError` * :class:`neo4j.exceptions.ResultNotSingleError` @@ -1946,6 +1948,9 @@ Client-side errors :show-inheritance: :members: result +.. autoexception:: neo4j.exceptions.ResultFailedError() + :show-inheritance: + .. autoexception:: neo4j.exceptions.ResultConsumedError() :show-inheritance: diff --git a/src/neo4j/_async/work/result.py b/src/neo4j/_async/work/result.py index fe87d6cd..9e1b5c78 100644 --- a/src/neo4j/_async/work/result.py +++ b/src/neo4j/_async/work/result.py @@ -38,6 +38,7 @@ ) from ...exceptions import ( ResultConsumedError, + ResultFailedError, ResultNotSingleError, ) from ...time import ( @@ -57,6 +58,10 @@ _TResultKey = t.Union[int, str] +_RESULT_FAILED_ERROR = ( + "The result has failed. Either this result or another result in the same " + "transaction has encountered an error." +) _RESULT_OUT_OF_SCOPE_ERROR = ( "The result is out of scope. The associated transaction " "has been closed. Results can only be used while the " @@ -76,8 +81,11 @@ class AsyncResult: """ def __init__(self, connection, fetch_size, on_closed, on_error): - self._connection = ConnectionErrorHandler(connection, on_error) + self._connection = ConnectionErrorHandler( + connection, self._connection_error_handler + ) self._hydration_scope = connection.new_hydration_scope() + self._on_error = on_error self._on_closed = on_closed self._metadata = None self._keys = None @@ -101,6 +109,13 @@ def __init__(self, connection, fetch_size, on_closed, on_error): self._consumed = False # the result has been closed as a result of closing the transaction self._out_of_scope = False + # exception shared across all results of a transaction + self._exception = None + + async def _connection_error_handler(self, exc): + self._exception = exc + self._attached = False + await AsyncUtil.callback(self._on_error, exc) @property def _qid(self): @@ -257,6 +272,9 @@ async def __aiter__(self) -> t.AsyncIterator[Record]: await self._connection.send_all() self._exhausted = True + if self._exception is not None: + raise ResultFailedError(self, _RESULT_FAILED_ERROR) \ + from self._exception if self._out_of_scope: raise ResultConsumedError(self, _RESULT_OUT_OF_SCOPE_ERROR) if self._consumed: @@ -346,6 +364,11 @@ async def _tx_end(self): await self._exhaust() self._out_of_scope = True + def _tx_failure(self, exc): + # Handle failure of the associated transaction. + self._attached = False + self._exception = exc + async def consume(self) -> ResultSummary: """Consume the remainder of this result and return a :class:`neo4j.ResultSummary`. diff --git a/src/neo4j/_async/work/transaction.py b/src/neo4j/_async/work/transaction.py index 42eeac7f..f009bf7c 100644 --- a/src/neo4j/_async/work/transaction.py +++ b/src/neo4j/_async/work/transaction.py @@ -92,6 +92,8 @@ async def _result_on_closed_handler(self): async def _error_handler(self, exc): self._last_error = exc + for result in self._results: + result._tx_failure(exc) if isinstance(exc, asyncio.CancelledError): self._cancel() return diff --git a/src/neo4j/_sync/work/result.py b/src/neo4j/_sync/work/result.py index be3ea2e3..b2a07996 100644 --- a/src/neo4j/_sync/work/result.py +++ b/src/neo4j/_sync/work/result.py @@ -38,6 +38,7 @@ ) from ...exceptions import ( ResultConsumedError, + ResultFailedError, ResultNotSingleError, ) from ...time import ( @@ -57,6 +58,10 @@ _TResultKey = t.Union[int, str] +_RESULT_FAILED_ERROR = ( + "The result has failed. Either this result or another result in the same " + "transaction has encountered an error." +) _RESULT_OUT_OF_SCOPE_ERROR = ( "The result is out of scope. The associated transaction " "has been closed. Results can only be used while the " @@ -76,8 +81,11 @@ class Result: """ def __init__(self, connection, fetch_size, on_closed, on_error): - self._connection = ConnectionErrorHandler(connection, on_error) + self._connection = ConnectionErrorHandler( + connection, self._connection_error_handler + ) self._hydration_scope = connection.new_hydration_scope() + self._on_error = on_error self._on_closed = on_closed self._metadata = None self._keys = None @@ -101,6 +109,13 @@ def __init__(self, connection, fetch_size, on_closed, on_error): self._consumed = False # the result has been closed as a result of closing the transaction self._out_of_scope = False + # exception shared across all results of a transaction + self._exception = None + + def _connection_error_handler(self, exc): + self._exception = exc + self._attached = False + Util.callback(self._on_error, exc) @property def _qid(self): @@ -257,6 +272,9 @@ def __iter__(self) -> t.Iterator[Record]: self._connection.send_all() self._exhausted = True + if self._exception is not None: + raise ResultFailedError(self, _RESULT_FAILED_ERROR) \ + from self._exception if self._out_of_scope: raise ResultConsumedError(self, _RESULT_OUT_OF_SCOPE_ERROR) if self._consumed: @@ -346,6 +364,11 @@ def _tx_end(self): self._exhaust() self._out_of_scope = True + def _tx_failure(self, exc): + # Handle failure of the associated transaction. + self._attached = False + self._exception = exc + def consume(self) -> ResultSummary: """Consume the remainder of this result and return a :class:`neo4j.ResultSummary`. diff --git a/src/neo4j/_sync/work/transaction.py b/src/neo4j/_sync/work/transaction.py index 9e1cc8c0..1eda4faa 100644 --- a/src/neo4j/_sync/work/transaction.py +++ b/src/neo4j/_sync/work/transaction.py @@ -92,6 +92,8 @@ def _result_on_closed_handler(self): def _error_handler(self, exc): self._last_error = exc + for result in self._results: + result._tx_failure(exc) if isinstance(exc, asyncio.CancelledError): self._cancel() return diff --git a/src/neo4j/exceptions.py b/src/neo4j/exceptions.py index 1c366fae..ac1d3ae9 100644 --- a/src/neo4j/exceptions.py +++ b/src/neo4j/exceptions.py @@ -40,6 +40,7 @@ + TransactionError + TransactionNestingError + ResultError + + ResultFailedError + ResultConsumedError + ResultNotSingleError + BrokenRecordError @@ -464,6 +465,17 @@ def __init__(self, result_, *args, **kwargs): self.result = result_ +# DriverError > ResultError > ResultFailedError +class ResultFailedError(ResultError): + """Raised when trying to access records of a failed result. + + A :class:`.Result` will be considered failed if + * itself encountered an error while fetching records + * another result within the same transaction encountered an error while + fetching records + """ + + # DriverError > ResultError > ResultConsumedError class ResultConsumedError(ResultError): """Raised when trying to access records of a consumed result.""" diff --git a/testkitbackend/_async/requests.py b/testkitbackend/_async/requests.py index 1ad3cbdb..2038c8ba 100644 --- a/testkitbackend/_async/requests.py +++ b/testkitbackend/_async/requests.py @@ -401,7 +401,7 @@ async def ExecuteQuery(backend, data): def resolution_func(backend, custom_resolver=False, custom_dns_resolver=False): # This solution (putting custom resolution together with DNS resolution - # into one function only works because the Python driver calls the custom + # into one function) only works because the Python driver calls the custom # resolver function for every connection, which is not true for all # drivers. Properly exposing a way to change the DNS lookup behavior is not # possible without changing the driver's code. diff --git a/testkitbackend/_sync/requests.py b/testkitbackend/_sync/requests.py index 9932021e..efaec61b 100644 --- a/testkitbackend/_sync/requests.py +++ b/testkitbackend/_sync/requests.py @@ -401,7 +401,7 @@ def ExecuteQuery(backend, data): def resolution_func(backend, custom_resolver=False, custom_dns_resolver=False): # This solution (putting custom resolution together with DNS resolution - # into one function only works because the Python driver calls the custom + # into one function) only works because the Python driver calls the custom # resolver function for every connection, which is not true for all # drivers. Properly exposing a way to change the DNS lookup behavior is not # possible without changing the driver's code. diff --git a/testkitbackend/test_config.json b/testkitbackend/test_config.json index 5957054e..92af6d4a 100644 --- a/testkitbackend/test_config.json +++ b/testkitbackend/test_config.json @@ -13,13 +13,7 @@ "'neo4j.datatypes.test_temporal_types.TestDataTypes.test_should_echo_all_timezone_ids'": "test_subtest_skips.dt_conversion", "'neo4j.datatypes.test_temporal_types.TestDataTypes.test_date_time_cypher_created_tz_id'": - "test_subtest_skips.tz_id", - "'stub.tx_run.test_tx_run.TestTxRun.test_should_prevent_discard_after_tx_termination_on_run'": - "Fixme: transactions don't prevent further actions after failure.", - "'stub.tx_run.test_tx_run.TestTxRun.test_should_prevent_pull_after_tx_termination_on_pull'": - "Fixme: transactions don't prevent further actions after failure.", - "'stub.tx_run.test_tx_run.TestTxRun.test_should_prevent_pull_after_tx_termination_on_run'": - "Fixme: transactions don't prevent further actions after failure." + "test_subtest_skips.tz_id" }, "features": { "Feature:API:BookmarkManager": true, diff --git a/tests/unit/async_/fixtures/fake_connection.py b/tests/unit/async_/fixtures/fake_connection.py index 9e3995cf..5989b152 100644 --- a/tests/unit/async_/fixtures/fake_connection.py +++ b/tests/unit/async_/fixtures/fake_connection.py @@ -24,6 +24,7 @@ from neo4j._async.io import AsyncBolt from neo4j._deadline import Deadline from neo4j.auth_management import AsyncAuthManager +from neo4j.exceptions import Neo4jError __all__ = [ @@ -154,10 +155,12 @@ def set_script(self, callbacks): [ ("run", {"on_success": ({},), "on_summary": None}), ("pull", { + "on_records": ([some_record],), "on_success": None, "on_summary": None, - "on_records": }) + # use any exception to throw it instead of calling handlers + ("commit", RuntimeError("oh no!")) ] ``` Note that arguments can be `None`. In this case, ScriptedConnection @@ -180,6 +183,9 @@ def func(*args, **kwargs): self._script_pos += 1 async def callback(): + if isinstance(scripted_callbacks, BaseException): + raise scripted_callbacks + error = None for cb_name, default_cb_args in ( ("on_ignored", ({},)), ("on_failure", ({},)), @@ -197,10 +203,14 @@ async def callback(): if cb_args is None: cb_args = default_cb_args res = cb(*cb_args) + if cb_name == "on_failure": + error = Neo4jError.hydrate(**cb_args[0]) try: await res # maybe the callback is async except TypeError: pass # or maybe it wasn't ;) + if error is not None: + raise error self.callbacks.append(callback) diff --git a/tests/unit/async_/work/test_transaction.py b/tests/unit/async_/work/test_transaction.py index 33238c76..b7d4d719 100644 --- a/tests/unit/async_/work/test_transaction.py +++ b/tests/unit/async_/work/test_transaction.py @@ -17,7 +17,6 @@ from unittest.mock import MagicMock -from uuid import uuid4 import pytest @@ -26,6 +25,11 @@ NotificationMinimumSeverity, Query, ) +from neo4j.exceptions import ( + ClientError, + ResultFailedError, + ServiceUnavailable, +) from ...._async_compat import mark_async_test @@ -275,3 +279,51 @@ async def test_transaction_begin_pipelining( expected_calls.append(("send_all",)) expected_calls.append(("fetch_all",)) assert async_fake_connection.method_calls == expected_calls + + +@pytest.mark.parametrize("error", ("server", "connection")) +@mark_async_test +async def test_server_error_propagates(async_scripted_connection, error): + connection = async_scripted_connection + script = [ + # res 1 + ("run", {"on_success": ({"fields": ["n"]},), "on_summary": None}), + ("pull", {"on_records": ([[1], [2]],), + "on_success": ({"has_more": True},)}), + # res 2 + ("run", {"on_success": ({"fields": ["n"]},), "on_summary": None}), + ("pull", {"on_records": ([[1], [2]],), + "on_success": ({"has_more": True},)}), + ] + if error == "server": + script.append( + ("pull", {"on_failure": ({"code": "Neo.ClientError.Made.Up"},), + "on_summary": None}) + ) + expected_error = ClientError + elif error == "connection": + script.append(("pull", ServiceUnavailable())) + expected_error = ServiceUnavailable + else: + raise ValueError(f"Unknown error type {error}") + connection.set_script(script) + + tx = AsyncTransaction( + connection, 2, lambda *args, **kwargs: None, + lambda *args, **kwargs: None, lambda *args, **kwargs: None + ) + res1 = await tx.run("UNWIND range(1, 1000) AS n RETURN n") + assert await res1.__anext__() == {"n": 1} + + res2 = await tx.run("RETURN 'causes error later'") + assert await res2.fetch(2) == [{"n": 1}, {"n": 2}] + with pytest.raises(expected_error) as exc1: + await res2.__anext__() + + # can finish the buffer + assert await res1.fetch(1) == [{"n": 2}] + # then fails because the connection was broken by res2 + with pytest.raises(ResultFailedError) as exc2: + await res1.__anext__() + + assert exc1.value is exc2.value.__cause__ diff --git a/tests/unit/sync/fixtures/fake_connection.py b/tests/unit/sync/fixtures/fake_connection.py index 659daebe..e504ef50 100644 --- a/tests/unit/sync/fixtures/fake_connection.py +++ b/tests/unit/sync/fixtures/fake_connection.py @@ -24,6 +24,7 @@ from neo4j._deadline import Deadline from neo4j._sync.io import Bolt from neo4j.auth_management import AuthManager +from neo4j.exceptions import Neo4jError __all__ = [ @@ -154,10 +155,12 @@ def set_script(self, callbacks): [ ("run", {"on_success": ({},), "on_summary": None}), ("pull", { + "on_records": ([some_record],), "on_success": None, "on_summary": None, - "on_records": }) + # use any exception to throw it instead of calling handlers + ("commit", RuntimeError("oh no!")) ] ``` Note that arguments can be `None`. In this case, ScriptedConnection @@ -180,6 +183,9 @@ def func(*args, **kwargs): self._script_pos += 1 def callback(): + if isinstance(scripted_callbacks, BaseException): + raise scripted_callbacks + error = None for cb_name, default_cb_args in ( ("on_ignored", ({},)), ("on_failure", ({},)), @@ -197,10 +203,14 @@ def callback(): if cb_args is None: cb_args = default_cb_args res = cb(*cb_args) + if cb_name == "on_failure": + error = Neo4jError.hydrate(**cb_args[0]) try: res # maybe the callback is async except TypeError: pass # or maybe it wasn't ;) + if error is not None: + raise error self.callbacks.append(callback) diff --git a/tests/unit/sync/work/test_transaction.py b/tests/unit/sync/work/test_transaction.py index 6bffe784..d13681b0 100644 --- a/tests/unit/sync/work/test_transaction.py +++ b/tests/unit/sync/work/test_transaction.py @@ -17,7 +17,6 @@ from unittest.mock import MagicMock -from uuid import uuid4 import pytest @@ -26,6 +25,11 @@ Query, Transaction, ) +from neo4j.exceptions import ( + ClientError, + ResultFailedError, + ServiceUnavailable, +) from ...._async_compat import mark_sync_test @@ -275,3 +279,51 @@ def test_transaction_begin_pipelining( expected_calls.append(("send_all",)) expected_calls.append(("fetch_all",)) assert fake_connection.method_calls == expected_calls + + +@pytest.mark.parametrize("error", ("server", "connection")) +@mark_sync_test +def test_server_error_propagates(scripted_connection, error): + connection = scripted_connection + script = [ + # res 1 + ("run", {"on_success": ({"fields": ["n"]},), "on_summary": None}), + ("pull", {"on_records": ([[1], [2]],), + "on_success": ({"has_more": True},)}), + # res 2 + ("run", {"on_success": ({"fields": ["n"]},), "on_summary": None}), + ("pull", {"on_records": ([[1], [2]],), + "on_success": ({"has_more": True},)}), + ] + if error == "server": + script.append( + ("pull", {"on_failure": ({"code": "Neo.ClientError.Made.Up"},), + "on_summary": None}) + ) + expected_error = ClientError + elif error == "connection": + script.append(("pull", ServiceUnavailable())) + expected_error = ServiceUnavailable + else: + raise ValueError(f"Unknown error type {error}") + connection.set_script(script) + + tx = Transaction( + connection, 2, lambda *args, **kwargs: None, + lambda *args, **kwargs: None, lambda *args, **kwargs: None + ) + res1 = tx.run("UNWIND range(1, 1000) AS n RETURN n") + assert res1.__next__() == {"n": 1} + + res2 = tx.run("RETURN 'causes error later'") + assert res2.fetch(2) == [{"n": 1}, {"n": 2}] + with pytest.raises(expected_error) as exc1: + res2.__next__() + + # can finish the buffer + assert res1.fetch(1) == [{"n": 2}] + # then fails because the connection was broken by res2 + with pytest.raises(ResultFailedError) as exc2: + res1.__next__() + + assert exc1.value is exc2.value.__cause__