diff --git a/runhouse/resources/hardware/cluster.py b/runhouse/resources/hardware/cluster.py index 3295b1ec5..7f4dcba55 100644 --- a/runhouse/resources/hardware/cluster.py +++ b/runhouse/resources/hardware/cluster.py @@ -1197,10 +1197,9 @@ def rsync( ) runner = SkySSHRunner( - node, + (node, self.ssh_port), **ssh_credentials, ssh_control_name=ssh_control_name, - port=self.ssh_port, docker_user=self.docker_user, ) if not pwd: @@ -1416,10 +1415,9 @@ def _run_commands_with_ssh( ) runner = SkySSHRunner( - host, + (host, self.ssh_port), **ssh_credentials, ssh_control_name=ssh_control_name, - port=self.ssh_port, docker_user=self.docker_user, ) diff --git a/runhouse/resources/hardware/on_demand_cluster.py b/runhouse/resources/hardware/on_demand_cluster.py index 41b8d3c42..abdec8092 100644 --- a/runhouse/resources/hardware/on_demand_cluster.py +++ b/runhouse/resources/hardware/on_demand_cluster.py @@ -666,9 +666,8 @@ def ssh(self, node: str = None): raise FileNotFoundError(f"Expected default sky key in path: {sky_key}") runner = SkySSHRunner( - ip=node or self.address, + (node or self.address, self.ssh_port), ssh_user=ssh_user, - port=self.ssh_port, ssh_private_key=str(sky_key), docker_user=self.docker_user, ) diff --git a/runhouse/resources/hardware/sagemaker/sagemaker_cluster.py b/runhouse/resources/hardware/sagemaker/sagemaker_cluster.py index 84ff0255e..286d2705a 100644 --- a/runhouse/resources/hardware/sagemaker/sagemaker_cluster.py +++ b/runhouse/resources/hardware/sagemaker/sagemaker_cluster.py @@ -705,8 +705,7 @@ def _run_commands_with_ssh( from runhouse.resources.hardware.sky_ssh_runner import SkySSHRunner runner = SkySSHRunner( - self.name, - port=self.ssh_port, + (self.name, self.ssh_port), ssh_user=self.DEFAULT_USER, ssh_private_key=self._abs_ssh_key_path, ssh_control_name=f"{self.name}:{self.ssh_port}", diff --git a/runhouse/resources/hardware/sky/command_runner.py b/runhouse/resources/hardware/sky/command_runner.py index f47e905e2..330584dfa 100644 --- a/runhouse/resources/hardware/sky/command_runner.py +++ b/runhouse/resources/hardware/sky/command_runner.py @@ -1,4 +1,4 @@ -# Source: https://github.com/skypilot-org/skypilot/blob/feb52cf/sky/utils/command_runner.py +# Source: https://github.com/skypilot-org/skypilot/blob/465d36/sky/utils/command_runner.py """Runner for commands to be executed on the cluster.""" import enum @@ -7,10 +7,7 @@ import pathlib import shlex import time -from typing import List, Optional, Tuple, Union - - -# from sky.skylet import log_lib +from typing import Any, Callable, Iterable, List, Optional, Tuple, Type, Union from runhouse.resources.hardware.sky import ( constants, @@ -25,10 +22,13 @@ ##### RH modification ##### - # The git exclude file to support. GIT_EXCLUDE = '.git/info/exclude' # Rsync options +# TODO(zhwu): This will print a per-file progress bar (with -P), +# shooting a lot of messages to the output. --info=progress2 is used +# to get a total progress bar, but it requires rsync>=3.1.0 and Mac +# OS has a default rsync==2.6.9 (16 years old). RSYNC_DISPLAY_OPTION = '-Pavz' # Legend # dir-merge: ignore file can appear in any subdir, applies to that @@ -40,6 +40,7 @@ RSYNC_EXCLUDE_OPTION = '--exclude-from={}' _HASH_MAX_LENGTH = 10 +_DEFAULT_CONNECT_TIMEOUT = 30 def _ssh_control_path(ssh_control_filename: Optional[str]) -> Optional[str]: @@ -52,23 +53,31 @@ def _ssh_control_path(ssh_control_filename: Optional[str]) -> Optional[str]: return path +# Disable sudo for root user. This is useful when the command is running in a +# docker container, i.e. image_id is a docker image. +ALIAS_SUDO_TO_EMPTY_FOR_ROOT_CMD = ( + '{ [ "$(whoami)" == "root" ] && function sudo() { "$@"; } || true; }') + + def ssh_options_list( ssh_private_key: Optional[str], ssh_control_name: Optional[str], *, ssh_proxy_command: Optional[str] = None, docker_ssh_proxy_command: Optional[str] = None, - timeout: int = 30, + connect_timeout: Optional[int] = None, port: int = 22, disable_control_master: Optional[bool] = False, ) -> List[str]: """Returns a list of sane options for 'ssh'.""" + if connect_timeout is None: + connect_timeout = _DEFAULT_CONNECT_TIMEOUT # Forked from Ray SSHOptions: # https://github.com/ray-project/ray/blob/master/python/ray/autoscaler/_private/command_runner.py arg_dict = { # SSH port 'Port': port, - # Supresses initial fingerprint verification. + # Suppresses initial fingerprint verification. 'StrictHostKeyChecking': 'no', # SSH IP and fingerprint pairs no longer added to known_hosts. # This is to remove a 'REMOTE HOST IDENTIFICATION HAS CHANGED' @@ -76,6 +85,10 @@ def ssh_options_list( # deleted node, because the fingerprints will not match in # that case. 'UserKnownHostsFile': os.devnull, + # Suppresses the warning messages, such as: + # Warning: Permanently added 'xx.xx.xx.xx' (EDxxx) to the list of + # known hosts. + 'LogLevel': 'ERROR', # Try fewer extraneous key pairs. 'IdentitiesOnly': 'yes', # Abort if port forwarding fails (instead of just printing to @@ -86,7 +99,7 @@ def ssh_options_list( 'ServerAliveInterval': 5, 'ServerAliveCountMax': 3, # ConnectTimeout. - 'ConnectTimeout': f'{timeout}s', + 'ConnectTimeout': f'{connect_timeout}s', # Agent forwarding for git. 'ForwardAgent': 'yes', } @@ -143,17 +156,255 @@ class SshMode(enum.Enum): LOGIN = 2 -class SSHCommandRunner: +class CommandRunner: + """Runner for commands to be executed on the cluster.""" + + def __init__(self, node: Tuple[Any, Any], **kwargs): + del kwargs # Unused. + self.node = node + + @property + def node_id(self) -> str: + return '-'.join(str(x) for x in self.node) + + def _get_command_to_run( + self, + cmd: Union[str, List[str]], + process_stream: bool, + separate_stderr: bool, + skip_lines: int, + source_bashrc: bool = False, + ) -> str: + """Returns the command to run.""" + if isinstance(cmd, list): + cmd = ' '.join(cmd) + + # We need this to correctly run the cmd, and get the output. + command = [ + '/bin/bash', + '--login', + '-c', + ] + if source_bashrc: + command += [ + # Need this `-i` option to make sure `source ~/.bashrc` work. + # Sourcing bashrc may take a few seconds causing overheads. + '-i', + shlex.quote( + f'true && source ~/.bashrc && export OMP_NUM_THREADS=1 ' + f'PYTHONWARNINGS=ignore && ({cmd})'), + ] + else: + # Optimization: this reduces the time for connecting to the remote + # cluster by 1 second. + # sourcing ~/.bashrc is not required for internal executions + command += [ + shlex.quote('true && export OMP_NUM_THREADS=1 ' + f'PYTHONWARNINGS=ignore && ({cmd})') + ] + if not separate_stderr: + command.append('2>&1') + if not process_stream and skip_lines: + command += [ + # A hack to remove the following bash warnings (twice): + # bash: cannot set terminal process group + # bash: no job control in this shell + f'| stdbuf -o0 tail -n +{skip_lines}', + # This is required to make sure the executor of command can get + # correct returncode, since linux pipe is used. + '; exit ${PIPESTATUS[0]}' + ] + + command_str = ' '.join(command) + return command_str + + def _rsync( + self, + source: str, + target: str, + node_destination: str, + up: bool, + rsh_option: str, + # Advanced options. + log_path: str = os.devnull, + stream_logs: bool = True, + max_retry: int = 1, + prefix_command: Optional[str] = None, + get_remote_home_dir: Callable[[], str] = lambda: '~') -> None: + """Builds the rsync command.""" + # Build command. + rsync_command = [] + if prefix_command is not None: + rsync_command.append(prefix_command) + rsync_command += ['rsync', RSYNC_DISPLAY_OPTION] + + # --filter + rsync_command.append(RSYNC_FILTER_OPTION) + + if up: + # Build --exclude-from argument. + # The source is a local path, so we need to resolve it. + resolved_source = pathlib.Path(source).expanduser().resolve() + if (resolved_source / GIT_EXCLUDE).exists(): + # Ensure file exists; otherwise, rsync will error out. + # + # We shlex.quote() because the path may contain spaces: + # 'my dir/.git/info/exclude' + # Without quoting rsync fails. + rsync_command.append( + RSYNC_EXCLUDE_OPTION.format( + shlex.quote(str(resolved_source / GIT_EXCLUDE)))) + + rsync_command.append(f'-e {shlex.quote(rsh_option)}') + + if up: + resolved_target = target + if target.startswith('~'): + remote_home_dir = get_remote_home_dir() + resolved_target = target.replace('~', remote_home_dir) + full_source_str = str(resolved_source) + if resolved_source.is_dir(): + full_source_str = os.path.join(full_source_str, '') + rsync_command.extend([ + f'{full_source_str!r}', + f'{node_destination}:{resolved_target!r}', + ]) + else: + resolved_source = source + if source.startswith('~'): + remote_home_dir = get_remote_home_dir() + resolved_source = source.replace('~', remote_home_dir) + rsync_command.extend([ + f'{node_destination}:{resolved_source!r}', + f'{os.path.expanduser(target)!r}', + ]) + command = ' '.join(rsync_command) + logger.debug(f'Running rsync command: {command}') + + backoff = common_utils.Backoff(initial_backoff=5, max_backoff_factor=5) + assert max_retry > 0, f'max_retry {max_retry} must be positive.' + while max_retry >= 0: + returncode, stdout, stderr = log_lib.run_with_log( + command, + log_path=log_path, + stream_logs=stream_logs, + shell=True, + require_outputs=True) + if returncode == 0: + break + max_retry -= 1 + time.sleep(backoff.current_backoff()) + + direction = 'up' if up else 'down' + error_msg = (f'Failed to rsync {direction}: {source} -> {target}. ' + 'Ensure that the network is stable, then retry.') + subprocess_utils.handle_returncode(returncode, + command, + error_msg, + stderr=stdout + stderr, + stream_logs=stream_logs) + + # @timeline.event + def run( + self, + cmd: Union[str, List[str]], + *, + require_outputs: bool = False, + # Advanced options. + log_path: str = os.devnull, + # If False, do not redirect stdout/stderr to optimize performance. + process_stream: bool = True, + stream_logs: bool = True, + ssh_mode: SshMode = SshMode.NON_INTERACTIVE, + separate_stderr: bool = False, + connect_timeout: Optional[int] = None, + source_bashrc: bool = False, + skip_lines: int = 0, + **kwargs) -> Union[int, Tuple[int, str, str]]: + """Runs the command on the cluster. + + Args: + cmd: The command to run. + require_outputs: Whether to return the stdout/stderr of the command. + log_path: Redirect stdout/stderr to the log_path. + stream_logs: Stream logs to the stdout/stderr. + ssh_mode: The mode to use for ssh. + See SSHMode for more details. + separate_stderr: Whether to separate stderr from stdout. + connect_timeout: timeout in seconds for the ssh connection. + source_bashrc: Whether to source the ~/.bashrc before running the + command. + skip_lines: The number of lines to skip at the beginning of the + output. This is used when the output is not processed by + SkyPilot but we still want to get rid of some warning messages, + such as SSH warnings. + + + Returns: + returncode + or + A tuple of (returncode, stdout, stderr). + """ + raise NotImplementedError + + # @timeline.event + def rsync( + self, + source: str, + target: str, + *, + up: bool, + # Advanced options. + log_path: str = os.devnull, + stream_logs: bool = True, + max_retry: int = 1, + ) -> None: + """Uses 'rsync' to sync 'source' to 'target'. + + Args: + source: The source path. + target: The target path. + up: The direction of the sync, True for local to cluster, False + for cluster to local. + log_path: Redirect stdout/stderr to the log_path. + stream_logs: Stream logs to the stdout/stderr. + max_retry: The maximum number of retries for the rsync command. + This value should be non-negative. + + Raises: + exceptions.CommandError: rsync command failed. + """ + raise NotImplementedError + + @classmethod + def make_runner_list( + cls: Type['CommandRunner'], + node_list: Iterable[Any], + **kwargs, + ) -> List['CommandRunner']: + """Helper function for creating runners with the same credentials""" + return [cls(node, **kwargs) for node in node_list] + + def check_connection(self) -> bool: + """Check if the connection to the remote machine is successful.""" + returncode = self.run('true', connect_timeout=5, stream_logs=False) + return returncode == 0 + + def close_cached_connection(self) -> None: + """Close the cached connection to the remote machine.""" + pass + + +class SSHCommandRunner(CommandRunner): """Runner for SSH commands.""" def __init__( self, - ip: str, + node: Tuple[str, int], ssh_user: str, ssh_private_key: str, ssh_control_name: Optional[str] = '__default__', ssh_proxy_command: Optional[str] = None, - port: int = 22, docker_user: Optional[str] = None, disable_control_master: Optional[bool] = False, ): @@ -165,7 +416,7 @@ def __init__( runner.rsync(source, target, up=True) Args: - ip: The IP address of the remote machine. + node: (ip, port) The IP address and port of the remote machine. ssh_private_key: The path to the private key to use for ssh. ssh_user: The user to use for ssh. ssh_control_name: The files name of the ssh_control to use. This is @@ -178,11 +429,13 @@ def __init__( port: The port to use for ssh. docker_user: The docker user to use for ssh. If specified, the command will be run inside a docker container which have a ssh - server running at port DEFAULT_DOCKER_PORT + server running at port sky.skylet.constants.DEFAULT_DOCKER_PORT disable_control_master: bool; specifies either or not the ssh command will utilize ControlMaster. We currently disable it for k8s instance. """ + super().__init__(node) + ip, port = node self.ssh_private_key = ssh_private_key self.ssh_control_name = ( None if ssh_control_name is None else hashlib.md5( @@ -207,29 +460,9 @@ def __init__( self.port = port self._docker_ssh_proxy_command = None - @staticmethod - def make_runner_list( - ip_list: List[str], - ssh_user: str, - ssh_private_key: str, - ssh_control_name: Optional[str] = None, - ssh_proxy_command: Optional[str] = None, - disable_control_master: Optional[bool] = False, - port_list: Optional[List[int]] = None, - docker_user: Optional[str] = None, - ) -> List['SSHCommandRunner']: - """Helper function for creating runners with the same ssh credentials""" - if not port_list: - port_list = [22] * len(ip_list) - return [ - SSHCommandRunner(ip, ssh_user, ssh_private_key, ssh_control_name, - ssh_proxy_command, port, docker_user, - disable_control_master) - for ip, port in zip(ip_list, port_list) - ] - def _ssh_base_command(self, *, ssh_mode: SshMode, - port_forward: Optional[List[int]]) -> List[str]: + port_forward: Optional[List[int]], + connect_timeout: Optional[int]) -> List[str]: ssh = ['ssh'] if ssh_mode == SshMode.NON_INTERACTIVE: # Disable pseudo-terminal allocation. Otherwise, the output of @@ -254,10 +487,32 @@ def _ssh_base_command(self, *, ssh_mode: SshMode, ssh_proxy_command=self._ssh_proxy_command, docker_ssh_proxy_command=docker_ssh_proxy_command, port=self.port, + connect_timeout=connect_timeout, disable_control_master=self.disable_control_master) + [ f'{self.ssh_user}@{self.ip}' ] + def close_cached_connection(self) -> None: + """Close the cached connection to the remote machine. + + This is useful when we need to make the permission update effective of a + ssh user, e.g. usermod -aG docker $USER. + """ + if self.ssh_control_name is not None: + control_path = _ssh_control_path(self.ssh_control_name) + if control_path is not None: + cmd = (f'ssh -O exit -S {control_path}/%C ' + f'{self.ssh_user}@{self.ip}') + logger.debug(f'Closing cached connection {control_path!r} with ' + f'cmd: {cmd}') + log_lib.run_with_log(cmd, + log_path=os.devnull, + require_outputs=False, + stream_logs=False, + process_stream=False, + shell=True) + + # @timeline.event def run( self, cmd: Union[str, List[str]], @@ -271,11 +526,13 @@ def run( stream_logs: bool = True, ssh_mode: SshMode = SshMode.NON_INTERACTIVE, separate_stderr: bool = False, + connect_timeout: Optional[int] = None, + source_bashrc: bool = False, + skip_lines: int = 0, **kwargs) -> Union[int, Tuple[int, str, str]]: """Uses 'ssh' to run 'cmd' on a node with ip. Args: - ip: The IP address of the node. cmd: The command to run. port_forward: A list of ports to forward from the localhost to the remote host. @@ -289,53 +546,38 @@ def run( ssh_mode: The mode to use for ssh. See SSHMode for more details. separate_stderr: Whether to separate stderr from stdout. - + connect_timeout: timeout in seconds for the ssh connection. + source_bashrc: Whether to source the bashrc before running the + command. + skip_lines: The number of lines to skip at the beginning of the + output. This is used when the output is not processed by + SkyPilot but we still want to get rid of some warning messages, + such as SSH warnings. Returns: returncode or A tuple of (returncode, stdout, stderr). """ - base_ssh_command = self._ssh_base_command(ssh_mode=ssh_mode, - port_forward=port_forward) + base_ssh_command = self._ssh_base_command( + ssh_mode=ssh_mode, + port_forward=port_forward, + connect_timeout=connect_timeout) if ssh_mode == SshMode.LOGIN: assert isinstance(cmd, list), 'cmd must be a list for login mode.' command = base_ssh_command + cmd proc = subprocess_utils.run(command, shell=False, check=False) return proc.returncode, '', '' - if isinstance(cmd, list): - cmd = ' '.join(cmd) + + command_str = self._get_command_to_run(cmd, + process_stream, + separate_stderr, + skip_lines=skip_lines, + source_bashrc=source_bashrc) + command = base_ssh_command + [shlex.quote(command_str)] log_dir = os.path.expanduser(os.path.dirname(log_path)) os.makedirs(log_dir, exist_ok=True) - # We need this to correctly run the cmd, and get the output. - command = [ - 'bash', - '--login', - '-c', - # Need this `-i` option to make sure `source ~/.bashrc` work. - '-i', - ] - - command += [ - shlex.quote(f'true && source ~/.bashrc && export OMP_NUM_THREADS=1 ' - f'PYTHONWARNINGS=ignore && ({cmd})'), - ] - if not separate_stderr: - command.append('2>&1') - if not process_stream and ssh_mode == SshMode.NON_INTERACTIVE: - command += [ - # A hack to remove the following bash warnings (twice): - # bash: cannot set terminal process group - # bash: no job control in this shell - '| stdbuf -o0 tail -n +5', - # This is required to make sure the executor of command can get - # correct returncode, since linux pipe is used. - '; exit ${PIPESTATUS[0]}' - ] - - command_str = ' '.join(command) - command = base_ssh_command + [shlex.quote(command_str)] executable = None if not process_stream: @@ -358,6 +600,7 @@ def run( executable=executable, **kwargs) + # @timeline.event def rsync( self, source: str, @@ -384,30 +627,6 @@ def rsync( Raises: exceptions.CommandError: rsync command failed. """ - # Build command. - # TODO(zhwu): This will print a per-file progress bar (with -P), - # shooting a lot of messages to the output. --info=progress2 is used - # to get a total progress bar, but it requires rsync>=3.1.0 and Mac - # OS has a default rsync==2.6.9 (16 years old). - rsync_command = ['rsync', RSYNC_DISPLAY_OPTION] - - # --filter - rsync_command.append(RSYNC_FILTER_OPTION) - - if up: - # The source is a local path, so we need to resolve it. - # --exclude-from - resolved_source = pathlib.Path(source).expanduser().resolve() - if (resolved_source / GIT_EXCLUDE).exists(): - # Ensure file exists; otherwise, rsync will error out. - # - # We shlex.quote() because the path may contain spaces: - # 'my dir/.git/info/exclude' - # Without quoting rsync fails. - rsync_command.append( - RSYNC_EXCLUDE_OPTION.format( - shlex.quote(str(resolved_source / GIT_EXCLUDE)))) - if self._docker_ssh_proxy_command is not None: docker_ssh_proxy_command = self._docker_ssh_proxy_command(['ssh']) else: @@ -420,43 +639,12 @@ def rsync( docker_ssh_proxy_command=docker_ssh_proxy_command, port=self.port, disable_control_master=self.disable_control_master)) - rsync_command.append(f'-e "ssh {ssh_options}"') - # To support spaces in the path, we need to quote source and target. - # rsync doesn't support '~' in a quoted local path, but it is ok to - # have '~' in a quoted remote path. - if up: - full_source_str = str(resolved_source) - if resolved_source.is_dir(): - full_source_str = os.path.join(full_source_str, '') - rsync_command.extend([ - f'{full_source_str!r}', - f'{self.ssh_user}@{self.ip}:{target!r}', - ]) - else: - rsync_command.extend([ - f'{self.ssh_user}@{self.ip}:{source!r}', - f'{os.path.expanduser(target)!r}', - ]) - command = ' '.join(rsync_command) - - backoff = common_utils.Backoff(initial_backoff=5, max_backoff_factor=5) - while max_retry >= 0: - returncode, _, stderr = log_lib.run_with_log( - command, - log_path=log_path, - stream_logs=stream_logs, - shell=True, - require_outputs=True) - if returncode == 0: - break - max_retry -= 1 - time.sleep(backoff.current_backoff()) - - direction = 'up' if up else 'down' - error_msg = (f'Failed to rsync {direction}: {source} -> {target}. ' - 'Ensure that the network is stable, then retry.') - subprocess_utils.handle_returncode(returncode, - command, - error_msg, - stderr=stderr, - stream_logs=stream_logs) + rsh_option = f'ssh {ssh_options}' + self._rsync(source, + target, + node_destination=f'{self.ssh_user}@{self.ip}', + up=up, + rsh_option=rsh_option, + log_path=log_path, + stream_logs=stream_logs, + max_retry=max_retry) diff --git a/runhouse/resources/hardware/sky_ssh_runner.py b/runhouse/resources/hardware/sky_ssh_runner.py index 5a5dd9a5b..5d237a751 100644 --- a/runhouse/resources/hardware/sky_ssh_runner.py +++ b/runhouse/resources/hardware/sky_ssh_runner.py @@ -36,9 +36,8 @@ def get_docker_user(cluster: "Cluster", ssh_creds: Dict) -> str: """Find docker container username.""" runner = SkySSHRunner( - ip=cluster.address, + node=(cluster.address, cluster.ssh_port), ssh_user=ssh_creds.get("ssh_user", None), - port=cluster.ssh_port, ssh_private_key=ssh_creds.get("ssh_private_key", None), ssh_control_name=ssh_creds.get( "ssh_control_name", f"{cluster.address}:{cluster.ssh_port}" @@ -62,9 +61,9 @@ def get_docker_user(cluster: "Cluster", ssh_creds: Dict) -> str: class SkySSHRunner(SSHCommandRunner): def __init__( self, - ip, - ssh_user=None, - ssh_private_key=None, + node: Tuple[str, int], + ssh_user: Optional[str] = None, + ssh_private_key: Optional[str] = None, ssh_control_name: Optional[str] = "__default__", ssh_proxy_command: Optional[str] = None, port: int = 22, @@ -73,12 +72,11 @@ def __init__( local_bind_port: Optional[int] = None, ): super().__init__( - ip, + node, ssh_user, ssh_private_key, ssh_control_name, ssh_proxy_command, - port, docker_user, disable_control_master, ) @@ -89,7 +87,11 @@ def __init__( self.remote_bind_port = None def _ssh_base_command( - self, *, ssh_mode: SshMode, port_forward: Optional[List[int]] + self, + *, + ssh_mode: SshMode, + port_forward: Optional[List[int]], + connect_timeout: Optional[int] = None, ) -> List[str]: return _ssh_base_command( address=self.ip, @@ -102,6 +104,7 @@ def _ssh_base_command( disable_control_master=self.disable_control_master, ssh_mode=ssh_mode, port_forward=port_forward, + connect_timeout=connect_timeout, ) def run( @@ -117,6 +120,9 @@ def run( stream_logs: bool = True, ssh_mode: SshMode = SshMode.NON_INTERACTIVE, separate_stderr: bool = False, + connect_timeout: Optional[int] = None, + source_bashrc: bool = True, # RH MODIFIED + skip_lines: int = 0, return_cmd: bool = False, # RH MODIFIED quiet_ssh: bool = False, # RH MODIFIED **kwargs, @@ -138,6 +144,13 @@ def run( ssh_mode: The mode to use for ssh. See SSHMode for more details. separate_stderr: Whether to separate stderr from stdout. + connect_timeout: timeout in seconds for the ssh connection. + source_bashrc: Whether to source the bashrc before running the + command. + skip_lines: The number of lines to skip at the beginning of the + output. This is used when the output is not processed by + SkyPilot but we still want to get rid of some warning messages, + such as SSH warnings. return_cmd: If True, return the command string instead of running it. quiet_ssh: If True, do not print the OpenSSH outputs (i.e. add "-q" option to ssh). @@ -148,15 +161,24 @@ def run( A tuple of (returncode, stdout, stderr). """ base_ssh_command = self._ssh_base_command( - ssh_mode=ssh_mode, port_forward=port_forward + ssh_mode=ssh_mode, + port_forward=port_forward, + connect_timeout=connect_timeout, ) if ssh_mode == SshMode.LOGIN: assert isinstance(cmd, list), "cmd must be a list for login mode." command = base_ssh_command + cmd proc = subprocess_utils.run(command, shell=False, check=False) return proc.returncode, "", "" - if isinstance(cmd, list): - cmd = " ".join(cmd) + + command_str = self._get_command_to_run( + cmd, + process_stream, + separate_stderr, + skip_lines=skip_lines, + source_bashrc=source_bashrc, + ) + command = base_ssh_command + [shlex.quote(command_str)] # RH MODIFIED: Add quiet_ssh option if quiet_ssh: @@ -164,38 +186,6 @@ def run( log_dir = os.path.expanduser(os.path.dirname(log_path)) os.makedirs(log_dir, exist_ok=True) - # We need this to correctly run the cmd, and get the output. - command = [ - "bash", - "--login", - "-c", - # Need this `-i` option to make sure `source ~/.bashrc` work. - "-i", - ] - - cmd = f"conda deactivate && {cmd}" if self.docker_user else cmd - - command += [ - shlex.quote( - f"true && source ~/.bashrc && export OMP_NUM_THREADS=1 " - f"PYTHONWARNINGS=ignore && ({cmd})" - ), - ] - if not separate_stderr: - command.append("2>&1") - if not process_stream and ssh_mode == SshMode.NON_INTERACTIVE: - command += [ - # A hack to remove the following bash warnings (twice): - # bash: cannot set terminal process group - # bash: no job control in this shell - "| stdbuf -o0 tail -n +5", - # This is required to make sure the executor of command can get - # correct returncode, since linux pipe is used. - "; exit ${PIPESTATUS[0]}", - ] - - command_str = " ".join(command) - command = base_ssh_command + [shlex.quote(command_str)] executable = None if not process_stream: @@ -237,7 +227,7 @@ def rsync( log_path: str = os.devnull, stream_logs: bool = True, max_retry: int = 1, - return_cmd: bool = False, + return_cmd: bool = False, # RH MODIFIED ) -> None: """Uses 'rsync' to sync 'source' to 'target'. @@ -274,7 +264,9 @@ def rsync( if (resolved_source / GIT_EXCLUDE).exists(): # Ensure file exists; otherwise, rsync will error out. rsync_command.append( - RSYNC_EXCLUDE_OPTION.format(str(resolved_source / GIT_EXCLUDE)) + RSYNC_EXCLUDE_OPTION.format( + shlex.quote(str(resolved_source / GIT_EXCLUDE)) + ) ) if self._docker_ssh_proxy_command is not None: diff --git a/runhouse/resources/hardware/utils.py b/runhouse/resources/hardware/utils.py index 422def12c..f6af971fb 100644 --- a/runhouse/resources/hardware/utils.py +++ b/runhouse/resources/hardware/utils.py @@ -166,6 +166,7 @@ def _docker_ssh_proxy_command( ) +# Adapted from SkyPilot Command Runner def _ssh_base_command( address: str, ssh_user: str, @@ -177,6 +178,7 @@ def _ssh_base_command( disable_control_master: Optional[bool] = False, ssh_mode: SshMode = SshMode.INTERACTIVE, port_forward: Optional[List[int]] = None, + connect_timeout: Optional[int] = None, ): ssh = ["ssh"] if ssh_mode == SshMode.NON_INTERACTIVE: @@ -213,6 +215,7 @@ def _ssh_base_command( docker_ssh_proxy_command=docker_ssh_proxy_command, # TODO change to None like before? port=ssh_port, + connect_timeout=connect_timeout, disable_control_master=disable_control_master, ) + [f"{ssh_user}@{address}"]