Skip to content

Commit

Permalink
sync rh to all workers in addition to head node
Browse files Browse the repository at this point in the history
  • Loading branch information
jlewitt1 committed Jan 2, 2024
1 parent 2f0a603 commit e2b587f
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 18 deletions.
84 changes: 69 additions & 15 deletions runhouse/resources/hardware/cluster.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
import contextlib
import copy
import json
Expand All @@ -10,6 +11,7 @@
import threading
import time
import warnings
from concurrent.futures import ProcessPoolExecutor
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Union

Expand Down Expand Up @@ -118,7 +120,7 @@ def address(self, addr):
self.ips = self.ips or [None]
self.ips[0] = addr

def save_config_to_cluster(self):
def save_config_to_cluster(self, node: str = None):
config = self.config_for_rns
if "live_state" in config.keys():
# a bunch of setup commands that mess up dumping
Expand All @@ -128,7 +130,8 @@ def save_config_to_cluster(self):
self.run(
[
f"mkdir -p ~/.rh; touch {CLUSTER_CONFIG_PATH}; echo '{json_config}' > {CLUSTER_CONFIG_PATH}"
]
],
node=node,
)

@staticmethod
Expand Down Expand Up @@ -253,10 +256,20 @@ def keep_warm(self):
)
return self

def _sync_runhouse_to_cluster(self, _install_url=None, env=None):
if not self.address:
raise ValueError(f"No address set for cluster <{self.name}>. Is it up?")
def _sync_to_nodes(self, _rh_install_url: str, address: str):
"""Sync to all nodes (head node + workers where relevant)"""
try:
# Assuming _sync_runhouse_to_cluster and save_config_to_cluster are process-safe
self._sync_runhouse_to_cluster(_install_url=_rh_install_url, node=address)

# TODO: Deprecate once config is stored via Ray
# Update the cluster config on the cluster
self.save_config_to_cluster(node=address)

except Exception as e:
raise e

def _sync_runhouse_to_cluster(self, node: str, _install_url=None, env=None):
local_rh_package_path = Path(pkgutil.get_loader("runhouse").path).parent

