Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix cluster status scheduler, move config.yaml creation to restart server() #868

Merged
merged 1 commit into from
Jun 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions runhouse/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
RESERVED_SYSTEM_NAMES: List[str] = ["file", "s3", "gs", "azure", "here", "ssh", "sftp"]
CLUSTER_CONFIG_PATH: str = "~/.rh/cluster_config.json"
CONFIG_YAML_PATH: str = "~/.rh/config.yaml"
SERVER_LOGFILE_PATH = "~/.rh/server.log"
LOCALHOST: str = "127.0.0.1"
LOCAL_HOSTS: List[str] = ["localhost", LOCALHOST]

Expand Down
47 changes: 35 additions & 12 deletions runhouse/resources/hardware/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Union

import yaml

from runhouse.resources.envs.utils import run_with_logs

from runhouse.rns.utils.api import ResourceAccess, ResourceVisibility
Expand Down Expand Up @@ -770,6 +772,17 @@ def _start_ray_workers(self, ray_port, env):
env=env,
)

def _run_cli_commands_on_cluster_helper(self, commands: list[str]):
if self.on_this_cluster():
return self.run(commands=commands, env=self._default_env, node=self.address)
else:
return self._run_commands_with_ssh(
commands=commands,
cmd_prefix=self._default_env._run_cmd if self._default_env else "",
env_vars=self._default_env.env_vars if self._default_env else {},
node=self.address,
)

def restart_server(
self,
_rh_install_url: str = None,
Expand Down Expand Up @@ -838,7 +851,27 @@ def restart_server(
# Update the cluster config on the cluster
self.save_config_to_cluster()

cmd = (
# Save a limited version of the local ~/.rh config to the cluster with the user's hashed token,
# if such does not exist on the cluster

if rns_client.token:
user_config = yaml.safe_dump(
{
"token": rns_client.cluster_token(
rns_client.token, rns_client.username
),
"username": rns_client.username,
"default_folder": rns_client.default_folder,
}
)

create_config_yaml_cmd = [
f"if [ ! -f ~/.rh/config.yaml ] ; then echo '{user_config}' > ~/.rh/config.yaml ; else echo 'Did not change config.yaml' ; fi"
]
self._run_cli_commands_on_cluster_helper(commands=create_config_yaml_cmd)
logger.debug("Saved user config to cluster")

restart_cmd = (
CLI_RESTART_CMD
+ (" --restart-ray" if restart_ray else "")
+ (" --use-https" if https_flag else "")
Expand All @@ -859,17 +892,7 @@ def restart_server(
+ " --from-python"
)

if self.on_this_cluster():
status_codes = self.run(
commands=[cmd], env=self._default_env, node=self.address
)
else:
status_codes = self._run_commands_with_ssh(
commands=[cmd],
cmd_prefix=self._default_env._run_cmd if self._default_env else "",
env_vars=self._default_env.env_vars if self._default_env else {},
node=self.address,
)
status_codes = self._run_cli_commands_on_cluster_helper(commands=[restart_cmd])

if not status_codes[0][0] == 0:
raise ValueError(f"Failed to restart server {self.name}")
Expand Down
13 changes: 0 additions & 13 deletions runhouse/resources/hardware/on_demand_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,19 +446,6 @@ def up(self):

self.restart_server()

# Save a limited version of the local ~/.rh config to the cluster with the user's hashed token
user_config = yaml.safe_dump(
{
"token": rns_client.cluster_token(
rns_client.token, rns_client.username
),
"username": rns_client.username,
"default_folder": rns_client.default_folder,
}
)
self.run([f"echo '{user_config}' > ~/.rh/config.yaml"])
logger.debug("Saved user config to cluster")

return self

def keep_warm(self, autostop_mins: int = -1):
Expand Down
2 changes: 1 addition & 1 deletion runhouse/rns/rns_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ def request_headers(
return {"Authorization": f"Bearer {hashed_token}"}

def cluster_token(self, den_token: str, resource_address: str):
if "/" in resource_address:
if resource_address and "/" in resource_address:
# If provided as a full rns address, extract the top level directory
resource_address = self.base_folder(resource_address)

Expand Down
Loading