Skip to content

Commit

Permalink
modify cluster status checks tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Alexandra Belousov authored and Alexandra Belousov committed Sep 4, 2024
1 parent 5507636 commit cc6cf61
Show file tree
Hide file tree
Showing 7 changed files with 159 additions and 36 deletions.
85 changes: 75 additions & 10 deletions runhouse/servers/cluster_servlet.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,11 +93,11 @@ async def __init__(
)
collect_gpu_thread.start()

logger.info("Creating periodic_cluster_checks thread.")
cluster_checks_thread = threading.Thread(
logger.debug("Creating periodic_cluster_checks thread.")
self.cluster_checks_thread = threading.Thread(
target=self.periodic_cluster_checks, daemon=True
)
cluster_checks_thread.start()
self.cluster_checks_thread.start()

##############################################
# Cluster config state storage methods
Expand Down Expand Up @@ -259,7 +259,9 @@ async def asave_status_metrics_to_den(self, status: dict):

status_data = {
"status": ResourceServerStatus.running,
"resource_type": status_copy.get("cluster_config").pop("resource_type"),
"resource_type": status_copy.get("cluster_config").pop(
"resource_type", "cluster"
),
"resource_info": status_copy,
"env_servlet_processes": env_servlet_processes,
}
Expand All @@ -275,6 +277,56 @@ async def asave_status_metrics_to_den(self, status: dict):
def save_status_metrics_to_den(self, status: dict):
return sync_function(self.asave_status_metrics_to_den)(status)

async def acheck_cluster_status(self, send_to_den: bool = True):

logger.debug("Performing cluster status checks")
status, den_resp_status_code = await self.astatus(send_to_den=send_to_den)

if not send_to_den:
return status, den_resp_status_code

if den_resp_status_code == 404:
logger.info(
"Cluster has not yet been saved to Den, cannot update status or logs"
)
elif den_resp_status_code != 200:
logger.error("Failed to send cluster status to Den")
else:
logger.debug("Successfully sent cluster status to Den")

return status, den_resp_status_code

async def acheck_cluster_logs(self, interval_size: int):

logger.debug("Performing logs checks")

cluster_config = await self.aget_cluster_config()
prev_end_log_line = cluster_config.get("end_log_line", 0)
(
logs_den_resp,
new_start_log_line,
new_end_log_line,
) = await self.send_cluster_logs_to_den(
prev_end_log_line=prev_end_log_line,
)
if not logs_den_resp:
logger.debug(
f"No logs were generated in the past {interval_size} minute(s), logs were not sent to Den"
)

elif logs_den_resp.status_code == 200:
logger.debug("Successfully sent cluster logs to Den")
await self.aset_cluster_config_value(
key="start_log_line", value=new_start_log_line
)
await self.aset_cluster_config_value(
key="end_log_line", value=new_end_log_line
)
else:
logger.error("Failed to send logs to Den")

return logs_den_resp, new_start_log_line, new_end_log_line

async def aperiodic_cluster_checks(self):
"""Periodically check the status of the cluster, gather metrics about the cluster's utilization & memory,
and save it to Den."""
Expand All @@ -284,6 +336,16 @@ async def aperiodic_cluster_checks(self):
"status_check_interval", DEFAULT_STATUS_CHECK_INTERVAL
)
while True:
should_send_status_and_logs_to_den: bool = (
configs.token is not None
and interval_size != -1
and self._cluster_uri is not None
)
should_update_autostop: bool = self.autostop_helper is not None

if not should_send_status_and_logs_to_den and not should_update_autostop:
break

