Skip to content

[wip]: Exp/Add TimeWindowFilter for controlled logging #550

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 40 additions & 11 deletions src/litdata/debugger.py
Original file line number Diff line number Diff line change
@@ -14,6 +14,7 @@
import logging
import os
import sys
from datetime import datetime, timedelta
from functools import lru_cache
from typing import Tuple

@@ -32,6 +33,26 @@ def get_logger_level(level: str) -> int:
raise ValueError(f"Invalid log level: {level}. Valid levels: {list(logging._nameToLevel.keys())}.")


class TimeWindowFilter(logging.Filter):
"""Filter log records based on a time window."""

def __init__(self, start_time: int = None, end_time: int = None):
"""Initialize the filter with start and end time in seconds."""
super().__init__()
self.start = datetime.now()
self.start_time = timedelta(seconds=start_time) if start_time else timedelta(seconds=0)
self.end_time = timedelta(seconds=end_time) if end_time else None
if self.end_time and self.start_time > self.end_time:
raise ValueError("Start time must be less than or equal to end time.")

def filter(self, record: logging.LogRecord) -> bool:
now = datetime.now()
elapsed = now - self.start
if self.end_time:
return self.start_time <= elapsed <= self.end_time
return elapsed >= self.start_time


class LitDataLogger:
def __init__(self, name: str):
self.logger = logging.getLogger(name)
@@ -54,25 +75,31 @@ def setup_logger(self) -> None:

self.logger.setLevel(self.log_level)

# Console handler
console_handler = logging.StreamHandler(sys.stdout)
console_handler.setLevel(self.log_level)

# File handler
file_handler = logging.FileHandler(self.log_file)
file_handler.setLevel(self.log_level)

# Log format
formatter = logging.Formatter(
"ts:%(created)s; logger_name:%(name)s; level:%(levelname)s; PID:%(process)d; TID:%(thread)d; %(message)s"
)
# ENV - f"{WORLD_SIZE, GLOBAL_RANK, NNODES, LOCAL_RANK, NODE_RANK}"
console_handler.setFormatter(formatter)
file_handler.setFormatter(formatter)

# Attach handlers
# Time window filter
start_time = int(os.getenv("LITDATA_LOG_START_TIME", 0))
end_time = os.getenv("LITDATA_LOG_END_TIME")
end_time = int(end_time) if end_time else None
time_window_filter = TimeWindowFilter(start_time, end_time)

# Console handler
if _PRINT_DEBUG_LOGS:
console_handler = logging.StreamHandler(sys.stdout)
console_handler.setLevel(self.log_level)
console_handler.setFormatter(formatter)
console_handler.addFilter(time_window_filter)
self.logger.addHandler(console_handler)

# File handler
file_handler = logging.FileHandler(self.log_file)
file_handler.setLevel(self.log_level)
file_handler.setFormatter(formatter)
file_handler.addFilter(time_window_filter)
self.logger.addHandler(file_handler)


@@ -120,6 +147,8 @@ def env_info() -> dict:
# thread_state_runnable: {r: 133, g: 160, b: 210},
# ....
class ChromeTraceColors:
"""Predefined Chrome tracing colors."""

PINK = "thread_state_iowait"
GREEN = "thread_state_running"
LIGHT_BLUE = "thread_state_runnable"