-
Notifications
You must be signed in to change notification settings - Fork 37
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
use simplified default logger (#1177)
- Loading branch information
Showing
48 changed files
with
269 additions
and
211 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
import runhouse as rh | ||
from runhouse.logger import get_logger | ||
|
||
from torch_example_for_airflow import DownloadData, SimpleTrainer | ||
|
||
|
||
logger = get_logger(name=__name__) | ||
|
||
|
||
def bring_up_cluster_callable(**kwargs): | ||
logger.info("Connecting to remote cluster") | ||
rh.ondemand_cluster( | ||
name="a10g-cluster", instance_type="A10G:1", provider="aws" | ||
).up_if_not() | ||
# cluster.save() ## Use if you have a Runhouse Den account to save and monitor the resource. | ||
|
||
|
||
def access_data_callable(**kwargs): | ||
logger.info("Step 2: Access data") | ||
env = rh.env(name="test_env", reqs=["torch", "torchvision"]) | ||
|
||
cluster = rh.cluster(name="a10g-cluster").up_if_not() | ||
remote_download = rh.function(DownloadData).to(cluster, env=env) | ||
logger.info("Download function sent to remote") | ||
remote_download() | ||
logger.info("Downloaded") | ||
|
||
|
||
def train_model_callable(**kwargs): | ||
cluster = rh.cluster(name="a10g-cluster").up_if_not() | ||
|
||
env = rh.env(name="test_env", reqs=["torch", "torchvision"]) | ||
|
||
remote_torch_example = rh.module(SimpleTrainer).to( | ||
cluster, env=env, name="torch-basic-training" | ||
) | ||
|
||
model = remote_torch_example() | ||
|
||
batch_size = 64 | ||
epochs = 5 | ||
learning_rate = 0.01 | ||
|
||
model.load_train("./data", batch_size) | ||
model.load_test("./data", batch_size) | ||
|
||
for epoch in range(epochs): | ||
model.train_model(learning_rate=learning_rate) | ||
model.test_model() | ||
model.save_model( | ||
bucket_name="my-simple-torch-model-example", | ||
s3_file_path=f"checkpoints/model_epoch_{epoch + 1}.pth", | ||
) | ||
|
||
|
||
def down_cluster(**kwargs): | ||
cluster = rh.cluster(name="a10g-cluster") | ||
cluster.teardown() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,148 +1,25 @@ | ||
import logging | ||
import logging.config | ||
import re | ||
from datetime import datetime, timezone | ||
from typing import List, Union | ||
import os | ||
|
||
from runhouse.constants import DEFAULT_LOG_LEVEL | ||
|
||
def get_logger(name: str = __name__): | ||
logger = logging.getLogger(name) | ||
|
||
class ColoredFormatter: | ||
COLORS = { | ||
"black": "\u001b[30m", | ||
"red": "\u001b[31m", | ||
"green": "\u001b[32m", | ||
"yellow": "\u001b[33m", | ||
"blue": "\u001b[34m", | ||
"magenta": "\u001b[35m", | ||
"cyan": "\u001b[36m", | ||
"white": "\u001b[37m", | ||
"reset": "\u001b[0m", | ||
} | ||
level = os.getenv("RH_LOG_LEVEL") | ||
if level: | ||
# Set the logging level | ||
logger.setLevel(level.upper()) | ||
|
||
@classmethod | ||
def get_color(cls, color: str): | ||
return cls.COLORS.get(color, "") | ||
|
||
# TODO: This method is a temp solution, until we'll update logging architecture. Remove once logging is cleaned up. | ||
@classmethod | ||
def format_log(cls, text): | ||
ansi_escape = re.compile(r"(?:\x1B[@-_][0-?]*[ -/]*[@-~])") | ||
return ansi_escape.sub("", text) | ||
|
||
|
||
class ClusterLogsFormatter: | ||
def __init__(self, system): | ||
self.system = system | ||
self._display_title = False | ||
|
||
def format(self, output_type): | ||
from runhouse import Resource | ||
from runhouse.servers.http.http_utils import OutputType | ||
|
||
system_color = ColoredFormatter.get_color("cyan") | ||
reset_color = ColoredFormatter.get_color("reset") | ||
|
||
prettify_logs = output_type in [ | ||
OutputType.STDOUT, | ||
OutputType.EXCEPTION, | ||
OutputType.STDERR, | ||
] | ||
|
||
if ( | ||
isinstance(self.system, Resource) | ||
and prettify_logs | ||
and not self._display_title | ||
): | ||
# Display the system name before subsequent logs only once | ||
system_name = self.system.name | ||
dotted_line = "-" * len(system_name) | ||
print(dotted_line) | ||
print(f"{system_color}{system_name}{reset_color}") | ||
print(dotted_line) | ||
|
||
# Only display the system name once | ||
self._display_title = True | ||
|
||
return system_color, reset_color | ||
|
||
|
||
class FunctionLogHandler(logging.Handler): | ||
def __init__(self): | ||
super().__init__() | ||
self.log_records = [] | ||
|
||
def emit(self, record): | ||
self.log_records.append(record) | ||
|
||
@staticmethod | ||
def log_records_to_stdout(log_records: List[logging.LogRecord]) -> str: | ||
"""Convert the log records to a string repr of the stdout output""" | ||
captured_logs = [ | ||
f"{log_record.levelname} | {log_record.asctime} | {log_record.msg}" | ||
for log_record in log_records | ||
] | ||
return "\n".join(captured_logs) | ||
|
||
|
||
class UTCFormatter(logging.Formatter): | ||
"""Ensure logs are always in UTC time""" | ||
|
||
@staticmethod | ||
def converter(timestamp): | ||
return datetime.fromtimestamp(timestamp, tz=timezone.utc) | ||
# Apply a custom formatter | ||
formatter = logging.Formatter( | ||
fmt="%(levelname)s | %(asctime)s | %(message)s", datefmt="%Y-%m-%d %H:%M:%S" | ||
) | ||
|
||
def formatTime(self, record, datefmt=None): | ||
dt = self.converter(record.created) | ||
if datefmt: | ||
return dt.strftime(datefmt) | ||
else: | ||
return dt.isoformat(timespec="milliseconds") | ||
# Apply the formatter to each handler | ||
for handler in logger.handlers: | ||
handler.setFormatter(formatter) | ||
|
||
# Prevent the logger from propagating to the root logger | ||
logger.propagate = False | ||
|
||
def get_logger(name: str = __name__, log_level: Union[int, str] = logging.INFO): | ||
level_name = ( | ||
logging.getLevelName(log_level) if isinstance(log_level, int) else log_level | ||
) | ||
LOGGING_CONFIG = { | ||
"version": 1, | ||
"disable_existing_loggers": True, | ||
"formatters": { | ||
"utc_formatter": { | ||
"()": UTCFormatter, | ||
"format": "%(levelname)s | %(asctime)s | %(message)s", | ||
"datefmt": "%Y-%m-%d %H:%M:%S.%f", | ||
}, | ||
}, | ||
"handlers": { | ||
"default": { | ||
"level": level_name, | ||
"formatter": "utc_formatter", | ||
"class": "logging.StreamHandler", | ||
"stream": "ext://sys.stderr", # Default is stderr | ||
}, | ||
}, | ||
"loggers": { | ||
"": { # root logger | ||
"handlers": ["default"], | ||
"level": level_name, | ||
"propagate": False, | ||
}, | ||
"my.packg": { | ||
"handlers": ["default"], | ||
"level": level_name, | ||
"propagate": False, | ||
}, | ||
"__main__": { # if __name__ == '__main__' | ||
"handlers": ["default"], | ||
"level": level_name, | ||
"propagate": False, | ||
}, | ||
}, | ||
} | ||
logging.config.dictConfig(LOGGING_CONFIG) | ||
logger = logging.getLogger(name=name) | ||
return logger | ||
|
||
|
||
logger = get_logger(name=__name__, log_level=DEFAULT_LOG_LEVEL) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.