diff --git a/executorlib/task_scheduler/interactive/shared.py b/executorlib/task_scheduler/interactive/shared.py index 02162308..8ce33ada 100644 --- a/executorlib/task_scheduler/interactive/shared.py +++ b/executorlib/task_scheduler/interactive/shared.py @@ -77,22 +77,18 @@ def execute_tasks( if error_log_file is not None: task_dict["error_log_file"] = error_log_file if cache_directory is None: - _execute_task_without_cache( - interface=interface, task_dict=task_dict, future_queue=future_queue - ) + _execute_task_without_cache(interface=interface, task_dict=task_dict) else: _execute_task_with_cache( interface=interface, task_dict=task_dict, - future_queue=future_queue, cache_directory=cache_directory, cache_key=cache_key, ) + _task_done(future_queue=future_queue) -def _execute_task_without_cache( - interface: SocketInterface, task_dict: dict, future_queue: queue.Queue -): +def _execute_task_without_cache(interface: SocketInterface, task_dict: dict): """ Execute the task in the task_dict by communicating it via the interface. @@ -100,7 +96,6 @@ def _execute_task_without_cache( interface (SocketInterface): socket interface for zmq communication task_dict (dict): task submitted to the executor as dictionary. This dictionary has the following keys {"fn": Callable, "args": (), "kwargs": {}, "resource_dict": {}} - future_queue (Queue): Queue for receiving new tasks. """ f = task_dict.pop("future") if not f.done() and f.set_running_or_notify_cancel(): @@ -108,16 +103,12 @@ def _execute_task_without_cache( f.set_result(interface.send_and_receive_dict(input_dict=task_dict)) except Exception as thread_exception: interface.shutdown(wait=True) - _task_done(future_queue=future_queue) f.set_exception(exception=thread_exception) - else: - _task_done(future_queue=future_queue) def _execute_task_with_cache( interface: SocketInterface, task_dict: dict, - future_queue: queue.Queue, cache_directory: str, cache_key: Optional[str] = None, ): @@ -128,7 +119,6 @@ def _execute_task_with_cache( interface (SocketInterface): socket interface for zmq communication task_dict (dict): task submitted to the executor as dictionary. This dictionary has the following keys {"fn": Callable, "args": (), "kwargs": {}, "resource_dict": {}} - future_queue (Queue): Queue for receiving new tasks. cache_directory (str): The directory to store cache files. cache_key (str, optional): By default the cache_key is generated based on the function hash, this can be overwritten by setting the cache_key. @@ -155,16 +145,11 @@ def _execute_task_with_cache( f.set_result(result) except Exception as thread_exception: interface.shutdown(wait=True) - _task_done(future_queue=future_queue) f.set_exception(exception=thread_exception) - raise thread_exception - else: - _task_done(future_queue=future_queue) else: _, _, result = get_output(file_name=file_name) future = task_dict["future"] future.set_result(result) - _task_done(future_queue=future_queue) def _task_done(future_queue: queue.Queue): diff --git a/tests/test_mpiexecspawner.py b/tests/test_mpiexecspawner.py index 1811cbae..9a9b861f 100644 --- a/tests/test_mpiexecspawner.py +++ b/tests/test_mpiexecspawner.py @@ -443,13 +443,13 @@ def test_execute_task_failed_no_argument(self): q.put({"fn": calc_array, "args": (), "kwargs": {}, "future": f}) q.put({"shutdown": True, "wait": True}) cloudpickle_register(ind=1) + execute_tasks( + future_queue=q, + cores=1, + openmpi_oversubscribe=False, + spawner=MpiExecSpawner, + ) with self.assertRaises(TypeError): - execute_tasks( - future_queue=q, - cores=1, - openmpi_oversubscribe=False, - spawner=MpiExecSpawner, - ) f.result() q.join() @@ -459,13 +459,13 @@ def test_execute_task_failed_wrong_argument(self): q.put({"fn": calc_array, "args": (), "kwargs": {"j": 4}, "future": f}) q.put({"shutdown": True, "wait": True}) cloudpickle_register(ind=1) + execute_tasks( + future_queue=q, + cores=1, + openmpi_oversubscribe=False, + spawner=MpiExecSpawner, + ) with self.assertRaises(TypeError): - execute_tasks( - future_queue=q, - cores=1, - openmpi_oversubscribe=False, - spawner=MpiExecSpawner, - ) f.result() q.join() @@ -533,13 +533,15 @@ def test_execute_task_cache_failed_no_argument(self): f = Future() q = Queue() q.put({"fn": calc_array, "args": (), "kwargs": {}, "future": f}) + q.put({"shutdown": True, "wait": True}) cloudpickle_register(ind=1) + execute_tasks( + future_queue=q, + cores=1, + openmpi_oversubscribe=False, + spawner=MpiExecSpawner, + cache_directory="executorlib_cache", + ) with self.assertRaises(TypeError): - execute_tasks( - future_queue=q, - cores=1, - openmpi_oversubscribe=False, - spawner=MpiExecSpawner, - cache_directory="executorlib_cache", - ) + f.result() q.join()