Skip to content

Commit

Permalink
send cluster status to den when running runhouse status cli command
Browse files Browse the repository at this point in the history
  • Loading branch information
Alexandra Belousov authored and Alexandra Belousov committed Aug 27, 2024
1 parent f21641d commit da456c2
Show file tree
Hide file tree
Showing 8 changed files with 180 additions and 185 deletions.
64 changes: 31 additions & 33 deletions runhouse/main.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import copy
import importlib
import logging
import math
Expand Down Expand Up @@ -33,7 +32,7 @@
START_NOHUP_CMD,
START_SCREEN_CMD,
)
from runhouse.globals import obj_store, rns_client
from runhouse.globals import rns_client
from runhouse.logger import logger
from runhouse.resources.hardware.ray_utils import (
check_for_existing_ray_instance,
Expand Down Expand Up @@ -240,9 +239,10 @@ def _print_envs_info(
env_servlet_processes: Dict[str, Dict[str, Any]], current_cluster: Cluster
):
"""
Prints info about the envs in the current_cluster.
Prints the resources in each env, and the CPU and GPU usage of the env (if exists).
Prints info about the envs in the current_cluster: resources in each env, the CPU usage and GPU usage of the env
(if exists)
"""

# Print headline
envs_in_cluster_headline = "Serving 🍦 :"
console.print(envs_in_cluster_headline)
Expand Down Expand Up @@ -395,29 +395,36 @@ def _print_status(status_data: dict, current_cluster: Cluster) -> None:
)
console.print(daemon_headline_txt, style="bold royal_blue1")

console.print(f'Runhouse v{status_data.get("runhouse_version")}')
console.print(f'server pid: {status_data.get("server_pid")}')
console.print(f"Runhouse v{status_data.get('runhouse_version')}")
console.print(f"server pid: {status_data.get('server_pid')}")

# Print relevant info from cluster config.
_print_cluster_config(cluster_config)

# print the environments in the cluster, and the resources associated with each environment.
_print_envs_info(env_servlet_processes, current_cluster)

return status_data


@app.command()
def status(
cluster_name: str = typer.Argument(
None,
help="Name of cluster to check. If not specified will check the local cluster.",
)
),
send_to_den: bool = typer.Option(
default=False,
help="Whether to update Den with the status",
),
):
"""Load the status of the Runhouse daemon running on a cluster."""
cluster_or_local = rh.here

if cluster_name:
if cluster_or_local == "file" and not cluster_name:
# If running outside the cluster must specify a cluster name
console.print("Missing argument `cluster_name`.")
return

