diff --git a/runhouse/servers/cluster_servlet.py b/runhouse/servers/cluster_servlet.py index eca709263..7dcf39437 100644 --- a/runhouse/servers/cluster_servlet.py +++ b/runhouse/servers/cluster_servlet.py @@ -2,6 +2,7 @@ import copy import datetime import json +import os.path import threading from typing import Any, Dict, List, Optional, Set, Tuple, Union @@ -67,10 +68,10 @@ async def __init__( ) logger.info("Creating periodic_cluster_checks thread.") - cluster_checks_thread = threading.Thread( + self.cluster_checks_thread = threading.Thread( target=self.periodic_cluster_checks, daemon=True ) - cluster_checks_thread.start() + self.cluster_checks_thread.start() ############################################## # Cluster config state storage methods @@ -84,6 +85,16 @@ async def aset_cluster_config(self, cluster_config: Dict[str, Any]): self.cluster_config = cluster_config + new_cluster_name = self.cluster_config.get("name", None) + + if self._cluster_name != new_cluster_name: + self._cluster_name = new_cluster_name + self._cluster_uri = ( + rns_client.format_rns_address(self._cluster_name) + if self._cluster_name + else None + ) + # Propagate the changes to all other process's obj_stores await asyncio.gather( *[ @@ -232,7 +243,9 @@ async def asave_status_metrics_to_den(self, status: dict): status_data = { "status": ResourceServerStatus.running, - "resource_type": status_copy.get("cluster_config").pop("resource_type"), + "resource_type": status_copy.get("cluster_config").pop( + "resource_type", "cluster" + ), "resource_info": status_copy, "env_servlet_processes": env_servlet_processes, } @@ -248,6 +261,56 @@ async def asave_status_metrics_to_den(self, status: dict): def save_status_metrics_to_den(self, status: dict): return sync_function(self.asave_status_metrics_to_den)(status) + async def acheck_cluster_status(self, send_to_den: bool = True): + + logger.debug("Performing cluster status checks") + status, den_resp_status_code = await self.astatus(send_to_den=send_to_den) + + if not send_to_den: + return status, den_resp_status_code + + if den_resp_status_code == 404: + logger.info( + "Cluster has not yet been saved to Den, cannot update status or logs" + ) + elif den_resp_status_code != 200: + logger.error("Failed to send cluster status to Den") + else: + logger.debug("Successfully sent cluster status to Den") + + return status, den_resp_status_code + + async def acheck_cluster_logs(self, interval_size: int): + + logger.debug("Performing logs checks") + + cluster_config = await self.aget_cluster_config() + prev_end_log_line = cluster_config.get("end_log_line", 0) + ( + logs_den_resp, + new_start_log_line, + new_end_log_line, + ) = await self.send_cluster_logs_to_den( + prev_end_log_line=prev_end_log_line, + ) + if not logs_den_resp: + logger.debug( + f"No logs were generated in the past {interval_size} minute(s), logs were not sent to Den" + ) + + elif logs_den_resp.status_code == 200: + logger.debug("Successfully sent cluster logs to Den") + await self.aset_cluster_config_value( + key="start_log_line", value=new_start_log_line + ) + await self.aset_cluster_config_value( + key="end_log_line", value=new_end_log_line + ) + else: + logger.error("Failed to send logs to Den") + + return logs_den_resp, new_start_log_line, new_end_log_line + async def aperiodic_cluster_checks(self): """Periodically check the status of the cluster, gather metrics about the cluster's utilization & memory, and save it to Den.""" @@ -257,6 +320,16 @@ async def aperiodic_cluster_checks(self): "status_check_interval", DEFAULT_STATUS_CHECK_INTERVAL ) while True: + should_send_status_and_logs_to_den: bool = ( + configs.token is not None + and interval_size != -1 + and self._cluster_uri is not None + ) + should_update_autostop: bool = self.autostop_helper is not None + + if not should_send_status_and_logs_to_den and not should_update_autostop: + break + try: # Only if one of these is true, do we actually need to get the status from each EnvServlet should_send_status_and_logs_to_den: bool = ( @@ -273,7 +346,7 @@ async def aperiodic_cluster_checks(self): break logger.debug("Performing cluster checks") - status, status_code = await self.astatus( + status, den_resp_code = await self.acheck_cluster_status( send_to_den=should_send_status_and_logs_to_den ) @@ -284,39 +357,8 @@ async def aperiodic_cluster_checks(self): if not should_send_status_and_logs_to_den: break - if status_code == 404: - logger.info( - "Cluster has not yet been saved to Den, cannot update status or logs." - ) - elif status_code != 200: - logger.error("Failed to send cluster status to Den") - else: - logger.debug("Successfully sent cluster status to Den") - - prev_end_log_line = cluster_config.get("end_log_line", 0) - ( - logs_resp_status_code, - new_start_log_line, - new_end_log_line, - ) = await self.send_cluster_logs_to_den( - prev_end_log_line=prev_end_log_line, - ) - if not logs_resp_status_code: - logger.debug( - f"No logs were generated in the past {interval_size} minute(s), logs were not sent to Den." - ) - - elif logs_resp_status_code == 200: - logger.debug("Successfully sent cluster logs to Den.") - await self.aset_cluster_config_value( - key="start_log_line", value=new_start_log_line - ) - await self.aset_cluster_config_value( - key="end_log_line", value=new_end_log_line - ) - # since we are setting a new values to the cluster_config, we need to reload it so the next - # cluster check iteration will reference to the updated cluster config. - cluster_config = await self.aget_cluster_config() + if den_resp_code == 200: + await self.acheck_cluster_logs(interval_size=interval_size) except Exception: logger.error( @@ -515,6 +557,10 @@ def status(self, send_to_den: bool = False): # Save cluster logs to Den ############################################## def _get_logs(self): + + if not os.path.exists(SERVER_LOGFILE): + return "" + with open(SERVER_LOGFILE) as log_file: log_lines = log_file.readlines() cleaned_log_lines = [ColoredFormatter.format_log(line) for line in log_lines] @@ -526,7 +572,7 @@ def _generate_logs_file_name(self): async def send_cluster_logs_to_den( self, prev_end_log_line: int - ) -> Tuple[Optional[int], Optional[int], Optional[int]]: + ) -> Tuple[Optional[requests.Response], Optional[int], Optional[int]]: """Load the most recent logs from the server's log file and send them to Den.""" # setting to a list, so it will be easier to get the end line num + the logs delta to send to den. latest_logs = self._get_logs().split("\n") @@ -562,4 +608,7 @@ async def send_cluster_logs_to_den( f"{resp_status_code}: Failed to send cluster logs to Den: {post_logs_resp.json()}" ) - return resp_status_code, prev_end_log_line, new_end_log_line + return post_logs_resp, prev_end_log_line, new_end_log_line + + def _cluster_periodic_thread_alive(self): + return self.cluster_checks_thread.is_alive() diff --git a/runhouse/servers/env_servlet.py b/runhouse/servers/env_servlet.py index ff418343f..e5199afff 100644 --- a/runhouse/servers/env_servlet.py +++ b/runhouse/servers/env_servlet.py @@ -197,7 +197,7 @@ def _get_env_cpu_usage(self, cluster_config: dict = None): if not cluster_config.get("resource_subtype") == "Cluster": stable_internal_external_ips = cluster_config.get( - "stable_internal_external_ips" + "stable_internal_external_ips", [] ) for ips_set in stable_internal_external_ips: internal_ip, external_ip = ips_set[0], ips_set[1] @@ -209,7 +209,7 @@ def _get_env_cpu_usage(self, cluster_config: dict = None): node_name = f"worker_{stable_internal_external_ips.index(ips_set)} ({external_ip})" else: # a case it is a BYO cluster, assume that first ip in the ips list is the head. - ips = cluster_config.get("ips") + ips = cluster_config.get("ips", []) if len(ips) == 1 or node_ip == ips[0]: node_name = f"head ({node_ip})" else: diff --git a/tests/test_resources/test_clusters/test_on_demand_cluster.py b/tests/test_resources/test_clusters/test_on_demand_cluster.py index 21599f7f0..87b9dfa7b 100644 --- a/tests/test_resources/test_clusters/test_on_demand_cluster.py +++ b/tests/test_resources/test_clusters/test_on_demand_cluster.py @@ -2,9 +2,13 @@ import time import pytest +import requests import runhouse as rh +from runhouse.globals import rns_client +from runhouse.resources.hardware.utils import ResourceServerStatus + import tests.test_resources.test_clusters.test_cluster from tests.utils import friend_account @@ -191,3 +195,35 @@ def test_docker_container_reqs(self, ondemand_aws_cluster): def test_fn_to_docker_container(self, ondemand_aws_cluster): remote_torch_exists = rh.function(torch_exists).to(ondemand_aws_cluster) assert remote_torch_exists() + + #################################################################################################### + # Status tests + #################################################################################################### + + @pytest.mark.level("minimal") + def test_set_status_after_teardown(self, cluster, mocker): + mock_function = mocker.patch("sky.down") + response = cluster.teardown() + assert isinstance(response, int) + assert ( + response == 200 + ) # that means that the call to post status endpoint in den was successful + mock_function.assert_called_once() + + cluster_config = cluster.config() + cluster_uri = rns_client.format_rns_address(cluster.rns_address) + api_server_url = cluster_config.get("api_server_url", rns_client.api_server_url) + cluster.teardown() + get_status_data_resp = requests.get( + f"{api_server_url}/resource/{cluster_uri}/cluster/status", + headers=rns_client.request_headers(), + ) + + assert get_status_data_resp.status_code == 200 + # For UI displaying purposes, the cluster/status endpoint returns cluster status history. + # The latest status info is the first element in the list returned by the endpoint. + get_status_data = get_status_data_resp.json()["data"][0] + assert get_status_data["resource_type"] == cluster_config.get("resource_type") + assert get_status_data["status"] == ResourceServerStatus.terminated + + assert cluster.is_up() diff --git a/tests/test_servers/conftest.py b/tests/test_servers/conftest.py index 6e8dc3a1b..501bec219 100644 --- a/tests/test_servers/conftest.py +++ b/tests/test_servers/conftest.py @@ -14,7 +14,11 @@ from runhouse.servers.http.certs import TLSCertConfig from runhouse.servers.http.http_server import app, HTTPServer -from tests.utils import friend_account, get_ray_servlet_and_obj_store +from tests.utils import ( + friend_account, + get_ray_cluster_servlet, + get_ray_env_servlet_and_obj_store, +) logger = get_logger() @@ -109,9 +113,15 @@ def local_client_with_den_auth(logged_in_account): @pytest.fixture(scope="session") -def test_servlet(): - servlet, _ = get_ray_servlet_and_obj_store("test_servlet") - yield servlet +def test_env_servlet(): + env_servlet, _ = get_ray_env_servlet_and_obj_store("test_env_servlet") + yield env_servlet + + +@pytest.fixture(scope="session") +def test_cluster_servlet(request): + cluster_servlet = get_ray_cluster_servlet() + yield cluster_servlet @pytest.fixture(scope="function") @@ -119,7 +129,7 @@ def obj_store(request): # Use the parameter to set the name of the servlet actor to use env_servlet_name = request.param - _, test_obj_store = get_ray_servlet_and_obj_store(env_servlet_name) + _, test_obj_store = get_ray_env_servlet_and_obj_store(env_servlet_name) # Clears everything, not just what's in this env servlet test_obj_store.clear() diff --git a/tests/test_servers/test_server_obj_store.py b/tests/test_servers/test_server_obj_store.py index f80e675dd..de2fdb81f 100644 --- a/tests/test_servers/test_server_obj_store.py +++ b/tests/test_servers/test_server_obj_store.py @@ -2,7 +2,7 @@ from runhouse.servers.obj_store import ObjStoreError -from tests.utils import friend_account, get_ray_servlet_and_obj_store +from tests.utils import friend_account, get_ray_env_servlet_and_obj_store def list_compare(list1, list2): @@ -126,7 +126,7 @@ def test_clear(self, obj_store): def test_many_env_servlets(self, obj_store): assert obj_store.keys() == [] - _, obj_store_2 = get_ray_servlet_and_obj_store("other") + _, obj_store_2 = get_ray_env_servlet_and_obj_store("other") assert obj_store_2.keys() == [] obj_store.put("k1", "v1") @@ -298,7 +298,7 @@ def test_many_env_servlets(self, obj_store): assert obj_store.keys_for_env_servlet_name(obj_store_2.servlet_name) == [] # Testing of maintaining envs - _, obj_store_3 = get_ray_servlet_and_obj_store("third") + _, obj_store_3 = get_ray_env_servlet_and_obj_store("third") assert obj_store_3.keys() == ["k1"] obj_store_3.put("k2", "v2") obj_store_3.put("k3", "v3") @@ -312,7 +312,7 @@ def test_many_env_servlets(self, obj_store): @pytest.mark.level("unit") def test_delete_env_servlet(self, obj_store): - _, obj_store_2 = get_ray_servlet_and_obj_store("obj_store_2") + _, obj_store_2 = get_ray_env_servlet_and_obj_store("obj_store_2") assert obj_store.keys() == [] assert obj_store_2.keys() == [] diff --git a/tests/test_servers/test_servlet.py b/tests/test_servers/test_servlet.py index bd2dd651b..b7c2f678d 100644 --- a/tests/test_servers/test_servlet.py +++ b/tests/test_servers/test_servlet.py @@ -11,7 +11,7 @@ @pytest.mark.servertest class TestServlet: @pytest.mark.level("unit") - def test_put_resource(self, test_servlet, blob_data): + def test_put_resource(self, test_env_servlet, blob_data): with tempfile.TemporaryDirectory() as temp_dir: resource_path = Path(temp_dir, "local-blob") local_blob = rh.blob(blob_data, path=resource_path) @@ -19,7 +19,7 @@ def test_put_resource(self, test_servlet, blob_data): state = {} resp = ObjStore.call_actor_method( - test_servlet, + test_env_servlet, "aput_resource_local", data=serialize_data( (resource.config(condensed=False), state, resource.dryrun), "pickle" @@ -31,12 +31,12 @@ def test_put_resource(self, test_servlet, blob_data): assert deserialize_data(resp.data, resp.serialization).startswith("file_") @pytest.mark.level("unit") - def test_put_obj_local(self, test_servlet, blob_data): + def test_put_obj_local(self, test_env_servlet, blob_data): with tempfile.TemporaryDirectory() as temp_dir: resource_path = Path(temp_dir, "local-blob") resource = rh.blob(blob_data, path=resource_path) resp = ObjStore.call_actor_method( - test_servlet, + test_env_servlet, "aput_local", key="key1", data=serialize_data(resource, "pickle"), @@ -45,9 +45,9 @@ def test_put_obj_local(self, test_servlet, blob_data): assert resp.output_type == "success" @pytest.mark.level("unit") - def test_get_obj(self, test_servlet): + def test_get_obj(self, test_env_servlet): resp = ObjStore.call_actor_method( - test_servlet, + test_env_servlet, "aget_local", key="key1", default=KeyError, @@ -59,9 +59,9 @@ def test_get_obj(self, test_servlet): assert isinstance(blob, rh.Blob) @pytest.mark.level("unit") - def test_get_obj_remote(self, test_servlet): + def test_get_obj_remote(self, test_env_servlet): resp = ObjStore.call_actor_method( - test_servlet, + test_env_servlet, "aget_local", key="key1", default=KeyError, @@ -73,9 +73,9 @@ def test_get_obj_remote(self, test_servlet): assert isinstance(blob_config, dict) @pytest.mark.level("unit") - def test_get_obj_does_not_exist(self, test_servlet): + def test_get_obj_does_not_exist(self, test_env_servlet): resp = ObjStore.call_actor_method( - test_servlet, + test_env_servlet, "aget_local", key="abcdefg", default=KeyError, @@ -85,3 +85,61 @@ def test_get_obj_does_not_exist(self, test_servlet): assert resp.output_type == "exception" error = deserialize_data(resp.data["error"], "pickle") assert isinstance(error, KeyError) + + @pytest.mark.level("local") + def test_cluster_checks_thread_is_running(self, test_cluster_servlet): + + resp_status_checks_thread = ObjStore.call_actor_method( + test_cluster_servlet, + "_cluster_periodic_thread_alive", + ) + + assert resp_status_checks_thread + + @pytest.mark.level("local") + def test_cluster_checks_thread_logic_status(self, cluster, test_cluster_servlet): + from runhouse.servers.obj_store import ObjStore + + cluster.save() + + status, den_status_resp = ObjStore.call_actor_method( + test_cluster_servlet, + "acheck_cluster_status", + send_to_den=True, + ) + + assert isinstance(den_status_resp, int) + assert den_status_resp == 200 + + status, den_status_resp = ObjStore.call_actor_method( + test_cluster_servlet, + "acheck_cluster_status", + send_to_den=False, + ) + + assert den_status_resp is None + + @pytest.mark.level("local") + def test_cluster_checks_thread_logic_logs(self, cluster, test_cluster_servlet): + from runhouse.constants import SERVER_LOGFILE + + cluster.save() + + server_logfile_path = Path(SERVER_LOGFILE) + # Ensure the parent directory exists + server_logfile_path.parent.mkdir(parents=True, exist_ok=True) + # Create the file if it does not exist + server_logfile_path.touch(exist_ok=True) + + try: + server_logfile_path.write_text("This is a demo log\n") + + (logs_den_resp, _, _) = ObjStore.call_actor_method( + test_cluster_servlet, "acheck_cluster_logs", interval_size=60 + ) + + assert logs_den_resp + assert logs_den_resp.status_code == 200 + + finally: + server_logfile_path.unlink() diff --git a/tests/utils.py b/tests/utils.py index b12d8011d..3414f8a8c 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -11,21 +11,34 @@ from runhouse.constants import TESTING_LOG_LEVEL from runhouse.globals import rns_client -from runhouse.servers.obj_store import ObjStore, RaySetupOption +from runhouse.servers.obj_store import get_cluster_servlet, ObjStore, RaySetupOption -def get_ray_servlet_and_obj_store(env_name): - """Helper method for getting auth servlet and base env servlet""" +def get_ray_env_servlet_and_obj_store(env_name): + """Helper method for getting object store""" test_obj_store = ObjStore() test_obj_store.initialize(env_name, setup_ray=RaySetupOption.GET_OR_FAIL) - servlet = test_obj_store.get_env_servlet( + test_env_servlet = test_obj_store.get_env_servlet( env_name=env_name, create=True, ) - return servlet, test_obj_store + return test_env_servlet, test_obj_store + + +def get_ray_cluster_servlet(cluster_config=None): + """Helper method for getting base cluster servlet""" + + cluster_servlet = get_cluster_servlet(create_if_not_exists=True) + + if cluster_config: + ObjStore.call_actor_method( + cluster_servlet, "aset_cluster_config", cluster_config + ) + + return cluster_servlet def get_pid_and_ray_node(a=0):