Skip to content

Commit

Permalink
use simplified default logger (#1177)
Browse files Browse the repository at this point in the history
  • Loading branch information
jlewitt1 authored Aug 28, 2024
1 parent 768db12 commit c2f7691
Show file tree
Hide file tree
Showing 48 changed files with 269 additions and 211 deletions.
58 changes: 58 additions & 0 deletions examples/airflow-torch-training/callables.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import runhouse as rh
from runhouse.logger import get_logger

from torch_example_for_airflow import DownloadData, SimpleTrainer


logger = get_logger(name=__name__)


def bring_up_cluster_callable(**kwargs):
logger.info("Connecting to remote cluster")
rh.ondemand_cluster(
name="a10g-cluster", instance_type="A10G:1", provider="aws"
).up_if_not()
# cluster.save() ## Use if you have a Runhouse Den account to save and monitor the resource.


def access_data_callable(**kwargs):
logger.info("Step 2: Access data")
env = rh.env(name="test_env", reqs=["torch", "torchvision"])

cluster = rh.cluster(name="a10g-cluster").up_if_not()
remote_download = rh.function(DownloadData).to(cluster, env=env)
logger.info("Download function sent to remote")
remote_download()
logger.info("Downloaded")


def train_model_callable(**kwargs):
cluster = rh.cluster(name="a10g-cluster").up_if_not()

env = rh.env(name="test_env", reqs=["torch", "torchvision"])

remote_torch_example = rh.module(SimpleTrainer).to(
cluster, env=env, name="torch-basic-training"
)

model = remote_torch_example()

batch_size = 64
epochs = 5
learning_rate = 0.01

model.load_train("./data", batch_size)
model.load_test("./data", batch_size)

for epoch in range(epochs):
model.train_model(learning_rate=learning_rate)
model.test_model()
model.save_model(
bucket_name="my-simple-torch-model-example",
s3_file_path=f"checkpoints/model_epoch_{epoch + 1}.pth",
)


def down_cluster(**kwargs):
cluster = rh.cluster(name="a10g-cluster")
cluster.teardown()
155 changes: 16 additions & 139 deletions runhouse/logger.py
Original file line number Diff line number Diff line change
@@ -1,148 +1,25 @@
import logging
import logging.config
import re
from datetime import datetime, timezone
from typing import List, Union
import os

from runhouse.constants import DEFAULT_LOG_LEVEL

def get_logger(name: str = __name__):
logger = logging.getLogger(name)

class ColoredFormatter:
COLORS = {
"black": "\u001b[30m",
"red": "\u001b[31m",
"green": "\u001b[32m",
"yellow": "\u001b[33m",
"blue": "\u001b[34m",
"magenta": "\u001b[35m",
"cyan": "\u001b[36m",
"white": "\u001b[37m",
"reset": "\u001b[0m",
}
level = os.getenv("RH_LOG_LEVEL")
if level:
# Set the logging level
logger.setLevel(level.upper())

@classmethod
def get_color(cls, color: str):
return cls.COLORS.get(color, "")

# TODO: This method is a temp solution, until we'll update logging architecture. Remove once logging is cleaned up.
@classmethod
def format_log(cls, text):
ansi_escape = re.compile(r"(?:\x1B[@-_][0-?]*[ -/]*[@-~])")
return ansi_escape.sub("", text)


class ClusterLogsFormatter:
def __init__(self, system):
self.system = system
self._display_title = False

def format(self, output_type):
from runhouse import Resource
from runhouse.servers.http.http_utils import OutputType

system_color = ColoredFormatter.get_color("cyan")
reset_color = ColoredFormatter.get_color("reset")

prettify_logs = output_type in [
OutputType.STDOUT,
OutputType.EXCEPTION,
OutputType.STDERR,
]

if (
isinstance(self.system, Resource)
and prettify_logs
and not self._display_title
):
# Display the system name before subsequent logs only once
system_name = self.system.name
dotted_line = "-" * len(system_name)
print(dotted_line)
print(f"{system_color}{system_name}{reset_color}")
print(dotted_line)

# Only display the system name once
self._display_title = True

return system_color, reset_color


class FunctionLogHandler(logging.Handler):
def __init__(self):
super().__init__()
self.log_records = []

def emit(self, record):
self.log_records.append(record)

@staticmethod
def log_records_to_stdout(log_records: List[logging.LogRecord]) -> str:
"""Convert the log records to a string repr of the stdout output"""
captured_logs = [
f"{log_record.levelname} | {log_record.asctime} | {log_record.msg}"
for log_record in log_records
]
return "\n".join(captured_logs)


class UTCFormatter(logging.Formatter):
"""Ensure logs are always in UTC time"""

@staticmethod
def converter(timestamp):
return datetime.fromtimestamp(timestamp, tz=timezone.utc)
# Apply a custom formatter
formatter = logging.Formatter(
fmt="%(levelname)s | %(asctime)s | %(message)s", datefmt="%Y-%m-%d %H:%M:%S"
)

def formatTime(self, record, datefmt=None):
dt = self.converter(record.created)
if datefmt:
return dt.strftime(datefmt)
else:
return dt.isoformat(timespec="milliseconds")
# Apply the formatter to each handler
for handler in logger.handlers:
handler.setFormatter(formatter)

# Prevent the logger from propagating to the root logger
logger.propagate = False

def get_logger(name: str = __name__, log_level: Union[int, str] = logging.INFO):
level_name = (
logging.getLevelName(log_level) if isinstance(log_level, int) else log_level
)
LOGGING_CONFIG = {
"version": 1,
"disable_existing_loggers": True,
"formatters": {
"utc_formatter": {
"()": UTCFormatter,
"format": "%(levelname)s | %(asctime)s | %(message)s",
"datefmt": "%Y-%m-%d %H:%M:%S.%f",
},
},
"handlers": {
"default": {
"level": level_name,
"formatter": "utc_formatter",
"class": "logging.StreamHandler",
"stream": "ext://sys.stderr", # Default is stderr
},
},
"loggers": {
"": { # root logger
"handlers": ["default"],
"level": level_name,
"propagate": False,
},
"my.packg": {
"handlers": ["default"],
"level": level_name,
"propagate": False,
},
"__main__": { # if __name__ == '__main__'
"handlers": ["default"],
"level": level_name,
"propagate": False,
},
},
}
logging.config.dictConfig(LOGGING_CONFIG)
logger = logging.getLogger(name=name)
return logger


logger = get_logger(name=__name__, log_level=DEFAULT_LOG_LEVEL)
5 changes: 4 additions & 1 deletion runhouse/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,18 +34,21 @@
START_SCREEN_CMD,
)
from runhouse.globals import obj_store, rns_client
from runhouse.logger import logger
from runhouse.logger import get_logger
from runhouse.resources.hardware.ray_utils import (
check_for_existing_ray_instance,
kill_actors,
)


