From 65473ec726a3bc98d52f83377ba3e647ea632504 Mon Sep 17 00:00:00 2001 From: Rohin Bhasin Date: Tue, 4 Jun 2024 16:03:23 -0400 Subject: [PATCH] Consolidate periodic loops into one function updating Den and updating autostop. --- runhouse/constants.py | 4 +- runhouse/rns/rns_client.py | 38 ++++++- runhouse/servers/autostop_servlet.py | 32 ++++++ runhouse/servers/cluster_servlet.py | 155 +++++++-------------------- 4 files changed, 110 insertions(+), 119 deletions(-) create mode 100644 runhouse/servers/autostop_servlet.py diff --git a/runhouse/constants.py b/runhouse/constants.py index 96d53c22b..d85519215 100644 --- a/runhouse/constants.py +++ b/runhouse/constants.py @@ -67,6 +67,6 @@ BULLET_UNICODE = "\u2022" MINUTE = 60 HOUR = 3600 -DEFAULT_STATUS_CHECK_INTERVAL = 2 * MINUTE +DEFAULT_STATUS_CHECK_INTERVAL = 1 * MINUTE INCREASED_STATUS_CHECK_INTERVAL = 1 * HOUR -STATUS_CHECK_DELAY = 2 * MINUTE +STATUS_CHECK_DELAY = 1 * MINUTE diff --git a/runhouse/rns/rns_client.py b/runhouse/rns/rns_client.py index 45b723a04..f25b61a75 100644 --- a/runhouse/rns/rns_client.py +++ b/runhouse/rns/rns_client.py @@ -5,11 +5,13 @@ import os import shutil from pathlib import Path -from typing import Dict, Optional, Set, Union +from typing import Any, Dict, List, Optional, Set, Union import dotenv +import httpx import requests +from pydantic import BaseModel from runhouse.rns.utils.api import ( generate_uuid, @@ -21,6 +23,17 @@ logger = logging.getLogger(__name__) +# This is a copy of the Pydantic model that we use to validate in Den +class ResourceStatusData(BaseModel): + cluster_config: dict + env_resource_mapping: Dict[str, List[Dict[str, Any]]] + system_cpu_usage: float + system_memory_usage: Dict[str, Any] + system_disk_usage: Dict[str, Any] + env_servlet_processes: Dict[str, Dict[str, Any]] + server_pid: int + runhouse_version: str + class RNSClient: """Manage a particular resource with the runhouse database""" @@ -659,3 +672,26 @@ def contents(self, name_or_path, full_paths): return folder(name=name_or_path, path=folder_url).resources( full_paths=full_paths ) + + async def send_status(self, status: ResourceStatusData, cluster_rns_address: str): + status_data = { + "status": "running", + "resource_type": status.cluster_config.get("resource_type"), + "data": dict(status), + } + cluster_uri = self.format_rns_address(cluster_rns_address) + api_server_url = status.cluster_config.get( + "api_server_url", self.api_server_url + ) + client = httpx.AsyncClient() + resp = await client.post( + f"{api_server_url}/resource/{cluster_uri}/cluster/status", + data=json.dumps(status_data), + headers=self.request_headers(), + ) + if resp.status_code != 200: + logger.error( + f"Received [{resp.status_code}]: Failed to send cluster status info to Den: {resp.text}." + ) + else: + logger.info("Successfully sent cluster status info to Den.") diff --git a/runhouse/servers/autostop_servlet.py b/runhouse/servers/autostop_servlet.py new file mode 100644 index 000000000..33a6a3665 --- /dev/null +++ b/runhouse/servers/autostop_servlet.py @@ -0,0 +1,32 @@ +import time + + +class AutostopServlet: + """A helper class strictly to run SkyPilot methods on OnDemandClusters inside SkyPilot's conda env.""" + + def __init__(self): + self._last_activity = time.time() + self._last_register = None + + def set_last_active_time_to_now(self): + self._last_activity = time.time() + + def set_autostop(self, value=None): + from sky.skylet import autostop_lib + + self.set_last_active_time_to_now() + autostop_lib.set_autostop(value, None, True) + + def update_autostop_in_sky_config(self): + import pickle + + from sky.skylet import configs as sky_configs + + autostop_mins = pickle.loads( + sky_configs.get_config("autostop_config") + ).autostop_idle_minutes + if autostop_mins > 0 and ( + self._last_register is None or self._last_register < self._last_activity + ): + sky_configs.set_config("autostop_last_active_time", self._last_activity) + self._last_register = self._last_activity diff --git a/runhouse/servers/cluster_servlet.py b/runhouse/servers/cluster_servlet.py index 90f2d66b7..5ffe00fdd 100644 --- a/runhouse/servers/cluster_servlet.py +++ b/runhouse/servers/cluster_servlet.py @@ -3,13 +3,9 @@ import json import logging import threading -import time from pathlib import Path from typing import Any, Dict, List, Optional, Set, Union -import requests -from pydantic import BaseModel - import runhouse from runhouse.constants import ( @@ -21,7 +17,9 @@ from runhouse.globals import configs, obj_store, rns_client from runhouse.resources.hardware import load_cluster_config_from_file +from runhouse.rns.rns_client import ResourceStatusData from runhouse.rns.utils.api import ResourceAccess +from runhouse.servers.autostop_servlet import AutostopServlet from runhouse.servers.http.auth import AuthCache from runhouse.utils import sync_function @@ -33,18 +31,6 @@ class ClusterServletError(Exception): pass -# This is a copy of the Pydantic model that we use to validate in Den -class ResourceStatusData(BaseModel): - cluster_config: dict - env_resource_mapping: Dict[str, List[Dict[str, Any]]] - system_cpu_usage: float - system_memory_usage: Dict[str, Any] - system_disk_usage: Dict[str, Any] - env_servlet_processes: Dict[str, Dict[str, Any]] - server_pid: int - runhouse_version: str - - class ClusterServlet: async def __init__( self, cluster_config: Optional[Dict[str, Any]] = None, *args, **kwargs @@ -85,10 +71,10 @@ async def __init__( # Only send for clusters that have den_auth enabled and if we are logged in with a user's token # to authenticate the request - if self.cluster_config.get("den_auth", False) and configs.token: - logger.debug("Creating send_status_info_to_den thread.") + if self.cluster_config.get("den_auth", False): + logger.info("Creating periodic_status_check thread.") post_status_thread = threading.Thread( - target=self.send_status_info_to_den, daemon=True + target=self.periodic_status_check, daemon=True ) post_status_thread.start() @@ -260,63 +246,43 @@ async def aclear_all_references_to_env_servlet_name(self, env_servlet_name: str) # Cluster status functions ############################################## - async def asend_status_info_to_den(self): + async def aperiodic_status_check(self): # Delay the start of post_status_thread, so we'll finish the cluster startup properly await asyncio.sleep(STATUS_CHECK_DELAY) while True: - logger.info("Trying to send cluster status to Den.") try: - is_config_updated = ( - await self.aupdate_status_check_interval_in_cluster_config() - ) - interval_size = (await self.aget_cluster_config()).get( + await self.aupdate_status_check_interval_in_cluster_config() + + cluster_config = await self.aget_cluster_config() + interval_size = cluster_config.get( "status_check_interval", DEFAULT_STATUS_CHECK_INTERVAL ) - den_auth = (await self.aget_cluster_config()).get( - "den_auth", DEFAULT_STATUS_CHECK_INTERVAL - ) - if interval_size == -1: - if is_config_updated: - logger.info( - f"Disabled periodic cluster status check. For enabling it, please run " - f"cluster.restart_server(). If you want to set up an interval size that is not the " - f"default value {round(DEFAULT_STATUS_CHECK_INTERVAL/60,2)} please run " - f"cluster._enable_or_update_status_check(interval_size) after restarting the server." - ) - break - if not den_auth: - logger.info( - f"Disabled periodic cluster status check because den_auth is disabled. For enabling it, please run " - f"cluster.restart_server() and make sure that den_auth is enabled. If you want to set up an interval size that is not the " - f"default value {round(DEFAULT_STATUS_CHECK_INTERVAL / 60, 2)} please run " - f"cluster._enable_or_update_status_check(interval_size) after restarting the server." - ) - break - status: ResourceStatusData = await self.astatus() - status_data = { - "status": "running", - "resource_type": status.cluster_config.get("resource_type"), - "data": dict(status), - } - cluster_uri = rns_client.format_rns_address( - (await self.aget_cluster_config()).get("name") - ) - api_server_url = status.cluster_config.get( - "api_server_url", rns_client.api_server_url - ) - post_status_data_resp = requests.post( - f"{api_server_url}/resource/{cluster_uri}/cluster/status", - data=json.dumps(status_data), - headers=rns_client.request_headers(), - ) - if post_status_data_resp.status_code != 200: - logger.error( - f"({post_status_data_resp.status_code}) Failed to send cluster status check to Den: {post_status_data_resp.text}" - ) - else: + den_auth = cluster_config.get("den_auth", False) + + # Only if one of these is true, do we actually need to get the status from each EnvServlet + should_send_status_to_den = den_auth and interval_size != -1 + should_update_autostop = self.autostop_servlet is not None + if should_send_status_to_den or should_update_autostop: logger.info( - f"Successfully updated cluster status in Den. Next status check will be in {round(interval_size / 60, 2)} minutes." + "Performing cluster status check: potentially sending to Den or updating autostop." ) + status: ResourceStatusData = await self.astatus() + + if should_update_autostop: + function_running = any( + any( + len(resource_info["active_function_calls"]) > 0 + for resource_info in resources + ) + for resources in status.env_resource_mapping.values() + ) + if function_running: + await self.autostop_servlet.set_last_active_time_to_now.remote() + await self.autostop_servlet.update_autostop_in_sky_config.remote() + + if should_send_status_to_den: + cluster_rns_address = cluster_config.get("name") + await rns_client.send_status(status, cluster_rns_address) except Exception as e: logger.error( f"Cluster status check has failed: {e}. Please check cluster logs for more info." @@ -328,11 +294,13 @@ async def asend_status_info_to_den(self): f"If a value is not provided, interval size will be set to {DEFAULT_STATUS_CHECK_INTERVAL}" ) await asyncio.sleep(INCREASED_STATUS_CHECK_INTERVAL) - finally: + else: await asyncio.sleep(interval_size) - def send_status_info_to_den(self): - asyncio.run(self.asend_status_info_to_den()) + def periodic_status_check(self): + # This is only ever called once in its own thread, so we can do asyncio.run here instead of + # sync_function. + asyncio.run(self.aperiodic_status_check()) async def _status_for_env_servlet(self, env_servlet_name): try: @@ -420,48 +388,3 @@ async def astatus(self): def status(self): return sync_function(self.astatus)() - - -class AutostopServlet: - """A helper class strictly to run SkyPilot methods on OnDemandClusters inside SkyPilot's conda env.""" - - def __init__(self): - self._last_activity = time.time() - self._last_register = None - autostop_thread = threading.Thread(target=self.update_autostop, daemon=True) - autostop_thread.start() - - def set_last_active_time_to_now(self): - self._last_activity = time.time() - - def set_autostop(self, value=None): - from sky.skylet import autostop_lib - - self.set_last_active_time_to_now() - autostop_lib.set_autostop(value, None, True) - - def update_autostop(self): - import pickle - - from sky.skylet import configs as sky_configs - - while True: - - autostop_mins = pickle.loads( - sky_configs.get_config("autostop_config") - ).autostop_idle_minutes - self._last_register = float( - sky_configs.get_config("autostop_last_active_time") - ) - if autostop_mins > 0 and ( - not self._last_register - or ( - # within 2 min of autostop and there's more recent activity - 60 * autostop_mins - (time.time() - self._last_register) < 120 - and self._last_activity > self._last_register - ) - ): - sky_configs.set_config("autostop_last_active_time", self._last_activity) - self._last_register = self._last_activity - - time.sleep(30)