Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/litdata/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
### Added

- Introduced `CHANGELOG.md` to track changes across releases ([#733](https://github.com/lightning-ai/litdata/pull/733))
- Added `configure_logging` to `src/litdata/debugger.py` ([#685](https://github.com/Lightning-AI/litData/issues/685))
- Added necessary test for `configure_logging` at `tests/test_debugger.py` ([#685](https://github.com/Lightning-AI/litData/issues/685))
- Added `configure_logging` import to `src/litdata/__init__.py` ([#685](https://github.com/Lightning-AI/litData/issues/685))

### Changed

Expand Down
56 changes: 56 additions & 0 deletions src/litdata/debugger.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,17 @@
import logging
import os
import re
import sys
import threading
import time
from functools import lru_cache
from typing import TYPE_CHECKING, Any, TextIO, Union

from litdata.utilities.env import _DistributedEnv, _is_in_dataloader_worker, _WorkerEnv

_DEFAULT_LOG_FORMAT = (
"%(asctime)s - %(processName)s[%(process)d] - %(name)s - %(levelname)s - %(filename)s:%(lineno)d - %(message)s"
)

class TimedFlushFileHandler(logging.FileHandler):
"""FileHandler that flushes every N seconds in a background thread."""
Expand Down Expand Up @@ -79,6 +84,57 @@ def get_logger_level(level: str) -> int:
raise ValueError(f"Invalid log level: {level}")


def _get_default_handler(stream, format):
handler = logging.StreamHandler(stream)
formatter = logging.Formatter(format)
handler.setFormatter(formatter)
return handler


def configure_logging(
level: Union[str, int] = logging.INFO,
format: str = _DEFAULT_LOG_FORMAT,
stream: TextIO = sys.stdout,
use_rich: bool = False,
):
"""Configure logging for the entire library with sensible defaults.

Args:
level (int): Logging level (default: logging.INFO)
format (str): Log message format string
stream (file-like): Output stream for logs
use_rich (bool): Makes the logs more readable by using rich, useful for debugging. Defaults to False.

"""
if isinstance(level, str):
level = level.upper()
level = getattr(logging, level)

# Clear any existing handlers to prevent duplicates
library_logger = logging.getLogger("litdata")
for handler in library_logger.handlers[:]:
library_logger.removeHandler(handler)

if use_rich:
try:
from rich.logging import RichHandler
from rich.traceback import install

install(show_locals=True)
handler = RichHandler(rich_tracebacks=True, show_time=True, show_path=True)
except ImportError:
logging.warning("Rich is not installed, using default logging")
handler = _get_default_handler(stream, format)
else:
handler = _get_default_handler(stream, format)

# Configure library logger
library_logger.setLevel(level)
library_logger.addHandler(handler)
library_logger.propagate = False
pass


class LitDataLogger:
_instance = None
_lock = threading.Lock()
Expand Down
37 changes: 37 additions & 0 deletions tests/test_debugger.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,16 @@
import io
import logging
from unittest import mock

import pytest

from litdata.debugger import configure_logging


@pytest.fixture
def log_stream():
return io.StringIO()


def test_get_logger_level():
from litdata.debugger import get_logger_level
Expand All @@ -13,3 +22,31 @@ def test_get_logger_level():
assert get_logger_level("CRITICAL") == logging.CRITICAL
with pytest.raises(ValueError, match="Invalid log level"):
get_logger_level("INVALID")


def test_configure_logging(log_stream):
# Configure logging with test stream
configure_logging(level=logging.DEBUG, stream=log_stream)

# Get logger and log a test message
logger = logging.getLogger("litdata")
test_message = "Test debug message"
logger.debug(test_message)

# Verify log output
log_contents = log_stream.getvalue()
assert test_message in log_contents
assert "DEBUG" in log_contents
assert logger.propagate is False


def test_configure_logging_2():
configure_logging(use_rich=False)
assert logging.getLogger("litdata").handlers[0].__class__.__name__ == "StreamHandler"


def test_configure_logging_rich_not_installed():
# patch builtins.__import__ to raise ImportError
with mock.patch("builtins.__import__", side_effect=ImportError):
configure_logging(use_rich=True)
assert logging.getLogger("litdata").handlers[0].__class__.__name__ == "StreamHandler"
Loading