diff --git a/runhouse/resources/hardware/cluster.py b/runhouse/resources/hardware/cluster.py index 29f3f5134..4e15a9dde 100644 --- a/runhouse/resources/hardware/cluster.py +++ b/runhouse/resources/hardware/cluster.py @@ -200,7 +200,7 @@ def save(self, name: str = None, overwrite: bool = True, folder: str = None): if on_this_cluster: obj_store.set_cluster_config_value("name", self.rns_address) elif self._http_client: - self.client.set_cluster_name(self.rns_address) + self.call_client_method("set_cluster_name", self.rns_address) return self @@ -483,7 +483,8 @@ def get(self, key: str, default: Any = None, remote=False): if self.on_this_cluster(): return obj_store.get(key, default=default, remote=remote) try: - res = self.client.get( + res = self.call_client_method( + "get", key, default=default, remote=remote, @@ -513,7 +514,9 @@ def put(self, key: str, obj: Any, env=None): self.check_server() if self.on_this_cluster(): return obj_store.put(key, obj, env=env) - return self.client.put_object(key, obj, env=env or self.default_env.name) + return self.call_client_method( + "put_object", key, obj, env=env or self.default_env.name + ) def put_resource( self, resource: Resource, state: Dict = None, dryrun: bool = False, env=None @@ -549,8 +552,12 @@ def put_resource( if self.on_this_cluster(): data = (resource.config(condensed=False), state, dryrun) return obj_store.put_resource(serialized_data=data, env_name=env_name) - return self.client.put_resource( - resource, state=state or {}, env_name=env_name, dryrun=dryrun + return self.call_client_method( + "put_resource", + resource, + state=state or {}, + env_name=env_name, + dryrun=dryrun, ) def rename(self, old_key: str, new_key: str): @@ -558,14 +565,14 @@ def rename(self, old_key: str, new_key: str): self.check_server() if self.on_this_cluster(): return obj_store.rename(old_key, new_key) - return self.client.rename_object(old_key, new_key) + return self.call_client_method("rename_object", old_key, new_key) def keys(self, env=None): """List all keys in the cluster's object store.""" self.check_server() if self.on_this_cluster(): return obj_store.keys() - res = self.client.keys(env=env) + res = self.call_client_method("keys", env=env) return res def delete(self, keys: Union[None, str, List[str]]): @@ -575,14 +582,14 @@ def delete(self, keys: Union[None, str, List[str]]): keys = [keys] if self.on_this_cluster(): return obj_store.delete(keys) - return self.client.delete(keys) + return self.call_client_method("delete", keys) def clear(self): """Clear the cluster's object store.""" self.check_server() if self.on_this_cluster(): return obj_store.clear() - return self.client.delete() + return self.call_client_method("delete") def on_this_cluster(self): """Whether this function is being called on the same cluster.""" @@ -593,6 +600,10 @@ def on_this_cluster(self): # ----------------- RPC Methods ----------------- # + def call_client_method(self, method_name, *args, **kwargs): + method = getattr(self.client, method_name) + return method(*args, **kwargs) + def connect_tunnel(self, force_reconnect=False): if self._rpc_tunnel and force_reconnect: self._rpc_tunnel.terminate() @@ -705,8 +716,8 @@ def status(self, resource_address: str = None): if self.on_this_cluster(): status = obj_store.status() else: - status = self.client.status( - resource_address=resource_address or self.rns_address + status = self.call_client_method( + "status", resource_address=resource_address or self.rns_address ) return status @@ -1525,7 +1536,7 @@ def remove_conda_env( def download_cert(self): """Download certificate from the cluster (Note: user must have access to the cluster)""" self.check_server() - self.client.get_certificate() + self.call_client_method("get_certificate") logger.info( f"Latest TLS certificate for {self.name} saved to local path: {self.cert_config.cert_path}" ) @@ -1537,7 +1548,9 @@ def enable_den_auth(self, flush=True): raise ValueError("Cannot toggle Den Auth live on the cluster.") else: self.den_auth = True - self.client.set_settings({"den_auth": True, "flush_auth_cache": flush}) + self.call_client_method( + "set_settings", {"den_auth": True, "flush_auth_cache": flush} + ) return self def disable_den_auth(self): @@ -1546,7 +1559,7 @@ def disable_den_auth(self): raise ValueError("Cannot toggle Den Auth live on the cluster.") else: self.den_auth = False - self.client.set_settings({"den_auth": False}) + self.call_client_method("set_settings", {"den_auth": False}) return self def set_connection_defaults(self, **kwargs): @@ -1668,7 +1681,7 @@ def _disable_status_check(self): ) return self.check_server() - self.client.set_settings({"status_check_interval": -1}) + self.call_client_method("set_settings", {"status_check_interval": -1}) def _enable_or_update_status_check( self, new_interval: int = DEFAULT_STATUS_CHECK_INTERVAL @@ -1684,4 +1697,4 @@ def _enable_or_update_status_check( ) return self.check_server() - self.client.set_settings({"status_check_interval": new_interval}) + self.call_client_method("set_settings", {"status_check_interval": new_interval}) diff --git a/runhouse/resources/hardware/on_demand_cluster.py b/runhouse/resources/hardware/on_demand_cluster.py index 8239ea016..0add74575 100644 --- a/runhouse/resources/hardware/on_demand_cluster.py +++ b/runhouse/resources/hardware/on_demand_cluster.py @@ -140,7 +140,7 @@ def autostop_mins(self, mins): raise ImportError( "Skypilot must be installed on the cluster in order to set autostop." ) - self.client.set_settings({"autostop_mins": mins}) + self.call_client_method("set_settings", {"autostop_mins": mins}) sky.autostop(self.name, mins, down=True) @property