Skip to content

[#722] fix segfault and hung threads on KeyboardIinterrupt during parallel get #728

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 48 additions & 17 deletions irods/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,22 @@
import concurrent.futures
import threading
import multiprocessing
import weakref

from irods.data_object import iRODSDataObject
from irods.exception import DataObjectDoesNotExist
import irods.keywords as kw
from queue import Queue, Full, Empty


transfer_managers = weakref.WeakKeyDictionary()


def abort_asynchronous_transfers():
for mgr in transfer_managers:
mgr.quit()


logger = logging.getLogger(__name__)
_nullh = logging.NullHandler()
logger.addHandler(_nullh)
Expand Down Expand Up @@ -90,9 +99,11 @@ def __init__(
for future in self._futures:
future.add_done_callback(self)
else:
self.__invoke_done_callback()
self.__invoke_futures_done_logic()
return

self.progress = [0, 0]

if (progress_Queue) and (total is not None):
self.progress[1] = total

Expand All @@ -111,7 +122,7 @@ def _progress(Q, this): # - thread to update progress indicator

self._progress_fn = _progress
self._progress_thread = threading.Thread(
target=self._progress_fn, args=(progress_Queue, self)
target=self._progress_fn, args=(progress_Queue, self), daemon=True
)
self._progress_thread.start()

Expand Down Expand Up @@ -152,11 +163,13 @@ def __call__(
with self._lock:
self._futures_done[future] = future.result()
if len(self._futures) == len(self._futures_done):
self.__invoke_done_callback()
self.__invoke_futures_done_logic(
skip_user_callback=(None in self._futures_done.values())
)

def __invoke_done_callback(self):
def __invoke_futures_done_logic(self, skip_user_callback=False):
try:
if callable(self.done_callback):
if not skip_user_callback and callable(self.done_callback):
self.done_callback(self)
finally:
self.keep.pop("mgr", None)
Expand Down Expand Up @@ -239,6 +252,9 @@ def _copy_part(src, dst, length, queueObject, debug_info, mgr, updatables=()):
bytecount = 0
accum = 0
while True and bytecount < length:
if mgr._quit:
bytecount = None
break
buf = src.read(min(COPY_BUF_SIZE, length - bytecount))
buf_len = len(buf)
if 0 == buf_len:
Expand Down Expand Up @@ -274,11 +290,16 @@ class _Multipart_close_manager:
"""

def __init__(self, initial_io_, exit_barrier_):
self._quit = False
self.exit_barrier = exit_barrier_
self.initial_io = initial_io_
self.__lock = threading.Lock()
self.aux = []

def quit(self):
self._quit = True
self.exit_barrier.abort()

def __contains__(self, Io):
with self.__lock:
return Io is self.initial_io or Io in self.aux
Expand All @@ -303,8 +324,12 @@ def remove_io(self, Io):
Io.close()
self.aux.remove(Io)
is_initial = False
self.exit_barrier.wait()
if is_initial:
broken = False
try:
self.exit_barrier.wait()
except threading.BrokenBarrierError:
broken = True
if is_initial and not (broken or self._quit):
self.finalize()

def finalize(self):
Expand Down Expand Up @@ -439,13 +464,19 @@ def bytes_range_for_thread(i, num_threads, total_bytes, chunk):
Io = File = None

if Operation.isNonBlocking():
if queueLength:
return futures, queueObject, mgr
else:
return futures
return futures, queueObject, mgr
else:
bytecounts = [f.result() for f in futures]
return sum(bytecounts), total_size
bytes_transferred = 0
try:
bytecounts = [f.result() for f in futures]
if None not in bytecounts:
bytes_transferred = sum(bytecounts)
except KeyboardInterrupt:
if any(not f.done() for f in futures):
# Induce any threads still alive to quit the transfer and exit.
mgr.quit()
raise
return bytes_transferred, total_size


def io_main(session, Data, opr_, fname, R="", **kwopt):
Expand Down Expand Up @@ -558,10 +589,10 @@ def io_main(session, Data, opr_, fname, R="", **kwopt):

if Operation.isNonBlocking():

if queueLength > 0:
(futures, chunk_notify_queue, mgr) = retval
else:
futures = retval
(futures, chunk_notify_queue, mgr) = retval
transfer_managers[mgr] = None

if queueLength <= 0:
chunk_notify_queue = total_bytes = None

return AsyncNotify(
Expand Down
7 changes: 7 additions & 0 deletions irods/test/data_obj_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2955,6 +2955,13 @@ def test_replica_truncate__issue_534(self):
if data_objs.exists(data_path):
data_objs.unlink(data_path, force=True)

def test_handling_of_termination_signals_during_multithread_get__issue_722(self):
from irods.test.modules.test_signal_handling_in_multithread_get import (
test as test__issue_722,
)

test__issue_722(self)


if __name__ == "__main__":
# let the tests find the parent irods lib
Expand Down
120 changes: 120 additions & 0 deletions irods/test/modules/test_signal_handling_in_multithread_get.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
import os
import re
import signal
import subprocess
import sys
import tempfile
import time

import irods
import irods.helpers
from irods.test import modules as test_modules

OBJECT_SIZE = 2 * 1024**3
OBJECT_NAME = "data_get_issue__722"
LOCAL_TEMPFILE_NAME = "data_object_for_issue_722.dat"


_clock_polling_interval = max(0.01, time.clock_getres(time.CLOCK_BOOTTIME))


def wait_till_true(function, timeout=None):
start_time = time.clock_gettime_ns(time.CLOCK_BOOTTIME)
while not (truth_value := function()):
if (
timeout is not None
and (time.clock_gettime_ns(time.CLOCK_BOOTTIME) - start_time) * 1e-9
> timeout
):
break
time.sleep(_clock_polling_interval)
return truth_value


def test(test_case, signal_names=("SIGTERM", "SIGINT")):
"""Creates a child process executing a long get() and ensures the process can be
terminated using SIGINT or SIGTERM.
"""
program = os.path.join(test_modules.__path__[0], os.path.basename(__file__))

for signal_name in signal_names:

test_case.subTest(f"Testing with signal {signal_name}")

# Call into this same module as a command. This will initiate another Python process that
# performs a lengthy data object "get" operation (see the main body of the script, below.)
process = subprocess.Popen(
[sys.executable, program],
stderr=subprocess.PIPE,
stdout=subprocess.PIPE,
text=True,
)

# Wait for download process to reach the point of spawning data transfer threads. In Python 3.9+ versions
# of the concurrent.futures module, these are nondaemon threads and will block the exit of the main thread
# unless measures are taken (#722).
localfile = process.stdout.readline().strip()
test_case.assertTrue(
wait_till_true(
lambda: os.path.exists(localfile)
and os.stat(localfile).st_size > OBJECT_SIZE // 2
),
"Parallel download from data_objects.get() probably experienced a fatal error before spawning auxiliary data transfer threads.",
)

sig = getattr(signal, signal_name)

# Interrupt the subprocess with the given signal.
process.send_signal(sig)
# Assert that this signal is what killed the subprocess, rather than a timed out process "wait" or a natural exit
# due to misproper or incomplete handling of the signal.
try:
test_case.assertEqual(
process.wait(timeout=15),
-sig,
"Unexpected subprocess return code.",
)
except subprocess.TimeoutExpired as timeout_exc:
test_case.fail(
f"Subprocess timed out before terminating. "
"Non-daemon thread(s) probably prevented subprocess's main thread from exiting."
)
# Assert that in the case of SIGINT, the process registered a KeyboardInterrupt.
if sig == signal.SIGINT:
test_case.assertTrue(
re.search("KeyboardInterrupt", process.stderr.read()),
"Did not find expected string 'KeyboardInterrupt' in log output.",
)


if __name__ == "__main__":
# These lines are run only if the module is launched as a process.
session = irods.helpers.make_session()
hc = irods.helpers.home_collection(session)
TESTFILE_FILL = b"_" * (1024 * 1024)
object_path = f"{hc}/{OBJECT_NAME}"

# Create the object to be downloaded.
with session.data_objects.open(object_path, "w") as f:
for y in range(OBJECT_SIZE // len(TESTFILE_FILL)):
f.write(TESTFILE_FILL)
local_path = None
# Establish where (ie absolute path) to place the downloaded file, i.e. the get() target.
try:
with tempfile.NamedTemporaryFile(
prefix="local_file_issue_722.dat", delete=True
) as t:
local_path = t.name

# Tell the parent process the name of the local file being "get"ted (got) from iRODS
print(local_path)
sys.stdout.flush()

# "get" the object
session.data_objects.get(object_path, local_path)
finally:
# Clean up, whether or not the download succeeded.
if local_path is not None and os.path.exists(local_path):
os.unlink(local_path)
if session.data_objects.exists(object_path):
session.data_objects.unlink(object_path, force=True)