Skip to content

Commit

Permalink
Fix race condition connecting to docker ondemand_cluster (#957)
Browse files Browse the repository at this point in the history
  • Loading branch information
dongreenberg authored Jul 2, 2024
1 parent 221f3b6 commit 5eefb20
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 5 deletions.
1 change: 1 addition & 0 deletions runhouse/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
SERVER_LOGFILE_PATH = "~/.rh/server.log"
LOCALHOST: str = "127.0.0.1"
LOCAL_HOSTS: List[str] = ["localhost", LOCALHOST]
TUNNEL_TIMEOUT = 5

LOGS_DIR = ".rh/logs"
RH_LOGFILE_PATH = Path.home() / LOGS_DIR
Expand Down
2 changes: 1 addition & 1 deletion runhouse/resources/hardware/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -606,7 +606,7 @@ def connect_server_client(self, force_reconnect=False):
# re-established, would require a password, and then fail. We should really figure out how to
# authenticate with a password in the SSH tunnel command. But, this is a fine hack for now.
if self.creds_values.get("password") is not None:
self._run_commands_with_ssh(["Initiating password connection."])
self._run_commands_with_ssh(["echo 'Initiating password connection.'"])

# Case 1: Server connection requires SSH tunnel, but we don't have one up yet
self._rpc_tunnel = self.ssh_tunnel(
Expand Down
15 changes: 11 additions & 4 deletions runhouse/resources/hardware/sky_ssh_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import time
from typing import Dict, List, Optional, Tuple, Union

from runhouse.constants import DEFAULT_DOCKER_CONTAINER_NAME, LOCALHOST
from runhouse.constants import DEFAULT_DOCKER_CONTAINER_NAME, LOCALHOST, TUNNEL_TIMEOUT

from runhouse.globals import sky_ssh_runner_cache

Expand Down Expand Up @@ -264,15 +264,22 @@ def tunnel(self, local_port, remote_port):
ssh_mode=SshMode.NON_INTERACTIVE, port_forward=[(local_port, remote_port)]
)
command = " ".join(base_cmd)
logger.debug(f"Running forwarding command: {command}")
logger.info(f"Running forwarding command: {command}")
proc = subprocess.Popen(
shlex.split(command),
stdin=subprocess.PIPE,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
)
# Wait for the ssh connection to start
time.sleep(1)
# Wait until tunnel is formed by trying to create a socket in a loop

start_time = time.time()
while not is_port_in_use(local_port):
time.sleep(0.1)
if time.time() - start_time > TUNNEL_TIMEOUT:
raise ConnectionError(
f"Failed to create tunnel from {local_port} to {remote_port} on {self.ip}"
)

# Set the tunnel process and ports to be cleaned up later
self.tunnel_proc = proc
Expand Down

0 comments on commit 5eefb20

Please sign in to comment.