Skip to content

Commit

Permalink
Fix and test active function calls in status
Browse files Browse the repository at this point in the history
  • Loading branch information
dongreenberg committed Jun 16, 2024
1 parent 86b3c0a commit 93fdee7
Show file tree
Hide file tree
Showing 2 changed files with 99 additions and 47 deletions.
109 changes: 63 additions & 46 deletions runhouse/servers/obj_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -1102,6 +1102,50 @@ async def acall_local(
)
log_ctx.__enter__()

# Use a finally to track the active functions so that it is always removed
request_id = req_ctx.get().request_id

# There can be many function calls in one request_id, since request_id is tied to a call from the
# client to the server.
# We store with this func_call_id so we can easily pop the active call info out after the function
# concludes. In theory we could use a tuple of (key, start_time, etc), but it doesn't accomplish much
func_call_id = uuid.uuid4()
self.active_function_calls[func_call_id] = ActiveFunctionCallInfo(
key=key,
method_name=method_name,
request_id=request_id,
start_time=time.time(),
)
try:
res = await self._acall_local_helper(
key,
method_name,
*args,
run_name=run_name,
stream_logs=stream_logs,
remote=remote,
**kwargs,
)
finally:
del self.active_function_calls[func_call_id]
if log_ctx:
log_ctx.__exit__(None, None, None)

return res

async def _acall_local_helper(
self,
key: str,
method_name: Optional[str] = None,
*args,
run_name: Optional[str] = None,
stream_logs: bool = False,
remote: bool = False,
**kwargs,
):
"""acall_local primarily sets up the logging and tracking for the function call, then calls
_acall_local_helper to actually do the work. This is so we can have a finally block in acall_local to clean up
the active function calls tracking."""
obj = self.get_local(key, default=KeyError)

from runhouse.resources.module import Module
Expand Down Expand Up @@ -1259,9 +1303,6 @@ async def acall_local(
# If remote is True and the result is a resource, we return just the config
res = res.config()

if log_ctx:
log_ctx.__exit__(None, None, None)

return res

@staticmethod
Expand Down Expand Up @@ -1313,32 +1354,15 @@ async def acall(
"kwargs", {}
)

# Use a finally to track the active functions so that it is always removed
request_id = req_ctx.get().request_id

# There can be many function calls in one request_id, since request_id is tied to a call from the
# client to the server.
# We store with this func_call_id so we can easily pop the active call info out after the function
# concludes. In theory we could use a tuple of (key, start_time, etc), but it doesn't accomplish much
func_call_id = uuid.uuid4()
self.active_function_calls[func_call_id] = ActiveFunctionCallInfo(
key=key,
method_name=method_name,
request_id=request_id,
start_time=time.time(),
res = await self.acall_local(
key,
method_name,
run_name=run_name,
stream_logs=stream_logs,
remote=remote,
*args,
**kwargs,
)
try:
res = await self.acall_local(
key,
method_name,
run_name=run_name,
stream_logs=stream_logs,
remote=remote,
*args,
**kwargs,
)
finally:
del self.active_function_calls[func_call_id]
else:
res = await self.acall_for_env_servlet_name(
env_servlet_name_containing_key,
Expand Down Expand Up @@ -1531,21 +1555,10 @@ def keys_with_info(self):
if not self.has_local_storage or self.servlet_name is None:
raise NoLocalObjStoreError()

# Need to copy to avoid race conditions here, and build a new dict that maps keys to all the info we need
# Need to copy to avoid race conditions
current_active_function_calls = copy.copy(self.active_function_calls)
current_active_function_calls_by_key = {}
for _, active_function_call_info in current_active_function_calls.items():
if (
active_function_call_info.key
not in current_active_function_calls_by_key
):
current_active_function_calls_by_key[active_function_call_info.key] = []

current_active_function_calls_by_key[active_function_call_info.key].append(
active_function_call_info
)

keys_with_info = []
keys_and_info = []
for k, v in self._kv_store.items():
cls = type(v)
py_module = cls.__module__
Expand All @@ -1555,17 +1568,21 @@ def keys_with_info(self):
else (py_module + "." + cls.__qualname__)
)

keys_with_info.append(
active_fn_calls = [
call_info.dict()
for call_info in current_active_function_calls.values()
if call_info.key == k
]

keys_and_info.append(
{
"name": k,
"resource_type": cls_name,
"active_function_calls": current_active_function_calls_by_key.get(
k, []
),
"active_function_calls": active_fn_calls,
}
)

return keys_with_info
return keys_and_info

async def astatus(self):
return await self.acall_actor_method(self.cluster_servlet, "status")
Expand Down
37 changes: 36 additions & 1 deletion tests/test_resources/test_clusters/test_cluster.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import subprocess
import time
from threading import Thread

import pandas as pd
import pytest
Expand Down Expand Up @@ -76,6 +77,12 @@ def assume_caller_and_get_token():
return token_default, token_as_caller


def sleep_fn(secs):
import time

time.sleep(secs)


class TestCluster(tests.test_resources.test_resource.TestResource):
MAP_FIXTURES = {"resource": "cluster"}

Expand Down Expand Up @@ -271,9 +278,17 @@ def test_rh_here_objects(self, cluster):

@pytest.mark.level("local")
def test_rh_status_pythonic(self, cluster):
rh.env(reqs=["pytest"], name="worker_env").to(cluster)
sleep_remote = rh.function(sleep_fn).to(
cluster, env=rh.env(reqs=["pytest", "pandas"], name="worker_env")
)
cluster.put(key="status_key1", obj="status_value1", env="worker_env")
# Run these in a separate thread so that the main thread can continue
call_threads = [Thread(target=sleep_remote, args=[3]) for _ in range(3)]
for call_thread in call_threads:
call_thread.start()

# Wait a second so the calls can start
time.sleep(1)
cluster_data = cluster.status()

expected_cluster_status_data_keys = [
Expand Down Expand Up @@ -305,6 +320,26 @@ def test_rh_status_pythonic(self, cluster):
"resource_type": "str",
"active_function_calls": [],
} in cluster_data.get("env_resource_mapping")["worker_env"]
sleep_calls = cluster_data.get("env_resource_mapping")["worker_env"][1][
"active_function_calls"
]
assert len(sleep_calls) == 3
assert sleep_calls[0]["key"] == "sleep_fn"
assert sleep_calls[0]["method_name"] == "call"
assert sleep_calls[0]["request_id"]
assert sleep_calls[0]["start_time"]

# wait for threads to finish
for call_thread in call_threads:
call_thread.join()
updated_status = cluster.status()
# Check that the sleep calls are no longer active
assert (
updated_status.get("env_resource_mapping")["worker_env"][1][
"active_function_calls"
]
== []
)

# test memory usage info
expected_env_servlet_keys = [
Expand Down

0 comments on commit 93fdee7

Please sign in to comment.