diff --git a/examples/airflow-torch-training/callables.py b/examples/airflow-torch-training/callables.py new file mode 100644 index 000000000..c9d8c9a4e --- /dev/null +++ b/examples/airflow-torch-training/callables.py @@ -0,0 +1,58 @@ +import runhouse as rh +from runhouse.logger import get_logger + +from torch_example_for_airflow import DownloadData, SimpleTrainer + + +logger = get_logger(name=__name__) + + +def bring_up_cluster_callable(**kwargs): + logger.info("Connecting to remote cluster") + rh.ondemand_cluster( + name="a10g-cluster", instance_type="A10G:1", provider="aws" + ).up_if_not() + # cluster.save() ## Use if you have a Runhouse Den account to save and monitor the resource. + + +def access_data_callable(**kwargs): + logger.info("Step 2: Access data") + env = rh.env(name="test_env", reqs=["torch", "torchvision"]) + + cluster = rh.cluster(name="a10g-cluster").up_if_not() + remote_download = rh.function(DownloadData).to(cluster, env=env) + logger.info("Download function sent to remote") + remote_download() + logger.info("Downloaded") + + +def train_model_callable(**kwargs): + cluster = rh.cluster(name="a10g-cluster").up_if_not() + + env = rh.env(name="test_env", reqs=["torch", "torchvision"]) + + remote_torch_example = rh.module(SimpleTrainer).to( + cluster, env=env, name="torch-basic-training" + ) + + model = remote_torch_example() + + batch_size = 64 + epochs = 5 + learning_rate = 0.01 + + model.load_train("./data", batch_size) + model.load_test("./data", batch_size) + + for epoch in range(epochs): + model.train_model(learning_rate=learning_rate) + model.test_model() + model.save_model( + bucket_name="my-simple-torch-model-example", + s3_file_path=f"checkpoints/model_epoch_{epoch + 1}.pth", + ) + + +def down_cluster(**kwargs): + cluster = rh.cluster(name="a10g-cluster") + cluster.teardown() diff --git a/runhouse/logger.py b/runhouse/logger.py index 432c80f77..ade888c31 100644 --- a/runhouse/logger.py +++ b/runhouse/logger.py @@ -1,148 +1,25 @@ import logging -import logging.config -import re -from datetime import datetime, timezone -from typing import List, Union +import os -from runhouse.constants import DEFAULT_LOG_LEVEL +def get_logger(name: str = __name__): + logger = logging.getLogger(name) -class ColoredFormatter: - COLORS = { - "black": "\u001b[30m", - "red": "\u001b[31m", - "green": "\u001b[32m", - "yellow": "\u001b[33m", - "blue": "\u001b[34m", - "magenta": "\u001b[35m", - "cyan": "\u001b[36m", - "white": "\u001b[37m", - "reset": "\u001b[0m", - } + level = os.getenv("RH_LOG_LEVEL") + if level: + # Set the logging level + logger.setLevel(level.upper()) - @classmethod - def get_color(cls, color: str): - return cls.COLORS.get(color, "") - - # TODO: This method is a temp solution, until we'll update logging architecture. Remove once logging is cleaned up. - @classmethod - def format_log(cls, text): - ansi_escape = re.compile(r"(?:\x1B[@-_][0-?]*[ -/]*[@-~])") - return ansi_escape.sub("", text) - - -class ClusterLogsFormatter: - def __init__(self, system): - self.system = system - self._display_title = False - - def format(self, output_type): - from runhouse import Resource - from runhouse.servers.http.http_utils import OutputType - - system_color = ColoredFormatter.get_color("cyan") - reset_color = ColoredFormatter.get_color("reset") - - prettify_logs = output_type in [ - OutputType.STDOUT, - OutputType.EXCEPTION, - OutputType.STDERR, - ] - - if ( - isinstance(self.system, Resource) - and prettify_logs - and not self._display_title - ): - # Display the system name before subsequent logs only once - system_name = self.system.name - dotted_line = "-" * len(system_name) - print(dotted_line) - print(f"{system_color}{system_name}{reset_color}") - print(dotted_line) - - # Only display the system name once - self._display_title = True - - return system_color, reset_color - - -class FunctionLogHandler(logging.Handler): - def __init__(self): - super().__init__() - self.log_records = [] - - def emit(self, record): - self.log_records.append(record) - - @staticmethod - def log_records_to_stdout(log_records: List[logging.LogRecord]) -> str: - """Convert the log records to a string repr of the stdout output""" - captured_logs = [ - f"{log_record.levelname} | {log_record.asctime} | {log_record.msg}" - for log_record in log_records - ] - return "\n".join(captured_logs) - - -class UTCFormatter(logging.Formatter): - """Ensure logs are always in UTC time""" - - @staticmethod - def converter(timestamp): - return datetime.fromtimestamp(timestamp, tz=timezone.utc) + # Apply a custom formatter + formatter = logging.Formatter( + fmt="%(levelname)s | %(asctime)s | %(message)s", datefmt="%Y-%m-%d %H:%M:%S" + ) - def formatTime(self, record, datefmt=None): - dt = self.converter(record.created) - if datefmt: - return dt.strftime(datefmt) - else: - return dt.isoformat(timespec="milliseconds") + # Apply the formatter to each handler + for handler in logger.handlers: + handler.setFormatter(formatter) + # Prevent the logger from propagating to the root logger + logger.propagate = False -def get_logger(name: str = __name__, log_level: Union[int, str] = logging.INFO): - level_name = ( - logging.getLevelName(log_level) if isinstance(log_level, int) else log_level - ) - LOGGING_CONFIG = { - "version": 1, - "disable_existing_loggers": True, - "formatters": { - "utc_formatter": { - "()": UTCFormatter, - "format": "%(levelname)s | %(asctime)s | %(message)s", - "datefmt": "%Y-%m-%d %H:%M:%S.%f", - }, - }, - "handlers": { - "default": { - "level": level_name, - "formatter": "utc_formatter", - "class": "logging.StreamHandler", - "stream": "ext://sys.stderr", # Default is stderr - }, - }, - "loggers": { - "": { # root logger - "handlers": ["default"], - "level": level_name, - "propagate": False, - }, - "my.packg": { - "handlers": ["default"], - "level": level_name, - "propagate": False, - }, - "__main__": { # if __name__ == '__main__' - "handlers": ["default"], - "level": level_name, - "propagate": False, - }, - }, - } - logging.config.dictConfig(LOGGING_CONFIG) - logger = logging.getLogger(name=name) return logger - - -logger = get_logger(name=__name__, log_level=DEFAULT_LOG_LEVEL) diff --git a/runhouse/main.py b/runhouse/main.py index 90c09b7a4..69c1d0c03 100644 --- a/runhouse/main.py +++ b/runhouse/main.py @@ -34,18 +34,21 @@ START_SCREEN_CMD, ) from runhouse.globals import obj_store, rns_client -from runhouse.logger import logger +from runhouse.logger import get_logger from runhouse.resources.hardware.ray_utils import ( check_for_existing_ray_instance, kill_actors, ) + # create an explicit Typer application app = typer.Typer(add_completion=False) # For printing with typer console = Console() +logger = get_logger(name=__name__) + @app.command() def login( diff --git a/runhouse/resources/envs/conda_env.py b/runhouse/resources/envs/conda_env.py index 274e5bab9..c8e5b2c90 100644 --- a/runhouse/resources/envs/conda_env.py +++ b/runhouse/resources/envs/conda_env.py @@ -7,14 +7,15 @@ from runhouse.constants import ENVS_DIR from runhouse.globals import obj_store +from runhouse.logger import get_logger -from runhouse.logger import logger from runhouse.resources.envs.utils import install_conda, run_setup_command - from runhouse.resources.packages import Package from .env import Env +logger = get_logger(name=__name__) + class CondaEnv(Env): RESOURCE_TYPE = "env" diff --git a/runhouse/resources/envs/env.py b/runhouse/resources/envs/env.py index c18b978e6..384f43d22 100644 --- a/runhouse/resources/envs/env.py +++ b/runhouse/resources/envs/env.py @@ -1,19 +1,20 @@ import copy -import logging import os import shlex from pathlib import Path from typing import Dict, List, Optional, Union from runhouse.globals import obj_store +from runhouse.logger import get_logger -from runhouse.logger import logger from runhouse.resources.envs.utils import _process_env_vars, run_setup_command from runhouse.resources.hardware import _get_cluster_from, Cluster from runhouse.resources.packages import InstallTarget, Package from runhouse.resources.resource import Resource from runhouse.utils import run_with_logs +logger = get_logger(name=__name__) + class Env(Resource): RESOURCE_TYPE = "env" diff --git a/runhouse/resources/folders/folder.py b/runhouse/resources/folders/folder.py index 0c71b1996..2736fa7ce 100644 --- a/runhouse/resources/folders/folder.py +++ b/runhouse/resources/folders/folder.py @@ -7,13 +7,15 @@ from typing import List, Optional, Union from runhouse.globals import rns_client +from runhouse.logger import get_logger -from runhouse.logger import logger from runhouse.resources.hardware import _current_cluster, _get_cluster_from, Cluster from runhouse.resources.resource import Resource from runhouse.rns.utils.api import generate_uuid, relative_file_path from runhouse.utils import locate_working_dir +logger = get_logger(name=__name__) + class Folder(Resource): RESOURCE_TYPE = "folder" diff --git a/runhouse/resources/folders/folder_factory.py b/runhouse/resources/folders/folder_factory.py index 77cb800ca..b977f7d24 100644 --- a/runhouse/resources/folders/folder_factory.py +++ b/runhouse/resources/folders/folder_factory.py @@ -1,11 +1,13 @@ from pathlib import Path from typing import Optional, Union -from runhouse.logger import logger +from runhouse.logger import get_logger from runhouse.resources.folders.folder import Folder from runhouse.resources.hardware.utils import _get_cluster_from +logger = get_logger(name=__name__) + def folder( name: Optional[str] = None, diff --git a/runhouse/resources/folders/gcs_folder.py b/runhouse/resources/folders/gcs_folder.py index c0351fbd4..0a1ebcdc8 100644 --- a/runhouse/resources/folders/gcs_folder.py +++ b/runhouse/resources/folders/gcs_folder.py @@ -4,10 +4,12 @@ from pathlib import Path from typing import List, Optional -from runhouse.logger import logger +from runhouse.logger import get_logger from .folder import Folder +logger = get_logger(name=__name__) + class GCSFolder(Folder): RESOURCE_TYPE = "folder" diff --git a/runhouse/resources/folders/s3_folder.py b/runhouse/resources/folders/s3_folder.py index 8222fab2e..33f5a2f8d 100644 --- a/runhouse/resources/folders/s3_folder.py +++ b/runhouse/resources/folders/s3_folder.py @@ -5,7 +5,7 @@ from pathlib import Path from typing import List, Optional -from runhouse.logger import logger +from runhouse.logger import get_logger from .folder import Folder @@ -13,6 +13,8 @@ POLL_INTERVAL = 1 TIMEOUT_SECONDS = 3600 +logger = get_logger(name=__name__) + class S3Folder(Folder): RESOURCE_TYPE = "folder" diff --git a/runhouse/resources/functions/aws_lambda.py b/runhouse/resources/functions/aws_lambda.py index a91f2c57f..c74de599e 100644 --- a/runhouse/resources/functions/aws_lambda.py +++ b/runhouse/resources/functions/aws_lambda.py @@ -17,16 +17,16 @@ pass from runhouse.globals import rns_client - -from runhouse.logger import logger +from runhouse.logger import get_logger from runhouse.resources.envs import _get_env_from, Env - from runhouse.resources.functions.function import Function CRED_PATH = f"{Path.home()}/.aws/credentials" LOG_GROUP_PREFIX = "/aws/lambda/" +logger = get_logger(name=__name__) + class LambdaFunction(Function): RESOURCE_TYPE = "lambda_function" diff --git a/runhouse/resources/functions/function.py b/runhouse/resources/functions/function.py index 43f304a83..907f804ee 100644 --- a/runhouse/resources/functions/function.py +++ b/runhouse/resources/functions/function.py @@ -1,17 +1,18 @@ import inspect -import logging from pathlib import Path from typing import Any, List, Optional, Tuple, Union from runhouse import globals +from runhouse.logger import get_logger -from runhouse.logger import logger from runhouse.resources.envs import Env from runhouse.resources.hardware import Cluster from runhouse.resources.module import Module from runhouse.resources.resource import Resource +logger = get_logger(name=__name__) + class Function(Module): RESOURCE_TYPE = "function" diff --git a/runhouse/resources/functions/function_factory.py b/runhouse/resources/functions/function_factory.py index ad849706f..ad015cfda 100644 --- a/runhouse/resources/functions/function_factory.py +++ b/runhouse/resources/functions/function_factory.py @@ -2,12 +2,14 @@ from pathlib import Path from typing import Callable, List, Optional, Union -from runhouse.logger import logger +from runhouse.logger import get_logger from runhouse.resources.envs import _get_env_from, Env from runhouse.resources.functions.function import Function from runhouse.resources.packages import git_package +logger = get_logger(name=__name__) + def function( fn: Optional[Union[str, Callable]] = None, diff --git a/runhouse/resources/hardware/cluster.py b/runhouse/resources/hardware/cluster.py index 4554e0be2..e7e5fae18 100644 --- a/runhouse/resources/hardware/cluster.py +++ b/runhouse/resources/hardware/cluster.py @@ -43,8 +43,8 @@ RESERVED_SYSTEM_NAMES, ) from runhouse.globals import configs, obj_store, rns_client +from runhouse.logger import get_logger -from runhouse.logger import logger from runhouse.resources.envs.utils import _get_env_from from runhouse.resources.hardware.utils import ( _current_cluster, @@ -56,6 +56,8 @@ from runhouse.servers.http import HTTPClient +logger = get_logger(name=__name__) + class Cluster(Resource): RESOURCE_TYPE = "cluster" diff --git a/runhouse/resources/hardware/cluster_factory.py b/runhouse/resources/hardware/cluster_factory.py index a4a3a9557..a44454aba 100644 --- a/runhouse/resources/hardware/cluster_factory.py +++ b/runhouse/resources/hardware/cluster_factory.py @@ -7,7 +7,7 @@ from runhouse.constants import DEFAULT_SERVER_PORT, LOCAL_HOSTS, RESERVED_SYSTEM_NAMES from runhouse.globals import rns_client -from runhouse.logger import logger +from runhouse.logger import get_logger from runhouse.resources.hardware.utils import ServerConnectionType from runhouse.rns.utils.api import relative_file_path @@ -15,6 +15,7 @@ from .on_demand_cluster import OnDemandCluster from .sagemaker.sagemaker_cluster import SageMakerCluster +logger = get_logger(name=__name__) # Cluster factory method def cluster( diff --git a/runhouse/resources/hardware/on_demand_cluster.py b/runhouse/resources/hardware/on_demand_cluster.py index af64dc3c9..e4c3181e2 100644 --- a/runhouse/resources/hardware/on_demand_cluster.py +++ b/runhouse/resources/hardware/on_demand_cluster.py @@ -26,12 +26,13 @@ ) from runhouse.globals import configs, obj_store, rns_client - -from runhouse.logger import logger +from runhouse.logger import get_logger from runhouse.resources.hardware.utils import ResourceServerStatus, ServerConnectionType from .cluster import Cluster +logger = get_logger(name=__name__) + class OnDemandCluster(Cluster): RESOURCE_TYPE = "cluster" diff --git a/runhouse/resources/hardware/ray_utils.py b/runhouse/resources/hardware/ray_utils.py index 5dfcbfbf7..23451bc6a 100644 --- a/runhouse/resources/hardware/ray_utils.py +++ b/runhouse/resources/hardware/ray_utils.py @@ -4,7 +4,9 @@ import ray from ray.experimental.state.api import list_actors -from runhouse.logger import logger +from runhouse.logger import get_logger + +logger = get_logger(name=__name__) def check_for_existing_ray_instance(): diff --git a/runhouse/resources/hardware/sagemaker/sagemaker_cluster.py b/runhouse/resources/hardware/sagemaker/sagemaker_cluster.py index fda2e035b..e398260d1 100644 --- a/runhouse/resources/hardware/sagemaker/sagemaker_cluster.py +++ b/runhouse/resources/hardware/sagemaker/sagemaker_cluster.py @@ -32,8 +32,8 @@ from runhouse.constants import LOCAL_HOSTS from runhouse.globals import configs, rns_client +from runhouse.logger import get_logger -from runhouse.logger import logger from runhouse.resources.hardware.cluster import Cluster from runhouse.resources.hardware.utils import ServerConnectionType from runhouse.rns.utils.api import ( @@ -44,6 +44,7 @@ from runhouse.utils import generate_default_name +logger = get_logger(name=__name__) #################################################################################################### # Caching mechanisms for SSHTunnelForwarder #################################################################################################### diff --git a/runhouse/resources/hardware/sky/command_runner.py b/runhouse/resources/hardware/sky/command_runner.py index f47e905e2..8bd2f7a8d 100644 --- a/runhouse/resources/hardware/sky/command_runner.py +++ b/runhouse/resources/hardware/sky/command_runner.py @@ -20,9 +20,9 @@ ) ##### RH modification ##### -import logging -from runhouse.logger import logger +from runhouse.logger import get_logger +logger = get_logger(name=__name__) ##### RH modification ##### diff --git a/runhouse/resources/hardware/sky/subprocess_utils.py b/runhouse/resources/hardware/sky/subprocess_utils.py index fd6238b5e..4a823b4af 100644 --- a/runhouse/resources/hardware/sky/subprocess_utils.py +++ b/runhouse/resources/hardware/sky/subprocess_utils.py @@ -3,9 +3,9 @@ import psutil from typing import Callable, List, Optional, Union -import logging -from runhouse.logger import logger +from runhouse.logger import get_logger +logger = get_logger(name=__name__) class CommandError(Exception): """Raised when a command fails. diff --git a/runhouse/resources/hardware/sky_ssh_runner.py b/runhouse/resources/hardware/sky_ssh_runner.py index 35f593843..f3eba5c8b 100644 --- a/runhouse/resources/hardware/sky_ssh_runner.py +++ b/runhouse/resources/hardware/sky_ssh_runner.py @@ -1,4 +1,5 @@ import copy +import logging import os import pathlib import shlex @@ -9,8 +10,7 @@ from runhouse.constants import DEFAULT_DOCKER_CONTAINER_NAME, LOCALHOST, TUNNEL_TIMEOUT from runhouse.globals import sky_ssh_runner_cache - -from runhouse.logger import logger +from runhouse.logger import get_logger from runhouse.resources.hardware.sky import common_utils, log_lib, subprocess_utils @@ -25,10 +25,10 @@ ) +logger = get_logger(name=__name__) + # Get rid of the constant "Found credentials in shared credentials file: ~/.aws/credentials" message try: - import logging - import boto3 boto3.set_stream_logger(name="botocore.credentials", level=logging.ERROR) diff --git a/runhouse/resources/module.py b/runhouse/resources/module.py index 693b0bcee..b8b39efff 100644 --- a/runhouse/resources/module.py +++ b/runhouse/resources/module.py @@ -13,7 +13,7 @@ from pydantic import create_model from runhouse.globals import obj_store, rns_client -from runhouse.logger import logger +from runhouse.logger import get_logger from runhouse.resources.envs import _get_env_from, Env from runhouse.resources.hardware import ( _current_cluster, @@ -52,6 +52,8 @@ "_dumb_signature_cache", ] +logger = get_logger(name=__name__) + class Module(Resource): RESOURCE_TYPE = "module" diff --git a/runhouse/resources/packages/package.py b/runhouse/resources/packages/package.py index 5351fdbb4..960fe0fb2 100644 --- a/runhouse/resources/packages/package.py +++ b/runhouse/resources/packages/package.py @@ -5,6 +5,8 @@ from pathlib import Path from typing import Dict, Optional, Union +from runhouse.logger import get_logger + from runhouse.resources.envs.utils import install_conda, run_setup_command from runhouse.resources.hardware.cluster import Cluster from runhouse.resources.hardware.utils import ( @@ -22,7 +24,7 @@ INSTALL_METHODS = {"local", "reqs", "pip", "conda", "rh"} -from runhouse.logger import logger +logger = get_logger(name=__name__) class CodeSyncError(Exception): diff --git a/runhouse/resources/provenance.py b/runhouse/resources/provenance.py index 464b9b8df..570dff16f 100644 --- a/runhouse/resources/provenance.py +++ b/runhouse/resources/provenance.py @@ -9,6 +9,7 @@ from runhouse.constants import LOGS_DIR from runhouse.globals import configs, rns_client +from runhouse.logger import get_logger from runhouse.resources.blobs import file # Need to alias so it doesn't conflict with the folder property @@ -19,8 +20,7 @@ from runhouse.rns.utils.api import log_timestamp, resolve_absolute_path from runhouse.utils import StreamTee -# Load the root logger -logger = logging.getLogger("") +logger = get_logger(name=__name__) class RunStatus(str, Enum): diff --git a/runhouse/resources/resource.py b/runhouse/resources/resource.py index 44008a6b5..eb9456e7b 100644 --- a/runhouse/resources/resource.py +++ b/runhouse/resources/resource.py @@ -5,8 +5,8 @@ from typing import Dict, List, Optional, Tuple, Union from runhouse.globals import obj_store, rns_client +from runhouse.logger import get_logger -from runhouse.logger import logger from runhouse.rns.top_level_rns_fns import ( resolve_rns_path, save, @@ -19,6 +19,8 @@ ResourceVisibility, ) +logger = get_logger(name=__name__) + class Resource: RESOURCE_TYPE = "resource" diff --git a/runhouse/resources/secrets/provider_secrets/ssh_secret.py b/runhouse/resources/secrets/provider_secrets/ssh_secret.py index ebab04847..895e05acf 100644 --- a/runhouse/resources/secrets/provider_secrets/ssh_secret.py +++ b/runhouse/resources/secrets/provider_secrets/ssh_secret.py @@ -5,11 +5,13 @@ from typing import Any, Dict, Optional, Union from runhouse.globals import rns_client -from runhouse.logger import logger +from runhouse.logger import get_logger from runhouse.resources.blobs.file import File from runhouse.resources.hardware.cluster import Cluster from runhouse.resources.secrets.provider_secrets.provider_secret import ProviderSecret +logger = get_logger(name=__name__) + class SSHSecret(ProviderSecret): """ diff --git a/runhouse/resources/secrets/secret.py b/runhouse/resources/secrets/secret.py index 1a30e255e..b88b10779 100644 --- a/runhouse/resources/secrets/secret.py +++ b/runhouse/resources/secrets/secret.py @@ -6,14 +6,16 @@ from typing import Dict, List, Optional, Union from runhouse.globals import configs, rns_client +from runhouse.logger import get_logger -from runhouse.logger import logger from runhouse.resources.hardware import _get_cluster_from, Cluster from runhouse.resources.resource import Resource from runhouse.resources.secrets.utils import _delete_vault_secrets, load_config from runhouse.rns.utils.api import load_resp_content, read_resp_data from runhouse.utils import generate_default_name +logger = get_logger(name=__name__) + class Secret(Resource): RESOURCE_TYPE = "secret" diff --git a/runhouse/resources/secrets/utils.py b/runhouse/resources/secrets/utils.py index 439ed764b..d8dca10e4 100644 --- a/runhouse/resources/secrets/utils.py +++ b/runhouse/resources/secrets/utils.py @@ -4,14 +4,16 @@ from typing import Dict, Optional, Union from runhouse.globals import rns_client - -from runhouse.logger import logger +from runhouse.logger import get_logger from runhouse.rns.utils.api import load_resp_content, read_resp_data USER_ENDPOINT = "user/secret" +logger = get_logger(name=__name__) + + def load_config(name: str, endpoint: str = USER_ENDPOINT): if "/" not in name: name = f"{rns_client.current_folder}/{name}" diff --git a/runhouse/rns/defaults.py b/runhouse/rns/defaults.py index 39f341023..850216f7f 100644 --- a/runhouse/rns/defaults.py +++ b/runhouse/rns/defaults.py @@ -10,13 +10,15 @@ import requests import yaml -from runhouse.logger import logger +from runhouse.logger import get_logger from runhouse.rns.utils.api import read_resp_data, to_bool req_ctx = contextvars.ContextVar("rh_ctx", default={}) +logger = get_logger(name=__name__) + class Defaults: """Class to handle defaults for Runhouse. Defaults are stored in a json file in the user's home directory.""" diff --git a/runhouse/rns/login.py b/runhouse/rns/login.py index 3149ebb49..018a97338 100644 --- a/runhouse/rns/login.py +++ b/runhouse/rns/login.py @@ -2,12 +2,12 @@ from typing import Dict, Optional import requests - import typer from runhouse.globals import configs, rns_client +from runhouse.logger import get_logger -from runhouse.logger import logger +logger = get_logger(name=__name__) def is_interactive(): diff --git a/runhouse/rns/rns_client.py b/runhouse/rns/rns_client.py index 7545149ae..3e077a948 100644 --- a/runhouse/rns/rns_client.py +++ b/runhouse/rns/rns_client.py @@ -11,7 +11,7 @@ import requests from pydantic import BaseModel -from runhouse.logger import logger +from runhouse.logger import get_logger from runhouse.rns.utils.api import ( generate_uuid, @@ -23,6 +23,8 @@ from runhouse.utils import locate_working_dir +logger = get_logger(name=__name__) + # This is a copy of the Pydantic model that we use to validate in Den class ResourceStatusData(BaseModel): diff --git a/runhouse/rns/top_level_rns_fns.py b/runhouse/rns/top_level_rns_fns.py index 80a259b21..8d92891ed 100644 --- a/runhouse/rns/top_level_rns_fns.py +++ b/runhouse/rns/top_level_rns_fns.py @@ -5,12 +5,12 @@ from runhouse.constants import EMPTY_DEFAULT_ENV_NAME from runhouse.globals import configs, obj_store, rns_client - +from runhouse.logger import get_logger from runhouse.servers.obj_store import ClusterServletSetupOption -# Configure the logger once -logging.getLogger("numexpr").setLevel(logging.WARNING) +logger = get_logger(name=__name__) +logging.getLogger("numexpr").setLevel(logging.WARNING) collect_data: bool = configs.data_collection_enabled() if collect_data: diff --git a/runhouse/servers/autostop_helper.py b/runhouse/servers/autostop_helper.py index 389225ab0..c4f06954c 100644 --- a/runhouse/servers/autostop_helper.py +++ b/runhouse/servers/autostop_helper.py @@ -1,7 +1,9 @@ import shlex import subprocess -from runhouse.logger import logger +from runhouse.logger import get_logger + +logger = get_logger(name=__name__) class AutostopHelper: diff --git a/runhouse/servers/caddy/config.py b/runhouse/servers/caddy/config.py index a1c9c2f9f..3862f1612 100644 --- a/runhouse/servers/caddy/config.py +++ b/runhouse/servers/caddy/config.py @@ -3,9 +3,9 @@ from pathlib import Path from runhouse.constants import DEFAULT_SERVER_PORT +from runhouse.logger import get_logger -from runhouse.logger import logger - +logger = get_logger(name=__name__) SYSTEMCTL_ERROR = "systemctl: command not found" diff --git a/runhouse/servers/cluster_servlet.py b/runhouse/servers/cluster_servlet.py index 6dc3b5a57..dbfcfba3a 100644 --- a/runhouse/servers/cluster_servlet.py +++ b/runhouse/servers/cluster_servlet.py @@ -19,15 +19,16 @@ ) from runhouse.globals import configs, obj_store, rns_client -from runhouse.logger import ColoredFormatter, logger +from runhouse.logger import get_logger from runhouse.resources.hardware import load_cluster_config_from_file from runhouse.resources.hardware.utils import detect_cuda_version_or_cpu from runhouse.rns.rns_client import ResourceStatusData from runhouse.rns.utils.api import ResourceAccess from runhouse.servers.autostop_helper import AutostopHelper from runhouse.servers.http.auth import AuthCache +from runhouse.utils import ColoredFormatter, sync_function -from runhouse.utils import sync_function +logger = get_logger(name=__name__) class ClusterServletError(Exception): diff --git a/runhouse/servers/env_servlet.py b/runhouse/servers/env_servlet.py index e7a24ea8e..c7b8fac9b 100644 --- a/runhouse/servers/env_servlet.py +++ b/runhouse/servers/env_servlet.py @@ -5,7 +5,7 @@ from runhouse.constants import DEFAULT_LOG_LEVEL from runhouse.globals import obj_store -from runhouse.logger import logger +from runhouse.logger import get_logger from runhouse.servers.http.http_utils import ( deserialize_data, @@ -15,9 +15,10 @@ serialize_data, ) from runhouse.servers.obj_store import ClusterServletSetupOption - from runhouse.utils import arun_in_thread, get_node_ip +logger = get_logger(name=__name__) + def error_handling_decorator(func): @wraps(func) diff --git a/runhouse/servers/http/auth.py b/runhouse/servers/http/auth.py index a4500e7f7..b61f34fef 100644 --- a/runhouse/servers/http/auth.py +++ b/runhouse/servers/http/auth.py @@ -1,10 +1,12 @@ from typing import Optional, Union from runhouse.globals import rns_client -from runhouse.logger import logger +from runhouse.logger import get_logger from runhouse.rns.utils.api import load_resp_content, ResourceAccess from runhouse.servers.http.http_utils import username_from_token +logger = get_logger(name=__name__) + class AuthCache: # Maps a user's token to all the resources they have access to diff --git a/runhouse/servers/http/certs.py b/runhouse/servers/http/certs.py index 5306b654d..971526f81 100644 --- a/runhouse/servers/http/certs.py +++ b/runhouse/servers/http/certs.py @@ -8,9 +8,12 @@ from cryptography.hazmat.primitives.asymmetric import rsa from cryptography.x509.oid import NameOID -from runhouse.logger import logger +from runhouse.logger import get_logger + from runhouse.rns.utils.api import resolve_absolute_path +logger = get_logger(name=__name__) + class TLSCertConfig: """Handler for creating and managing the TLS certs needed to enable HTTPS on the Runhouse API server.""" diff --git a/runhouse/servers/http/http_client.py b/runhouse/servers/http/http_client.py index dc3cd70ca..6133461ca 100644 --- a/runhouse/servers/http/http_client.py +++ b/runhouse/servers/http/http_client.py @@ -12,8 +12,7 @@ import requests from runhouse.globals import rns_client - -from runhouse.logger import ClusterLogsFormatter, logger +from runhouse.logger import get_logger from runhouse.resources.envs.utils import _get_env_from @@ -36,13 +35,15 @@ serialize_data, ) -from runhouse.utils import generate_default_name +from runhouse.utils import ClusterLogsFormatter, generate_default_name # Make this global so connections are pooled across instances of HTTPClient session = requests.Session() session.timeout = None +logger = get_logger(name=__name__) + def retry_with_exponential_backoff(func): @wraps(func) diff --git a/runhouse/servers/http/http_server.py b/runhouse/servers/http/http_server.py index 0b1cd5030..e1d49173e 100644 --- a/runhouse/servers/http/http_server.py +++ b/runhouse/servers/http/http_server.py @@ -27,7 +27,7 @@ RH_LOGFILE_PATH, ) from runhouse.globals import configs, obj_store, rns_client -from runhouse.logger import logger +from runhouse.logger import get_logger from runhouse.rns.utils.api import resolve_absolute_path, ResourceAccess from runhouse.servers.caddy.config import CaddyConfig from runhouse.servers.http.auth import averify_cluster_access @@ -69,6 +69,8 @@ app = FastAPI(docs_url=None, redoc_url=None) +logger = get_logger(name=__name__) + def validate_cluster_access(func): """If using Den auth, validate the user's cluster subtoken and access to the cluster before continuing.""" diff --git a/runhouse/servers/http/http_utils.py b/runhouse/servers/http/http_utils.py index 786285f8f..b9dca7be8 100644 --- a/runhouse/servers/http/http_utils.py +++ b/runhouse/servers/http/http_utils.py @@ -13,8 +13,12 @@ from ray import cloudpickle as pickle from ray.exceptions import RayTaskError -from runhouse.logger import ClusterLogsFormatter, logger +from runhouse.logger import get_logger + from runhouse.servers.obj_store import RunhouseStopIteration +from runhouse.utils import ClusterLogsFormatter + +logger = get_logger(name=__name__) class RequestContext(BaseModel): diff --git a/runhouse/servers/obj_store.py b/runhouse/servers/obj_store.py index 859ae8a65..f9b4260c6 100644 --- a/runhouse/servers/obj_store.py +++ b/runhouse/servers/obj_store.py @@ -13,8 +13,7 @@ from pydantic import BaseModel from runhouse.constants import DEFAULT_LOG_LEVEL - -from runhouse.logger import logger +from runhouse.logger import get_logger from runhouse.rns.defaults import req_ctx from runhouse.rns.utils.api import ResourceVisibility @@ -25,6 +24,8 @@ sync_function, ) +logger = get_logger(name=__name__) + class RaySetupOption(str, Enum): GET_OR_FAIL = "get_or_fail" diff --git a/runhouse/utils.py b/runhouse/utils.py index 60c6b04cb..e698d52fe 100644 --- a/runhouse/utils.py +++ b/runhouse/utils.py @@ -31,9 +31,9 @@ import pexpect from runhouse.constants import LOGS_DIR +from runhouse.logger import get_logger -logger = logging.getLogger(__name__) - +logger = get_logger(name=__name__) #################################################################################################### # Python package utilities #################################################################################################### @@ -475,3 +475,66 @@ def generate_default_name(prefix: str = None, precision: str = "s", sep="_") -> if prefix is None: return timestamp_key return f"{prefix}{sep}{timestamp_key}" + + +#################################################################################################### +# Logger utils +#################################################################################################### +class ColoredFormatter: + COLORS = { + "black": "\u001b[30m", + "red": "\u001b[31m", + "green": "\u001b[32m", + "yellow": "\u001b[33m", + "blue": "\u001b[34m", + "magenta": "\u001b[35m", + "cyan": "\u001b[36m", + "white": "\u001b[37m", + "reset": "\u001b[0m", + } + + @classmethod + def get_color(cls, color: str): + return cls.COLORS.get(color, "") + + # TODO: This method is a temp solution, until we'll update logging architecture. Remove once logging is cleaned up. + @classmethod + def format_log(cls, text): + ansi_escape = re.compile(r"(?:\x1B[@-_][0-?]*[ -/]*[@-~])") + return ansi_escape.sub("", text) + + +class ClusterLogsFormatter: + def __init__(self, system): + self.system = system + self._display_title = False + + def format(self, output_type): + from runhouse import Resource + from runhouse.servers.http.http_utils import OutputType + + system_color = ColoredFormatter.get_color("cyan") + reset_color = ColoredFormatter.get_color("reset") + + prettify_logs = output_type in [ + OutputType.STDOUT, + OutputType.EXCEPTION, + OutputType.STDERR, + ] + + if ( + isinstance(self.system, Resource) + and prettify_logs + and not self._display_title + ): + # Display the system name before subsequent logs only once + system_name = self.system.name + dotted_line = "-" * len(system_name) + print(dotted_line) + print(f"{system_color}{system_name}{reset_color}") + print(dotted_line) + + # Only display the system name once + self._display_title = True + + return system_color, reset_color diff --git a/tests/test_obj_store.py b/tests/test_obj_store.py index f20e8088f..842422fcf 100644 --- a/tests/test_obj_store.py +++ b/tests/test_obj_store.py @@ -5,6 +5,7 @@ import pytest import runhouse as rh +from runhouse.logger import get_logger from tests.test_resources.test_modules.test_functions.test_function import ( multiproc_np_sum, @@ -13,8 +14,7 @@ TEMP_FILE = "my_file.txt" TEMP_FOLDER = "~/runhouse-tests" -from runhouse.logger import logger - +logger = get_logger(name=__name__) UNIT = {"cluster": []} LOCAL = { diff --git a/tests/test_performance.py b/tests/test_performance.py index 236a71d56..1dc0e5faa 100644 --- a/tests/test_performance.py +++ b/tests/test_performance.py @@ -3,8 +3,9 @@ import requests from runhouse.globals import rns_client +from runhouse.logger import get_logger -from runhouse.logger import logger +logger = get_logger(name=__name__) def profile(func, reps=10): diff --git a/tests/test_resources/test_clusters/test_on_demand_cluster.py b/tests/test_resources/test_clusters/test_on_demand_cluster.py index 284168050..6a9e58c50 100644 --- a/tests/test_resources/test_clusters/test_on_demand_cluster.py +++ b/tests/test_resources/test_clusters/test_on_demand_cluster.py @@ -8,8 +8,8 @@ import runhouse as rh from runhouse.constants import SERVER_LOGFILE_PATH from runhouse.globals import rns_client -from runhouse.logger import ColoredFormatter from runhouse.resources.hardware.utils import ResourceServerStatus +from runhouse.utils import ColoredFormatter import tests.test_resources.test_clusters.test_cluster from tests.utils import friend_account diff --git a/tests/test_resources/test_modules/test_functions/test_function.py b/tests/test_resources/test_modules/test_functions/test_function.py index d49a0201b..e79e8bafc 100644 --- a/tests/test_resources/test_modules/test_functions/test_function.py +++ b/tests/test_resources/test_modules/test_functions/test_function.py @@ -7,11 +7,12 @@ import runhouse as rh from runhouse.globals import rns_client - -from runhouse.logger import logger +from runhouse.logger import get_logger from tests.utils import friend_account +logger = get_logger(name=__name__) + def get_remote_func_name(test_folder): return f"@/{test_folder}/remote_function" diff --git a/tests/test_resources/test_modules/test_module.py b/tests/test_resources/test_modules/test_module.py index fb1ad0a35..148b7f8f0 100644 --- a/tests/test_resources/test_modules/test_module.py +++ b/tests/test_resources/test_modules/test_module.py @@ -13,9 +13,9 @@ import runhouse as rh from runhouse import Package from runhouse.constants import TEST_ORG +from runhouse.logger import get_logger -from runhouse.logger import logger - +logger = get_logger(name=__name__) """ Tests for runhouse.Module. Structure: - Test call_module_method rpc, with various envs diff --git a/tests/test_servers/conftest.py b/tests/test_servers/conftest.py index 135a2366a..fb836c4fe 100644 --- a/tests/test_servers/conftest.py +++ b/tests/test_servers/conftest.py @@ -9,13 +9,15 @@ import runhouse as rh from runhouse.globals import rns_client +from runhouse.logger import get_logger -from runhouse.logger import logger from runhouse.servers.http.certs import TLSCertConfig from runhouse.servers.http.http_server import app, HTTPServer from tests.utils import friend_account, get_ray_servlet_and_obj_store +logger = get_logger(name=__name__) + # -------- HELPERS ----------- # def summer(a, b): return a + b