Skip to content

Commit

Permalink
Add sky k8s command runner
Browse files Browse the repository at this point in the history
  • Loading branch information
carolineechen committed Aug 29, 2024
1 parent 0f9c71a commit 246f0a1
Show file tree
Hide file tree
Showing 2 changed files with 389 additions and 5 deletions.
190 changes: 190 additions & 0 deletions runhouse/resources/hardware/sky/command_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -648,3 +648,193 @@ def rsync(
log_path=log_path,
stream_logs=stream_logs,
max_retry=max_retry)


class KubernetesCommandRunner(CommandRunner):
"""Runner for Kubernetes commands."""

def __init__(
self,
node: Tuple[str, str],
**kwargs,
):
"""Initialize KubernetesCommandRunner.
Example Usage:
runner = KubernetesCommandRunner((namespace, pod_name))
runner.run('ls -l')
runner.rsync(source, target, up=True)
Args:
node: The namespace and pod_name of the remote machine.
"""
del kwargs
super().__init__(node)
self.namespace, self.pod_name = node

# @timeline.event
def run(
self,
cmd: Union[str, List[str]],
*,
port_forward: Optional[List[int]] = None,
require_outputs: bool = False,
# Advanced options.
log_path: str = os.devnull,
# If False, do not redirect stdout/stderr to optimize performance.
process_stream: bool = True,
stream_logs: bool = True,
ssh_mode: SshMode = SshMode.NON_INTERACTIVE,
separate_stderr: bool = False,
connect_timeout: Optional[int] = None,
source_bashrc: bool = False,
skip_lines: int = 0,
**kwargs) -> Union[int, Tuple[int, str, str]]:
"""Uses 'kubectl exec' to run 'cmd' on a pod by its name and namespace.
Args:
cmd: The command to run.
port_forward: This should be None for k8s.
Advanced options:
require_outputs: Whether to return the stdout/stderr of the command.
log_path: Redirect stdout/stderr to the log_path.
stream_logs: Stream logs to the stdout/stderr.
check: Check the success of the command.
ssh_mode: The mode to use for ssh.
See SSHMode for more details.
separate_stderr: Whether to separate stderr from stdout.
connect_timeout: timeout in seconds for the pod connection.
source_bashrc: Whether to source the bashrc before running the
command.
skip_lines: The number of lines to skip at the beginning of the
output. This is used when the output is not processed by
SkyPilot but we still want to get rid of some warning messages,
such as SSH warnings.
Returns:
returncode
or
A tuple of (returncode, stdout, stderr).
"""
# TODO(zhwu): implement port_forward for k8s.
assert port_forward is None, ('port_forward is not supported for k8s '
f'for now, but got: {port_forward}')
if connect_timeout is None:
connect_timeout = _DEFAULT_CONNECT_TIMEOUT
kubectl_args = [
'--pod-running-timeout', f'{connect_timeout}s', '-n',
self.namespace, self.pod_name
]
if ssh_mode == SshMode.LOGIN:
assert isinstance(cmd, list), 'cmd must be a list for login mode.'
base_cmd = ['kubectl', 'exec', '-it', *kubectl_args, '--']
command = base_cmd + cmd
proc = subprocess_utils.run(command, shell=False, check=False)
return proc.returncode, '', ''

kubectl_base_command = ['kubectl', 'exec']

if ssh_mode == SshMode.INTERACTIVE:
kubectl_base_command.append('-i')
kubectl_base_command += [*kubectl_args, '--']

command_str = self._get_command_to_run(cmd,
process_stream,
separate_stderr,
skip_lines=skip_lines,
source_bashrc=source_bashrc)
command = kubectl_base_command + [
# It is important to use /bin/bash -c here to make sure we quote the
# command to be run properly. Otherwise, directly appending commands
# after '--' will not work for some commands, such as '&&', '>' etc.
'/bin/bash',
'-c',
shlex.quote(command_str)
]

log_dir = os.path.expanduser(os.path.dirname(log_path))
os.makedirs(log_dir, exist_ok=True)

executable = None
if not process_stream:
if stream_logs:
command += [
f'| tee {log_path}',
# This also requires the executor to be '/bin/bash' instead
# of the default '/bin/sh'.
'; exit ${PIPESTATUS[0]}'
]
else:
command += [f'> {log_path}']
executable = '/bin/bash'
return log_lib.run_with_log(' '.join(command),
log_path,
require_outputs=require_outputs,
stream_logs=stream_logs,
process_stream=process_stream,
shell=True,
executable=executable,
**kwargs)

# @timeline.event
def rsync(
self,
source: str,
target: str,
*,
up: bool,
# Advanced options.
log_path: str = os.devnull,
stream_logs: bool = True,
max_retry: int = 1,
) -> None:
"""Uses 'rsync' to sync 'source' to 'target'.
Args:
source: The source path.
target: The target path.
up: The direction of the sync, True for local to cluster, False
for cluster to local.
log_path: Redirect stdout/stderr to the log_path.
stream_logs: Stream logs to the stdout/stderr.
max_retry: The maximum number of retries for the rsync command.
This value should be non-negative.
Raises:
exceptions.CommandError: rsync command failed.
"""

def get_remote_home_dir() -> str:
# Use `echo ~` to get the remote home directory, instead of pwd or
# echo $HOME, because pwd can be `/` when the remote user is root
# and $HOME is not always set.
rc, remote_home_dir, stderr = self.run('echo ~',
require_outputs=True,
separate_stderr=True,
stream_logs=False)
if rc != 0:
raise ValueError('Failed to get remote home directory: '
f'{remote_home_dir + stderr}')
remote_home_dir = remote_home_dir.strip()
return remote_home_dir

# Build command.
helper_path = os.path.join(os.path.abspath(os.path.dirname(__file__)),
'kubernetes', 'rsync_helper.sh')
self._rsync(
source,
target,
node_destination=f'{self.pod_name}@{self.namespace}',
up=up,
rsh_option=helper_path,
log_path=log_path,
stream_logs=stream_logs,
max_retry=max_retry,
prefix_command=f'chmod +x {helper_path} && ',
# rsync with `kubectl` as the rsh command will cause ~/xx parsed as
# /~/xx, so we need to replace ~ with the remote home directory. We
# only need to do this when ~ is at the beginning of the path.
get_remote_home_dir=get_remote_home_dir)
Loading

0 comments on commit 246f0a1

Please sign in to comment.