From d861ab902c94764d87d78b7341616ca63873fb25 Mon Sep 17 00:00:00 2001 From: notEvil Date: Wed, 15 Mar 2023 11:24:42 +0100 Subject: [PATCH 1/5] - tests: added test_race.py --- tests/test_race.py | 74 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 74 insertions(+) create mode 100644 tests/test_race.py diff --git a/tests/test_race.py b/tests/test_race.py new file mode 100644 index 00000000..3ef4018e --- /dev/null +++ b/tests/test_race.py @@ -0,0 +1,74 @@ +import rpyc +import rpyc.core.async_ as rc_async_ +import rpyc.core.protocol as rc_protocol +import contextlib +import logging +import signal +import threading +import time +import unittest + + +class TestRace(unittest.TestCase): + def setUp(self): + self.connection = rpyc.classic.connect_thread() + + self.a_str = rpyc.async_(self.connection.builtin.str) + + def tearDown(self): + self.connection.close() + + def test_asyncresult_race(self): + with _patch(): + event = threading.Event() + + def hook(): + event.set() # start race + time.sleep(0.1) # loose race + + _AsyncResult._HOOK = hook + + threading.Thread(target=self.connection.serve_all).start() + time.sleep(0.1) # wait for thread to serve + + # schedule KeyboardInterrupt + thread_id = threading.get_ident() + _ = lambda: signal.pthread_kill(thread_id, signal.SIGINT) + timer = threading.Timer(1, _) + timer.start() + + a_result = self.a_str("") # request + event.wait() # wait for race to start + try: + a_result.wait() + except KeyboardInterrupt: + raise Exception("deadlock") + + timer.cancel() + + +class _AsyncResult(rc_async_.AsyncResult): + _HOOK = None + + def __call__(self, *args, **kwargs): + hook = type(self)._HOOK + if hook is not None: + hook() + return super().__call__(*args, **kwargs) + + +@contextlib.contextmanager +def _patch(): + AsyncResult = rc_async_.AsyncResult + try: + rc_async_.AsyncResult = _AsyncResult + rc_protocol.AsyncResult = _AsyncResult # from import + yield + + finally: + rc_async_.AsyncResult = AsyncResult + rc_protocol.AsyncResult = AsyncResult + + +if __name__ == "__main__": + unittest.main() From 43027f9447aa88f598fb0fbccee1f462fb128eec Mon Sep 17 00:00:00 2001 From: notEvil Date: Wed, 15 Mar 2023 11:27:35 +0100 Subject: [PATCH 2/5] - rpyc/core/protocol.Connection: fixed race condition with AsyncResult.wait --- rpyc/core/async_.py | 7 +++++-- rpyc/core/protocol.py | 20 ++++++++++++++++---- 2 files changed, 21 insertions(+), 6 deletions(-) diff --git a/rpyc/core/async_.py b/rpyc/core/async_.py index 0af147db..816626a6 100644 --- a/rpyc/core/async_.py +++ b/rpyc/core/async_.py @@ -44,16 +44,19 @@ def wait(self): """Waits for the result to arrive. If the AsyncResult object has an expiry set, and the result did not arrive within that timeout, an :class:`AsyncResultTimeout` exception is raised""" - while not (self._is_ready or self.expired): + while self._waiting(): # Serve the connection since we are not ready. Suppose # the reply for our seq is served. The callback is this class # so __call__ sets our obj and _is_ready to true. - self._conn.serve(self._ttl) + self._conn.serve(self._ttl, waiting=self._waiting) # Check if we timed out before result was ready if not self._is_ready: raise AsyncResultTimeout("result expired") + def _waiting(self): + return not (self._is_ready or self.expired) + def add_callback(self, func): """Adds a callback to be invoked when the result arrives. The callback function takes a single argument, which is the current AsyncResult diff --git a/rpyc/core/protocol.py b/rpyc/core/protocol.py index 69643c72..bc725403 100644 --- a/rpyc/core/protocol.py +++ b/rpyc/core/protocol.py @@ -260,7 +260,7 @@ def _get_seq_id(self): # IO return next(self._seqcounter) def _send(self, msg, seq, args): # IO - data = brine.dump((msg, seq, args)) + data = brine.I1.pack(msg) + brine.dump((seq, args)) # see _dispatch if self._bind_threads: this_thread = self._get_thread() data = brine.I8I8.pack(this_thread.id, this_thread._remote_thread_id) + data @@ -392,8 +392,10 @@ def _seq_request_callback(self, msg, seq, is_exc, obj): self._config["logger"].debug(debug_msg.format(msg, seq)) def _dispatch(self, data): # serving---dispatch? - msg, seq, args = brine.load(data) + msg, = brine.I1.unpack(data[:1]) # unpack just msg to minimize time to release if msg == consts.MSG_REQUEST: + self._recvlock.release() + seq, args = brine.load(data[1:]) if self._bind_threads: self._get_thread()._occupation_count += 1 self._dispatch_request(seq, args) @@ -404,15 +406,19 @@ def _dispatch(self, data): # serving---dispatch? if this_thread._occupation_count == 0: this_thread._remote_thread_id = UNBOUND_THREAD_ID if msg == consts.MSG_REPLY: + seq, args = brine.load(data[1:]) obj = self._unbox(args) self._seq_request_callback(msg, seq, False, obj) + self._recvlock.release() # releasing here fixes race condition with AsyncResult.wait elif msg == consts.MSG_EXCEPTION: + self._recvlock.release() + seq, args = brine.load(data[1:]) obj = self._unbox_exc(args) self._seq_request_callback(msg, seq, True, obj) else: raise ValueError(f"invalid message type: {msg!r}") - def serve(self, timeout=1, wait_for_lock=True): # serving + def serve(self, timeout=1, wait_for_lock=True, waiting=lambda: True): # serving """Serves a single request or reply that arrives within the given time frame (default is 1 sec). Note that the dispatching of a request might trigger multiple (nested) requests, thus this function may be @@ -427,10 +433,17 @@ def serve(self, timeout=1, wait_for_lock=True): # serving # Exit early if we cannot acquire the recvlock if not self._recvlock.acquire(False): if wait_for_lock: + if not waiting(): # unlikely, but the result could've arrived and another thread could've won the race to acquire + return False # Wait condition for recvlock release; recvlock is not underlying lock for condition return self._recv_event.wait(timeout.timeleft()) else: return False + if not waiting(): # the result arrived and we won the race to acquire, unlucky + self._recvlock.release() + with self._recv_event: + self._recv_event.notify_all() + return False # Assume the receive rlock is acquired and incremented # We must release once BEFORE dispatch, dispatch any data, and THEN notify all (see issue #527 and #449) try: @@ -442,7 +455,6 @@ def serve(self, timeout=1, wait_for_lock=True): # serving self.close() # sends close async request raise else: - self._recvlock.release() if data: self._dispatch(data) # Dispatch will unbox, invoke callbacks, etc. return True From b04c98d48990db732bd7bacdd7ab49261d3a95e2 Mon Sep 17 00:00:00 2001 From: notEvil Date: Wed, 15 Mar 2023 12:04:40 +0100 Subject: [PATCH 3/5] - fixed minor regression --- rpyc/core/protocol.py | 12 ++++++++---- tests/test_race.py | 4 ++++ 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/rpyc/core/protocol.py b/rpyc/core/protocol.py index bc725403..8b3ae94b 100644 --- a/rpyc/core/protocol.py +++ b/rpyc/core/protocol.py @@ -394,10 +394,11 @@ def _seq_request_callback(self, msg, seq, is_exc, obj): def _dispatch(self, data): # serving---dispatch? msg, = brine.I1.unpack(data[:1]) # unpack just msg to minimize time to release if msg == consts.MSG_REQUEST: - self._recvlock.release() - seq, args = brine.load(data[1:]) if self._bind_threads: self._get_thread()._occupation_count += 1 + else: + self._recvlock.release() + seq, args = brine.load(data[1:]) self._dispatch_request(seq, args) else: if self._bind_threads: @@ -409,9 +410,11 @@ def _dispatch(self, data): # serving---dispatch? seq, args = brine.load(data[1:]) obj = self._unbox(args) self._seq_request_callback(msg, seq, False, obj) - self._recvlock.release() # releasing here fixes race condition with AsyncResult.wait + if not self._bind_threads: + self._recvlock.release() # releasing here fixes race condition with AsyncResult.wait elif msg == consts.MSG_EXCEPTION: - self._recvlock.release() + if not self._bind_threads: + self._recvlock.release() seq, args = brine.load(data[1:]) obj = self._unbox_exc(args) self._seq_request_callback(msg, seq, True, obj) @@ -459,6 +462,7 @@ def serve(self, timeout=1, wait_for_lock=True, waiting=lambda: True): # serving self._dispatch(data) # Dispatch will unbox, invoke callbacks, etc. return True else: + self._recvlock.release() return False finally: with self._recv_event: diff --git a/tests/test_race.py b/tests/test_race.py index 3ef4018e..84b73ec9 100644 --- a/tests/test_race.py +++ b/tests/test_race.py @@ -3,6 +3,7 @@ import rpyc.core.protocol as rc_protocol import contextlib import logging +import os import signal import threading import time @@ -18,6 +19,9 @@ def setUp(self): def tearDown(self): self.connection.close() + @unittest.skipIf( + os.environ.get("RPYC_BIND_THREADS") == "true", "bind threads is unaffected" + ) def test_asyncresult_race(self): with _patch(): event = threading.Event() From 185bd72ba54eebb4d035b57985d2822bf2b88bc6 Mon Sep 17 00:00:00 2001 From: notEvil Date: Wed, 15 Mar 2023 13:55:25 +0100 Subject: [PATCH 4/5] - tests/test_race: changed to support thread binding --- tests/test_race.py | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/tests/test_race.py b/tests/test_race.py index 84b73ec9..c59287ef 100644 --- a/tests/test_race.py +++ b/tests/test_race.py @@ -2,8 +2,6 @@ import rpyc.core.async_ as rc_async_ import rpyc.core.protocol as rc_protocol import contextlib -import logging -import os import signal import threading import time @@ -24,11 +22,8 @@ def tearDown(self): ) def test_asyncresult_race(self): with _patch(): - event = threading.Event() - def hook(): - event.set() # start race - time.sleep(0.1) # loose race + time.sleep(0.2) # loose race _AsyncResult._HOOK = hook @@ -42,7 +37,7 @@ def hook(): timer.start() a_result = self.a_str("") # request - event.wait() # wait for race to start + time.sleep(0.1) # wait for race to start try: a_result.wait() except KeyboardInterrupt: From 25430e5a413f8687c90790a18dc7fe4844c42b87 Mon Sep 17 00:00:00 2001 From: notEvil Date: Wed, 15 Mar 2023 13:58:30 +0100 Subject: [PATCH 5/5] - fixed minor regression --- tests/test_race.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/tests/test_race.py b/tests/test_race.py index c59287ef..f71d5c6b 100644 --- a/tests/test_race.py +++ b/tests/test_race.py @@ -17,9 +17,6 @@ def setUp(self): def tearDown(self): self.connection.close() - @unittest.skipIf( - os.environ.get("RPYC_BIND_THREADS") == "true", "bind threads is unaffected" - ) def test_asyncresult_race(self): with _patch(): def hook():