diff --git a/executorlib/interactive/shared.py b/executorlib/interactive/shared.py index 3eb79986..7fdc95f9 100644 --- a/executorlib/interactive/shared.py +++ b/executorlib/interactive/shared.py @@ -624,13 +624,20 @@ def _execute_task_with_cache( os.makedirs(cache_directory, exist_ok=True) file_name = os.path.join(cache_directory, task_key + ".h5out") if task_key + ".h5out" not in os.listdir(cache_directory): - _execute_task( - interface=interface, - task_dict=task_dict, - future_queue=future_queue, - ) - data_dict["output"] = future.result() - dump(file_name=file_name, data_dict=data_dict) + f = task_dict.pop("future") + if f.set_running_or_notify_cancel(): + try: + result = interface.send_and_receive_dict(input_dict=task_dict) + data_dict["output"] = result + dump(file_name=file_name, data_dict=data_dict) + f.set_result(result) + except Exception as thread_exception: + interface.shutdown(wait=True) + future_queue.task_done() + f.set_exception(exception=thread_exception) + raise thread_exception + else: + future_queue.task_done() else: _, result = get_output(file_name=file_name) future = task_dict["future"] diff --git a/tests/test_executor_backend_mpi.py b/tests/test_executor_backend_mpi.py index 9a002136..4876cbfc 100644 --- a/tests/test_executor_backend_mpi.py +++ b/tests/test_executor_backend_mpi.py @@ -97,7 +97,7 @@ def tearDown(self): ) def test_meta_executor_parallel_cache(self): with Executor( - max_cores=2, + max_workers=2, resource_dict={"cores": 2}, backend="local", block_allocation=True, diff --git a/tests/test_local_executor.py b/tests/test_local_executor.py index 29c5e72b..a40803cf 100644 --- a/tests/test_local_executor.py +++ b/tests/test_local_executor.py @@ -2,6 +2,7 @@ import importlib.util from queue import Queue from time import sleep +import shutil import unittest import numpy as np @@ -16,6 +17,12 @@ from executorlib.standalone.interactive.backend import call_funct from executorlib.standalone.serialize import cloudpickle_register +try: + import h5py + + skip_h5py_test = False +except ImportError: + skip_h5py_test = True skip_mpi4py_test = importlib.util.find_spec("mpi4py") is None @@ -473,3 +480,45 @@ def test_execute_task_parallel(self): ) self.assertEqual(f.result(), [np.array(4), np.array(4)]) q.join() + + +class TestFuturePoolCache(unittest.TestCase): + def tearDown(self): + shutil.rmtree("./cache") + + @unittest.skipIf( + skip_h5py_test, "h5py is not installed, so the h5py tests are skipped." + ) + def test_execute_task_cache(self): + f = Future() + q = Queue() + q.put({"fn": calc, "args": (), "kwargs": {"i": 1}, "future": f}) + q.put({"shutdown": True, "wait": True}) + cloudpickle_register(ind=1) + execute_parallel_tasks( + future_queue=q, + cores=1, + openmpi_oversubscribe=False, + spawner=MpiExecSpawner, + cache_directory="./cache", + ) + self.assertEqual(f.result(), 1) + q.join() + + @unittest.skipIf( + skip_h5py_test, "h5py is not installed, so the h5py tests are skipped." + ) + def test_execute_task_cache_failed_no_argument(self): + f = Future() + q = Queue() + q.put({"fn": calc_array, "args": (), "kwargs": {}, "future": f}) + cloudpickle_register(ind=1) + with self.assertRaises(TypeError): + execute_parallel_tasks( + future_queue=q, + cores=1, + openmpi_oversubscribe=False, + spawner=MpiExecSpawner, + cache_directory="./cache", + ) + q.join()