# create an explicit Typer application
app = typer.Typer(add_completion=False)

# For printing with typer
console = Console()

logger = get_logger(name=__name__)


@app.command()
def login(
Expand Down
5 changes: 3 additions & 2 deletions runhouse/resources/envs/conda_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,15 @@

from runhouse.constants import ENVS_DIR
from runhouse.globals import obj_store
from runhouse.logger import get_logger

from runhouse.logger import logger
from runhouse.resources.envs.utils import install_conda, run_setup_command

from runhouse.resources.packages import Package

from .env import Env

logger = get_logger(name=__name__)


class CondaEnv(Env):
RESOURCE_TYPE = "env"
Expand Down
5 changes: 3 additions & 2 deletions runhouse/resources/envs/env.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,20 @@
import copy
import logging
import os
import shlex
from pathlib import Path
from typing import Dict, List, Optional, Union

from runhouse.globals import obj_store
from runhouse.logger import get_logger

from runhouse.logger import logger
from runhouse.resources.envs.utils import _process_env_vars, run_setup_command
from runhouse.resources.hardware import _get_cluster_from, Cluster
from runhouse.resources.packages import InstallTarget, Package
from runhouse.resources.resource import Resource
from runhouse.utils import run_with_logs

logger = get_logger(name=__name__)


class Env(Resource):
RESOURCE_TYPE = "env"
Expand Down
4 changes: 3 additions & 1 deletion runhouse/resources/folders/folder.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,15 @@
from typing import List, Optional, Union

from runhouse.globals import rns_client
from runhouse.logger import get_logger

from runhouse.logger import logger
from runhouse.resources.hardware import _current_cluster, _get_cluster_from, Cluster
from runhouse.resources.resource import Resource
from runhouse.rns.utils.api import generate_uuid, relative_file_path
from runhouse.utils import locate_working_dir

logger = get_logger(name=__name__)


class Folder(Resource):
RESOURCE_TYPE = "folder"
Expand Down
4 changes: 3 additions & 1 deletion runhouse/resources/folders/folder_factory.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from pathlib import Path
from typing import Optional, Union

from runhouse.logger import logger
from runhouse.logger import get_logger

from runhouse.resources.folders.folder import Folder
from runhouse.resources.hardware.utils import _get_cluster_from

logger = get_logger(name=__name__)


def folder(
name: Optional[str] = None,
Expand Down
4 changes: 3 additions & 1 deletion runhouse/resources/folders/gcs_folder.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,12 @@
from pathlib import Path
from typing import List, Optional

from runhouse.logger import logger
from runhouse.logger import get_logger

from .folder import Folder

logger = get_logger(name=__name__)


class GCSFolder(Folder):
RESOURCE_TYPE = "folder"
Expand Down
4 changes: 3 additions & 1 deletion runhouse/resources/folders/s3_folder.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,16 @@
from pathlib import Path
from typing import List, Optional

from runhouse.logger import logger
from runhouse.logger import get_logger

from .folder import Folder

MAX_POLLS = 120000
POLL_INTERVAL = 1
TIMEOUT_SECONDS = 3600

logger = get_logger(name=__name__)


class S3Folder(Folder):
RESOURCE_TYPE = "folder"
Expand Down
6 changes: 3 additions & 3 deletions runhouse/resources/functions/aws_lambda.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,16 @@
pass

from runhouse.globals import rns_client

from runhouse.logger import logger
from runhouse.logger import get_logger
from runhouse.resources.envs import _get_env_from, Env

from runhouse.resources.functions.function import Function


CRED_PATH = f"{Path.home()}/.aws/credentials"
LOG_GROUP_PREFIX = "/aws/lambda/"

logger = get_logger(name=__name__)


class LambdaFunction(Function):
RESOURCE_TYPE = "lambda_function"
Expand Down
5 changes: 3 additions & 2 deletions runhouse/resources/functions/function.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
import inspect
import logging
from pathlib import Path
from typing import Any, List, Optional, Tuple, Union

from runhouse import globals
from runhouse.logger import get_logger

from runhouse.logger import logger
from runhouse.resources.envs import Env
from runhouse.resources.hardware import Cluster
from runhouse.resources.module import Module

from runhouse.resources.resource import Resource

logger = get_logger(name=__name__)


class Function(Module):
RESOURCE_TYPE = "function"
Expand Down
Loading

0 comments on commit c2f7691

Please sign in to comment.