diff --git a/executorlib/base/executor.py b/executorlib/base/executor.py index 74831dcf..199a8841 100644 --- a/executorlib/base/executor.py +++ b/executorlib/base/executor.py @@ -6,12 +6,12 @@ from concurrent.futures import ( Future, ) +from threading import Thread from typing import Callable, Optional, Union from executorlib.standalone.inputcheck import check_resource_dict from executorlib.standalone.queue import cancel_items_in_queue from executorlib.standalone.serialize import cloudpickle_register -from executorlib.standalone.thread import RaisingThread class ExecutorBase(FutureExecutor): @@ -29,7 +29,7 @@ def __init__(self, max_cores: Optional[int] = None): cloudpickle_register(ind=3) self._max_cores = max_cores self._future_queue: Optional[queue.Queue] = queue.Queue() - self._process: Optional[Union[RaisingThread, list[RaisingThread]]] = None + self._process: Optional[Union[Thread, list[Thread]]] = None @property def info(self) -> Optional[dict]: @@ -40,13 +40,13 @@ def info(self) -> Optional[dict]: Optional[dict]: Information about the executor. """ if self._process is not None and isinstance(self._process, list): - meta_data_dict = self._process[0].get_kwargs().copy() + meta_data_dict = self._process[0]._kwargs.copy() # type: ignore if "future_queue" in meta_data_dict: del meta_data_dict["future_queue"] meta_data_dict["max_workers"] = len(self._process) return meta_data_dict elif self._process is not None: - meta_data_dict = self._process.get_kwargs().copy() + meta_data_dict = self._process._kwargs.copy() # type: ignore if "future_queue" in meta_data_dict: del meta_data_dict["future_queue"] return meta_data_dict @@ -138,13 +138,13 @@ def shutdown(self, wait: bool = True, *, cancel_futures: bool = False): cancel_items_in_queue(que=self._future_queue) if self._process is not None and self._future_queue is not None: self._future_queue.put({"shutdown": True, "wait": wait}) - if wait and isinstance(self._process, RaisingThread): + if wait and isinstance(self._process, Thread): self._process.join() self._future_queue.join() self._process = None self._future_queue = None - def _set_process(self, process: RaisingThread): + def _set_process(self, process: Thread): """ Set the process for the executor. diff --git a/executorlib/cache/executor.py b/executorlib/cache/executor.py index 46938005..36b29d26 100644 --- a/executorlib/cache/executor.py +++ b/executorlib/cache/executor.py @@ -1,4 +1,5 @@ import os +from threading import Thread from typing import Callable, Optional from executorlib.base.executor import ExecutorBase @@ -15,7 +16,6 @@ check_max_workers_and_cores, check_nested_flux_executor, ) -from executorlib.standalone.thread import RaisingThread try: from executorlib.cache.queue_spawner import execute_with_pysqa @@ -64,7 +64,7 @@ def __init__( cache_directory_path = os.path.abspath(cache_directory) os.makedirs(cache_directory_path, exist_ok=True) self._set_process( - RaisingThread( + Thread( target=execute_tasks_h5, kwargs={ "future_queue": self._future_queue, diff --git a/executorlib/cache/shared.py b/executorlib/cache/shared.py index 4a0f65b5..d7eea67c 100644 --- a/executorlib/cache/shared.py +++ b/executorlib/cache/shared.py @@ -115,8 +115,10 @@ def execute_tasks_h5( ] else: if len(future_wait_key_lst) > 0: - raise ValueError( - "Future objects are not supported as input if disable_dependencies=True." + task_dict["future"].set_exception( + ValueError( + "Future objects are not supported as input if disable_dependencies=True." + ) ) task_dependent_lst = [] process_dict[task_key] = execute_function( diff --git a/executorlib/interactive/executor.py b/executorlib/interactive/executor.py index b2019aa6..8d46b1bc 100644 --- a/executorlib/interactive/executor.py +++ b/executorlib/interactive/executor.py @@ -1,4 +1,5 @@ from concurrent.futures import Future +from threading import Thread from typing import Any, Callable, Optional from executorlib.base.executor import ExecutorBase @@ -8,7 +9,6 @@ generate_nodes_and_edges, generate_task_hash, ) -from executorlib.standalone.thread import RaisingThread class ExecutorWithDependencies(ExecutorBase): @@ -41,7 +41,7 @@ def __init__( ) -> None: super().__init__(max_cores=max_cores) self._set_process( - RaisingThread( + Thread( target=execute_tasks_with_dependencies, kwargs={ # Executor Arguments diff --git a/executorlib/interactive/shared.py b/executorlib/interactive/shared.py index cf8b6464..408758b5 100644 --- a/executorlib/interactive/shared.py +++ b/executorlib/interactive/shared.py @@ -5,6 +5,7 @@ import time from asyncio.exceptions import CancelledError from concurrent.futures import Future, TimeoutError +from threading import Thread from time import sleep from typing import Any, Callable, Optional, Union @@ -20,7 +21,6 @@ ) from executorlib.standalone.interactive.spawner import BaseSpawner, MpiExecSpawner from executorlib.standalone.serialize import serialize_funct_h5 -from executorlib.standalone.thread import RaisingThread class ExecutorBroker(ExecutorBase): @@ -89,7 +89,7 @@ def shutdown(self, wait: bool = True, *, cancel_futures: bool = False): self._process = None self._future_queue = None - def _set_process(self, process: list[RaisingThread]): # type: ignore + def _set_process(self, process: list[Thread]): # type: ignore """ Set the process for the executor. @@ -149,7 +149,7 @@ def __init__( executor_kwargs["queue_join_on_shutdown"] = False self._set_process( process=[ - RaisingThread( + Thread( target=execute_parallel_tasks, kwargs=executor_kwargs, ) @@ -205,7 +205,7 @@ def __init__( executor_kwargs["max_cores"] = max_cores executor_kwargs["max_workers"] = max_workers self._set_process( - RaisingThread( + Thread( target=execute_separate_tasks, kwargs=executor_kwargs, ) @@ -363,17 +363,18 @@ def execute_tasks_with_dependencies( ): future_lst, ready_flag = _get_future_objects_from_input(task_dict=task_dict) exception_lst = _get_exception_lst(future_lst=future_lst) - if len(exception_lst) > 0: - task_dict["future"].set_exception(exception_lst[0]) - elif len(future_lst) == 0 or ready_flag: - # No future objects are used in the input or all future objects are already done - task_dict["args"], task_dict["kwargs"] = _update_futures_in_input( - args=task_dict["args"], kwargs=task_dict["kwargs"] - ) - executor_queue.put(task_dict) - else: # Otherwise add the function to the wait list - task_dict["future_lst"] = future_lst - wait_lst.append(task_dict) + if not _get_exception(future_obj=task_dict["future"]): + if len(exception_lst) > 0: + task_dict["future"].set_exception(exception_lst[0]) + elif len(future_lst) == 0 or ready_flag: + # No future objects are used in the input or all future objects are already done + task_dict["args"], task_dict["kwargs"] = _update_futures_in_input( + args=task_dict["args"], kwargs=task_dict["kwargs"] + ) + executor_queue.put(task_dict) + else: # Otherwise add the function to the wait list + task_dict["future_lst"] = future_lst + wait_lst.append(task_dict) future_queue.task_done() elif len(wait_lst) > 0: number_waiting = len(wait_lst) @@ -589,7 +590,7 @@ def _submit_function_to_separate_process( "init_function": None, } ) - process = RaisingThread( + process = Thread( target=execute_parallel_tasks, kwargs=task_kwargs, ) @@ -610,14 +611,13 @@ def _execute_task( future_queue (Queue): Queue for receiving new tasks. """ f = task_dict.pop("future") - if f.set_running_or_notify_cancel(): + if not f.done() and f.set_running_or_notify_cancel(): try: f.set_result(interface.send_and_receive_dict(input_dict=task_dict)) except Exception as thread_exception: interface.shutdown(wait=True) future_queue.task_done() f.set_exception(exception=thread_exception) - raise thread_exception else: future_queue.task_done() diff --git a/executorlib/standalone/__init__.py b/executorlib/standalone/__init__.py index c14857eb..c752f544 100644 --- a/executorlib/standalone/__init__.py +++ b/executorlib/standalone/__init__.py @@ -7,7 +7,6 @@ interface_shutdown, ) from executorlib.standalone.interactive.spawner import MpiExecSpawner -from executorlib.standalone.thread import RaisingThread __all__ = [ "SocketInterface", @@ -16,6 +15,5 @@ "interface_send", "interface_shutdown", "interface_receive", - "RaisingThread", "MpiExecSpawner", ] diff --git a/executorlib/standalone/thread.py b/executorlib/standalone/thread.py deleted file mode 100644 index f9cdaa2c..00000000 --- a/executorlib/standalone/thread.py +++ /dev/null @@ -1,42 +0,0 @@ -from threading import Thread - - -class RaisingThread(Thread): - """ - A subclass of Thread that allows catching exceptions raised in the thread. - - Based on https://stackoverflow.com/questions/2829329/catch-a-threads-exception-in-the-caller-thread - """ - - def __init__( - self, group=None, target=None, name=None, args=(), kwargs=None, *, daemon=None - ): - super().__init__( - group=group, - target=target, - name=name, - args=args, - kwargs=kwargs, - daemon=daemon, - ) - self._exception = None - - def get_kwargs(self): - return self._kwargs - - def run(self) -> None: - """ - Run the thread's target function and catch any exceptions raised. - """ - try: - super().run() - except Exception as e: - self._exception = e - - def join(self, timeout=None) -> None: - """ - Wait for the thread to complete and re-raise any exceptions caught during execution. - """ - super().join(timeout=timeout) - if self._exception: - raise self._exception diff --git a/tests/test_cache_executor_serial.py b/tests/test_cache_executor_serial.py index bb3a6967..2a923965 100644 --- a/tests/test_cache_executor_serial.py +++ b/tests/test_cache_executor_serial.py @@ -3,12 +3,12 @@ from queue import Queue import shutil import unittest +from threading import Thread from executorlib.cache.subprocess_spawner import ( execute_in_subprocess, terminate_subprocess, ) -from executorlib.standalone.thread import RaisingThread try: from executorlib.cache.executor import FileExecutor, create_file_executor @@ -57,7 +57,8 @@ def test_executor_dependence_error(self): with FileExecutor( execute_function=execute_in_subprocess, disable_dependencies=True ) as exe: - exe.submit(my_funct, 1, b=exe.submit(my_funct, 1, b=2)) + fs = exe.submit(my_funct, 1, b=exe.submit(my_funct, 1, b=2)) + fs.result() def test_executor_working_directory(self): cwd = os.path.join(os.path.dirname(__file__), "executables") @@ -81,7 +82,7 @@ def test_executor_function(self): ) cache_dir = os.path.abspath("cache") os.makedirs(cache_dir, exist_ok=True) - process = RaisingThread( + process = Thread( target=execute_tasks_h5, kwargs={ "future_queue": q, @@ -122,7 +123,7 @@ def test_executor_function_dependence_kwargs(self): ) cache_dir = os.path.abspath("cache") os.makedirs(cache_dir, exist_ok=True) - process = RaisingThread( + process = Thread( target=execute_tasks_h5, kwargs={ "future_queue": q, @@ -163,7 +164,7 @@ def test_executor_function_dependence_args(self): ) cache_dir = os.path.abspath("cache") os.makedirs(cache_dir, exist_ok=True) - process = RaisingThread( + process = Thread( target=execute_tasks_h5, kwargs={ "future_queue": q, diff --git a/tests/test_dependencies_executor.py b/tests/test_dependencies_executor.py index 5ad2c902..774df6d3 100644 --- a/tests/test_dependencies_executor.py +++ b/tests/test_dependencies_executor.py @@ -3,12 +3,12 @@ import sys from time import sleep from queue import Queue +from threading import Thread from executorlib import SingleNodeExecutor from executorlib.interfaces.single import create_single_node_executor from executorlib.interactive.shared import execute_tasks_with_dependencies from executorlib.standalone.serialize import cloudpickle_register -from executorlib.standalone.thread import RaisingThread try: @@ -90,7 +90,7 @@ def test_dependency_steps(self): "slurm_cmd_args": [], }, ) - process = RaisingThread( + process = Thread( target=execute_tasks_with_dependencies, kwargs={ "future_queue": q, @@ -142,7 +142,7 @@ def test_dependency_steps_error(self): "slurm_cmd_args": [], }, ) - process = RaisingThread( + process = Thread( target=execute_tasks_with_dependencies, kwargs={ "future_queue": q, @@ -196,7 +196,7 @@ def test_dependency_steps_error_before(self): "slurm_cmd_args": [], }, ) - process = RaisingThread( + process = Thread( target=execute_tasks_with_dependencies, kwargs={ "future_queue": q, @@ -210,7 +210,10 @@ def test_dependency_steps_error_before(self): self.assertTrue(fs2.exception() is not None) with self.assertRaises(RuntimeError): fs2.result() + executor.shutdown(wait=True) q.put({"shutdown": True, "wait": True}) + q.join() + process.join() def test_many_to_one(self): length = 5 @@ -254,25 +257,29 @@ def test_block_allocation_false_one_worker(self): with self.assertRaises(RuntimeError): with SingleNodeExecutor(max_cores=1, block_allocation=False) as exe: cloudpickle_register(ind=1) - _ = exe.submit(raise_error, parameter=0) + fs = exe.submit(raise_error, parameter=0) + fs.result() def test_block_allocation_true_one_worker(self): with self.assertRaises(RuntimeError): with SingleNodeExecutor(max_cores=1, block_allocation=True) as exe: cloudpickle_register(ind=1) - _ = exe.submit(raise_error, parameter=0) + fs = exe.submit(raise_error, parameter=0) + fs.result() def test_block_allocation_false_two_workers(self): with self.assertRaises(RuntimeError): with SingleNodeExecutor(max_cores=2, block_allocation=False) as exe: cloudpickle_register(ind=1) - _ = exe.submit(raise_error, parameter=0) + fs = exe.submit(raise_error, parameter=0) + fs.result() def test_block_allocation_true_two_workers(self): with self.assertRaises(RuntimeError): with SingleNodeExecutor(max_cores=2, block_allocation=True) as exe: cloudpickle_register(ind=1) - _ = exe.submit(raise_error, parameter=0) + fs = exe.submit(raise_error, parameter=0) + fs.result() def test_block_allocation_false_one_worker_loop(self): with self.assertRaises(RuntimeError): diff --git a/tests/test_local_executor.py b/tests/test_local_executor.py index 27b1828f..be453da0 100644 --- a/tests/test_local_executor.py +++ b/tests/test_local_executor.py @@ -318,7 +318,8 @@ def test_executor_exception(self): executor_kwargs={"cores": 1}, spawner=MpiExecSpawner, ) as p: - p.submit(raise_error) + fs = p.submit(raise_error) + fs.result() def test_executor_exception_future(self): with self.assertRaises(RuntimeError): @@ -424,6 +425,7 @@ def test_execute_task_failed_no_argument(self): f = Future() q = Queue() q.put({"fn": calc_array, "args": (), "kwargs": {}, "future": f}) + q.put({"shutdown": True, "wait": True}) cloudpickle_register(ind=1) with self.assertRaises(TypeError): execute_parallel_tasks( @@ -432,12 +434,14 @@ def test_execute_task_failed_no_argument(self): openmpi_oversubscribe=False, spawner=MpiExecSpawner, ) + f.result() q.join() def test_execute_task_failed_wrong_argument(self): f = Future() q = Queue() q.put({"fn": calc_array, "args": (), "kwargs": {"j": 4}, "future": f}) + q.put({"shutdown": True, "wait": True}) cloudpickle_register(ind=1) with self.assertRaises(TypeError): execute_parallel_tasks( @@ -446,6 +450,7 @@ def test_execute_task_failed_wrong_argument(self): openmpi_oversubscribe=False, spawner=MpiExecSpawner, ) + f.result() q.join() def test_execute_task(self): diff --git a/tests/test_shared_thread.py b/tests/test_shared_thread.py deleted file mode 100644 index accd4585..00000000 --- a/tests/test_shared_thread.py +++ /dev/null @@ -1,15 +0,0 @@ -import unittest - -from executorlib.standalone.thread import RaisingThread - - -def raise_error(): - raise ValueError - - -class TestRaisingThread(unittest.TestCase): - def test_raising_thread(self): - with self.assertRaises(ValueError): - process = RaisingThread(target=raise_error) - process.start() - process.join() diff --git a/tests/test_shell_executor.py b/tests/test_shell_executor.py index 971befd0..caea210e 100644 --- a/tests/test_shell_executor.py +++ b/tests/test_shell_executor.py @@ -53,6 +53,9 @@ def test_wrong_error(self): "future": f, } ) + test_queue.put( + {"shutdown": True, "wait": True} + ) cloudpickle_register(ind=1) with self.assertRaises(TypeError): execute_parallel_tasks( @@ -61,6 +64,7 @@ def test_wrong_error(self): openmpi_oversubscribe=False, spawner=MpiExecSpawner, ) + f.result() def test_broken_executable(self): test_queue = queue.Queue() @@ -73,6 +77,12 @@ def test_broken_executable(self): "future": f, } ) + test_queue.put( + { + "shutdown": True, + "wait": True, + } + ) cloudpickle_register(ind=1) with self.assertRaises(FileNotFoundError): execute_parallel_tasks( @@ -81,6 +91,7 @@ def test_broken_executable(self): openmpi_oversubscribe=False, spawner=MpiExecSpawner, ) + f.result() def test_shell_static_executor_args(self): with SingleNodeExecutor(max_workers=1) as exe: