Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Write cache first #492

Merged
merged 6 commits into from
Nov 11, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 14 additions & 7 deletions executorlib/interactive/shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -624,13 +624,20 @@ def _execute_task_with_cache(
os.makedirs(cache_directory, exist_ok=True)
file_name = os.path.join(cache_directory, task_key + ".h5out")
if task_key + ".h5out" not in os.listdir(cache_directory):
_execute_task(
interface=interface,
task_dict=task_dict,
future_queue=future_queue,
)
data_dict["output"] = future.result()
dump(file_name=file_name, data_dict=data_dict)
f = task_dict.pop("future")
if f.set_running_or_notify_cancel():
try:
result = interface.send_and_receive_dict(input_dict=task_dict)
data_dict["output"] = result
dump(file_name=file_name, data_dict=data_dict)
f.set_result(result)
Comment on lines +631 to +633
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Consider atomic cache writes

The current implementation writes the cache file directly, which could lead to corrupted cache files if the process is interrupted between lines 631-633.

Consider using atomic writes with a temporary file:

-                data_dict["output"] = result
-                dump(file_name=file_name, data_dict=data_dict)
-                f.set_result(result)
+                data_dict["output"] = result
+                temp_file = file_name + '.tmp'
+                dump(file_name=temp_file, data_dict=data_dict)
+                os.replace(temp_file, file_name)
+                f.set_result(result)

Committable suggestion skipped: line range outside the PR's diff.

except Exception as thread_exception:
interface.shutdown(wait=True)
future_queue.task_done()
f.set_exception(exception=thread_exception)
raise thread_exception
else:
future_queue.task_done()
Comment on lines +627 to +640
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Potential race condition in cache writing

The cache writing operation happens after getting the result but before setting the future result. This could lead to a race condition if multiple processes try to write to the same cache file simultaneously.

Consider adding file locking mechanism:

 try:
     result = interface.send_and_receive_dict(input_dict=task_dict)
     data_dict["output"] = result
+    import fcntl
+    with open(file_name, 'wb') as f:
+        fcntl.flock(f.fileno(), fcntl.LOCK_EX)
     dump(file_name=file_name, data_dict=data_dict)
+        fcntl.flock(f.fileno(), fcntl.LOCK_UN)
     f.set_result(result)

Committable suggestion skipped: line range outside the PR's diff.

else:
_, result = get_output(file_name=file_name)
future = task_dict["future"]
Expand Down
2 changes: 1 addition & 1 deletion tests/test_executor_backend_mpi.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def tearDown(self):
)
def test_meta_executor_parallel_cache(self):
with Executor(
max_cores=2,
max_workers=2,
resource_dict={"cores": 2},
backend="local",
block_allocation=True,
Expand Down
31 changes: 31 additions & 0 deletions tests/test_local_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import importlib.util
from queue import Queue
from time import sleep
import shutil
import unittest

import numpy as np
Expand All @@ -16,6 +17,12 @@
from executorlib.standalone.interactive.backend import call_funct
from executorlib.standalone.serialize import cloudpickle_register

try:
import h5py

skip_h5py_test = False
except ImportError:
skip_h5py_test = True

skip_mpi4py_test = importlib.util.find_spec("mpi4py") is None

Expand Down Expand Up @@ -473,3 +480,27 @@ def test_execute_task_parallel(self):
)
self.assertEqual(f.result(), [np.array(4), np.array(4)])
q.join()


class TestFuturePoolCache(unittest.TestCase):
def tearDown(self):
shutil.rmtree("./cache")

@unittest.skipIf(
skip_h5py_test, "h5py is not installed, so the h5py tests are skipped."
)
def test_execute_task_cache(self):
f = Future()
q = Queue()
q.put({"fn": calc, "args": (), "kwargs": {"i": 1}, "future": f})
q.put({"shutdown": True, "wait": True})
cloudpickle_register(ind=1)
execute_parallel_tasks(
future_queue=q,
cores=1,
openmpi_oversubscribe=False,
spawner=MpiExecSpawner,
cache_directory="./cache",
)
self.assertEqual(f.result(), 1)
q.join()
Comment on lines +489 to +506
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Enhance cache testing coverage

The current test only verifies basic functionality. Consider adding tests for:

  1. Cache hits (same input executed twice)
  2. Cache misses (different inputs)
  3. Cache behavior with errors
  4. Verification that results are actually coming from cache

Here's a suggested test to verify cache hits:

def test_execute_task_cache_hit(self):
    # First execution
    f1 = Future()
    q1 = Queue()
    q1.put({"fn": calc, "args": (), "kwargs": {"i": 1}, "future": f1})
    q1.put({"shutdown": True, "wait": True})
    cloudpickle_register(ind=1)
    execute_parallel_tasks(
        future_queue=q1,
        cores=1,
        openmpi_oversubscribe=False,
        spawner=MpiExecSpawner,
        cache_directory="./cache",
    )
    result1 = f1.result()
    q1.join()

    # Second execution (should hit cache)
    f2 = Future()
    q2 = Queue()
    q2.put({"fn": calc, "args": (), "kwargs": {"i": 1}, "future": f2})
    q2.put({"shutdown": True, "wait": True})
    execute_parallel_tasks(
        future_queue=q2,
        cores=1,
        openmpi_oversubscribe=False,
        spawner=MpiExecSpawner,
        cache_directory="./cache",
    )
    result2 = f2.result()
    q2.join()

    self.assertEqual(result1, 1)
    self.assertEqual(result2, 1)
    # Verify cache file exists
    self.assertTrue(os.path.exists("./cache"))

Would you like me to generate additional test cases or open a GitHub issue to track this enhancement?

Loading