Skip to content

Commit df966af

Browse files
Cache: Add working directory parameter (#446)
* Cache: Add working directory parameter * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix tests --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 451cc92 commit df966af

File tree

3 files changed

+30
-4
lines changed

3 files changed

+30
-4
lines changed

executorlib/cache/executor.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import os
2+
from typing import Optional
23

34
from executorlib.base.executor import ExecutorBase
45
from executorlib.cache.shared import execute_in_subprocess, execute_tasks_h5
@@ -11,6 +12,7 @@ def __init__(
1112
cache_directory: str = "cache",
1213
execute_function: callable = execute_in_subprocess,
1314
cores_per_worker: int = 1,
15+
cwd: Optional[str] = None,
1416
):
1517
"""
1618
Initialize the FileExecutor.
@@ -19,6 +21,7 @@ def __init__(
1921
cache_directory (str, optional): The directory to store cache files. Defaults to "cache".
2022
execute_function (callable, optional): The function to execute tasks. Defaults to execute_in_subprocess.
2123
cores_per_worker (int, optional): The number of CPU cores per worker. Defaults to 1.
24+
cwd (str/None): current working directory where the parallel python task is executed
2225
"""
2326
super().__init__()
2427
cache_directory_path = os.path.abspath(cache_directory)
@@ -31,6 +34,7 @@ def __init__(
3134
"execute_function": execute_function,
3235
"cache_directory": cache_directory_path,
3336
"cores_per_worker": cores_per_worker,
37+
"cwd": cwd,
3438
},
3539
)
3640
)

executorlib/cache/shared.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import subprocess
55
import sys
66
from concurrent.futures import Future
7-
from typing import Tuple
7+
from typing import Optional, Tuple
88

99
from executorlib.standalone.command import get_command_path
1010
from executorlib.standalone.hdf import dump, get_output
@@ -48,14 +48,17 @@ def done(self) -> bool:
4848

4949

5050
def execute_in_subprocess(
51-
command: list, task_dependent_lst: list = []
51+
command: list,
52+
task_dependent_lst: list = [],
53+
cwd: Optional[str] = None,
5254
) -> subprocess.Popen:
5355
"""
5456
Execute a command in a subprocess.
5557
5658
Args:
5759
command (list): The command to be executed.
58-
task_dependent_lst (list, optional): A list of subprocesses that the current subprocess depends on. Defaults to [].
60+
task_dependent_lst (list): A list of subprocesses that the current subprocess depends on. Defaults to [].
61+
cwd (str/None): current working directory where the parallel python task is executed
5962
6063
Returns:
6164
subprocess.Popen: The subprocess object.
@@ -65,14 +68,15 @@ def execute_in_subprocess(
6568
task_dependent_lst = [
6669
task for task in task_dependent_lst if task.poll() is None
6770
]
68-
return subprocess.Popen(command, universal_newlines=True)
71+
return subprocess.Popen(command, universal_newlines=True, cwd=cwd)
6972

7073

7174
def execute_tasks_h5(
7275
future_queue: queue.Queue,
7376
cache_directory: str,
7477
cores_per_worker: int,
7578
execute_function: callable,
79+
cwd: Optional[str],
7680
) -> None:
7781
"""
7882
Execute tasks stored in a queue using HDF5 files.
@@ -82,6 +86,7 @@ def execute_tasks_h5(
8286
cache_directory (str): The directory to store the HDF5 files.
8387
cores_per_worker (int): The number of cores per worker.
8488
execute_function (callable): The function to execute the tasks.
89+
cwd (str/None): current working directory where the parallel python task is executed
8590
8691
Returns:
8792
None
@@ -123,6 +128,7 @@ def execute_tasks_h5(
123128
task_dependent_lst=[
124129
process_dict[k] for k in future_wait_key_lst
125130
],
131+
cwd=cwd,
126132
)
127133
file_name_dict[task_key] = os.path.join(
128134
cache_directory, task_key + ".h5out"

tests/test_cache_executor_serial.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,10 @@ def my_funct(a, b):
1919
return a + b
2020

2121

22+
def list_files_in_working_directory():
23+
return os.listdir(os.getcwd())
24+
25+
2226
@unittest.skipIf(
2327
skip_h5io_test, "h5io is not installed, so the h5io tests are skipped."
2428
)
@@ -38,6 +42,12 @@ def test_executor_dependence_mixed(self):
3842
self.assertEqual(fs2.result(), 4)
3943
self.assertTrue(fs2.done())
4044

45+
def test_executor_working_directory(self):
46+
cwd = os.path.join(os.path.dirname(__file__), "executables")
47+
with FileExecutor(cwd=cwd) as exe:
48+
fs1 = exe.submit(list_files_in_working_directory)
49+
self.assertEqual(fs1.result(), os.listdir(cwd))
50+
4151
def test_executor_function(self):
4252
fs1 = Future()
4353
q = Queue()
@@ -51,13 +61,15 @@ def test_executor_function(self):
5161
"cache_directory": cache_dir,
5262
"execute_function": execute_in_subprocess,
5363
"cores_per_worker": 1,
64+
"cwd": None,
5465
},
5566
)
5667
process.start()
5768
self.assertFalse(fs1.done())
5869
self.assertEqual(fs1.result(), 3)
5970
self.assertTrue(fs1.done())
6071
q.put({"shutdown": True, "wait": True})
72+
process.join()
6173

6274
def test_executor_function_dependence_kwargs(self):
6375
fs1 = Future()
@@ -74,13 +86,15 @@ def test_executor_function_dependence_kwargs(self):
7486
"cache_directory": cache_dir,
7587
"execute_function": execute_in_subprocess,
7688
"cores_per_worker": 1,
89+
"cwd": None,
7790
},
7891
)
7992
process.start()
8093
self.assertFalse(fs2.done())
8194
self.assertEqual(fs2.result(), 4)
8295
self.assertTrue(fs2.done())
8396
q.put({"shutdown": True, "wait": True})
97+
process.join()
8498

8599
def test_executor_function_dependence_args(self):
86100
fs1 = Future()
@@ -97,13 +111,15 @@ def test_executor_function_dependence_args(self):
97111
"cache_directory": cache_dir,
98112
"execute_function": execute_in_subprocess,
99113
"cores_per_worker": 1,
114+
"cwd": None,
100115
},
101116
)
102117
process.start()
103118
self.assertFalse(fs2.done())
104119
self.assertEqual(fs2.result(), 5)
105120
self.assertTrue(fs2.done())
106121
q.put({"shutdown": True, "wait": True})
122+
process.join()
107123

108124
def tearDown(self):
109125
if os.path.exists("cache"):

0 commit comments

Comments
 (0)