elif cluster_name:
current_cluster = cluster(name=cluster_name)
if not current_cluster.is_up():
console.print(
Expand All @@ -434,32 +441,22 @@ def status(
f"`runhouse ssh {cluster_name}` or `sky status -r` for on-demand clusters."
)
raise typer.Exit(1)
else:
if not cluster_or_local or cluster_or_local == "file":
console.print(
"\N{smiling face with horns} Runhouse Daemon is not running... \N{No Entry} \N{Runner}. "
"Start it with `runhouse restart` or specify a remote "
"cluster to poll with `runhouse status <cluster_name>`."
)
raise typer.Exit(1)

# case we are inside the cluster
if cluster_or_local != "file":
# If we are on the cluster load status directly from the object store
cluster_status: dict = dict(obj_store.status())
cluster_config = copy.deepcopy(cluster_status.get("cluster_config"))
current_cluster: Cluster = Cluster.from_config(cluster_config)
return _print_status(cluster_status, current_cluster)
elif not cluster_or_local:
console.print(
"\N{smiling face with horns} Runhouse Daemon is not running... \N{No Entry} \N{Runner}. "
"Start it with `runhouse restart` or specify a remote "
"cluster to poll with `runhouse status <cluster_name>`."
)
raise typer.Exit(1)

if cluster_name is None:
# If running outside the cluster must specify a cluster name
console.print("Missing argument `cluster_name`.")
return
else:
# we are inside the cluster
current_cluster = cluster_or_local # cluster_or_local = rh.here

try:
current_cluster: Cluster = Cluster.from_name(name=cluster_name)
cluster_status: dict = current_cluster.status(
resource_address=current_cluster.rns_address
cluster_status = current_cluster.status(
resource_address=current_cluster.rns_address, send_to_den=send_to_den
)

except ValueError:
Expand All @@ -470,7 +467,8 @@ def status(
"\N{smiling face with horns} Runhouse Daemon is not running... \N{No Entry} \N{Runner}"
)
return
return _print_status(cluster_status, current_cluster)

_print_status(cluster_status, current_cluster)


def load_cluster(cluster_name: str):
Expand Down
24 changes: 20 additions & 4 deletions runhouse/resources/hardware/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -748,17 +748,33 @@ def connect_server_client(self, force_reconnect=False):
system=self,
)

def status(self, resource_address: str = None):
"""Loads the status of the Runhouse daemon running on the cluster."""
def status(self, resource_address: str = None, send_to_den: bool = False):
"""Load the status of the Runhouse daemon running on a cluster."""

# Note: If running outside a local cluster need to include a resource address to construct the cluster subtoken
# Allow for specifying a resource address explicitly in case the resource has no rns address yet
if self.on_this_cluster():
status = obj_store.status()
status, den_resp = obj_store.status(send_to_den=send_to_den)
else:
status = self.call_client_method(
status, den_resp = self.call_client_method(
"status",
resource_address=resource_address or self.rns_address,
send_to_den=send_to_den,
)

if send_to_den:
send_to_den_status_code = den_resp.status_code

if send_to_den_status_code == 404:
logger.info(
"Cluster has not yet been saved to Den, cannot update status or logs."
)

elif send_to_den_status_code != 200:
logger.warning(
f"Failed to send cluster status to den: {den_resp.json()}"
)

return status

def ssh_tunnel(
Expand Down
103 changes: 63 additions & 40 deletions runhouse/servers/cluster_servlet.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,17 @@ async def __init__(
if cluster_config.get("resource_subtype", None) == "OnDemandCluster":
self.autostop_helper = AutostopHelper()

self._cluster_name = self.cluster_config.get("name", None)
self._cluster_uri = (
rns_client.format_rns_address(self._cluster_name)
if self._cluster_name
else None
)

self._api_server_url = self.cluster_config.get(
"api_server_url", rns_client.api_server_url
)

logger.info("Creating periodic_cluster_checks thread.")
cluster_checks_thread = threading.Thread(
target=self.periodic_cluster_checks, daemon=True
Expand Down Expand Up @@ -134,12 +145,11 @@ async def ahas_resource_access(self, token: str, resource_uri=None) -> bool:
"""Checks whether user has read or write access to a given module saved on the cluster."""
from runhouse.rns.utils.api import ResourceAccess

if token is None:
# If no token is provided assume no access
if token is None or self._cluster_name is None:
# If no token or cluster uri are provided assume no access
return False

cluster_uri = self.cluster_config["name"]
cluster_access = await self.aresource_access_level(token, cluster_uri)
cluster_access = await self.aresource_access_level(token, self._cluster_name)
if cluster_access == ResourceAccess.WRITE:
# if user has write access to cluster will have access to all resources
return True
Expand Down Expand Up @@ -214,27 +224,33 @@ async def aclear_all_references_to_env_servlet_name(self, env_servlet_name: str)
##############################################
# Periodic Cluster Checks APIs
##############################################
@staticmethod
async def save_status_metrics_to_den(
status: ResourceStatusData, cluster_uri: str, api_server_url: str
):

async def asave_status_metrics_to_den(self, status: dict):
from runhouse.resources.hardware.utils import ResourceServerStatus

resource_info = dict(status)
env_servlet_processes = dict(resource_info.pop("env_servlet_processes"))
# making a copy so the status won't be modified with pop, since it will be returned after sending to den.
# (status is passed as pointer).
status_copy = copy.deepcopy(status)
env_servlet_processes = status_copy.pop("env_servlet_processes")

status_data = {
"status": ResourceServerStatus.running,
"resource_type": status.cluster_config.get("resource_type"),
"resource_info": resource_info,
"resource_type": status_copy.get("cluster_config").pop("resource_type"),
"resource_info": status_copy,
"env_servlet_processes": env_servlet_processes,
}

client = httpx.AsyncClient()

return await client.post(
f"{api_server_url}/resource/{cluster_uri}/cluster/status",
f"{self._api_server_url}/resource/{self._cluster_uri}/cluster/status",
data=json.dumps(status_data),
headers=rns_client.request_headers(),
)

def save_status_metrics_to_den(self, status: dict):
return sync_function(self.asave_status_metrics_to_den)(status)

async def aperiodic_cluster_checks(self):
"""Periodically check the status of the cluster, gather metrics about the cluster's utilization & memory,
and save it to Den."""
Expand All @@ -247,7 +263,9 @@ async def aperiodic_cluster_checks(self):
try:
# Only if one of these is true, do we actually need to get the status from each EnvServlet
should_send_status_and_logs_to_den: bool = (
configs.token is not None and interval_size != -1
configs.token is not None
and interval_size != -1
and self._cluster_uri is not None
)
should_update_autostop: bool = self.autostop_helper is not None

Expand All @@ -258,46 +276,36 @@ async def aperiodic_cluster_checks(self):
break

logger.debug("Performing cluster checks")
status: ResourceStatusData = await self.astatus()
status, den_resp = await self.astatus(
send_to_den=should_send_status_and_logs_to_den
)

if should_update_autostop:
logger.debug("Updating autostop")
await self._update_autostop(status)

if not should_send_status_and_logs_to_den:
break

logger.debug("Sending cluster status to Den")
cluster_rns_address = cluster_config.get("name")
cluster_uri = rns_client.format_rns_address(cluster_rns_address)
api_server_url = status.cluster_config.get(
"api_server_url", rns_client.api_server_url
)

resp = await ClusterServlet.save_status_metrics_to_den(
status=status,
cluster_uri=cluster_uri,
api_server_url=api_server_url,
)
status_code = resp.status_code
status_code = den_resp.status_code

if status_code == 404:
logger.info(
"Cluster has not yet been saved to Den, cannot update status or logs."
)
elif status_code != 200:
logger.error(
f"{status_code}: Failed to send cluster status to Den: {resp.json()}"
f"Failed to send cluster status to Den: {den_resp.json()}"
)
else:
logger.debug("Successfully sent cluster status to Den.")

prev_end_log_line = cluster_config.get("end_log_line", 0)
(
logs_resp_status_code,
new_start_log_line,
new_end_log_line,
) = await self.send_cluster_logs_to_den(
cluster_uri=cluster_uri,
api_server_url=api_server_url,
prev_end_log_line=prev_end_log_line,
)
if not logs_resp_status_code:
Expand All @@ -313,6 +321,9 @@ async def aperiodic_cluster_checks(self):
await self.aset_cluster_config_value(
key="end_log_line", value=new_end_log_line
)
# since we are setting a new values to the cluster_config, we need to reload it so the next
# cluster check iteration will reference to the updated cluster config.
cluster_config = await self.aget_cluster_config()

except Exception:
self.logger.error(
Expand All @@ -336,7 +347,7 @@ def periodic_cluster_checks(self):
# sync_function.
asyncio.run(self.aperiodic_cluster_checks())

async def _update_autostop(self, status: ResourceStatusData):
async def _update_autostop(self, status: dict):
function_running = any(
any(
len(
Expand All @@ -347,7 +358,7 @@ async def _update_autostop(self, status: ResourceStatusData):
> 0
for resource_name in resource["env_resource_mapping"].keys()
)
for resource in status.env_servlet_processes.values()
for resource in status.get("env_servlet_processes", {}).values()
)
if function_running:
await self.autostop_helper.set_last_active_time_to_now()
Expand Down Expand Up @@ -409,7 +420,9 @@ def _get_node_gpu_usage(self, server_pid: int):
"server_pid": server_pid, # will be useful for multi-node clusters.
}

async def astatus(self):
async def astatus(
self, send_to_den: bool = False
) -> Tuple[Dict, Optional[httpx.Response]]:
import psutil

from runhouse.utils import get_pid
Expand Down Expand Up @@ -492,11 +505,21 @@ async def astatus(self):
"server_memory_usage": memory_usage,
"server_gpu_usage": server_gpu_usage,
}
status_data = ResourceStatusData(**status_data)
return status_data

def status(self):
return sync_function(self.astatus)()
# converting status_data to ResourceStatusData instance to verify we constructed the status data correctly
status_data = ResourceStatusData(**status_data).dict()

if send_to_den:

logger.debug("Sending cluster status to Den")
den_resp = self.save_status_metrics_to_den(status=status_data)

return status_data, den_resp

return status_data, None

def status(self, send_to_den: bool = False):
return sync_function(self.astatus)(send_to_den=send_to_den)

##############################################
# Save cluster logs to Den
Expand All @@ -512,7 +535,7 @@ def _generate_logs_file_name(self):
return f"{current_timestamp}_{SERVER_LOGS_FILE_NAME}"

async def send_cluster_logs_to_den(
self, cluster_uri: str, api_server_url: str, prev_end_log_line: int
self, prev_end_log_line: int
) -> Tuple[Optional[int], Optional[int], Optional[int]]:
"""Load the most recent logs from the server's log file and send them to Den."""
# setting to a list, so it will be easier to get the end line num + the logs delta to send to den.
Expand All @@ -538,7 +561,7 @@ async def send_cluster_logs_to_den(
}

post_logs_resp = requests.post(
f"{api_server_url}/resource/{cluster_uri}/logs",
f"{self._api_server_url}/resource/{self._cluster_uri}/logs",
data=json.dumps(logs_data),
headers=rns_client.request_headers(),
)
Expand Down
Loading

0 comments on commit da456c2

Please sign in to comment.