# Check if runhouse is installed from source and has setup.py
Expand All @@ -272,6 +285,7 @@ def _sync_runhouse_to_cluster(self, _install_url=None, env=None):
self._rsync(
source=str(local_rh_package_path),
dest=dest_path,
node=node,
up=True,
contents=True,
filter_options="- docs/",
Expand All @@ -291,7 +305,7 @@ def _sync_runhouse_to_cluster(self, _install_url=None, env=None):

install_cmd = f"{env._run_cmd} {rh_install_cmd}" if env else rh_install_cmd

status_codes = self.run([install_cmd], stream_logs=True)
status_codes = self.run([install_cmd], node=node, stream_logs=True)

if status_codes[0][0] != 0:
raise ValueError(f"Error installing runhouse on cluster <{self.name}>")
Expand Down Expand Up @@ -703,10 +717,12 @@ def restart_server(
logger.info(f"Restarting Runhouse API server on {self.name}.")

if resync_rh:
self._sync_runhouse_to_cluster(_install_url=_rh_install_url)
if not self.address:
raise ValueError(f"No address set for cluster <{self.name}>. Is it up?")

# Update the cluster config on the cluster
self.save_config_to_cluster()
# Sync Runhouse & configs across all cluster nodes in parallel
self._sync_across_all_nodes(_rh_install_url)
logger.info("Finished syncing Runhouse to cluster nodes.")

use_custom_cert = self._use_custom_cert
if use_custom_cert:
Expand Down Expand Up @@ -842,6 +858,25 @@ def disconnect(self):
if self._rpc_tunnel:
self._rpc_tunnel.stop()

def _sync_across_all_nodes(self, _rh_install_url):
loop = asyncio.new_event_loop()
executor = ProcessPoolExecutor()

try:
tasks = [
loop.run_in_executor(
executor, self._sync_to_nodes, _rh_install_url, address
)
for address in self.ips
]
loop.run_until_complete(asyncio.gather(*tasks))

except Exception as e:
raise e

finally:
executor.shutdown() # Close the executor when done

def __getstate__(self):
"""Delete non-serializable elements (e.g. thread locks) before pickling."""
state = self.__dict__.copy()
Expand All @@ -861,6 +896,7 @@ def _rsync(
source: str,
dest: str,
up: bool,
node: str = None,
contents: bool = False,
filter_options: str = None,
stream_logs: bool = False,
Expand All @@ -872,16 +908,18 @@ def _rsync(
Ending `source` with a slash will copy the contents of the directory into dest,
while omitting it will copy the directory itself (adding a directory layer).
"""
# If no address provided explicitly use the head node address
node = node or self.address
# FYI, could be useful: https://github.com/gchamon/sysrsync
if contents:
source = source + "/" if not source.endswith("/") else source
dest = dest + "/" if not dest.endswith("/") else dest

ssh_credentials = copy.copy(self.ssh_creds) or {}
ssh_credentials.pop("ssh_host", self.address)
ssh_credentials.pop("ssh_host", node)
pwd = ssh_credentials.pop("password", None)

runner = SkySSHRunner(self.address, **ssh_credentials, port=self.ssh_port)
runner = SkySSHRunner(node, **ssh_credentials, port=self.ssh_port)
if not pwd:
if up:
runner.run(["mkdir", "-p", dest], stream_logs=False)
Expand Down Expand Up @@ -970,6 +1008,7 @@ def run(
stream_logs: bool = True,
port_forward: Union[None, int, Tuple[int, int]] = None,
require_outputs: bool = True,
node: Optional[str] = None,
run_name: Optional[str] = None,
) -> list:
"""Run a list of shell commands on the cluster. If `run_name` is provided, the commands will be
Expand All @@ -994,13 +1033,23 @@ def run(
if not run_name:
# If not creating a Run then just run the commands via SSH and return
return self._run_commands_with_ssh(
commands, cmd_prefix, stream_logs, port_forward, require_outputs
commands,
cmd_prefix,
stream_logs,
port_forward,
require_outputs,
node,
)

# Create and save the Run locally
with run(name=run_name, cmds=commands, overwrite=True) as r:
return_codes = self._run_commands_with_ssh(
commands, cmd_prefix, stream_logs, port_forward, require_outputs
commands,
cmd_prefix,
stream_logs,
port_forward,
require_outputs,
node,
)

# Register the completed Run
Expand All @@ -1015,11 +1064,12 @@ def _run_commands_with_ssh(
stream_logs: bool,
port_forward: int = None,
require_outputs: bool = True,
node: str = None,
):
return_codes = []

ssh_credentials = copy.copy(self.ssh_creds)
host = ssh_credentials.pop("ssh_host", self.address)
host = ssh_credentials.pop("ssh_host", node or self.address)
pwd = ssh_credentials.pop("password", None)

runner = SkySSHRunner(host, **ssh_credentials, port=self.ssh_port)
Expand Down Expand Up @@ -1074,10 +1124,11 @@ def run_python(
commands: List[str],
env: Union["Env", str] = None,
stream_logs: bool = True,
node: str = None,
port_forward: Optional[int] = None,
run_name: Optional[str] = None,
):
"""Run a list of python commands on the cluster.
"""Run a list of python commands on the cluster, or a specific cluster node if its IP is provided.
Example:
>>> cpu.run_python(['import numpy', 'print(numpy.__version__)'])
Expand All @@ -1087,6 +1138,8 @@ def run_python(
Running Python commands with nested quotes can be finicky. If using nested quotes,
try to wrap the outer quote with double quotes (") and the inner quotes with a single quote (').
"""
# If no node provided, assume the commands are to be run on the head node
node = node or self.address
cmd_prefix = "python3 -c"
if env:
if isinstance(env, str):
Expand All @@ -1107,6 +1160,7 @@ def run_python(
[formatted_command],
env=env,
stream_logs=stream_logs,
node=node,
port_forward=port_forward,
run_name=run_name,
)
Expand Down
10 changes: 8 additions & 2 deletions runhouse/resources/hardware/on_demand_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,10 +299,16 @@ def status(self, refresh: bool = True):

def _populate_connection_from_status_dict(self, cluster_dict: Dict[str, Any]):
if cluster_dict and cluster_dict["status"].name in ["UP", "INIT"]:
self.address = cluster_dict["handle"].head_ip
yaml_path = cluster_dict["handle"].cluster_yaml
handle = cluster_dict["handle"]
self.address = handle.head_ip
yaml_path = handle.cluster_yaml
if Path(yaml_path).exists():
self._ssh_creds = backend_utils.ssh_credential_from_yaml(yaml_path)

# Add worker IPs if multi-node cluster - keep the head node as the first IP
for ip in handle.cached_external_ips:
if ip not in self.ips:
self.ips.append(ip)
else:
self.address = None
self._ssh_creds = None
Expand Down
3 changes: 2 additions & 1 deletion runhouse/resources/hardware/sagemaker_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -608,6 +608,7 @@ def _run_commands_with_ssh(
stream_logs: bool,
port_forward: int = None,
require_outputs: bool = True,
node: str = None,
):
return_codes = []
for command in commands:
Expand Down Expand Up @@ -1221,7 +1222,7 @@ def _stop_instance(self, delete_configs=True):
rns_client.delete_configs(resource=self)
logger.info(f"Deleted {self.name} from configs")

def _sync_runhouse_to_cluster(self, _install_url=None, env=None):
def _sync_runhouse_to_cluster(self, node: str, _install_url=None, env=None):
if not self.instance_id:
raise ValueError(f"No instance ID set for cluster {self.name}. Is it up?")

Expand Down

0 comments on commit e2b587f

Please sign in to comment.