Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix and test active function calls in status #896

Merged
merged 1 commit into from
Jun 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 63 additions & 48 deletions runhouse/servers/obj_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -1102,6 +1102,50 @@ async def acall_local(
)
log_ctx.__enter__()

# Use a finally to track the active functions so that it is always removed
request_id = req_ctx.get().request_id

# There can be many function calls in one request_id, since request_id is tied to a call from the
# client to the server.
# We store with this func_call_id so we can easily pop the active call info out after the function
# concludes. In theory we could use a tuple of (key, start_time, etc), but it doesn't accomplish much
func_call_id = uuid.uuid4()
self.active_function_calls[func_call_id] = ActiveFunctionCallInfo(
key=key,
method_name=method_name,
request_id=request_id,
start_time=time.time(),
)
try:
res = await self._acall_local_helper(
key,
method_name,
*args,
run_name=run_name,
stream_logs=stream_logs,
remote=remote,
**kwargs,
)
finally:
del self.active_function_calls[func_call_id]
if log_ctx:
log_ctx.__exit__(None, None, None)

return res

async def _acall_local_helper(
self,
key: str,
method_name: Optional[str] = None,
*args,
run_name: Optional[str] = None,
stream_logs: bool = False,
remote: bool = False,
**kwargs,
):
"""acall_local primarily sets up the logging and tracking for the function call, then calls
_acall_local_helper to actually do the work. This is so we can have a finally block in acall_local to clean up
the active function calls tracking."""
obj = self.get_local(key, default=KeyError)

from runhouse.resources.module import Module
Expand Down Expand Up @@ -1236,8 +1280,6 @@ async def acall_local(
)
fut = self._construct_call_retrievable(res, run_name, laziness_type)
await self.aput_local(run_name, fut)
if log_ctx:
log_ctx.__exit__(None, None, None)
return fut

from runhouse.resources.resource import Resource
Expand All @@ -1259,9 +1301,6 @@ async def acall_local(
# If remote is True and the result is a resource, we return just the config
res = res.config()

if log_ctx:
log_ctx.__exit__(None, None, None)

return res

@staticmethod
Expand Down Expand Up @@ -1313,32 +1352,15 @@ async def acall(
"kwargs", {}
)

# Use a finally to track the active functions so that it is always removed
request_id = req_ctx.get().request_id

# There can be many function calls in one request_id, since request_id is tied to a call from the
# client to the server.
# We store with this func_call_id so we can easily pop the active call info out after the function
# concludes. In theory we could use a tuple of (key, start_time, etc), but it doesn't accomplish much
func_call_id = uuid.uuid4()
self.active_function_calls[func_call_id] = ActiveFunctionCallInfo(
key=key,
method_name=method_name,
request_id=request_id,
start_time=time.time(),
res = await self.acall_local(
key,
method_name,
run_name=run_name,
stream_logs=stream_logs,
remote=remote,
*args,
**kwargs,
)
try:
res = await self.acall_local(
key,
method_name,
run_name=run_name,
stream_logs=stream_logs,
remote=remote,
*args,
**kwargs,
)
finally:
del self.active_function_calls[func_call_id]
else:
res = await self.acall_for_env_servlet_name(
env_servlet_name_containing_key,
Expand Down Expand Up @@ -1531,21 +1553,10 @@ def keys_with_info(self):
if not self.has_local_storage or self.servlet_name is None:
raise NoLocalObjStoreError()

# Need to copy to avoid race conditions here, and build a new dict that maps keys to all the info we need
# Need to copy to avoid race conditions
current_active_function_calls = copy.copy(self.active_function_calls)
current_active_function_calls_by_key = {}
for _, active_function_call_info in current_active_function_calls.items():
if (
active_function_call_info.key
not in current_active_function_calls_by_key
):
current_active_function_calls_by_key[active_function_call_info.key] = []

current_active_function_calls_by_key[active_function_call_info.key].append(
active_function_call_info
)

keys_with_info = []
keys_and_info = []
for k, v in self._kv_store.items():
cls = type(v)
py_module = cls.__module__
Expand All @@ -1555,17 +1566,21 @@ def keys_with_info(self):
else (py_module + "." + cls.__qualname__)
)

keys_with_info.append(
active_fn_calls = [
call_info.dict()
for call_info in current_active_function_calls.values()
if call_info.key == k
]

keys_and_info.append(
{
"name": k,
"resource_type": cls_name,
"active_function_calls": current_active_function_calls_by_key.get(
k, []
),
"active_function_calls": active_fn_calls,
}
)

return keys_with_info
return keys_and_info

async def astatus(self):
return await self.acall_actor_method(self.cluster_servlet, "status")
Expand Down
37 changes: 36 additions & 1 deletion tests/test_resources/test_clusters/test_cluster.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import subprocess
import time
from threading import Thread

import pandas as pd
import pytest
Expand Down Expand Up @@ -76,6 +77,12 @@ def assume_caller_and_get_token():
return token_default, token_as_caller


def sleep_fn(secs):
import time

time.sleep(secs)


class TestCluster(tests.test_resources.test_resource.TestResource):
MAP_FIXTURES = {"resource": "cluster"}

Expand Down Expand Up @@ -271,9 +278,17 @@ def test_rh_here_objects(self, cluster):

@pytest.mark.level("local")
def test_rh_status_pythonic(self, cluster):
rh.env(reqs=["pytest"], name="worker_env").to(cluster)
sleep_remote = rh.function(sleep_fn).to(
cluster, env=rh.env(reqs=["pytest", "pandas"], name="worker_env")
)
cluster.put(key="status_key1", obj="status_value1", env="worker_env")
# Run these in a separate thread so that the main thread can continue
call_threads = [Thread(target=sleep_remote, args=[3]) for _ in range(3)]
for call_thread in call_threads:
call_thread.start()

# Wait a second so the calls can start
time.sleep(1)
cluster_data = cluster.status()

expected_cluster_status_data_keys = [
Expand Down Expand Up @@ -305,6 +320,26 @@ def test_rh_status_pythonic(self, cluster):
"resource_type": "str",
"active_function_calls": [],
} in cluster_data.get("env_resource_mapping")["worker_env"]
sleep_calls = cluster_data.get("env_resource_mapping")["worker_env"][1][
"active_function_calls"
]
assert len(sleep_calls) == 3
assert sleep_calls[0]["key"] == "sleep_fn"
assert sleep_calls[0]["method_name"] == "call"
assert sleep_calls[0]["request_id"]
assert sleep_calls[0]["start_time"]

# wait for threads to finish
for call_thread in call_threads:
call_thread.join()
updated_status = cluster.status()
# Check that the sleep calls are no longer active
assert (
updated_status.get("env_resource_mapping")["worker_env"][1][
"active_function_calls"
]
== []
)

# test memory usage info
expected_env_servlet_keys = [
Expand Down
Loading