diff --git a/runhouse/globals.py b/runhouse/globals.py index 84985dcf0..975daaf19 100644 --- a/runhouse/globals.py +++ b/runhouse/globals.py @@ -13,11 +13,11 @@ configs = Defaults() -sky_ssh_runner_cache = {} +ssh_tunnel_cache = {} def clean_up_ssh_connections(): - for _, v in sky_ssh_runner_cache.items(): + for _, v in ssh_tunnel_cache.items(): v.terminate() diff --git a/runhouse/resources/hardware/cluster.py b/runhouse/resources/hardware/cluster.py index 1f12f5bda..3295b1ec5 100644 --- a/runhouse/resources/hardware/cluster.py +++ b/runhouse/resources/hardware/cluster.py @@ -756,8 +756,8 @@ def status(self, resource_address: str = None): def ssh_tunnel( self, local_port, remote_port=None, num_ports_to_try: int = 0 - ) -> "SkySSHRunner": - from runhouse.resources.hardware.sky_ssh_runner import ssh_tunnel + ) -> "SshTunnel": + from runhouse.resources.hardware.ssh_tunnel import ssh_tunnel return ssh_tunnel( address=self.address, diff --git a/runhouse/resources/hardware/sky_ssh_runner.py b/runhouse/resources/hardware/sky_ssh_runner.py index 861846b35..5a5dd9a5b 100644 --- a/runhouse/resources/hardware/sky_ssh_runner.py +++ b/runhouse/resources/hardware/sky_ssh_runner.py @@ -1,14 +1,10 @@ -import copy import os import pathlib import shlex -import subprocess import time from typing import Dict, List, Optional, Tuple, Union -from runhouse.constants import DEFAULT_DOCKER_CONTAINER_NAME, LOCALHOST, TUNNEL_TIMEOUT - -from runhouse.globals import sky_ssh_runner_cache +from runhouse.constants import DEFAULT_DOCKER_CONTAINER_NAME from runhouse.logger import logger @@ -37,13 +33,6 @@ pass -def is_port_in_use(port: int) -> bool: - import socket - - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: - return s.connect_ex(("localhost", port)) == 0 - - def get_docker_user(cluster: "Cluster", ssh_creds: Dict) -> str: """Find docker container username.""" runner = SkySSHRunner( @@ -96,7 +85,6 @@ def __init__( # RH modified self.docker_user = docker_user - self.tunnel_proc = None self.local_bind_port = local_bind_port self.remote_bind_port = None @@ -238,73 +226,6 @@ def run( **kwargs, ) - def tunnel(self, local_port, remote_port): - base_cmd = self._ssh_base_command( - ssh_mode=SshMode.NON_INTERACTIVE, port_forward=[(local_port, remote_port)] - ) - command = " ".join(base_cmd) - logger.info(f"Running forwarding command: {command}") - proc = subprocess.Popen( - shlex.split(command), - stdin=subprocess.PIPE, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - ) - # Wait until tunnel is formed by trying to create a socket in a loop - - start_time = time.time() - while not is_port_in_use(local_port): - time.sleep(0.1) - if time.time() - start_time > TUNNEL_TIMEOUT: - raise ConnectionError( - f"Failed to create tunnel from {local_port} to {remote_port} on {self.ip}" - ) - - # Set the tunnel process and ports to be cleaned up later - self.tunnel_proc = proc - self.local_bind_port = local_port - self.remote_bind_port = remote_port - - def tunnel_is_up(self): - # Try and do as much as we can to check that this is still alive and the port is still forwarded - return self.local_bind_port is not None and is_port_in_use(self.local_bind_port) - - def __del__(self): - self.terminate() - - def terminate(self): - if self.tunnel_proc is not None: - - # Process keeping tunnel alive can only be killed with EOF - self.tunnel_proc.stdin.close() - - # Remove port forwarding - port_fwd_cmd = " ".join( - self._ssh_base_command( - ssh_mode=SshMode.NON_INTERACTIVE, - port_forward=[(self.local_bind_port, self.remote_bind_port)], - ) - ) - - if "ControlMaster" in port_fwd_cmd: - cancel_port_fwd = port_fwd_cmd.replace("-T", "-O cancel") - logger.debug(f"Running cancel command: {cancel_port_fwd}") - completed_cancel_cmd = subprocess.run( - shlex.split(cancel_port_fwd), - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - ) - - if completed_cancel_cmd.returncode != 0: - logger.warning( - f"Failed to cancel port forwarding from {self.local_bind_port} to {self.remote_bind_port}. " - f"Error: {completed_cancel_cmd.stderr}" - ) - - self.tunnel_proc = None - self.local_bind_port = None - self.remote_bind_port = None - def rsync( self, source: str, @@ -419,115 +340,3 @@ def rsync( subprocess_utils.handle_returncode( returncode, command, error_msg, stderr=stderr, stream_logs=stream_logs ) - - -#################################################################################################### -# Cache and retrieve existing SSH Runners that are set up for a given address and port -#################################################################################################### -# TODO: Shouldn't the control master prevent new ssh connections from being created? -def get_existing_sky_ssh_runner(address: str, ssh_port: int) -> Optional[SkySSHRunner]: - if (address, ssh_port) in sky_ssh_runner_cache: - existing_runner = sky_ssh_runner_cache.get((address, ssh_port)) - if existing_runner.tunnel_is_up(): - return existing_runner - else: - sky_ssh_runner_cache.pop((address, ssh_port)) - - else: - return None - - -def cache_existing_sky_ssh_runner( - address: str, ssh_port: int, runner: SkySSHRunner -) -> None: - sky_ssh_runner_cache[(address, ssh_port)] = runner - - -def ssh_tunnel( - address: str, - ssh_creds: Dict, - local_port: int, - ssh_port: int = 22, - remote_port: Optional[int] = None, - num_ports_to_try: int = 0, - docker_user: Optional[str] = None, -) -> SkySSHRunner: - """Initialize an ssh tunnel from a remote server to localhost - - Args: - address (str): The address of the server we are trying to port forward an address to our local machine with. - ssh_creds (Dict): A dictionary of ssh credentials used to connect to the remote server. - local_port (int): The port locally where we are attempting to bind the remote server address to. - ssh_port (int): The port on the machine where the ssh server is running. - This is generally port 22, but occasionally - we may forward a container's ssh port to a different port - on the actual machine itself (for example on a Docker VM). Defaults to 22. - remote_port (Optional[int], optional): The port of the remote server - we're attempting to port forward. Defaults to None. - num_ports_to_try (int, optional): The number of local ports to attempt to bind to, - starting at local_port and incrementing by 1 till we hit the max. Defaults to 0. - - Returns: - SkySSHRunner: The initialized tunnel. - """ - - # Debugging cmds (mac): - # netstat -vanp tcp | grep 32300 - # lsof -i :32300 - # kill -9 - - # If remote_port isn't specified, - # assume that the first attempted local port is - # the same as the remote port on the server. - remote_port = remote_port or local_port - - tunnel = get_existing_sky_ssh_runner(address, ssh_port) - tunnel_address = address if not docker_user else "localhost" - if ( - tunnel - and tunnel.ip == tunnel_address - and tunnel.remote_bind_port == remote_port - ): - logger.info( - f"SSH tunnel on to server's port {remote_port} " - f"via server's ssh port {ssh_port} already created with the cluster." - ) - return tunnel - - while is_port_in_use(local_port): - if num_ports_to_try < 0: - raise Exception( - f"Failed to create find open port after {num_ports_to_try} attempts" - ) - - logger.info(f"Port {local_port} is already in use. Trying next port.") - local_port += 1 - num_ports_to_try -= 1 - - # Start a tunnel using self.run in a thread, instead of ssh_tunnel - ssh_credentials = copy.copy(ssh_creds) - - # Host could be a proxy specified in credentials or is the provided address - host = ssh_credentials.pop("ssh_host", address) - ssh_control_name = ssh_credentials.pop("ssh_control_name", f"{address}:{ssh_port}") - - runner = SkySSHRunner( - ip=host, - ssh_user=ssh_creds.get("ssh_user"), - ssh_private_key=ssh_creds.get("ssh_private_key"), - ssh_proxy_command=ssh_creds.get("ssh_proxy_command"), - ssh_control_name=ssh_control_name, - docker_user=docker_user, - port=ssh_port, - ) - runner.tunnel(local_port, remote_port) - - logger.debug( - f"Successfully bound " - f"{LOCALHOST}:{remote_port} via ssh port {ssh_port} " - f"on remote server {address} " - f"to {LOCALHOST}:{local_port} on local machine." - ) - - cache_existing_sky_ssh_runner(address, ssh_port, runner) - return runner diff --git a/runhouse/resources/hardware/ssh_tunnel.py b/runhouse/resources/hardware/ssh_tunnel.py new file mode 100644 index 000000000..c6862e9dc --- /dev/null +++ b/runhouse/resources/hardware/ssh_tunnel.py @@ -0,0 +1,259 @@ +import copy +import shlex +import subprocess +import time +from typing import Dict, Optional + +from runhouse.constants import LOCALHOST, TUNNEL_TIMEOUT +from runhouse.globals import ssh_tunnel_cache + +from runhouse.logger import logger +from runhouse.resources.hardware.sky.command_runner import SshMode +from runhouse.resources.hardware.utils import ( + _generate_ssh_control_hash, + _ssh_base_command, +) + + +class SshTunnel: + def __init__( + self, + ip: str, + ssh_user: str = None, + ssh_private_key: str = None, + ssh_control_name: Optional[str] = "__default__", + ssh_proxy_command: Optional[str] = None, + ssh_port: int = 22, + disable_control_master: Optional[bool] = False, + docker_user: Optional[str] = None, + ): + """Initialize an ssh tunnel from a remote server to localhost + + Args: + address (str): The address of the server we are trying to port forward an address to our local machine with. + ssh_creds (Dict): A dictionary of ssh credentials used to connect to the remote server. + local_port (int): The port locally where we are attempting to bind the remote server address to. + ssh_port (int): The port on the machine where the ssh server is running. + This is generally port 22, but occasionally + we may forward a container's ssh port to a different port + on the actual machine itself (for example on a Docker VM). Defaults to 22. + remote_port (Optional[int], optional): The port of the remote server + we're attempting to port forward. Defaults to None. + num_ports_to_try (int, optional): The number of local ports to attempt to bind to, + starting at local_port and incrementing by 1 till we hit the max. Defaults to 0. + """ + self.ip = ip + self.ssh_user = ssh_user + self.ssh_private_key = ssh_private_key + self.ssh_control_name = ( + None + if ssh_control_name is None + else _generate_ssh_control_hash(ssh_control_name) + ) + self.ssh_proxy_command = ssh_proxy_command + self.ssh_port = ssh_port + self.docker_user = docker_user + self.disable_control_master = disable_control_master + + self.tunnel_proc = None + + def tunnel(self, local_port, remote_port): + base_cmd = _ssh_base_command( + address=self.ip, + ssh_user=self.ssh_user, + ssh_private_key=self.ssh_private_key, + ssh_control_name=self.ssh_control_name, + ssh_proxy_command=self.ssh_proxy_command, + ssh_port=self.ssh_port, + docker_user=self.docker_user, + disable_control_master=self.disable_control_master, + ssh_mode=SshMode.NON_INTERACTIVE, + port_forward=[(local_port, remote_port)], + ) + command = " ".join(base_cmd) + logger.info(f"Running forwarding command: {command}") + proc = subprocess.Popen( + shlex.split(command), + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ) + # Wait until tunnel is formed by trying to create a socket in a loop + + start_time = time.time() + while not is_port_in_use(local_port): + time.sleep(0.1) + if time.time() - start_time > TUNNEL_TIMEOUT: + raise ConnectionError( + f"Failed to create tunnel from {local_port} to {remote_port} on {self.ip}" + ) + + # Set the tunnel process and ports to be cleaned up later + self.tunnel_proc = proc + self.local_bind_port = local_port + self.remote_bind_port = remote_port + + def tunnel_is_up(self): + # Try and do as much as we can to check that this is still alive and the port is still forwarded + return self.local_bind_port is not None and is_port_in_use(self.local_bind_port) + + def __del__(self): + self.terminate() + + def terminate(self): + if self.tunnel_proc is not None: + + # Process keeping tunnel alive can only be killed with EOF + self.tunnel_proc.stdin.close() + + # Remove port forwarding + port_fwd_cmd = " ".join( + _ssh_base_command( + address=self.ip, + ssh_user=self.ssh_user, + ssh_private_key=self.ssh_private_key, + ssh_control_name=self.ssh_control_name, + ssh_proxy_command=self.ssh_proxy_command, + ssh_port=self.ssh_port, + docker_user=self.docker_user, + disable_control_master=self.disable_control_master, + ssh_mode=SshMode.NON_INTERACTIVE, + port_forward=[(self.local_bind_port, self.remote_bind_port)], + ) + ) + + if "ControlMaster" in port_fwd_cmd: + cancel_port_fwd = port_fwd_cmd.replace("-T", "-O cancel") + logger.debug(f"Running cancel command: {cancel_port_fwd}") + completed_cancel_cmd = subprocess.run( + shlex.split(cancel_port_fwd), + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ) + + if completed_cancel_cmd.returncode != 0: + logger.warning( + f"Failed to cancel port forwarding from {self.local_bind_port} to {self.remote_bind_port}. " + f"Error: {completed_cancel_cmd.stderr}" + ) + + self.tunnel_proc = None + self.local_bind_port = None + self.remote_bind_port = None + + +#################################################################################################### +# Cache and retrieve existing SSH Runners that are set up for a given address and port +#################################################################################################### +# TODO: Shouldn't the control master prevent new ssh connections from being created? + + +def is_port_in_use(port: int) -> bool: + import socket + + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + return s.connect_ex(("localhost", port)) == 0 + + +def cache_existing_ssh_tunnel(address: str, ssh_port: int, tunnel: SshTunnel) -> None: + ssh_tunnel_cache[(address, ssh_port)] = tunnel + + +def get_existing_sky_ssh_runner(address: str, ssh_port: int) -> Optional[SshTunnel]: + if (address, ssh_port) in ssh_tunnel_cache: + existing_tunnel = ssh_tunnel_cache.get((address, ssh_port)) + if existing_tunnel.tunnel_is_up(): + return existing_tunnel + else: + ssh_tunnel_cache.pop((address, ssh_port)) + else: + return None + + +def ssh_tunnel( + address: str, + ssh_creds: Dict, + local_port: int, + ssh_port: int = 22, + remote_port: Optional[int] = None, + num_ports_to_try: int = 0, + docker_user: Optional[str] = None, +) -> SshTunnel: + """Initialize an ssh tunnel from a remote server to localhost + + Args: + address (str): The address of the server we are trying to port forward an address to our local machine with. + ssh_creds (Dict): A dictionary of ssh credentials used to connect to the remote server. + local_port (int): The port locally where we are attempting to bind the remote server address to. + ssh_port (int): The port on the machine where the ssh server is running. + This is generally port 22, but occasionally + we may forward a container's ssh port to a different port + on the actual machine itself (for example on a Docker VM). Defaults to 22. + remote_port (Optional[int], optional): The port of the remote server + we're attempting to port forward. Defaults to None. + num_ports_to_try (int, optional): The number of local ports to attempt to bind to, + starting at local_port and incrementing by 1 till we hit the max. Defaults to 0. + + Returns: + SshTunnel: The initialized tunnel. + """ + + # Debugging cmds (mac): + # netstat -vanp tcp | grep 32300 + # lsof -i :32300 + # kill -9 + + # If remote_port isn't specified, + # assume that the first attempted local port is + # the same as the remote port on the server. + remote_port = remote_port or local_port + + tunnel = get_existing_sky_ssh_runner(address, ssh_port) + tunnel_address = address if not docker_user else "localhost" + if ( + tunnel + and tunnel.ip == tunnel_address + and tunnel.remote_bind_port == remote_port + ): + logger.info( + f"SSH tunnel on to server's port {remote_port} " + f"via server's ssh port {ssh_port} already created with the cluster." + ) + return tunnel + + while is_port_in_use(local_port): + if num_ports_to_try < 0: + raise Exception( + f"Failed to create find open port after {num_ports_to_try} attempts" + ) + + logger.info(f"Port {local_port} is already in use. Trying next port.") + local_port += 1 + num_ports_to_try -= 1 + + ssh_credentials = copy.copy(ssh_creds) + + # Host could be a proxy specified in credentials or is the provided address + host = ssh_credentials.pop("ssh_host", address) + ssh_control_name = ssh_credentials.pop("ssh_control_name", f"{address}:{ssh_port}") + + tunnel = SshTunnel( + ip=host, + ssh_user=ssh_creds.get("ssh_user"), + ssh_private_key=ssh_creds.get("ssh_private_key"), + ssh_proxy_command=ssh_creds.get("ssh_proxy_command"), + ssh_control_name=ssh_control_name, + docker_user=docker_user, + ssh_port=ssh_port, + ) + tunnel.tunnel(local_port, remote_port) + + logger.debug( + f"Successfully bound " + f"{LOCALHOST}:{remote_port} via ssh port {ssh_port} " + f"on remote server {address} " + f"to {LOCALHOST}:{local_port} on local machine." + ) + + cache_existing_ssh_tunnel(address, ssh_port, tunnel) + return tunnel diff --git a/runhouse/resources/hardware/utils.py b/runhouse/resources/hardware/utils.py index f5fd25a27..f818ca30e 100644 --- a/runhouse/resources/hardware/utils.py +++ b/runhouse/resources/hardware/utils.py @@ -220,3 +220,7 @@ def _ssh_base_command( ) + [f"{ssh_user}@{address}"] ) + + +def _generate_ssh_control_hash(ssh_control_name): + return hashlib.md5(ssh_control_name.encode()).hexdigest()[:_HASH_MAX_LENGTH] diff --git a/tests/test_resources/test_clusters/test_cluster.py b/tests/test_resources/test_clusters/test_cluster.py index 1efa400b5..a3c3d4b60 100644 --- a/tests/test_resources/test_clusters/test_cluster.py +++ b/tests/test_resources/test_clusters/test_cluster.py @@ -166,14 +166,14 @@ def test_cluster_factory_and_properties(self, cluster): def test_cluster_recreate(self, cluster): # Create underlying ssh connection if not already cluster.run(["echo hello"]) - num_open_tunnels = len(rh.globals.sky_ssh_runner_cache) + num_open_tunnels = len(rh.globals.ssh_tunnel_cache) # Create a new cluster object for the same remote cluster cluster.save() new_cluster = rh.cluster(cluster.rns_address) new_cluster.run(["echo hello"]) # Check that the same underlying ssh connection was used - assert len(rh.globals.sky_ssh_runner_cache) == num_open_tunnels + assert len(rh.globals.ssh_tunnel_cache) == num_open_tunnels @pytest.mark.level("local") @pytest.mark.clustertest