Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix race condition in Connection.serve and AsyncResult.wait #531

Merged
merged 5 commits into from
Mar 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions rpyc/core/async_.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,16 +44,19 @@ def wait(self):
"""Waits for the result to arrive. If the AsyncResult object has an
expiry set, and the result did not arrive within that timeout,
an :class:`AsyncResultTimeout` exception is raised"""
while not (self._is_ready or self.expired):
while self._waiting():
# Serve the connection since we are not ready. Suppose
# the reply for our seq is served. The callback is this class
# so __call__ sets our obj and _is_ready to true.
self._conn.serve(self._ttl)
self._conn.serve(self._ttl, waiting=self._waiting)

# Check if we timed out before result was ready
if not self._is_ready:
raise AsyncResultTimeout("result expired")

def _waiting(self):
return not (self._is_ready or self.expired)

def add_callback(self, func):
"""Adds a callback to be invoked when the result arrives. The callback
function takes a single argument, which is the current AsyncResult
Expand Down
24 changes: 20 additions & 4 deletions rpyc/core/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,7 @@ def _get_seq_id(self): # IO
return next(self._seqcounter)

def _send(self, msg, seq, args): # IO
data = brine.dump((msg, seq, args))
data = brine.I1.pack(msg) + brine.dump((seq, args)) # see _dispatch
if self._bind_threads:
this_thread = self._get_thread()
data = brine.I8I8.pack(this_thread.id, this_thread._remote_thread_id) + data
Expand Down Expand Up @@ -392,10 +392,13 @@ def _seq_request_callback(self, msg, seq, is_exc, obj):
self._config["logger"].debug(debug_msg.format(msg, seq))

def _dispatch(self, data): # serving---dispatch?
msg, seq, args = brine.load(data)
msg, = brine.I1.unpack(data[:1]) # unpack just msg to minimize time to release
if msg == consts.MSG_REQUEST:
if self._bind_threads:
self._get_thread()._occupation_count += 1
else:
self._recvlock.release()
seq, args = brine.load(data[1:])
self._dispatch_request(seq, args)
else:
if self._bind_threads:
Expand All @@ -404,15 +407,21 @@ def _dispatch(self, data): # serving---dispatch?
if this_thread._occupation_count == 0:
this_thread._remote_thread_id = UNBOUND_THREAD_ID
if msg == consts.MSG_REPLY:
seq, args = brine.load(data[1:])
obj = self._unbox(args)
self._seq_request_callback(msg, seq, False, obj)
if not self._bind_threads:
self._recvlock.release() # releasing here fixes race condition with AsyncResult.wait
elif msg == consts.MSG_EXCEPTION:
if not self._bind_threads:
self._recvlock.release()
seq, args = brine.load(data[1:])
obj = self._unbox_exc(args)
self._seq_request_callback(msg, seq, True, obj)
else:
raise ValueError(f"invalid message type: {msg!r}")

def serve(self, timeout=1, wait_for_lock=True): # serving
def serve(self, timeout=1, wait_for_lock=True, waiting=lambda: True): # serving
"""Serves a single request or reply that arrives within the given
time frame (default is 1 sec). Note that the dispatching of a request
might trigger multiple (nested) requests, thus this function may be
Expand All @@ -427,10 +436,17 @@ def serve(self, timeout=1, wait_for_lock=True): # serving
# Exit early if we cannot acquire the recvlock
if not self._recvlock.acquire(False):
if wait_for_lock:
if not waiting(): # unlikely, but the result could've arrived and another thread could've won the race to acquire
return False
# Wait condition for recvlock release; recvlock is not underlying lock for condition
return self._recv_event.wait(timeout.timeleft())
else:
return False
if not waiting(): # the result arrived and we won the race to acquire, unlucky
self._recvlock.release()
with self._recv_event:
self._recv_event.notify_all()
return False
# Assume the receive rlock is acquired and incremented
# We must release once BEFORE dispatch, dispatch any data, and THEN notify all (see issue #527 and #449)
try:
Expand All @@ -442,11 +458,11 @@ def serve(self, timeout=1, wait_for_lock=True): # serving
self.close() # sends close async request
raise
else:
self._recvlock.release()
if data:
self._dispatch(data) # Dispatch will unbox, invoke callbacks, etc.
return True
else:
self._recvlock.release()
return False
finally:
with self._recv_event:
Expand Down
70 changes: 70 additions & 0 deletions tests/test_race.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import rpyc
import rpyc.core.async_ as rc_async_
import rpyc.core.protocol as rc_protocol
import contextlib
import signal
import threading
import time
import unittest


class TestRace(unittest.TestCase):
def setUp(self):
self.connection = rpyc.classic.connect_thread()

self.a_str = rpyc.async_(self.connection.builtin.str)

def tearDown(self):
self.connection.close()

def test_asyncresult_race(self):
with _patch():
def hook():
time.sleep(0.2) # loose race

_AsyncResult._HOOK = hook

threading.Thread(target=self.connection.serve_all).start()
time.sleep(0.1) # wait for thread to serve

# schedule KeyboardInterrupt
thread_id = threading.get_ident()
_ = lambda: signal.pthread_kill(thread_id, signal.SIGINT)
timer = threading.Timer(1, _)
timer.start()

a_result = self.a_str("") # request
time.sleep(0.1) # wait for race to start
try:
a_result.wait()
except KeyboardInterrupt:
raise Exception("deadlock")

timer.cancel()


class _AsyncResult(rc_async_.AsyncResult):
_HOOK = None

def __call__(self, *args, **kwargs):
hook = type(self)._HOOK
if hook is not None:
hook()
return super().__call__(*args, **kwargs)


@contextlib.contextmanager
def _patch():
AsyncResult = rc_async_.AsyncResult
try:
rc_async_.AsyncResult = _AsyncResult
rc_protocol.AsyncResult = _AsyncResult # from import
yield

finally:
rc_async_.AsyncResult = AsyncResult
rc_protocol.AsyncResult = AsyncResult


if __name__ == "__main__":
unittest.main()