diff --git a/src/litdata/CHANGELOG.md b/src/litdata/CHANGELOG.md index fd392861..0af1eead 100644 --- a/src/litdata/CHANGELOG.md +++ b/src/litdata/CHANGELOG.md @@ -11,6 +11,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added - Introduced `CHANGELOG.md` to track changes across releases ([#733](https://github.com/lightning-ai/litdata/pull/733)) +- Added `configure_logging` to `src/litdata/debugger.py` ([#685](https://github.com/Lightning-AI/litData/issues/685)) +- Added necessary test for `configure_logging` at `tests/test_debugger.py` ([#685](https://github.com/Lightning-AI/litData/issues/685)) +- Added `configure_logging` import to `src/litdata/__init__.py` ([#685](https://github.com/Lightning-AI/litData/issues/685)) ### Changed diff --git a/src/litdata/debugger.py b/src/litdata/debugger.py index bc2b1e6c..321dd004 100644 --- a/src/litdata/debugger.py +++ b/src/litdata/debugger.py @@ -14,12 +14,17 @@ import logging import os import re +import sys import threading import time from functools import lru_cache +from typing import TYPE_CHECKING, Any, TextIO, Union from litdata.utilities.env import _DistributedEnv, _is_in_dataloader_worker, _WorkerEnv +_DEFAULT_LOG_FORMAT = ( + "%(asctime)s - %(processName)s[%(process)d] - %(name)s - %(levelname)s - %(filename)s:%(lineno)d - %(message)s" +) class TimedFlushFileHandler(logging.FileHandler): """FileHandler that flushes every N seconds in a background thread.""" @@ -79,6 +84,57 @@ def get_logger_level(level: str) -> int: raise ValueError(f"Invalid log level: {level}") +def _get_default_handler(stream, format): + handler = logging.StreamHandler(stream) + formatter = logging.Formatter(format) + handler.setFormatter(formatter) + return handler + + +def configure_logging( + level: Union[str, int] = logging.INFO, + format: str = _DEFAULT_LOG_FORMAT, + stream: TextIO = sys.stdout, + use_rich: bool = False, +): + """Configure logging for the entire library with sensible defaults. + + Args: + level (int): Logging level (default: logging.INFO) + format (str): Log message format string + stream (file-like): Output stream for logs + use_rich (bool): Makes the logs more readable by using rich, useful for debugging. Defaults to False. + + """ + if isinstance(level, str): + level = level.upper() + level = getattr(logging, level) + + # Clear any existing handlers to prevent duplicates + library_logger = logging.getLogger("litdata") + for handler in library_logger.handlers[:]: + library_logger.removeHandler(handler) + + if use_rich: + try: + from rich.logging import RichHandler + from rich.traceback import install + + install(show_locals=True) + handler = RichHandler(rich_tracebacks=True, show_time=True, show_path=True) + except ImportError: + logging.warning("Rich is not installed, using default logging") + handler = _get_default_handler(stream, format) + else: + handler = _get_default_handler(stream, format) + + # Configure library logger + library_logger.setLevel(level) + library_logger.addHandler(handler) + library_logger.propagate = False + pass + + class LitDataLogger: _instance = None _lock = threading.Lock() diff --git a/tests/test_debugger.py b/tests/test_debugger.py index d90303bd..fa43f1ce 100644 --- a/tests/test_debugger.py +++ b/tests/test_debugger.py @@ -1,7 +1,16 @@ +import io import logging +from unittest import mock import pytest +from litdata.debugger import configure_logging + + +@pytest.fixture +def log_stream(): + return io.StringIO() + def test_get_logger_level(): from litdata.debugger import get_logger_level @@ -13,3 +22,31 @@ def test_get_logger_level(): assert get_logger_level("CRITICAL") == logging.CRITICAL with pytest.raises(ValueError, match="Invalid log level"): get_logger_level("INVALID") + + +def test_configure_logging(log_stream): + # Configure logging with test stream + configure_logging(level=logging.DEBUG, stream=log_stream) + + # Get logger and log a test message + logger = logging.getLogger("litdata") + test_message = "Test debug message" + logger.debug(test_message) + + # Verify log output + log_contents = log_stream.getvalue() + assert test_message in log_contents + assert "DEBUG" in log_contents + assert logger.propagate is False + + +def test_configure_logging_2(): + configure_logging(use_rich=False) + assert logging.getLogger("litdata").handlers[0].__class__.__name__ == "StreamHandler" + + +def test_configure_logging_rich_not_installed(): + # patch builtins.__import__ to raise ImportError + with mock.patch("builtins.__import__", side_effect=ImportError): + configure_logging(use_rich=True) + assert logging.getLogger("litdata").handlers[0].__class__.__name__ == "StreamHandler"