Skip to content

Commit

Permalink
Introduce cluster.a_up (#1247)
Browse files Browse the repository at this point in the history
  • Loading branch information
dongreenberg authored Sep 11, 2024
1 parent a7fa9cd commit 0a6a31f
Show file tree
Hide file tree
Showing 3 changed files with 100 additions and 2 deletions.
36 changes: 35 additions & 1 deletion runhouse/resources/hardware/on_demand_cluster.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import asyncio
import contextlib
import json
import subprocess
import time
import warnings
from concurrent.futures import ProcessPoolExecutor
from pathlib import Path
from typing import Any, Dict, List, Union

Expand All @@ -27,7 +29,11 @@

from runhouse.globals import configs, obj_store, rns_client
from runhouse.logger import get_logger
from runhouse.resources.hardware.utils import ResourceServerStatus, ServerConnectionType
from runhouse.resources.hardware.utils import (
ResourceServerStatus,
ServerConnectionType,
up_cluster_helper,
)

from .cluster import Cluster

Expand Down Expand Up @@ -478,6 +484,34 @@ def num_cpus(self):

return None

async def a_up(self, capture_output: Union[bool, str] = True):
"""Up the cluster async in another process, so it can be parallelized and logs can be captured sanely.
capture_output: If True, supress the output of the cluster creation process. If False, print the output
normally. If a string, write the output to the file at that path.
"""

with ProcessPoolExecutor() as executor:
loop = asyncio.get_running_loop()
future = loop.run_in_executor(
executor, up_cluster_helper, self, capture_output
)

# Await the result from the separate process
result = await future
if isinstance(capture_output, str):
with open(capture_output, "w") as f:
f.write(result)

return self

async def a_up_if_not(self, capture_output: Union[bool, str] = True):
if not self.is_up():
# Don't store stale IPs
self.ips = None
await self.a_up(capture_output=capture_output)
return self

def up(self):
"""Up the cluster.
Expand Down
11 changes: 11 additions & 0 deletions runhouse/resources/hardware/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,3 +227,14 @@ def _ssh_base_command(

def _generate_ssh_control_hash(ssh_control_name):
return hashlib.md5(ssh_control_name.encode()).hexdigest()[:_HASH_MAX_LENGTH]


def up_cluster_helper(cluster, suppress_output=True):
from runhouse.utils import SuppressStd

if suppress_output:
with SuppressStd() as outfile:
cluster.up()
return outfile.output
else:
cluster.up()
55 changes: 54 additions & 1 deletion runhouse/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
import contextvars
import functools
import logging
from io import StringIO
import tempfile
from io import SEEK_SET, StringIO

try:
import importlib.metadata as metadata
Expand Down Expand Up @@ -516,6 +517,58 @@ def _path_to_file_by_ext(self, ext: str) -> str:
return path_to_ext


class SuppressStd(object):
"""Context to capture stderr and stdout at C-level."""

def __init__(self, outfile=None):
self.orig_stdout_fileno = sys.__stdout__.fileno()
self.orig_stderr_fileno = sys.__stderr__.fileno()
self.output = None

def __enter__(self):
# Redirect the stdout/stderr fd to temp file
self.orig_stdout_dup = os.dup(self.orig_stdout_fileno)
self.orig_stderr_dup = os.dup(self.orig_stderr_fileno)
self.tfile = tempfile.TemporaryFile(mode="w+b")
os.dup2(self.tfile.fileno(), self.orig_stdout_fileno)
os.dup2(self.tfile.fileno(), self.orig_stderr_fileno)

# Store the stdout object and replace it by the temp file.
self.stdout_obj = sys.stdout
self.stderr_obj = sys.stderr
sys.stdout = sys.__stdout__
sys.stderr = sys.__stderr__

return self

def __exit__(self, exc_class, value, traceback):

# Make sure to flush stdout
print(flush=True)

# Restore the stdout/stderr object.
sys.stdout = self.stdout_obj
sys.stderr = self.stderr_obj

# Close capture file handle
os.close(self.orig_stdout_fileno)
os.close(self.orig_stderr_fileno)

# Restore original stderr and stdout
os.dup2(self.orig_stdout_dup, self.orig_stdout_fileno)
os.dup2(self.orig_stderr_dup, self.orig_stderr_fileno)

# Close duplicate file handle.
os.close(self.orig_stdout_dup)
os.close(self.orig_stderr_dup)

# Copy contents of temporary file to the given stream
self.tfile.flush()
self.tfile.seek(0, SEEK_SET)
self.output = self.tfile.read().decode()
self.tfile.close()


####################################################################################################
# Name generation
####################################################################################################
Expand Down

0 comments on commit 0a6a31f

Please sign in to comment.