Skip to content

Commit

Permalink
Save information about active function calls. (#871)
Browse files Browse the repository at this point in the history
  • Loading branch information
rohinb2 authored Jun 8, 2024
1 parent 6dde929 commit bdc1618
Show file tree
Hide file tree
Showing 3 changed files with 80 additions and 25 deletions.
2 changes: 1 addition & 1 deletion runhouse/servers/env_servlet.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,7 @@ def _get_env_gpu_usage(self, env_servlet_pid: int):
return env_gpu_usage

def _status_local_helper(self):
objects_in_env_servlet = obj_store.keys_with_type()
objects_in_env_servlet = obj_store.keys_with_info()

(
env_memory_usage,
Expand Down
95 changes: 74 additions & 21 deletions runhouse/servers/obj_store.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
import asyncio
import copy
import inspect
import logging
import os
import time
import uuid
from enum import Enum
from functools import wraps
from typing import Any, Dict, List, Optional, Set, Union

import ray
from pydantic import BaseModel

from runhouse.rns.defaults import req_ctx
from runhouse.rns.utils.api import ResourceVisibility
Expand Down Expand Up @@ -35,6 +38,13 @@ class RunhouseStopIteration(Exception):
pass


class ActiveFunctionCallInfo(BaseModel):
key: str
method_name: str
request_id: str
start_time: float


class NoLocalObjStoreError(ObjStoreError):
def __init__(self, *args):
super().__init__("No local object store exists; cannot perform operation.")
Expand Down Expand Up @@ -132,6 +142,7 @@ def __init__(self):
self.installed_envs = {} # TODO: consider deleting it?
self._kv_store: Dict[Any, Any] = None
self.env_servlet_cache = {}
self.active_function_calls = {}

async def ainitialize(
self,
Expand Down Expand Up @@ -1302,15 +1313,32 @@ async def acall(
"kwargs", {}
)

res = await self.acall_local(
key,
method_name,
run_name=run_name,
stream_logs=stream_logs,
remote=remote,
*args,
**kwargs,
# Use a finally to track the active functions so that it is always removed
request_id = req_ctx.get().request_id

# There can be many function calls in one request_id, since request_id is tied to a call from the
# client to the server.
# We store with this func_call_id so we can easily pop the active call info out after the function
# concludes. In theory we could use a tuple of (key, start_time, etc), but it doesn't accomplish much
func_call_id = uuid.uuid4()
self.active_function_calls[func_call_id] = ActiveFunctionCallInfo(
key=key,
method_name=method_name,
request_id=request_id,
start_time=time.time(),
)
try:
res = await self.acall_local(
key,
method_name,
run_name=run_name,
stream_logs=stream_logs,
remote=remote,
*args,
**kwargs,
)
finally:
del self.active_function_calls[func_call_id]
else:
res = await self.acall_for_env_servlet_name(
env_servlet_name_containing_key,
Expand Down Expand Up @@ -1499,20 +1527,45 @@ async def aput_resource_local(
##############################################
# Cluster info methods
##############################################
def keys_with_type(self):
keys_with_type = []
if self.has_local_storage and self.servlet_name is not None:
for k, v in self._kv_store.items():
cls = type(v)
py_module = cls.__module__
cls_name = (
cls.__qualname__
if py_module == "builtins"
else (py_module + "." + cls.__qualname__)
)
def keys_with_info(self):
if not self.has_local_storage or self.servlet_name is None:
raise NoLocalObjStoreError()

# Need to copy to avoid race conditions here, and build a new dict that maps keys to all the info we need
current_active_function_calls = copy.copy(self.active_function_calls)
current_active_function_calls_by_key = {}
for _, active_function_call_info in current_active_function_calls.items():
if (
active_function_call_info.key
not in current_active_function_calls_by_key
):
current_active_function_calls_by_key[active_function_call_info.key] = []

current_active_function_calls_by_key[active_function_call_info.key].append(
active_function_call_info
)

keys_with_info = []
for k, v in self._kv_store.items():
cls = type(v)
py_module = cls.__module__
cls_name = (
cls.__qualname__
if py_module == "builtins"
else (py_module + "." + cls.__qualname__)
)

keys_with_info.append(
{
"name": k,
"resource_type": cls_name,
"active_function_calls": current_active_function_calls_by_key.get(
k, []
),
}
)

keys_with_type.append({"name": k, "resource_type": cls_name})
return keys_with_type
return keys_with_info

async def astatus(self):
return await self.acall_actor_method(self.cluster_servlet, "status")
Expand Down
8 changes: 5 additions & 3 deletions tests/test_resources/test_clusters/test_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,9 +298,11 @@ def test_rh_status_pythonic(self, cluster):
assert res.get("ips") == cluster.ips

assert "worker_env" in cluster_data.get("env_resource_mapping")
assert {"name": "status_key1", "resource_type": "str"} in cluster_data.get(
"env_resource_mapping"
)["worker_env"]
assert {
"name": "status_key1",
"resource_type": "str",
"active_function_calls": [],
} in cluster_data.get("env_resource_mapping")["worker_env"]

# test memory usage info
expected_env_servlet_keys = [
Expand Down

0 comments on commit bdc1618

Please sign in to comment.