Skip to content

Commit

Permalink
pythongh-126434: Detect when .set() is called by a thread already .wa…
Browse files Browse the repository at this point in the history
…it()-ing. Raise an exception if that is the case. Fix race condition in multiprocessing.Event.wait() as described in python#95826
  • Loading branch information
ivarref committed Nov 13, 2024
1 parent 0ca4a85 commit f9307f6
Show file tree
Hide file tree
Showing 5 changed files with 106 additions and 22 deletions.
15 changes: 10 additions & 5 deletions Lib/multiprocessing/synchronize.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,7 @@ class Event(object):

def __init__(self, *, ctx):
self._cond = ctx.Condition(ctx.Lock())
self._threadlocal_wait_lock = None
self._flag = ctx.Value('i', 0)

def is_set(self):
Expand All @@ -338,6 +339,9 @@ def set(self):
assert not self._cond._lock._semlock._is_mine(), \
'multiprocessing.Event is not reentrant for clear(), set() and wait()'
with self._cond:
if self._threadlocal_wait_lock is not None:
assert not self._threadlocal_wait_lock.v.locked(), \
'multiprocessing.Event.set() cannot be called from a thread that is already wait()-ing'
self._flag.value = 1
self._cond.notify_all()

Expand All @@ -351,14 +355,15 @@ def wait(self, timeout=None):
assert not self._cond._lock._semlock._is_mine(), \
'multiprocessing.Event is not reentrant for clear(), set() and wait()'
with self._cond:
if self._flag.value == 1:
return True
else:
self._cond.wait(timeout)
if self._threadlocal_wait_lock is None:
self._threadlocal_wait_lock = threading.local()
self._threadlocal_wait_lock.v = threading.Lock()

if self._flag.value == 1:
return True
return False
else:
with self._threadlocal_wait_lock.v:
return self._cond.wait(timeout)

def __repr__(self) -> str:
set_status = 'set' if self.is_set() else 'unset'
Expand Down
34 changes: 34 additions & 0 deletions Lib/test/multiprocessingdata/set_clear_race.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import multiprocessing
import sys


# Reproduction code copied and modified from https://github.com/python/cpython/issues/95826
class SimpleRepro:
def __init__(self):
self.heartbeat_event = multiprocessing.Event()
self.shutdown_event = multiprocessing.Event()
self.child_proc = multiprocessing.Process(target=self.child_process, daemon=True)
self.child_proc.start()

def child_process(self):
while True:
if self.shutdown_event.is_set():
return
self.heartbeat_event.set()
self.heartbeat_event.clear()

def test_heartbeat(self):
exit_code = 0
for i in range(2000):
success = self.heartbeat_event.wait(100)
if not success:
exit_code = 1
break
self.shutdown_event.set()
self.child_proc.join()
sys.exit(exit_code)


if __name__ == '__main__':
foo = SimpleRepro()
foo.test_heartbeat()
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@
import os
import signal
import concurrent.futures
import time


def send_sigint(pid):
time.sleep(1)
os.kill(pid, signal.SIGINT)


def run_signal_handler_test():
def run_signal_handler_set_is_set_test():
shutdown_event = multiprocessing.Event()

def sigterm_handler(_signo, _stack_frame):
Expand All @@ -24,4 +26,4 @@ def sigterm_handler(_signo, _stack_frame):


if __name__ == '__main__':
run_signal_handler_test()
run_signal_handler_set_is_set_test()
34 changes: 34 additions & 0 deletions Lib/test/multiprocessingdata/wait_set_throws.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import multiprocessing
import signal
import concurrent.futures
import time
import sys
import os


def send_sigint(pid):
time.sleep(1)
os.kill(pid, signal.SIGINT)


def run_signal_handler_wait_set_test():
shutdown_event = multiprocessing.Event()

def sigterm_handler(_signo, _stack_frame):
shutdown_event.set()

signal.signal(signal.SIGINT, sigterm_handler)

with concurrent.futures.ProcessPoolExecutor() as executor:
f = executor.submit(send_sigint, os.getpid())
shutdown_event.wait()
f.result()


if __name__ == '__main__':
try:
run_signal_handler_wait_set_test()
sys.exit(1)
except AssertionError as e:
assert 'multiprocessing.Event.set() cannot be called from a thread that is already wait()-ing' in str(e)
sys.exit(0)
39 changes: 24 additions & 15 deletions Lib/test/test_multiprocessing_event/test_event.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,27 +13,36 @@
_have_multiprocessing = False


class TestEventSignalHandling(unittest.TestCase):
@unittest.skipUnless(_have_multiprocessing,
"requires multiprocessing")
@unittest.skipUnless(hasattr(signal, 'signal'),
"Requires signal.signal")
@unittest.skipUnless(hasattr(signal, 'SIGINT'),
"Requires signal.SIGINT")
@unittest.skipUnless(hasattr(os, 'kill'),
@unittest.skipUnless(_have_multiprocessing,
"requires multiprocessing")
@unittest.skipUnless(hasattr(signal, 'signal'),
"Requires signal.signal")
@unittest.skipUnless(hasattr(signal, 'SIGINT'),
"Requires signal.SIGINT")
@unittest.skipUnless(hasattr(os, 'kill'),
"Requires os.kill")
@unittest.skipUnless(hasattr(os, 'getppid'),
@unittest.skipUnless(hasattr(os, 'getppid'),
"Requires os.getppid")
@support.requires_subprocess()
def test_event_signal_handling(self):
@support.requires_subprocess()
class TestEventSignalHandling(unittest.TestCase):
def test_no_race_for_set_is_set(self):
import subprocess
script = support.findfile("event_signal.py", subdir="multiprocessingdata")
script = support.findfile("set_is_set.py", subdir="multiprocessingdata")
for x in range(10):
try:
exit_code = subprocess.call([sys.executable, script], stdout=subprocess.DEVNULL, timeout=30)
assert exit_code == 0
assert 0 == subprocess.call([sys.executable, script], timeout=60)
except subprocess.TimeoutExpired:
assert False, 'subprocess.Timeoutexpired for event_signal.py'
assert False, 'subprocess.Timeoutexpired for set_is_set.py'

def test_no_race_set_clear(self):
import subprocess
script = support.findfile("set_clear_race.py", subdir="multiprocessingdata")
assert 0 == subprocess.call([sys.executable, script])

def test_wait_set_throws(self):
import subprocess
script = support.findfile("wait_set_throws.py", subdir="multiprocessingdata")
assert 0 == subprocess.call([sys.executable, script])


if __name__ == '__main__':
Expand Down

0 comments on commit f9307f6

Please sign in to comment.