Skip to content

Commit 0f59f57

Browse files
committed
Cache: Add working directory parameter
1 parent 057e531 commit 0f59f57

File tree

3 files changed

+22
-4
lines changed

3 files changed

+22
-4
lines changed

executorlib/cache/executor.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import os
2+
from typing import Optional
23

34
from executorlib.base.executor import ExecutorBase
45
from executorlib.cache.shared import execute_in_subprocess, execute_tasks_h5
@@ -11,6 +12,7 @@ def __init__(
1112
cache_directory: str = "cache",
1213
execute_function: callable = execute_in_subprocess,
1314
cores_per_worker: int = 1,
15+
cwd: Optional[str] = None,
1416
):
1517
"""
1618
Initialize the FileExecutor.
@@ -19,6 +21,7 @@ def __init__(
1921
cache_directory (str, optional): The directory to store cache files. Defaults to "cache".
2022
execute_function (callable, optional): The function to execute tasks. Defaults to execute_in_subprocess.
2123
cores_per_worker (int, optional): The number of CPU cores per worker. Defaults to 1.
24+
cwd (str/None): current working directory where the parallel python task is executed
2225
"""
2326
super().__init__()
2427
cache_directory_path = os.path.abspath(cache_directory)
@@ -31,6 +34,7 @@ def __init__(
3134
"execute_function": execute_function,
3235
"cache_directory": cache_directory_path,
3336
"cores_per_worker": cores_per_worker,
37+
"cwd": cwd,
3438
},
3539
)
3640
)

executorlib/cache/shared.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import subprocess
55
import sys
66
from concurrent.futures import Future
7-
from typing import Tuple
7+
from typing import Tuple, Optional
88

99
from executorlib.standalone.command import get_command_path
1010
from executorlib.standalone.hdf import dump, get_output
@@ -48,14 +48,15 @@ def done(self) -> bool:
4848

4949

5050
def execute_in_subprocess(
51-
command: list, task_dependent_lst: list = []
51+
command: list, task_dependent_lst: list = [], cwd: Optional[str] = None,
5252
) -> subprocess.Popen:
5353
"""
5454
Execute a command in a subprocess.
5555
5656
Args:
5757
command (list): The command to be executed.
58-
task_dependent_lst (list, optional): A list of subprocesses that the current subprocess depends on. Defaults to [].
58+
task_dependent_lst (list): A list of subprocesses that the current subprocess depends on. Defaults to [].
59+
cwd (str/None): current working directory where the parallel python task is executed
5960
6061
Returns:
6162
subprocess.Popen: The subprocess object.
@@ -65,14 +66,15 @@ def execute_in_subprocess(
6566
task_dependent_lst = [
6667
task for task in task_dependent_lst if task.poll() is None
6768
]
68-
return subprocess.Popen(command, universal_newlines=True)
69+
return subprocess.Popen(command, universal_newlines=True, cwd=cwd)
6970

7071

7172
def execute_tasks_h5(
7273
future_queue: queue.Queue,
7374
cache_directory: str,
7475
cores_per_worker: int,
7576
execute_function: callable,
77+
cwd: Optional[str],
7678
) -> None:
7779
"""
7880
Execute tasks stored in a queue using HDF5 files.
@@ -82,6 +84,7 @@ def execute_tasks_h5(
8284
cache_directory (str): The directory to store the HDF5 files.
8385
cores_per_worker (int): The number of cores per worker.
8486
execute_function (callable): The function to execute the tasks.
87+
cwd (str/None): current working directory where the parallel python task is executed
8588
8689
Returns:
8790
None
@@ -123,6 +126,7 @@ def execute_tasks_h5(
123126
task_dependent_lst=[
124127
process_dict[k] for k in future_wait_key_lst
125128
],
129+
cwd=cwd,
126130
)
127131
file_name_dict[task_key] = os.path.join(
128132
cache_directory, task_key + ".h5out"

tests/test_cache_executor_serial.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,10 @@ def my_funct(a, b):
1919
return a + b
2020

2121

22+
def list_files_in_working_directory():
23+
return os.listdir(os.getcwd())
24+
25+
2226
@unittest.skipIf(
2327
skip_h5io_test, "h5io is not installed, so the h5io tests are skipped."
2428
)
@@ -38,6 +42,12 @@ def test_executor_dependence_mixed(self):
3842
self.assertEqual(fs2.result(), 4)
3943
self.assertTrue(fs2.done())
4044

45+
def test_executor_working_directory(self):
46+
cwd = os.path.join(os.path.dirname(__file__), "executables")
47+
with FileExecutor(cwd=cwd) as exe:
48+
fs1 = exe.submit(list_files_in_working_directory)
49+
self.assertEqual(fs1.result(), os.listdir(cwd))
50+
4151
def test_executor_function(self):
4252
fs1 = Future()
4353
q = Queue()

0 commit comments

Comments
 (0)