Skip to content

Commit

Permalink
modify cluster status checks tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Alexandra Belousov authored and Alexandra Belousov committed Sep 2, 2024
1 parent 56cd67b commit 74be517
Show file tree
Hide file tree
Showing 7 changed files with 223 additions and 67 deletions.
130 changes: 89 additions & 41 deletions runhouse/servers/cluster_servlet.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import datetime
import json
import threading
from pathlib import Path
from typing import Any, Dict, List, Optional, Set, Tuple, Union

import httpx
Expand Down Expand Up @@ -66,11 +67,11 @@ async def __init__(
"api_server_url", rns_client.api_server_url
)

logger.info("Creating periodic_cluster_checks thread.")
cluster_checks_thread = threading.Thread(
logger.debug("Creating periodic_cluster_checks thread.")
self.cluster_checks_thread = threading.Thread(
target=self.periodic_cluster_checks, daemon=True
)
cluster_checks_thread.start()
self.cluster_checks_thread.start()

##############################################
# Cluster config state storage methods
Expand All @@ -84,6 +85,16 @@ async def aset_cluster_config(self, cluster_config: Dict[str, Any]):

self.cluster_config = cluster_config

new_cluster_name = self.cluster_config.get("name", None)

if self._cluster_name != new_cluster_name:
self._cluster_name = new_cluster_name
self._cluster_uri = (
rns_client.format_rns_address(self._cluster_name)
if self._cluster_name
else None
)

# Propagate the changes to all other process's obj_stores
await asyncio.gather(
*[
Expand Down Expand Up @@ -232,7 +243,9 @@ async def asave_status_metrics_to_den(self, status: dict):

status_data = {
"status": ResourceServerStatus.running,
"resource_type": status_copy.get("cluster_config").pop("resource_type"),
"resource_type": status_copy.get("cluster_config").pop(
"resource_type", "cluster"
),
"resource_info": status_copy,
"env_servlet_processes": env_servlet_processes,
}
Expand All @@ -248,6 +261,56 @@ async def asave_status_metrics_to_den(self, status: dict):
def save_status_metrics_to_den(self, status: dict):
return sync_function(self.asave_status_metrics_to_den)(status)

async def acheck_cluster_status(self, send_to_den: bool = True):

logger.debug("Performing cluster status checks")
status, den_resp_status_code = await self.astatus(send_to_den=send_to_den)

if not send_to_den:
return status, den_resp_status_code

if den_resp_status_code == 404:
logger.info(
"Cluster has not yet been saved to Den, cannot update status or logs"
)
elif den_resp_status_code != 200:
logger.error("Failed to send cluster status to Den")
else:
logger.debug("Successfully sent cluster status to Den")

return status, den_resp_status_code

async def acheck_cluster_logs(self, interval_size: int):

logger.debug("Performing logs checks")

cluster_config = await self.aget_cluster_config()
prev_end_log_line = cluster_config.get("end_log_line", 0)
(
logs_den_resp,
new_start_log_line,
new_end_log_line,
) = await self.send_cluster_logs_to_den(
prev_end_log_line=prev_end_log_line,
)
if not logs_den_resp:
logger.debug(
f"No logs were generated in the past {interval_size} minute(s), logs were not sent to Den"
)

elif logs_den_resp.status_code == 200:
logger.debug("Successfully sent cluster logs to Den")
await self.aset_cluster_config_value(
key="start_log_line", value=new_start_log_line
)
await self.aset_cluster_config_value(
key="end_log_line", value=new_end_log_line
)
else:
logger.error("Failed to send logs to Den")

return logs_den_resp, new_start_log_line, new_end_log_line

async def aperiodic_cluster_checks(self):
"""Periodically check the status of the cluster, gather metrics about the cluster's utilization & memory,
and save it to Den."""
Expand All @@ -257,6 +320,16 @@ async def aperiodic_cluster_checks(self):
"status_check_interval", DEFAULT_STATUS_CHECK_INTERVAL
)
while True:
should_send_status_and_logs_to_den: bool = (
configs.token is not None
and interval_size != -1
and self._cluster_uri is not None
)
should_update_autostop: bool = self.autostop_helper is not None

if not should_send_status_and_logs_to_den and not should_update_autostop:
break

try:
# Only if one of these is true, do we actually need to get the status from each EnvServlet
should_send_status_and_logs_to_den: bool = (
Expand All @@ -273,7 +346,7 @@ async def aperiodic_cluster_checks(self):
break

logger.debug("Performing cluster checks")
status, status_code = await self.astatus(
status, den_resp_code = await self.acheck_cluster_status(
send_to_den=should_send_status_and_logs_to_den
)

Expand All @@ -284,39 +357,8 @@ async def aperiodic_cluster_checks(self):
if not should_send_status_and_logs_to_den:
break

if status_code == 404:
logger.info(
"Cluster has not yet been saved to Den, cannot update status or logs."
)
elif status_code != 200:
logger.error("Failed to send cluster status to Den")
else:
logger.debug("Successfully sent cluster status to Den")

prev_end_log_line = cluster_config.get("end_log_line", 0)
(
logs_resp_status_code,
new_start_log_line,
new_end_log_line,
) = await self.send_cluster_logs_to_den(
prev_end_log_line=prev_end_log_line,
)
if not logs_resp_status_code:
logger.debug(
f"No logs were generated in the past {interval_size} minute(s), logs were not sent to Den."
)

elif logs_resp_status_code == 200:
logger.debug("Successfully sent cluster logs to Den.")
await self.aset_cluster_config_value(
key="start_log_line", value=new_start_log_line
)
await self.aset_cluster_config_value(
key="end_log_line", value=new_end_log_line
)
# since we are setting a new values to the cluster_config, we need to reload it so the next
# cluster check iteration will reference to the updated cluster config.
cluster_config = await self.aget_cluster_config()
if den_resp_code == 200:
await self.acheck_cluster_logs(interval_size=interval_size)

except Exception:
logger.error(
Expand Down Expand Up @@ -515,6 +557,10 @@ def status(self, send_to_den: bool = False):
# Save cluster logs to Den
##############################################
def _get_logs(self):

if not Path(SERVER_LOGFILE).exists():
return ""

with open(SERVER_LOGFILE) as log_file:
log_lines = log_file.readlines()
cleaned_log_lines = [ColoredFormatter.format_log(line) for line in log_lines]
Expand All @@ -526,7 +572,7 @@ def _generate_logs_file_name(self):

async def send_cluster_logs_to_den(
self, prev_end_log_line: int
) -> Tuple[Optional[int], Optional[int], Optional[int]]:
) -> Tuple[Optional[requests.Response], Optional[int], Optional[int]]:
"""Load the most recent logs from the server's log file and send them to Den."""
# setting to a list, so it will be easier to get the end line num + the logs delta to send to den.
latest_logs = self._get_logs().split("\n")
Expand All @@ -549,7 +595,9 @@ async def send_cluster_logs_to_den(
"start_line": prev_end_log_line,
"end_line": new_end_log_line,
}

logger.info(
f'uri is: {f"{self._api_server_url}/resource/{self._cluster_uri}/logs"}'
)
post_logs_resp = requests.post(
f"{self._api_server_url}/resource/{self._cluster_uri}/logs",
data=json.dumps(logs_data),
Expand All @@ -562,4 +610,4 @@ async def send_cluster_logs_to_den(
f"{resp_status_code}: Failed to send cluster logs to Den: {post_logs_resp.json()}"
)

return resp_status_code, prev_end_log_line, new_end_log_line
return post_logs_resp, prev_end_log_line, new_end_log_line
4 changes: 2 additions & 2 deletions runhouse/servers/env_servlet.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ def _get_env_cpu_usage(self, cluster_config: dict = None):

if not cluster_config.get("resource_subtype") == "Cluster":
stable_internal_external_ips = cluster_config.get(
"stable_internal_external_ips"
"stable_internal_external_ips", []
)
for ips_set in stable_internal_external_ips:
internal_ip, external_ip = ips_set[0], ips_set[1]
Expand All @@ -209,7 +209,7 @@ def _get_env_cpu_usage(self, cluster_config: dict = None):
node_name = f"worker_{stable_internal_external_ips.index(ips_set)} ({external_ip})"
else:
# a case it is a BYO cluster, assume that first ip in the ips list is the head.
ips = cluster_config.get("ips")
ips = cluster_config.get("ips", [])
if len(ips) == 1 or node_ip == ips[0]:
node_name = f"head ({node_ip})"
else:
Expand Down
36 changes: 36 additions & 0 deletions tests/test_resources/test_clusters/test_on_demand_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,13 @@
import time

import pytest
import requests

import runhouse as rh

from runhouse.globals import rns_client
from runhouse.resources.hardware.utils import ResourceServerStatus

import tests.test_resources.test_clusters.test_cluster
from tests.utils import friend_account

Expand Down Expand Up @@ -197,3 +201,35 @@ def test_docker_container_reqs(self, ondemand_aws_cluster):
def test_fn_to_docker_container(self, ondemand_aws_cluster):
remote_torch_exists = rh.function(torch_exists).to(ondemand_aws_cluster)
assert remote_torch_exists()

####################################################################################################
# Status tests
####################################################################################################

@pytest.mark.level("minimal")
def test_set_status_after_teardown(self, cluster, mocker):
mock_function = mocker.patch("sky.down")
response = cluster.teardown()
assert isinstance(response, int)
assert (
response == 200
) # that means that the call to post status endpoint in den was successful
mock_function.assert_called_once()

cluster_config = cluster.config()
cluster_uri = rns_client.format_rns_address(cluster.rns_address)
api_server_url = cluster_config.get("api_server_url", rns_client.api_server_url)
cluster.teardown()
get_status_data_resp = requests.get(
f"{api_server_url}/resource/{cluster_uri}/cluster/status",
headers=rns_client.request_headers(),
)

assert get_status_data_resp.status_code == 200
# For UI displaying purposes, the cluster/status endpoint returns cluster status history.
# The latest status info is the first element in the list returned by the endpoint.
get_status_data = get_status_data_resp.json()["data"][0]
assert get_status_data["resource_type"] == cluster_config.get("resource_type")
assert get_status_data["status"] == ResourceServerStatus.terminated

assert cluster.is_up()
20 changes: 15 additions & 5 deletions tests/test_servers/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,11 @@
from runhouse.servers.http.certs import TLSCertConfig
from runhouse.servers.http.http_server import app, HTTPServer

from tests.utils import friend_account, get_ray_servlet_and_obj_store
from tests.utils import (
friend_account,
get_ray_cluster_servlet,
get_ray_env_servlet_and_obj_store,
)

logger = get_logger(__name__)

Expand Down Expand Up @@ -109,17 +113,23 @@ def local_client_with_den_auth(logged_in_account):


@pytest.fixture(scope="session")
def test_servlet():
servlet, _ = get_ray_servlet_and_obj_store("test_servlet")
yield servlet
def test_env_servlet():
env_servlet, _ = get_ray_env_servlet_and_obj_store("test_env_servlet")
yield env_servlet


@pytest.fixture(scope="session")
def test_cluster_servlet(request):
cluster_servlet = get_ray_cluster_servlet()
yield cluster_servlet


@pytest.fixture(scope="function")
def obj_store(request):

# Use the parameter to set the name of the servlet actor to use
env_servlet_name = request.param
_, test_obj_store = get_ray_servlet_and_obj_store(env_servlet_name)
_, test_obj_store = get_ray_env_servlet_and_obj_store(env_servlet_name)

# Clears everything, not just what's in this env servlet
test_obj_store.clear()
Expand Down
8 changes: 4 additions & 4 deletions tests/test_servers/test_server_obj_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from runhouse.servers.obj_store import ObjStoreError

from tests.utils import friend_account, get_ray_servlet_and_obj_store
from tests.utils import friend_account, get_ray_env_servlet_and_obj_store


def list_compare(list1, list2):
Expand Down Expand Up @@ -126,7 +126,7 @@ def test_clear(self, obj_store):
def test_many_env_servlets(self, obj_store):
assert obj_store.keys() == []

_, obj_store_2 = get_ray_servlet_and_obj_store("other")
_, obj_store_2 = get_ray_env_servlet_and_obj_store("other")
assert obj_store_2.keys() == []

obj_store.put("k1", "v1")
Expand Down Expand Up @@ -298,7 +298,7 @@ def test_many_env_servlets(self, obj_store):
assert obj_store.keys_for_env_servlet_name(obj_store_2.servlet_name) == []

# Testing of maintaining envs
_, obj_store_3 = get_ray_servlet_and_obj_store("third")
_, obj_store_3 = get_ray_env_servlet_and_obj_store("third")
assert obj_store_3.keys() == ["k1"]
obj_store_3.put("k2", "v2")
obj_store_3.put("k3", "v3")
Expand All @@ -312,7 +312,7 @@ def test_many_env_servlets(self, obj_store):

@pytest.mark.level("unit")
def test_delete_env_servlet(self, obj_store):
_, obj_store_2 = get_ray_servlet_and_obj_store("obj_store_2")
_, obj_store_2 = get_ray_env_servlet_and_obj_store("obj_store_2")

assert obj_store.keys() == []
assert obj_store_2.keys() == []
Expand Down
Loading

0 comments on commit 74be517

Please sign in to comment.