try:
# Only if one of these is true, do we actually need to get the status from each EnvServlet
should_send_status_and_logs_to_den: bool = (
Expand All @@ -300,7 +362,7 @@ async def aperiodic_cluster_checks(self):
break

logger.debug("Performing cluster checks")
status, status_code = await self.astatus(
status, den_resp_code = await self.acheck_cluster_status(
send_to_den=should_send_status_and_logs_to_den
)

Expand All @@ -311,13 +373,13 @@ async def aperiodic_cluster_checks(self):
if not should_send_status_and_logs_to_den:
break

if status_code == 404:
if den_resp_code == 404:
logger.info(
"Cluster has not yet been saved to Den, cannot update status or logs."
)
elif status_code != 200:
elif den_resp_code != 200:
logger.error(
f"Failed to send cluster status to Den, status_code: {status_code}"
f"Failed to send cluster status to Den, status_code: {den_resp_code}"
)
else:
logger.debug("Successfully sent cluster status to Den")
Expand Down Expand Up @@ -346,6 +408,8 @@ async def aperiodic_cluster_checks(self):
# since we are setting a new values to the cluster_config, we need to reload it so the next
# cluster check iteration will reference to the updated cluster config.
cluster_config = await self.aget_cluster_config()
if den_resp_code == 200:
await self.acheck_cluster_logs(interval_size=interval_size)

except Exception:
logger.error(
Expand Down Expand Up @@ -585,6 +649,7 @@ def status(self, send_to_den: bool = False):
# Save cluster logs to Den
##############################################
def _get_logs(self):

with open(SERVER_LOGFILE) as log_file:
log_lines = log_file.readlines()
cleaned_log_lines = [ColoredFormatter.format_log(line) for line in log_lines]
Expand All @@ -596,7 +661,7 @@ def _generate_logs_file_name(self):

async def send_cluster_logs_to_den(
self, prev_end_log_line: int
) -> Tuple[Optional[int], Optional[int], Optional[int]]:
) -> Tuple[Optional[requests.Response], Optional[int], Optional[int]]:
"""Load the most recent logs from the server's log file and send them to Den."""
# setting to a list, so it will be easier to get the end line num + the logs delta to send to den.
latest_logs = self._get_logs().split("\n")
Expand Down Expand Up @@ -632,4 +697,4 @@ async def send_cluster_logs_to_den(
f"{resp_status_code}: Failed to send cluster logs to Den: {post_logs_resp.json()}"
)

return resp_status_code, prev_end_log_line, new_end_log_line
return post_logs_resp, prev_end_log_line, new_end_log_line
4 changes: 2 additions & 2 deletions runhouse/servers/env_servlet.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ def _get_env_cpu_usage(self, cluster_config: dict = None):

if not cluster_config.get("resource_subtype") == "Cluster":
stable_internal_external_ips = cluster_config.get(
"stable_internal_external_ips"
"stable_internal_external_ips", []
)
for ips_set in stable_internal_external_ips:
internal_ip, external_ip = ips_set[0], ips_set[1]
Expand All @@ -241,7 +241,7 @@ def _get_env_cpu_usage(self, cluster_config: dict = None):
node_name = f"worker_{stable_internal_external_ips.index(ips_set)} ({external_ip})"
else:
# a case it is a BYO cluster, assume that first ip in the ips list is the head.
ips = cluster_config.get("ips")
ips = cluster_config.get("ips", [])
if len(ips) == 1 or node_ip == ips[0]:
node_name = f"head ({node_ip})"
else:
Expand Down
36 changes: 36 additions & 0 deletions tests/test_resources/test_clusters/test_on_demand_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,13 @@
import time

import pytest
import requests

import runhouse as rh

from runhouse.globals import rns_client
from runhouse.resources.hardware.utils import ResourceServerStatus

import tests.test_resources.test_clusters.test_cluster
from tests.utils import friend_account

Expand Down Expand Up @@ -197,3 +201,35 @@ def test_docker_container_reqs(self, ondemand_aws_cluster):
def test_fn_to_docker_container(self, ondemand_aws_cluster):
remote_torch_exists = rh.function(torch_exists).to(ondemand_aws_cluster)
assert remote_torch_exists()

####################################################################################################
# Status tests
####################################################################################################

@pytest.mark.level("minimal")
def test_set_status_after_teardown(self, cluster, mocker):
mock_function = mocker.patch("sky.down")
response = cluster.teardown()
assert isinstance(response, int)
assert (
response == 200
) # that means that the call to post status endpoint in den was successful
mock_function.assert_called_once()

cluster_config = cluster.config()
cluster_uri = rns_client.format_rns_address(cluster.rns_address)
api_server_url = cluster_config.get("api_server_url", rns_client.api_server_url)
cluster.teardown()
get_status_data_resp = requests.get(
f"{api_server_url}/resource/{cluster_uri}/cluster/status",
headers=rns_client.request_headers(),
)

assert get_status_data_resp.status_code == 200
# For UI displaying purposes, the cluster/status endpoint returns cluster status history.
# The latest status info is the first element in the list returned by the endpoint.
get_status_data = get_status_data_resp.json()["data"][0]
assert get_status_data["resource_type"] == cluster_config.get("resource_type")
assert get_status_data["status"] == ResourceServerStatus.terminated

assert cluster.is_up()
20 changes: 15 additions & 5 deletions tests/test_servers/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,11 @@
from runhouse.servers.http.certs import TLSCertConfig
from runhouse.servers.http.http_server import app, HTTPServer

from tests.utils import friend_account, get_ray_servlet_and_obj_store
from tests.utils import (
friend_account,
get_ray_cluster_servlet,
get_ray_env_servlet_and_obj_store,
)

logger = get_logger(__name__)

Expand Down Expand Up @@ -109,17 +113,23 @@ def local_client_with_den_auth(logged_in_account):


@pytest.fixture(scope="session")
def test_servlet():
servlet, _ = get_ray_servlet_and_obj_store("test_servlet")
yield servlet
def test_env_servlet():
env_servlet, _ = get_ray_env_servlet_and_obj_store("test_env_servlet")
yield env_servlet


@pytest.fixture(scope="session")
def test_cluster_servlet(request):
cluster_servlet = get_ray_cluster_servlet()
yield cluster_servlet


@pytest.fixture(scope="function")
def obj_store(request):

# Use the parameter to set the name of the servlet actor to use
env_servlet_name = request.param
_, test_obj_store = get_ray_servlet_and_obj_store(env_servlet_name)
_, test_obj_store = get_ray_env_servlet_and_obj_store(env_servlet_name)

# Clears everything, not just what's in this env servlet
test_obj_store.clear()
Expand Down
8 changes: 4 additions & 4 deletions tests/test_servers/test_server_obj_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from runhouse.servers.obj_store import ObjStoreError

from tests.utils import friend_account, get_ray_servlet_and_obj_store
from tests.utils import friend_account, get_ray_env_servlet_and_obj_store


def list_compare(list1, list2):
Expand Down Expand Up @@ -126,7 +126,7 @@ def test_clear(self, obj_store):
def test_many_env_servlets(self, obj_store):
assert obj_store.keys() == []

_, obj_store_2 = get_ray_servlet_and_obj_store("other")
_, obj_store_2 = get_ray_env_servlet_and_obj_store("other")
assert obj_store_2.keys() == []

obj_store.put("k1", "v1")
Expand Down Expand Up @@ -298,7 +298,7 @@ def test_many_env_servlets(self, obj_store):
assert obj_store.keys_for_env_servlet_name(obj_store_2.servlet_name) == []

# Testing of maintaining envs
_, obj_store_3 = get_ray_servlet_and_obj_store("third")
_, obj_store_3 = get_ray_env_servlet_and_obj_store("third")
assert obj_store_3.keys() == ["k1"]
obj_store_3.put("k2", "v2")
obj_store_3.put("k3", "v3")
Expand All @@ -312,7 +312,7 @@ def test_many_env_servlets(self, obj_store):

@pytest.mark.level("unit")
def test_delete_env_servlet(self, obj_store):
_, obj_store_2 = get_ray_servlet_and_obj_store("obj_store_2")
_, obj_store_2 = get_ray_env_servlet_and_obj_store("obj_store_2")

assert obj_store.keys() == []
assert obj_store_2.keys() == []
Expand Down
20 changes: 10 additions & 10 deletions tests/test_servers/test_servlet.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,11 @@
@pytest.mark.servertest
class TestServlet:
@pytest.mark.level("unit")
def test_put_resource(self, test_servlet):
def test_put_resource(self, test_env_servlet):
resource = Resource(name="local-resource")
state = {}
resp = ObjStore.call_actor_method(
test_servlet,
test_env_servlet,
"aput_resource_local",
data=serialize_data(
(resource.config(condensed=False), state, resource.dryrun), "pickle"
Expand All @@ -23,10 +23,10 @@ def test_put_resource(self, test_servlet):
assert deserialize_data(resp.data, resp.serialization) == resource.name

@pytest.mark.level("unit")
def test_put_obj_local(self, test_servlet):
def test_put_obj_local(self, test_env_servlet):
resource = Resource(name="local-resource")
resp = ObjStore.call_actor_method(
test_servlet,
test_env_servlet,
"aput_local",
key="key1",
data=serialize_data(resource, "pickle"),
Expand All @@ -35,9 +35,9 @@ def test_put_obj_local(self, test_servlet):
assert resp.output_type == "success"

@pytest.mark.level("unit")
def test_get_obj(self, test_servlet):
def test_get_obj(self, test_env_servlet):
resp = ObjStore.call_actor_method(
test_servlet,
test_env_servlet,
"aget_local",
key="key1",
default=KeyError,
Expand All @@ -49,9 +49,9 @@ def test_get_obj(self, test_servlet):
assert isinstance(resource, Resource)

@pytest.mark.level("unit")
def test_get_obj_remote(self, test_servlet):
def test_get_obj_remote(self, test_env_servlet):
resp = ObjStore.call_actor_method(
test_servlet,
test_env_servlet,
"aget_local",
key="key1",
default=KeyError,
Expand All @@ -63,9 +63,9 @@ def test_get_obj_remote(self, test_servlet):
assert isinstance(resource_config, dict)

@pytest.mark.level("unit")
def test_get_obj_does_not_exist(self, test_servlet):
def test_get_obj_does_not_exist(self, test_env_servlet):
resp = ObjStore.call_actor_method(
test_servlet,
test_env_servlet,
"aget_local",
key="abcdefg",
default=KeyError,
Expand Down
Loading

0 comments on commit cc6cf61

Please sign in to comment.