Skip to content

Commit

Permalink
Consolidate periodic loops into one function updating Den and updatin…
Browse files Browse the repository at this point in the history
…g autostop. (#873)

autostop.
  • Loading branch information
rohinb2 committed Jun 8, 2024
1 parent bdc1618 commit deacb8e
Show file tree
Hide file tree
Showing 4 changed files with 110 additions and 119 deletions.
4 changes: 2 additions & 2 deletions runhouse/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,6 @@
BULLET_UNICODE = "\u2022"
MINUTE = 60
HOUR = 3600
DEFAULT_STATUS_CHECK_INTERVAL = 2 * MINUTE
DEFAULT_STATUS_CHECK_INTERVAL = 1 * MINUTE
INCREASED_STATUS_CHECK_INTERVAL = 1 * HOUR
STATUS_CHECK_DELAY = 2 * MINUTE
STATUS_CHECK_DELAY = 1 * MINUTE
38 changes: 37 additions & 1 deletion runhouse/rns/rns_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@
import os
import shutil
from pathlib import Path
from typing import Dict, Optional, Set, Union
from typing import Any, Dict, List, Optional, Set, Union

import dotenv
import httpx

import requests
from pydantic import BaseModel

from runhouse.rns.utils.api import (
generate_uuid,
Expand All @@ -21,6 +23,17 @@

logger = logging.getLogger(__name__)

# This is a copy of the Pydantic model that we use to validate in Den
class ResourceStatusData(BaseModel):
cluster_config: dict
env_resource_mapping: Dict[str, List[Dict[str, Any]]]
system_cpu_usage: float
system_memory_usage: Dict[str, Any]
system_disk_usage: Dict[str, Any]
env_servlet_processes: Dict[str, Dict[str, Any]]
server_pid: int
runhouse_version: str


class RNSClient:
"""Manage a particular resource with the runhouse database"""
Expand Down Expand Up @@ -659,3 +672,26 @@ def contents(self, name_or_path, full_paths):
return folder(name=name_or_path, path=folder_url).resources(
full_paths=full_paths
)

async def send_status(self, status: ResourceStatusData, cluster_rns_address: str):
status_data = {
"status": "running",
"resource_type": status.cluster_config.get("resource_type"),
"data": dict(status),
}
cluster_uri = self.format_rns_address(cluster_rns_address)
api_server_url = status.cluster_config.get(
"api_server_url", self.api_server_url
)
client = httpx.AsyncClient()
resp = await client.post(
f"{api_server_url}/resource/{cluster_uri}/cluster/status",
data=json.dumps(status_data),
headers=self.request_headers(),
)
if resp.status_code != 200:
logger.error(
f"Received [{resp.status_code}]: Failed to send cluster status info to Den: {resp.text}."
)
else:
logger.info("Successfully sent cluster status info to Den.")
32 changes: 32 additions & 0 deletions runhouse/servers/autostop_servlet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import time


class AutostopServlet:
"""A helper class strictly to run SkyPilot methods on OnDemandClusters inside SkyPilot's conda env."""

def __init__(self):
self._last_activity = time.time()
self._last_register = None

def set_last_active_time_to_now(self):
self._last_activity = time.time()

def set_autostop(self, value=None):
from sky.skylet import autostop_lib

self.set_last_active_time_to_now()
autostop_lib.set_autostop(value, None, True)

def update_autostop_in_sky_config(self):
import pickle

from sky.skylet import configs as sky_configs

autostop_mins = pickle.loads(
sky_configs.get_config("autostop_config")
).autostop_idle_minutes
if autostop_mins > 0 and (
self._last_register is None or self._last_register < self._last_activity
):
sky_configs.set_config("autostop_last_active_time", self._last_activity)
self._last_register = self._last_activity
155 changes: 39 additions & 116 deletions runhouse/servers/cluster_servlet.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,9 @@
import json
import logging
import threading
import time
from pathlib import Path
from typing import Any, Dict, List, Optional, Set, Union

import requests
from pydantic import BaseModel

import runhouse

from runhouse.constants import (
Expand All @@ -21,7 +17,9 @@

from runhouse.globals import configs, obj_store, rns_client
from runhouse.resources.hardware import load_cluster_config_from_file
from runhouse.rns.rns_client import ResourceStatusData
from runhouse.rns.utils.api import ResourceAccess
from runhouse.servers.autostop_servlet import AutostopServlet
from runhouse.servers.http.auth import AuthCache

from runhouse.utils import sync_function
Expand All @@ -33,18 +31,6 @@ class ClusterServletError(Exception):
pass


# This is a copy of the Pydantic model that we use to validate in Den
class ResourceStatusData(BaseModel):
cluster_config: dict
env_resource_mapping: Dict[str, List[Dict[str, Any]]]
system_cpu_usage: float
system_memory_usage: Dict[str, Any]
system_disk_usage: Dict[str, Any]
env_servlet_processes: Dict[str, Dict[str, Any]]
server_pid: int
runhouse_version: str


class ClusterServlet:
async def __init__(
self, cluster_config: Optional[Dict[str, Any]] = None, *args, **kwargs
Expand Down Expand Up @@ -85,10 +71,10 @@ async def __init__(

# Only send for clusters that have den_auth enabled and if we are logged in with a user's token
# to authenticate the request
if self.cluster_config.get("den_auth", False) and configs.token:
logger.debug("Creating send_status_info_to_den thread.")
if self.cluster_config.get("den_auth", False):
logger.info("Creating periodic_status_check thread.")
post_status_thread = threading.Thread(
target=self.send_status_info_to_den, daemon=True
target=self.periodic_status_check, daemon=True
)
post_status_thread.start()

Expand Down Expand Up @@ -260,63 +246,43 @@ async def aclear_all_references_to_env_servlet_name(self, env_servlet_name: str)
# Cluster status functions
##############################################

async def asend_status_info_to_den(self):
async def aperiodic_status_check(self):
# Delay the start of post_status_thread, so we'll finish the cluster startup properly
await asyncio.sleep(STATUS_CHECK_DELAY)
while True:
logger.info("Trying to send cluster status to Den.")
try:
is_config_updated = (
await self.aupdate_status_check_interval_in_cluster_config()
)
interval_size = (await self.aget_cluster_config()).get(
await self.aupdate_status_check_interval_in_cluster_config()

cluster_config = await self.aget_cluster_config()
interval_size = cluster_config.get(
"status_check_interval", DEFAULT_STATUS_CHECK_INTERVAL
)
den_auth = (await self.aget_cluster_config()).get(
"den_auth", DEFAULT_STATUS_CHECK_INTERVAL
)
if interval_size == -1:
if is_config_updated:
logger.info(
f"Disabled periodic cluster status check. For enabling it, please run "
f"cluster.restart_server(). If you want to set up an interval size that is not the "
f"default value {round(DEFAULT_STATUS_CHECK_INTERVAL/60,2)} please run "
f"cluster._enable_or_update_status_check(interval_size) after restarting the server."
)
break
if not den_auth:
logger.info(
f"Disabled periodic cluster status check because den_auth is disabled. For enabling it, please run "
f"cluster.restart_server() and make sure that den_auth is enabled. If you want to set up an interval size that is not the "
f"default value {round(DEFAULT_STATUS_CHECK_INTERVAL / 60, 2)} please run "
f"cluster._enable_or_update_status_check(interval_size) after restarting the server."
)
break
status: ResourceStatusData = await self.astatus()
status_data = {
"status": "running",
"resource_type": status.cluster_config.get("resource_type"),
"data": dict(status),
}
cluster_uri = rns_client.format_rns_address(
(await self.aget_cluster_config()).get("name")
)
api_server_url = status.cluster_config.get(
"api_server_url", rns_client.api_server_url
)
post_status_data_resp = requests.post(
f"{api_server_url}/resource/{cluster_uri}/cluster/status",
data=json.dumps(status_data),
headers=rns_client.request_headers(),
)
if post_status_data_resp.status_code != 200:
logger.error(
f"({post_status_data_resp.status_code}) Failed to send cluster status check to Den: {post_status_data_resp.text}"
)
else:
den_auth = cluster_config.get("den_auth", False)

# Only if one of these is true, do we actually need to get the status from each EnvServlet
should_send_status_to_den = den_auth and interval_size != -1
should_update_autostop = self.autostop_servlet is not None
if should_send_status_to_den or should_update_autostop:
logger.info(
f"Successfully updated cluster status in Den. Next status check will be in {round(interval_size / 60, 2)} minutes."
"Performing cluster status check: potentially sending to Den or updating autostop."
)
status: ResourceStatusData = await self.astatus()

if should_update_autostop:
function_running = any(
any(
len(resource_info["active_function_calls"]) > 0
for resource_info in resources
)
for resources in status.env_resource_mapping.values()
)
if function_running:
await self.autostop_servlet.set_last_active_time_to_now.remote()
await self.autostop_servlet.update_autostop_in_sky_config.remote()

if should_send_status_to_den:
cluster_rns_address = cluster_config.get("name")
await rns_client.send_status(status, cluster_rns_address)
except Exception as e:
logger.error(
f"Cluster status check has failed: {e}. Please check cluster logs for more info."
Expand All @@ -328,11 +294,13 @@ async def asend_status_info_to_den(self):
f"If a value is not provided, interval size will be set to {DEFAULT_STATUS_CHECK_INTERVAL}"
)
await asyncio.sleep(INCREASED_STATUS_CHECK_INTERVAL)
finally:
else:
await asyncio.sleep(interval_size)

def send_status_info_to_den(self):
asyncio.run(self.asend_status_info_to_den())
def periodic_status_check(self):
# This is only ever called once in its own thread, so we can do asyncio.run here instead of
# sync_function.
asyncio.run(self.aperiodic_status_check())

async def _status_for_env_servlet(self, env_servlet_name):
try:
Expand Down Expand Up @@ -420,48 +388,3 @@ async def astatus(self):

def status(self):
return sync_function(self.astatus)()


class AutostopServlet:
"""A helper class strictly to run SkyPilot methods on OnDemandClusters inside SkyPilot's conda env."""

def __init__(self):
self._last_activity = time.time()
self._last_register = None
autostop_thread = threading.Thread(target=self.update_autostop, daemon=True)
autostop_thread.start()

def set_last_active_time_to_now(self):
self._last_activity = time.time()

def set_autostop(self, value=None):
from sky.skylet import autostop_lib

self.set_last_active_time_to_now()
autostop_lib.set_autostop(value, None, True)

def update_autostop(self):
import pickle

from sky.skylet import configs as sky_configs

while True:

autostop_mins = pickle.loads(
sky_configs.get_config("autostop_config")
).autostop_idle_minutes
self._last_register = float(
sky_configs.get_config("autostop_last_active_time")
)
if autostop_mins > 0 and (
not self._last_register
or (
# within 2 min of autostop and there's more recent activity
60 * autostop_mins - (time.time() - self._last_register) < 120
and self._last_activity > self._last_register
)
):
sky_configs.set_config("autostop_last_active_time", self._last_activity)
self._last_register = self._last_activity

time.sleep(30)

0 comments on commit deacb8e

Please sign in to comment.