Skip to content

Commit

Permalink
Add wrapper for running client methods
Browse files Browse the repository at this point in the history
  • Loading branch information
carolineechen committed Jul 10, 2024
1 parent 8e9b5e4 commit ecbece8
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 17 deletions.
45 changes: 29 additions & 16 deletions runhouse/resources/hardware/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ def save(self, name: str = None, overwrite: bool = True, folder: str = None):
if on_this_cluster:
obj_store.set_cluster_config_value("name", self.rns_address)
elif self.http_client:
self.client.set_cluster_name(self.rns_address)
self.call_client_method("set_cluster_name", self.rns_address)

return self

Expand Down Expand Up @@ -483,7 +483,8 @@ def get(self, key: str, default: Any = None, remote=False):
if self.on_this_cluster():
return obj_store.get(key, default=default, remote=remote)
try:
res = self.client.get(
res = self.call_client_method(
"get",
key,
default=default,
remote=remote,
Expand Down Expand Up @@ -513,7 +514,9 @@ def put(self, key: str, obj: Any, env=None):
self.check_server()
if self.on_this_cluster():
return obj_store.put(key, obj, env=env)
return self.client.put_object(key, obj, env=env or self.default_env.name)
return self.call_client_method(
"put_object", key, obj, env=env or self.default_env.name
)

def put_resource(
self, resource: Resource, state: Dict = None, dryrun: bool = False, env=None
Expand Down Expand Up @@ -549,23 +552,27 @@ def put_resource(
if self.on_this_cluster():
data = (resource.config(condensed=False), state, dryrun)
return obj_store.put_resource(serialized_data=data, env_name=env_name)
return self.client.put_resource(
resource, state=state or {}, env_name=env_name, dryrun=dryrun
return self.call_client_method(
"put_resource",
resource,
state=state or {},
env_name=env_name,
dryrun=dryrun,
)

def rename(self, old_key: str, new_key: str):
"""Rename a key in the cluster's object store."""
self.check_server()
if self.on_this_cluster():
return obj_store.rename(old_key, new_key)
return self.client.rename_object(old_key, new_key)
return self.call_client_method("rename_object", old_key, new_key)

def keys(self, env=None):
"""List all keys in the cluster's object store."""
self.check_server()
if self.on_this_cluster():
return obj_store.keys()
res = self.client.keys(env=env)
res = self.call_client_method("keys", env=env)
return res

def delete(self, keys: Union[None, str, List[str]]):
Expand All @@ -575,14 +582,14 @@ def delete(self, keys: Union[None, str, List[str]]):
keys = [keys]
if self.on_this_cluster():
return obj_store.delete(keys)
return self.client.delete(keys)
return self.call_client_method("delete", keys)

def clear(self):
"""Clear the cluster's object store."""
self.check_server()
if self.on_this_cluster():
return obj_store.clear()
return self.client.delete()
return self.call_client_method("delete")

def on_this_cluster(self):
"""Whether this function is being called on the same cluster."""
Expand All @@ -593,6 +600,10 @@ def on_this_cluster(self):

# ----------------- RPC Methods ----------------- #

def call_client_method(self, method_name, *args, **kwargs):
method = getattr(self.client, method_name)
return method(*args, **kwargs)

def connect_tunnel(self, force_reconnect=False):
if self._rpc_tunnel and force_reconnect:
self._rpc_tunnel.terminate()
Expand Down Expand Up @@ -705,8 +716,8 @@ def status(self, resource_address: str = None):
if self.on_this_cluster():
status = obj_store.status()
else:
status = self.client.status(
resource_address=resource_address or self.rns_address
status = self.call_client_method(
"status", resource_address=resource_address or self.rns_address
)
return status

Expand Down Expand Up @@ -1525,7 +1536,7 @@ def remove_conda_env(
def download_cert(self):
"""Download certificate from the cluster (Note: user must have access to the cluster)"""
self.check_server()
self.client.get_certificate()
self.call_client_method("get_certificate")
logger.info(
f"Latest TLS certificate for {self.name} saved to local path: {self.cert_config.cert_path}"
)
Expand All @@ -1537,7 +1548,9 @@ def enable_den_auth(self, flush=True):
raise ValueError("Cannot toggle Den Auth live on the cluster.")
else:
self.den_auth = True
self.client.set_settings({"den_auth": True, "flush_auth_cache": flush})
self.call_client_method(
"set_settings", {"den_auth": True, "flush_auth_cache": flush}
)
return self

def disable_den_auth(self):
Expand All @@ -1546,7 +1559,7 @@ def disable_den_auth(self):
raise ValueError("Cannot toggle Den Auth live on the cluster.")
else:
self.den_auth = False
self.client.set_settings({"den_auth": False})
self.call_client_method("set_settings", {"den_auth": False})
return self

def set_connection_defaults(self, **kwargs):
Expand Down Expand Up @@ -1668,7 +1681,7 @@ def _disable_status_check(self):
)
return
self.check_server()
self.client.set_settings({"status_check_interval": -1})
self.call_client_method("set_settings", {"status_check_interval": -1})

def _enable_or_update_status_check(
self, new_interval: int = DEFAULT_STATUS_CHECK_INTERVAL
Expand All @@ -1684,4 +1697,4 @@ def _enable_or_update_status_check(
)
return
self.check_server()
self.client.set_settings({"status_check_interval": new_interval})
self.call_client_method("set_settings", {"status_check_interval": new_interval})
2 changes: 1 addition & 1 deletion runhouse/resources/hardware/on_demand_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def autostop_mins(self, mins):
raise ImportError(
"Skypilot must be installed on the cluster in order to set autostop."
)
self.client.set_settings({"autostop_mins": mins})
self.call_client_method("set_settings", {"autostop_mins": mins})
sky.autostop(self.name, mins, down=True)

@property
Expand Down

0 comments on commit ecbece8

Please sign in to comment.