Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce cluster.a_up #1247

Merged
merged 1 commit into from
Sep 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 35 additions & 1 deletion runhouse/resources/hardware/on_demand_cluster.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import asyncio
import contextlib
import json
import subprocess
import time
import warnings
from concurrent.futures import ProcessPoolExecutor
from pathlib import Path
from typing import Any, Dict, List, Union

Expand All @@ -27,7 +29,11 @@

from runhouse.globals import configs, obj_store, rns_client
from runhouse.logger import get_logger
from runhouse.resources.hardware.utils import ResourceServerStatus, ServerConnectionType
from runhouse.resources.hardware.utils import (
ResourceServerStatus,
ServerConnectionType,
up_cluster_helper,
)

from .cluster import Cluster

Expand Down Expand Up @@ -478,6 +484,34 @@ def num_cpus(self):

return None

async def a_up(self, capture_output: Union[bool, str] = True):
"""Up the cluster async in another process, so it can be parallelized and logs can be captured sanely.

capture_output: If True, supress the output of the cluster creation process. If False, print the output
normally. If a string, write the output to the file at that path.
"""

with ProcessPoolExecutor() as executor:
loop = asyncio.get_running_loop()
future = loop.run_in_executor(
executor, up_cluster_helper, self, capture_output
)

# Await the result from the separate process
result = await future
if isinstance(capture_output, str):
with open(capture_output, "w") as f:
f.write(result)

return self

async def a_up_if_not(self, capture_output: Union[bool, str] = True):
if not self.is_up():
# Don't store stale IPs
self.ips = None
await self.a_up(capture_output=capture_output)
return self

def up(self):
"""Up the cluster.

Expand Down
11 changes: 11 additions & 0 deletions runhouse/resources/hardware/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,3 +227,14 @@ def _ssh_base_command(

def _generate_ssh_control_hash(ssh_control_name):
return hashlib.md5(ssh_control_name.encode()).hexdigest()[:_HASH_MAX_LENGTH]


def up_cluster_helper(cluster, suppress_output=True):
from runhouse.utils import SuppressStd

if suppress_output:
with SuppressStd() as outfile:
cluster.up()
return outfile.output
else:
cluster.up()
55 changes: 54 additions & 1 deletion runhouse/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
import contextvars
import functools
import logging
from io import StringIO
import tempfile
from io import SEEK_SET, StringIO

try:
import importlib.metadata as metadata
Expand Down Expand Up @@ -516,6 +517,58 @@ def _path_to_file_by_ext(self, ext: str) -> str:
return path_to_ext


class SuppressStd(object):
"""Context to capture stderr and stdout at C-level."""

def __init__(self, outfile=None):
self.orig_stdout_fileno = sys.__stdout__.fileno()
self.orig_stderr_fileno = sys.__stderr__.fileno()
self.output = None

def __enter__(self):
# Redirect the stdout/stderr fd to temp file
self.orig_stdout_dup = os.dup(self.orig_stdout_fileno)
self.orig_stderr_dup = os.dup(self.orig_stderr_fileno)
self.tfile = tempfile.TemporaryFile(mode="w+b")
os.dup2(self.tfile.fileno(), self.orig_stdout_fileno)
os.dup2(self.tfile.fileno(), self.orig_stderr_fileno)

# Store the stdout object and replace it by the temp file.
self.stdout_obj = sys.stdout
self.stderr_obj = sys.stderr
sys.stdout = sys.__stdout__
sys.stderr = sys.__stderr__

return self

def __exit__(self, exc_class, value, traceback):

# Make sure to flush stdout
print(flush=True)

# Restore the stdout/stderr object.
sys.stdout = self.stdout_obj
sys.stderr = self.stderr_obj

# Close capture file handle
os.close(self.orig_stdout_fileno)
os.close(self.orig_stderr_fileno)

# Restore original stderr and stdout
os.dup2(self.orig_stdout_dup, self.orig_stdout_fileno)
os.dup2(self.orig_stderr_dup, self.orig_stderr_fileno)

# Close duplicate file handle.
os.close(self.orig_stdout_dup)
os.close(self.orig_stderr_dup)

# Copy contents of temporary file to the given stream
self.tfile.flush()
self.tfile.seek(0, SEEK_SET)
self.output = self.tfile.read().decode()
self.tfile.close()


####################################################################################################
# Name generation
####################################################################################################
Expand Down
Loading