diff --git a/src/litdata/debugger.py b/src/litdata/debugger.py index d7dfe77c..d5c36061 100644 --- a/src/litdata/debugger.py +++ b/src/litdata/debugger.py @@ -14,6 +14,7 @@ import logging import os import sys +from datetime import datetime, timedelta from functools import lru_cache from typing import Tuple @@ -32,6 +33,26 @@ def get_logger_level(level: str) -> int: raise ValueError(f"Invalid log level: {level}. Valid levels: {list(logging._nameToLevel.keys())}.") +class TimeWindowFilter(logging.Filter): + """Filter log records based on a time window.""" + + def __init__(self, start_time: int = None, end_time: int = None): + """Initialize the filter with start and end time in seconds.""" + super().__init__() + self.start = datetime.now() + self.start_time = timedelta(seconds=start_time) if start_time else timedelta(seconds=0) + self.end_time = timedelta(seconds=end_time) if end_time else None + if self.end_time and self.start_time > self.end_time: + raise ValueError("Start time must be less than or equal to end time.") + + def filter(self, record: logging.LogRecord) -> bool: + now = datetime.now() + elapsed = now - self.start + if self.end_time: + return self.start_time <= elapsed <= self.end_time + return elapsed >= self.start_time + + class LitDataLogger: def __init__(self, name: str): self.logger = logging.getLogger(name) @@ -54,25 +75,31 @@ def setup_logger(self) -> None: self.logger.setLevel(self.log_level) - # Console handler - console_handler = logging.StreamHandler(sys.stdout) - console_handler.setLevel(self.log_level) - - # File handler - file_handler = logging.FileHandler(self.log_file) - file_handler.setLevel(self.log_level) - # Log format formatter = logging.Formatter( "ts:%(created)s; logger_name:%(name)s; level:%(levelname)s; PID:%(process)d; TID:%(thread)d; %(message)s" ) # ENV - f"{WORLD_SIZE, GLOBAL_RANK, NNODES, LOCAL_RANK, NODE_RANK}" - console_handler.setFormatter(formatter) - file_handler.setFormatter(formatter) - # Attach handlers + # Time window filter + start_time = int(os.getenv("LITDATA_LOG_START_TIME", 0)) + end_time = os.getenv("LITDATA_LOG_END_TIME") + end_time = int(end_time) if end_time else None + time_window_filter = TimeWindowFilter(start_time, end_time) + + # Console handler if _PRINT_DEBUG_LOGS: + console_handler = logging.StreamHandler(sys.stdout) + console_handler.setLevel(self.log_level) + console_handler.setFormatter(formatter) + console_handler.addFilter(time_window_filter) self.logger.addHandler(console_handler) + + # File handler + file_handler = logging.FileHandler(self.log_file) + file_handler.setLevel(self.log_level) + file_handler.setFormatter(formatter) + file_handler.addFilter(time_window_filter) self.logger.addHandler(file_handler) @@ -120,6 +147,8 @@ def env_info() -> dict: # thread_state_runnable: {r: 133, g: 160, b: 210}, # .... class ChromeTraceColors: + """Predefined Chrome tracing colors.""" + PINK = "thread_state_iowait" GREEN = "thread_state_running" LIGHT_BLUE = "thread_state_runnable"