Skip to content

Commit

Permalink
support starting server in docker container for on-demand clusters
Browse files Browse the repository at this point in the history
  • Loading branch information
carolineechen committed May 23, 2024
1 parent d73d66d commit 0004acb
Show file tree
Hide file tree
Showing 6 changed files with 73 additions and 8 deletions.
1 change: 1 addition & 0 deletions runhouse/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
TEST_ORG = "test-org"

EMPTY_DEFAULT_ENV_NAME = "_cluster_default_env"
DEFAULT_DOCKER_CONTAINER_NAME = "sky_container"

# cluster status constants
DOUBLE_SPACE_UNICODE = "\u00A0\u00A0"
Expand Down
4 changes: 2 additions & 2 deletions runhouse/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -617,7 +617,7 @@ def _start_server(
try:
# Open and read the lines of the server logfile so we only print the most recent lines after starting
f = None
if screen and Path(SERVER_LOGFILE).exists():
if (screen or nohup) and Path(SERVER_LOGFILE).exists():
f = open(SERVER_LOGFILE, "r")
f.readlines() # Discard these, they're from the previous times the server was started

Expand All @@ -637,7 +637,7 @@ def _start_server(

server_started_str = "Uvicorn running on"
# Read and print the server logs until the
if screen:
if screen or nohup:
while not Path(SERVER_LOGFILE).exists():
time.sleep(1)
f = f or open(SERVER_LOGFILE, "r")
Expand Down
8 changes: 8 additions & 0 deletions runhouse/resources/hardware/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,10 @@ def creds_values(self) -> Dict:

return self._creds.values

@property
def docker_user(self) -> Optional[str]:
return None

@property
def default_env(self):
from runhouse.resources.envs import Env
Expand Down Expand Up @@ -712,6 +716,7 @@ def ssh_tunnel(
return ssh_tunnel(
address=self.address,
ssh_creds=self.creds_values,
docker_user=self.docker_user,
local_port=local_port,
ssh_port=self.ssh_port,
remote_port=remote_port,
Expand Down Expand Up @@ -1071,6 +1076,7 @@ def _rsync(
**ssh_credentials,
ssh_control_name=ssh_control_name,
port=self.ssh_port,
docker_user=self.docker_user,
)
if not pwd:
if up:
Expand Down Expand Up @@ -1145,6 +1151,7 @@ def ssh(self):
ssh_user=creds["ssh_user"],
port=self.ssh_port,
ssh_private_key=creds["ssh_private_key"],
docker_user=self.docker_user,
)
subprocess.run(
runner._ssh_base_command(ssh_mode=SshMode.INTERACTIVE, port_forward=None)
Expand Down Expand Up @@ -1323,6 +1330,7 @@ def _run_commands_with_ssh(
**ssh_credentials,
ssh_control_name=ssh_control_name,
port=self.ssh_port,
docker_user=self.docker_user,
)

env_var_prefix = (
Expand Down
3 changes: 2 additions & 1 deletion runhouse/resources/hardware/cluster_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,8 @@ def ondemand_cluster(
autostop_mins (int, optional): Number of minutes to keep the cluster up after inactivity,
or ``-1`` to keep cluster up indefinitely.
use_spot (bool, optional): Whether or not to use spot instance.
image_id (str, optional): Custom image ID for the cluster.
image_id (str, optional): Custom image ID for the cluster. If using a docker image, please use the following
string format: "docker:<registry>/<image>:<tag>".
region (str, optional): The region to use for the cluster.
memory (int or str, optional): Amount of memory to use for the cluster, e.g. "16" or "16+".
disk_size (int or str, optional): Amount of disk space to use for the cluster, e.g. "100" or "100+".
Expand Down
30 changes: 26 additions & 4 deletions runhouse/resources/hardware/on_demand_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ def __init__(
self.stable_internal_external_ips = kwargs.get(
"stable_internal_external_ips", None
)
self._docker_user = None

# Checks if state info is in local sky db, populates if so.
if not dryrun and not self.ips and not self.creds_values:
Expand All @@ -123,6 +124,22 @@ def autostop_mins(self, mins):
sky.autostop(self.name, mins, down=True)
self._autostop_mins = mins

@property
def docker_user(self) -> str:
if self._docker_user:
return self._docker_user

if not self.image_id:
return None

from runhouse.resources.hardware.sky_ssh_runner import get_docker_user

if not self._creds:
return
self._docker_user = get_docker_user(self, self._creds.values)

return self._docker_user

def config(self, condensed=True):
config = super().config(condensed)
self.save_attrs_to_config(
Expand Down Expand Up @@ -555,10 +572,15 @@ def ssh(self, node: str = None):
ip=node or self.address,
ssh_user=ssh_user,
port=self.ssh_port,
ssh_private_key=sky_key,
ssh_private_key=str(sky_key),
docker_user=self.docker_user,
)
ssh_command = runner._ssh_base_command(
ssh_mode=SshMode.INTERACTIVE, port_forward=None
)
if self.docker_user:
ssh_command += ["&& conda deactivate"]
subprocess.run(
runner._ssh_base_command(
ssh_mode=SshMode.INTERACTIVE, port_forward=None
)
" ".join(ssh_command),
shell=True,
)
35 changes: 34 additions & 1 deletion runhouse/resources/hardware/sky_ssh_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import time
from typing import Dict, List, Optional, Tuple, Union

from runhouse.constants import LOCALHOST
from runhouse.constants import DEFAULT_DOCKER_CONTAINER_NAME, LOCALHOST

from runhouse.globals import sky_ssh_runner_cache

Expand Down Expand Up @@ -45,6 +45,32 @@ def is_port_in_use(port: int) -> bool:
return s.connect_ex(("localhost", port)) == 0


def get_docker_user(cluster: "Cluster", ssh_creds: Dict) -> str:
"""Find docker container username."""
runner = SkySSHRunner(
ip=cluster.address,
ssh_user=ssh_creds.get("ssh_user", None),
port=cluster.ssh_port,
ssh_private_key=ssh_creds.get("ssh_private_key", None),
ssh_control_name=ssh_creds.get(
"ssh_control_name", f"{cluster.address}:{cluster.ssh_port}"
),
)
container_name = DEFAULT_DOCKER_CONTAINER_NAME
whoami_returncode, whoami_stdout, whoami_stderr = runner.run(
f"sudo docker exec {container_name} whoami",
stream_logs=False,
require_outputs=True,
)
assert whoami_returncode == 0, (
f"Failed to get docker container user. Return "
f"code: {whoami_returncode}, Error: {whoami_stderr}"
)
docker_user = whoami_stdout.strip()
logger.debug(f"Docker container user: {docker_user}")
return docker_user


class SkySSHRunner(SSHCommandRunner):
def __init__(
self,
Expand All @@ -68,6 +94,9 @@ def __init__(
docker_user,
disable_control_master,
)

# RH modified
self.docker_user = docker_user
self.tunnel_proc = None
self.local_bind_port = local_bind_port
self.remote_bind_port = None
Expand Down Expand Up @@ -179,6 +208,8 @@ def run(
"-i",
]

cmd = f"conda deactivate && {cmd}" if self.docker_user else cmd

command += [
shlex.quote(
f"true && source ~/.bashrc && export OMP_NUM_THREADS=1 "
Expand Down Expand Up @@ -434,6 +465,7 @@ def ssh_tunnel(
ssh_port: int = 22,
remote_port: Optional[int] = None,
num_ports_to_try: int = 0,
docker_user: Optional[str] = None,
) -> SkySSHRunner:
"""Initialize an ssh tunnel from a remote server to localhost
Expand Down Expand Up @@ -495,6 +527,7 @@ def ssh_tunnel(
ssh_private_key=ssh_creds.get("ssh_private_key"),
ssh_proxy_command=ssh_creds.get("ssh_proxy_command"),
ssh_control_name=ssh_control_name,
docker_user=docker_user,
port=ssh_port,
)
runner.tunnel(local_port, remote_port)
Expand Down

0 comments on commit 0004acb

Please sign in to comment.