diff --git a/runhouse/resources/hardware/cluster.py b/runhouse/resources/hardware/cluster.py index 29fd34226..1059e4ba2 100644 --- a/runhouse/resources/hardware/cluster.py +++ b/runhouse/resources/hardware/cluster.py @@ -119,6 +119,10 @@ def address(self, addr): @property def client(self): + if not self._http_client: + if not self.address: + raise ValueError(f"No address set for cluster <{self.name}>. Is it up?") + self.connect_server_client() return self._http_client @property @@ -195,7 +199,7 @@ def save(self, name: str = None, overwrite: bool = True, folder: str = None): # self.on_this_cluster() will still work as expected. if on_this_cluster: obj_store.set_cluster_config_value("name", self.rns_address) - elif self.client: + elif self._http_client: self.client.set_cluster_name(self.rns_address) return self @@ -322,7 +326,7 @@ def _client(self, restart_server=True): if self.on_this_cluster(): # Previously (before calling within the same cluster worked) returned None return obj_store - if not self.client: + if not self._http_client: self.check_server(restart_server=restart_server) return self.client @@ -589,14 +593,22 @@ def on_this_cluster(self): # ----------------- RPC Methods ----------------- # - def connect_server_client(self, force_reconnect=False): - if not self.address: - raise ValueError(f"No address set for cluster <{self.name}>. Is it up?") - + def connect_tunnel(self, force_reconnect=False): if self._rpc_tunnel and force_reconnect: self._rpc_tunnel.terminate() self._rpc_tunnel = None + if not self._rpc_tunnel: + self._rpc_tunnel = self.ssh_tunnel( + local_port=self.server_port, + remote_port=self.server_port, + num_ports_to_try=10, + ) + + def connect_server_client(self, force_reconnect=False): + if not self.address: + raise ValueError(f"No address set for cluster <{self.name}>. Is it up?") + if self.server_connection_type in [ ServerConnectionType.SSH, ServerConnectionType.AWS_SSM, @@ -609,12 +621,8 @@ def connect_server_client(self, force_reconnect=False): if self.creds_values.get("password") is not None: self._run_commands_with_ssh(["echo 'Initiating password connection.'"]) - # Case 1: Server connection requires SSH tunnel, but we don't have one up yet - self._rpc_tunnel = self.ssh_tunnel( - local_port=self.server_port, - remote_port=self.server_port, - num_ports_to_try=10, - ) + # Case 1: Server connection requires SSH tunnel + self.connect_tunnel(force_reconnect=force_reconnect) self.client_port = self._rpc_tunnel.local_bind_port # Connecting to localhost because it's tunneled into the server at the specified port. @@ -657,52 +665,35 @@ def check_server(self, restart_server=True): if self.on_this_cluster(): return - # For OnDemandCluster, this initial check doesn't trigger a sky.status, which is slow. - # If cluster simply doesn't have an address we likely need to up it. - if not self.address and not self.is_up(): - if not hasattr(self, "up"): - raise ValueError( - "Cluster must have a host address (i.e. be up) or have a reup_cluster method " - "(e.g. OnDemandCluster)." + try: + logger.debug(f"Checking server {self.name}") + self.client.check_server() + logger.info(f"Server {self.name} is up.") + except ( + requests.exceptions.ConnectionError, + requests.exceptions.ReadTimeout, + requests.exceptions.ChunkedEncodingError, + ): + if restart_server: + logger.info( + f"Server {self.name} is up, but the Runhouse API server may not be up." ) - # If this is a OnDemandCluster, before we up the cluster, run a sky.status to see if the cluster - # is already up but doesn't have an address assigned yet. - self.up_if_not() - - if not self.client: - try: - self.connect_server_client() - logger.debug(f"Checking server {self.name}") - self.client.check_server() - logger.info(f"Server {self.name} is up.") - except ( - requests.exceptions.ConnectionError, - requests.exceptions.ReadTimeout, - requests.exceptions.ChunkedEncodingError, - ): - # It's possible that the cluster went down while we were trying to install packages. - if not self.is_up(): - logger.info(f"Server {self.name} is down.") - self.up_if_not() - elif restart_server: - logger.info( - f"Server {self.name} is up, but the Runhouse API server may not be up." - ) - self.restart_server() - for i in range(5): - logger.info(f"Checking server {self.name} again [{i + 1}/5].") - try: - self.client.check_server() - logger.info(f"Server {self.name} is up.") - return - except ( - requests.exceptions.ConnectionError, - requests.exceptions.ReadTimeout, - ) as error: - if i == 5: - logger.error(error) - time.sleep(5) - raise ValueError(f"Could not connect to server {self.name}") + self._http_client = None + self.restart_server() + for i in range(5): + logger.info(f"Checking server {self.name} again [{i + 1}/5].") + try: + self.client.check_server() + logger.info(f"Server {self.name} is up.") + return + except ( + requests.exceptions.ConnectionError, + requests.exceptions.ReadTimeout, + ) as error: + if i == 5: + logger.error(error) + time.sleep(5) + raise ValueError(f"Could not connect to server {self.name}") return @@ -917,7 +908,7 @@ def restart_server( if not rns_address: raise ValueError("Cluster must have a name in order to enable HTTPS.") - if not self.client: + if not self._http_client: logger.debug("Reconnecting server client. Server restarted with HTTPS.") self.connect_server_client() diff --git a/runhouse/resources/hardware/on_demand_cluster.py b/runhouse/resources/hardware/on_demand_cluster.py index 9fe8173ad..8239ea016 100644 --- a/runhouse/resources/hardware/on_demand_cluster.py +++ b/runhouse/resources/hardware/on_demand_cluster.py @@ -110,6 +110,20 @@ def __init__( # Cluster status is set to INIT in the Sky DB right after starting, so we need to refresh once self._update_from_sky_status(dryrun=False) + @property + def client(self): + if not self._http_client: + if not self.address: + # Try loading in from local Sky DB + self._update_from_sky_status(dryrun=True) + if not self.address: + raise ValueError( + f"Could not determine address for ondemand cluster <{self.name}>. " + "Up the cluster with `cluster.up_if_not`." + ) + self.connect_server_client() + return self._http_client + @property def autostop_mins(self): return self._autostop_mins diff --git a/runhouse/resources/hardware/sagemaker/sagemaker_cluster.py b/runhouse/resources/hardware/sagemaker/sagemaker_cluster.py index 6a6acbda5..c3c41a1fe 100644 --- a/runhouse/resources/hardware/sagemaker/sagemaker_cluster.py +++ b/runhouse/resources/hardware/sagemaker/sagemaker_cluster.py @@ -429,7 +429,7 @@ def check_server(self, restart_server=True): logger.info(f"Cluster {self.name} is not up, bringing it up now.") self.up_if_not() - if not self.client: + if not self._http_client: try: self.connect_server_client() logger.info( @@ -1300,7 +1300,7 @@ def _sync_runhouse_to_cluster(self, node: str = None, _install_url=None, env=Non if not self.instance_id: raise ValueError(f"No instance ID set for cluster {self.name}. Is it up?") - if not self.client: + if not self._http_client: self.connect_server_client() # Sync the local ~/.rh directory to the cluster @@ -1421,7 +1421,7 @@ def _base_image_uri(self): def _update_autostop(self, autostop_mins: int = None): cluster_config = self.config() cluster_config["autostop_mins"] = autostop_mins or -1 - if not self.client: + if not self._http_client: self.connect_server_client() # Update the config on the server with the new autostop time self.client.check_server()