Skip to content

Commit

Permalink
Create separate ssh tunnel class
Browse files Browse the repository at this point in the history
  • Loading branch information
carolineechen committed Aug 29, 2024
1 parent ab924b0 commit 291206d
Show file tree
Hide file tree
Showing 5 changed files with 27 additions and 199 deletions.
4 changes: 2 additions & 2 deletions runhouse/globals.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@

configs = Defaults()

sky_ssh_runner_cache = {}
ssh_tunnel_cache = {}


def clean_up_ssh_connections():
for _, v in sky_ssh_runner_cache.items():
for _, v in ssh_tunnel_cache.items():
v.terminate()


Expand Down
11 changes: 9 additions & 2 deletions runhouse/resources/hardware/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -781,8 +781,14 @@ def status(self, resource_address: str = None, send_to_den: bool = False):

def ssh_tunnel(
self, local_port, remote_port=None, num_ports_to_try: int = 0
) -> "SkySSHRunner":
from runhouse.resources.hardware.sky_ssh_runner import ssh_tunnel
) -> "SshTunnel":
from runhouse.resources.hardware.ssh_tunnel import ssh_tunnel

cloud = (
self.launched_properties["cloud"]
if hasattr(self, "launched_properties")
else None
)

return ssh_tunnel(
address=self.address,
Expand All @@ -792,6 +798,7 @@ def ssh_tunnel(
ssh_port=self.ssh_port,
remote_port=remote_port,
num_ports_to_try=num_ports_to_try,
cloud=cloud,
)

@property
Expand Down
192 changes: 1 addition & 191 deletions runhouse/resources/hardware/sky_ssh_runner.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,12 @@
import copy
import logging
import os
import pathlib
import shlex
import subprocess
import time
from typing import Dict, List, Optional, Tuple, Union

from runhouse.constants import DEFAULT_DOCKER_CONTAINER_NAME, LOCALHOST, TUNNEL_TIMEOUT
from runhouse.constants import DEFAULT_DOCKER_CONTAINER_NAME

from runhouse.globals import sky_ssh_runner_cache
from runhouse.logger import get_logger

from runhouse.resources.hardware.sky import common_utils, log_lib, subprocess_utils
Expand Down Expand Up @@ -37,13 +34,6 @@
pass


def is_port_in_use(port: int) -> bool:
import socket

with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
return s.connect_ex(("localhost", port)) == 0


def get_docker_user(cluster: "Cluster", ssh_creds: Dict) -> str:
"""Find docker container username."""
runner = SkySSHRunner(
Expand Down Expand Up @@ -96,7 +86,6 @@ def __init__(

# RH modified
self.docker_user = docker_user
self.tunnel_proc = None
self.local_bind_port = local_bind_port
self.remote_bind_port = None

Expand Down Expand Up @@ -244,73 +233,6 @@ def run(
**kwargs,
)

def tunnel(self, local_port, remote_port):
base_cmd = self._ssh_base_command(
ssh_mode=SshMode.NON_INTERACTIVE, port_forward=[(local_port, remote_port)]
)
command = " ".join(base_cmd)
logger.info(f"Running forwarding command: {command}")
proc = subprocess.Popen(
shlex.split(command),
stdin=subprocess.PIPE,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
)
# Wait until tunnel is formed by trying to create a socket in a loop

start_time = time.time()
while not is_port_in_use(local_port):
time.sleep(0.1)
if time.time() - start_time > TUNNEL_TIMEOUT:
raise ConnectionError(
f"Failed to create tunnel from {local_port} to {remote_port} on {self.ip}"
)

# Set the tunnel process and ports to be cleaned up later
self.tunnel_proc = proc
self.local_bind_port = local_port
self.remote_bind_port = remote_port

def tunnel_is_up(self):
# Try and do as much as we can to check that this is still alive and the port is still forwarded
return self.local_bind_port is not None and is_port_in_use(self.local_bind_port)

def __del__(self):
self.terminate()

def terminate(self):
if self.tunnel_proc is not None:

# Process keeping tunnel alive can only be killed with EOF
self.tunnel_proc.stdin.close()

# Remove port forwarding
port_fwd_cmd = " ".join(
self._ssh_base_command(
ssh_mode=SshMode.NON_INTERACTIVE,
port_forward=[(self.local_bind_port, self.remote_bind_port)],
)
)

if "ControlMaster" in port_fwd_cmd:
cancel_port_fwd = port_fwd_cmd.replace("-T", "-O cancel")
logger.debug(f"Running cancel command: {cancel_port_fwd}")
completed_cancel_cmd = subprocess.run(
shlex.split(cancel_port_fwd),
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
)

if completed_cancel_cmd.returncode != 0:
logger.warning(
f"Failed to cancel port forwarding from {self.local_bind_port} to {self.remote_bind_port}. "
f"Error: {completed_cancel_cmd.stderr}"
)

self.tunnel_proc = None
self.local_bind_port = None
self.remote_bind_port = None

def rsync(
self,
source: str,
Expand Down Expand Up @@ -425,115 +347,3 @@ def rsync(
subprocess_utils.handle_returncode(
returncode, command, error_msg, stderr=stderr, stream_logs=stream_logs
)


####################################################################################################
# Cache and retrieve existing SSH Runners that are set up for a given address and port
####################################################################################################
# TODO: Shouldn't the control master prevent new ssh connections from being created?
def get_existing_sky_ssh_runner(address: str, ssh_port: int) -> Optional[SkySSHRunner]:
if (address, ssh_port) in sky_ssh_runner_cache:
existing_runner = sky_ssh_runner_cache.get((address, ssh_port))
if existing_runner.tunnel_is_up():
return existing_runner
else:
sky_ssh_runner_cache.pop((address, ssh_port))

else:
return None


def cache_existing_sky_ssh_runner(
address: str, ssh_port: int, runner: SkySSHRunner
) -> None:
sky_ssh_runner_cache[(address, ssh_port)] = runner


def ssh_tunnel(
address: str,
ssh_creds: Dict,
local_port: int,
ssh_port: int = 22,
remote_port: Optional[int] = None,
num_ports_to_try: int = 0,
docker_user: Optional[str] = None,
) -> SkySSHRunner:
"""Initialize an ssh tunnel from a remote server to localhost
Args:
address (str): The address of the server we are trying to port forward an address to our local machine with.
ssh_creds (Dict): A dictionary of ssh credentials used to connect to the remote server.
local_port (int): The port locally where we are attempting to bind the remote server address to.
ssh_port (int): The port on the machine where the ssh server is running.
This is generally port 22, but occasionally
we may forward a container's ssh port to a different port
on the actual machine itself (for example on a Docker VM). Defaults to 22.
remote_port (Optional[int], optional): The port of the remote server
we're attempting to port forward. Defaults to None.
num_ports_to_try (int, optional): The number of local ports to attempt to bind to,
starting at local_port and incrementing by 1 till we hit the max. Defaults to 0.
Returns:
SkySSHRunner: The initialized tunnel.
"""

# Debugging cmds (mac):
# netstat -vanp tcp | grep 32300
# lsof -i :32300
# kill -9 <pid>

# If remote_port isn't specified,
# assume that the first attempted local port is
# the same as the remote port on the server.
remote_port = remote_port or local_port

tunnel = get_existing_sky_ssh_runner(address, ssh_port)
tunnel_address = address if not docker_user else "localhost"
if (
tunnel
and tunnel.ip == tunnel_address
and tunnel.remote_bind_port == remote_port
):
logger.info(
f"SSH tunnel on to server's port {remote_port} "
f"via server's ssh port {ssh_port} already created with the cluster."
)
return tunnel

while is_port_in_use(local_port):
if num_ports_to_try < 0:
raise Exception(
f"Failed to create find open port after {num_ports_to_try} attempts"
)

logger.info(f"Port {local_port} is already in use. Trying next port.")
local_port += 1
num_ports_to_try -= 1

# Start a tunnel using self.run in a thread, instead of ssh_tunnel
ssh_credentials = copy.copy(ssh_creds)

# Host could be a proxy specified in credentials or is the provided address
host = ssh_credentials.pop("ssh_host", address)
ssh_control_name = ssh_credentials.pop("ssh_control_name", f"{address}:{ssh_port}")

runner = SkySSHRunner(
ip=host,
ssh_user=ssh_creds.get("ssh_user"),
ssh_private_key=ssh_creds.get("ssh_private_key"),
ssh_proxy_command=ssh_creds.get("ssh_proxy_command"),
ssh_control_name=ssh_control_name,
docker_user=docker_user,
port=ssh_port,
)
runner.tunnel(local_port, remote_port)

logger.debug(
f"Successfully bound "
f"{LOCALHOST}:{remote_port} via ssh port {ssh_port} "
f"on remote server {address} "
f"to {LOCALHOST}:{local_port} on local machine."
)

cache_existing_sky_ssh_runner(address, ssh_port, runner)
return runner
15 changes: 13 additions & 2 deletions runhouse/resources/hardware/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import hashlib
import json
import subprocess

Expand All @@ -13,7 +14,11 @@

from runhouse.logger import logger
from runhouse.resources.envs.utils import _get_env_from, run_setup_command
from runhouse.resources.hardware.sky.command_runner import ssh_options_list, SshMode
from runhouse.resources.hardware.sky.command_runner import (
_HASH_MAX_LENGTH,
ssh_options_list,
SshMode,
)


class ServerConnectionType(str, Enum):
Expand Down Expand Up @@ -149,10 +154,12 @@ def _run_ssh_command(
ip=address,
ssh_user=ssh_user,
ssh_private_key=ssh_private_key,
ssh_mode=SshMode.INTERACTIVE,
ssh_port=ssh_port,
docker_user=docker_user,
)
ssh_command = runner._ssh_base_command(
ssh_mode=SshMode.INTERACTIVE, port_forward=None
)
subprocess.run(ssh_command)


Expand Down Expand Up @@ -215,3 +222,7 @@ def _ssh_base_command(
)
+ [f"{ssh_user}@{address}"]
)


def _generate_ssh_control_hash(ssh_control_name):
return hashlib.md5(ssh_control_name.encode()).hexdigest()[:_HASH_MAX_LENGTH]
4 changes: 2 additions & 2 deletions tests/test_resources/test_clusters/test_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,14 +163,14 @@ def test_cluster_factory_and_properties(self, cluster):
def test_cluster_recreate(self, cluster):
# Create underlying ssh connection if not already
cluster.run(["echo hello"])
num_open_tunnels = len(rh.globals.sky_ssh_runner_cache)
num_open_tunnels = len(rh.globals.ssh_tunnel_cache)

# Create a new cluster object for the same remote cluster
cluster.save()
new_cluster = rh.cluster(cluster.rns_address)
new_cluster.run(["echo hello"])
# Check that the same underlying ssh connection was used
assert len(rh.globals.sky_ssh_runner_cache) == num_open_tunnels
assert len(rh.globals.ssh_tunnel_cache) == num_open_tunnels

@pytest.mark.level("local")
@pytest.mark.clustertest
Expand Down

0 comments on commit 291206d

Please sign in to comment.