diff --git a/runhouse/servers/obj_store.py b/runhouse/servers/obj_store.py index eb3b32919..a24e3c459 100644 --- a/runhouse/servers/obj_store.py +++ b/runhouse/servers/obj_store.py @@ -1102,6 +1102,50 @@ async def acall_local( ) log_ctx.__enter__() + # Use a finally to track the active functions so that it is always removed + request_id = req_ctx.get().request_id + + # There can be many function calls in one request_id, since request_id is tied to a call from the + # client to the server. + # We store with this func_call_id so we can easily pop the active call info out after the function + # concludes. In theory we could use a tuple of (key, start_time, etc), but it doesn't accomplish much + func_call_id = uuid.uuid4() + self.active_function_calls[func_call_id] = ActiveFunctionCallInfo( + key=key, + method_name=method_name, + request_id=request_id, + start_time=time.time(), + ) + try: + res = await self._acall_local_helper( + key, + method_name, + *args, + run_name=run_name, + stream_logs=stream_logs, + remote=remote, + **kwargs, + ) + finally: + del self.active_function_calls[func_call_id] + if log_ctx: + log_ctx.__exit__(None, None, None) + + return res + + async def _acall_local_helper( + self, + key: str, + method_name: Optional[str] = None, + *args, + run_name: Optional[str] = None, + stream_logs: bool = False, + remote: bool = False, + **kwargs, + ): + """acall_local primarily sets up the logging and tracking for the function call, then calls + _acall_local_helper to actually do the work. This is so we can have a finally block in acall_local to clean up + the active function calls tracking.""" obj = self.get_local(key, default=KeyError) from runhouse.resources.module import Module @@ -1259,9 +1303,6 @@ async def acall_local( # If remote is True and the result is a resource, we return just the config res = res.config() - if log_ctx: - log_ctx.__exit__(None, None, None) - return res @staticmethod @@ -1313,32 +1354,15 @@ async def acall( "kwargs", {} ) - # Use a finally to track the active functions so that it is always removed - request_id = req_ctx.get().request_id - - # There can be many function calls in one request_id, since request_id is tied to a call from the - # client to the server. - # We store with this func_call_id so we can easily pop the active call info out after the function - # concludes. In theory we could use a tuple of (key, start_time, etc), but it doesn't accomplish much - func_call_id = uuid.uuid4() - self.active_function_calls[func_call_id] = ActiveFunctionCallInfo( - key=key, - method_name=method_name, - request_id=request_id, - start_time=time.time(), + res = await self.acall_local( + key, + method_name, + run_name=run_name, + stream_logs=stream_logs, + remote=remote, + *args, + **kwargs, ) - try: - res = await self.acall_local( - key, - method_name, - run_name=run_name, - stream_logs=stream_logs, - remote=remote, - *args, - **kwargs, - ) - finally: - del self.active_function_calls[func_call_id] else: res = await self.acall_for_env_servlet_name( env_servlet_name_containing_key, @@ -1531,21 +1555,10 @@ def keys_with_info(self): if not self.has_local_storage or self.servlet_name is None: raise NoLocalObjStoreError() - # Need to copy to avoid race conditions here, and build a new dict that maps keys to all the info we need + # Need to copy to avoid race conditions current_active_function_calls = copy.copy(self.active_function_calls) - current_active_function_calls_by_key = {} - for _, active_function_call_info in current_active_function_calls.items(): - if ( - active_function_call_info.key - not in current_active_function_calls_by_key - ): - current_active_function_calls_by_key[active_function_call_info.key] = [] - current_active_function_calls_by_key[active_function_call_info.key].append( - active_function_call_info - ) - - keys_with_info = [] + keys_and_info = [] for k, v in self._kv_store.items(): cls = type(v) py_module = cls.__module__ @@ -1555,17 +1568,21 @@ def keys_with_info(self): else (py_module + "." + cls.__qualname__) ) - keys_with_info.append( + active_fn_calls = [ + call_info.dict() + for call_info in current_active_function_calls.values() + if call_info.key == k + ] + + keys_and_info.append( { "name": k, "resource_type": cls_name, - "active_function_calls": current_active_function_calls_by_key.get( - k, [] - ), + "active_function_calls": active_fn_calls, } ) - return keys_with_info + return keys_and_info async def astatus(self): return await self.acall_actor_method(self.cluster_servlet, "status") diff --git a/tests/test_resources/test_clusters/test_cluster.py b/tests/test_resources/test_clusters/test_cluster.py index 1247c36a0..25bd134ed 100644 --- a/tests/test_resources/test_clusters/test_cluster.py +++ b/tests/test_resources/test_clusters/test_cluster.py @@ -1,5 +1,6 @@ import subprocess import time +from threading import Thread import pandas as pd import pytest @@ -76,6 +77,12 @@ def assume_caller_and_get_token(): return token_default, token_as_caller +def sleep_fn(secs): + import time + + time.sleep(secs) + + class TestCluster(tests.test_resources.test_resource.TestResource): MAP_FIXTURES = {"resource": "cluster"} @@ -271,9 +278,17 @@ def test_rh_here_objects(self, cluster): @pytest.mark.level("local") def test_rh_status_pythonic(self, cluster): - rh.env(reqs=["pytest"], name="worker_env").to(cluster) + sleep_remote = rh.function(sleep_fn).to( + cluster, env=rh.env(reqs=["pytest", "pandas"], name="worker_env") + ) cluster.put(key="status_key1", obj="status_value1", env="worker_env") + # Run these in a separate thread so that the main thread can continue + call_threads = [Thread(target=sleep_remote, args=[3]) for _ in range(3)] + for call_thread in call_threads: + call_thread.start() + # Wait a second so the calls can start + time.sleep(1) cluster_data = cluster.status() expected_cluster_status_data_keys = [ @@ -305,6 +320,26 @@ def test_rh_status_pythonic(self, cluster): "resource_type": "str", "active_function_calls": [], } in cluster_data.get("env_resource_mapping")["worker_env"] + sleep_calls = cluster_data.get("env_resource_mapping")["worker_env"][1][ + "active_function_calls" + ] + assert len(sleep_calls) == 3 + assert sleep_calls[0]["key"] == "sleep_fn" + assert sleep_calls[0]["method_name"] == "call" + assert sleep_calls[0]["request_id"] + assert sleep_calls[0]["start_time"] + + # wait for threads to finish + for call_thread in call_threads: + call_thread.join() + updated_status = cluster.status() + # Check that the sleep calls are no longer active + assert ( + updated_status.get("env_resource_mapping")["worker_env"][1][ + "active_function_calls" + ] + == [] + ) # test memory usage info expected_env_servlet_keys = [