diff --git a/src/gallia/cli/gallia.py b/src/gallia/cli/gallia.py index 90f764a51..3dcdf721a 100644 --- a/src/gallia/cli/gallia.py +++ b/src/gallia/cli/gallia.py @@ -25,9 +25,8 @@ from gallia.plugins.plugin import CommandTree, load_commands, load_plugins from gallia.pydantic_argparse import ArgumentParser from gallia.pydantic_argparse import BaseCommand as PydanticBaseCommand -from gallia.utils import get_log_level -setup_logging(Loglevel.DEBUG) +setup_logging("gallia", Loglevel.DEBUG) defaults = dict[type, dict[str, Any]] @@ -123,7 +122,6 @@ def get_command(config: BaseCommandConfig) -> BaseCommand: def parse_and_run( commands: type[BaseCommand] | MutableMapping[str, CommandTree | type[BaseCommand]], auto_complete: bool = True, - setup_log: bool = True, top_level_options: Mapping[str, Callable[[], None]] | None = None, show_help_on_zero_args: bool = True, ) -> Never: @@ -136,7 +134,6 @@ def parse_and_run( :param commands: A hierarchy of commands. :param auto_complete: Turns auto-complete functionality on. - :param setup_log: Setup logging according to the parameters in the parsed config. :param top_level_options: Optional top-level actions, such as "--version", given by a mapping of arguments and functions. The program redirects control to the given function, once the program is called with the corresponding argument and terminates after it returns. @@ -182,12 +179,6 @@ def __call__( assert isinstance(config, BaseCommandConfig) - if setup_log: - setup_logging( - level=get_log_level(config.verbose), - no_volatile_info=not config.volatile_info, - ) - sys.exit(get_command(config).entry_point()) diff --git a/src/gallia/command/base.py b/src/gallia/command/base.py index 8731fbe66..5d61504b9 100644 --- a/src/gallia/command/base.py +++ b/src/gallia/command/base.py @@ -14,7 +14,6 @@ from collections.abc import MutableMapping from datetime import UTC, datetime from enum import Enum, unique -from logging import Handler from pathlib import Path from subprocess import CalledProcessError, run from tempfile import gettempdir @@ -26,12 +25,12 @@ from gallia.command.config import Field, GalliaBaseModel, Idempotent from gallia.db.handler import DBHandler from gallia.dumpcap import Dumpcap -from gallia.log import add_zst_log_handler, get_logger, tz +from gallia.log import LoggingSetupHandler, Loglevel, get_log_level, get_logger, setup_logging, tz from gallia.power_supply import PowerSupply from gallia.power_supply.uri import PowerSupplyURI from gallia.services.uds.core.exception import UDSException from gallia.transports import BaseTransport, TargetURI -from gallia.utils import camel_to_snake, get_file_log_level +from gallia.utils import camel_to_snake @unique @@ -180,9 +179,11 @@ class BaseCommand(FlockMixin, ABC): #: a log message with level critical is logged. CATCHED_EXCEPTIONS: list[type[Exception]] = [] - log_file_handlers: list[Handler] - - def __init__(self, config: BaseCommandConfig) -> None: + def __init__( + self, + config: BaseCommandConfig = BaseCommandConfig(), + logging_handler: LoggingSetupHandler | None = None, + ) -> None: self.id = camel_to_snake(self.__class__.__name__) self.config = config self.artifacts_dir = Path() @@ -195,7 +196,7 @@ def __init__(self, config: BaseCommandConfig) -> None: ) self._lock_file_fd: int | None = None self.db_handler: DBHandler | None = None - self.log_file_handlers = [] + self.provided_logging_handler = logging_handler @abstractmethod def run(self) -> int: ... @@ -323,15 +324,25 @@ def entry_point(self) -> int: if self.HAS_ARTIFACTS_DIR: self.artifacts_dir = self.prepare_artifactsdir( - self.config.artifacts_base, self.config.artifacts_dir + self.config.artifacts_base, + self.config.artifacts_dir, + ) + + if self.provided_logging_handler is None: + stderr_level = get_log_level(self.config.verbose) + logging_handler = setup_logging( + logger_name="gallia", + stderr_level=stderr_level, + close_on_exit=False, ) - self.log_file_handlers.append( - add_zst_log_handler( + if self.HAS_ARTIFACTS_DIR: + logging_handler.add_zst_file_handler( logger_name="gallia", filepath=self.artifacts_dir.joinpath(FileNames.LOGFILE.value), - file_log_level=get_file_log_level(self.config), + log_level=stderr_level if self.config.trace_log is False else Loglevel.TRACE, ) - ) + else: + logging_handler = self.provided_logging_handler if self.config.hooks: self.run_hook(HookVariant.PRE) @@ -380,6 +391,9 @@ def entry_point(self) -> int: if self._lock_file_fd is not None: self._release_flock() + if self.provided_logging_handler is None: + logging_handler.stop_logging() + return exit_code diff --git a/src/gallia/log.py b/src/gallia/log.py index fdcbdf96f..bdb7833f6 100644 --- a/src/gallia/log.py +++ b/src/gallia/log.py @@ -25,7 +25,18 @@ from pathlib import Path from queue import Queue from types import TracebackType -from typing import TYPE_CHECKING, Any, BinaryIO, Self, TextIO, TypeAlias, cast +from typing import ( + IO, + TYPE_CHECKING, + Any, + BinaryIO, + Literal, + Self, + TextIO, + TypeAlias, + cast, + overload, +) import zstandard @@ -37,41 +48,6 @@ tz = datetime.timezone(datetime.timedelta(seconds=gmt_offset)) -@unique -class ColorMode(Enum): - """ColorMode is used as an argument to :func:`set_color_mode`.""" - - #: Colors are always turned on. - ALWAYS = "always" - #: Colors are turned off if the target - #: stream (e.g. stderr) is not a tty. - AUTO = "auto" - #: No colors are used. In other words, - #: no ANSI escape codes are included. - NEVER = "never" - - -def resolve_color_mode(mode: ColorMode, stream: TextIO = sys.stderr) -> bool: - """Sets the color mode of the console log handler. - - :param mode: The available options are described in :class:`ColorMode`. - :param stream: Used as a reference for :attr:`ColorMode.AUTO`. - """ - if sys.platform == "win32": - return False - - match mode: - case ColorMode.ALWAYS: - return True - case ColorMode.AUTO: - if os.getenv("NO_COLOR") is not None: - return False - else: - return stream.isatty() - case ColorMode.NEVER: - return False - - # https://stackoverflow.com/a/35804945 def _add_logging_level(level_name: str, level_num: int) -> None: method_name = level_name.lower() @@ -108,6 +84,41 @@ def to_root(message, *args, **kwargs): # type: ignore _add_logging_level("NOTICE", 25) +@unique +class ColorMode(Enum): + """ColorMode is used as an argument to :func:`set_color_mode`.""" + + #: Colors are always turned on. + ALWAYS = "always" + #: Colors are turned off if the target + #: stream (e.g. stderr) is not a tty. + AUTO = "auto" + #: No colors are used. In other words, + #: no ANSI escape codes are included. + NEVER = "never" + + +def resolve_color_mode(mode: ColorMode, stream: TextIO = sys.stderr) -> bool: + """Sets the color mode of the console log handler. + + :param mode: The available options are described in :class:`ColorMode`. + :param stream: Used as a reference for :attr:`ColorMode.AUTO`. + """ + if sys.platform == "win32": + return False + + match mode: + case ColorMode.ALWAYS: + return True + case ColorMode.AUTO: + if os.getenv("NO_COLOR") is not None: + return False + else: + return stream.isatty() + case ColorMode.NEVER: + return False + + @unique class Loglevel(IntEnum): """A wrapper around the constants exposed by python's @@ -228,103 +239,163 @@ def to_level(self) -> Loglevel: raise ValueError("invalid value") +class LoggingSetupHandler: + def __init__(self) -> None: + self.listeners: list[QueueListener] = [] + + def __enter__(self) -> Self: + return self + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + tb: TracebackType | None, + ) -> None: + self.stop_logging() + + def add_stream_handler( + self, + logger_name: str, + level: Loglevel, + stream: IO[str], + volatile_info: bool, + colored: bool, + ) -> None: + queue: Queue[logging.LogRecord] = Queue() + logger = logging.getLogger(logger_name) + logger.addHandler(QueueHandler(queue)) + + handler = logging.StreamHandler(stream) + handler.setLevel(level) + + formatter: logging.Formatter + if stream.isatty(): + formatter = _ConsoleFormatter(colored=colored, volatile_info=volatile_info) + else: + formatter = _StreamFormatter() + + handler.terminator = "" # We manually handle the terminator while formatting + + handler.setFormatter(formatter) + + listener = QueueListener( + queue, + handler, + respect_handler_level=True, + ) + listener.start() + self.listeners.append(listener) + + def add_zst_file_handler( + self, + logger_name: str, + filepath: Path | str, + log_level: Loglevel, + ) -> None: + queue: Queue[Any] = Queue() + logger = get_logger(logger_name) + logger.addHandler(QueueHandler(queue)) + + handler = _ZstdFileHandler( + filepath, + level=log_level, + ) + handler.setLevel(log_level) + handler.setFormatter(_JSONFormatter()) + + queue_listener = QueueListener( + queue, + handler, + respect_handler_level=True, + ) + queue_listener.start() + self.listeners.append(queue_listener) + + def stop_logging(self) -> None: + for listener in self.listeners: + listener.stop() + + +def get_log_level(verbose: int) -> Loglevel: + level = Loglevel.INFO + if verbose == 1: + level = Loglevel.DEBUG + elif verbose >= 2: + level = Loglevel.TRACE + return level + + +@overload def setup_logging( - level: Loglevel | None = None, + logger_name: str, + stderr_level: Loglevel | None = ..., + color_mode: ColorMode = ..., + volatile_info: bool = ..., + close_on_exit: Literal[False] = False, + logfile: Path | str | None = ..., + logfile_level: Loglevel = ..., +) -> LoggingSetupHandler: ... + + +@overload +def setup_logging( + logger_name: str, + stderr_level: Loglevel | None = ..., + color_mode: ColorMode = ..., + volatile_info: bool = ..., + close_on_exit: Literal[True] = ..., + logfile: Path | str | None = ..., + logfile_level: Loglevel = ..., +) -> None: ... + + +def setup_logging( + logger_name: str, + stderr_level: Loglevel | None = Loglevel.INFO, color_mode: ColorMode = ColorMode.AUTO, - no_volatile_info: bool = False, - logger_name: str = "gallia", -) -> None: + volatile_info: bool = False, # deprecated: Introduce progress info + close_on_exit: bool = False, + logfile: Path | str | None = None, + logfile_level: Loglevel = Loglevel.INFO, +) -> LoggingSetupHandler | None: """Enable and configure gallia's logging system. If this fuction is not called as early as possible, the logging system is in an undefined state und might not behave as expected. Always use this function to - initialize gallia's logging. For instance, ``setup_logging()`` - initializes a QueueHandler to avoid blocking calls during - logging. - - :param level: The loglevel to enable for the console handler. - If this argument is None, the env variable - ``GALLIA_LOGLEVEL`` (see :doc:`../env`) is read. - :param file_level: The loglevel to enable for the file handler. - :param path: The path to the logfile containing json records. - :param color_mode: The color mode to use for the console. + initialize gallia's logging. """ - if level is None: - # FIXME: why is this here and not in config? - if (raw := os.getenv("GALLIA_LOGLEVEL")) is not None: - level = PenlogPriority.from_str(raw).to_level() - else: - level = Loglevel.DEBUG - # These are slow and not used by gallia. logging.logMultiprocessing = False logging.logThreads = False logging.logProcesses = False logger = logging.getLogger(logger_name) - # LogLevel cannot be 0 (NOTSET), because only the root logger sends it to its handlers then + + # FIXME: Randomly setting loglevels seems wrong. Address this better. logger.setLevel(1) - # Clean up potentially existing handlers and create a new async QueueHandler for stderr output - while len(logger.handlers) > 0: - logger.handlers[0].close() - logger.removeHandler(logger.handlers[0]) - colored = resolve_color_mode(color_mode) - add_stderr_log_handler(logger_name, level, no_volatile_info, colored) + # Clean up potentially existing handlers. + for h in logger.handlers[:]: + logger.removeHandler(h) + h.close() + handler = LoggingSetupHandler() -def add_stderr_log_handler( - logger_name: str, - level: Loglevel, - no_volatile_info: bool, - colored: bool, -) -> None: - queue: Queue[Any] = Queue() - logger = logging.getLogger(logger_name) - logger.addHandler(QueueHandler(queue)) - - stderr_handler = logging.StreamHandler(sys.stderr) - stderr_handler.setLevel(level) - console_formatter = _ConsoleFormatter() - - console_formatter.colored = colored - stderr_handler.terminator = "" # We manually handle the terminator while formatting - if no_volatile_info is False: - console_formatter.volatile_info = True - - stderr_handler.setFormatter(console_formatter) - - queue_listener = QueueListener( - queue, - *[stderr_handler], - respect_handler_level=True, - ) - queue_listener.start() - atexit.register(queue_listener.stop) - - -def add_zst_log_handler( - logger_name: str, filepath: Path, file_log_level: Loglevel -) -> logging.Handler: - queue: Queue[Any] = Queue() - logger = get_logger(logger_name) - logger.addHandler(QueueHandler(queue)) - - zstd_handler = _ZstdFileHandler( - filepath, - level=file_log_level, - ) - zstd_handler.setLevel(file_log_level) - zstd_handler.setFormatter(_JSONFormatter()) - - queue_listener = QueueListener( - queue, - *[zstd_handler], - respect_handler_level=True, - ) - queue_listener.start() - atexit.register(queue_listener.stop) - return zstd_handler + if stderr_level is not None: + colored = resolve_color_mode(color_mode, sys.stderr) + handler.add_stream_handler( + logger_name, stderr_level, sys.stderr, volatile_info, colored=colored + ) + + if logfile: + handler.add_zst_file_handler(logger_name, logfile, logfile_level) + + if close_on_exit: + atexit.register(handler.stop_logging) + return None + + return handler @dataclasses.dataclass @@ -705,9 +776,41 @@ def format(self, record: logging.LogRecord) -> str: return json.dumps(dataclasses.asdict(penlog_record)) +class _StreamFormatter(logging.Formatter): + def __init__(self) -> None: + pass + + def format( + self, + record: logging.LogRecord, + ) -> str: + stacktrace = None + + if record.exc_info: + exc_type, exc_value, exc_traceback = record.exc_info + assert exc_type + assert exc_value + assert exc_traceback + + stacktrace = "\n" + stacktrace += "".join(traceback.format_exception(exc_type, exc_value, exc_traceback)) + + return _format_record( + dt=datetime.datetime.fromtimestamp(record.created), + name=record.name, + data=record.getMessage(), + levelno=record.levelno, + tags=record.__dict__["tags"] if "tags" in record.__dict__ else None, + stacktrace=stacktrace, + colored=False, + volatile_info=False, + ) + + class _ConsoleFormatter(logging.Formatter): - colored: bool = False - volatile_info: bool = False + def __init__(self, colored: bool, volatile_info: bool) -> None: + self.colored = colored + self.volatile_info = volatile_info # deprecated: will be removed def format( self, @@ -737,7 +840,7 @@ def format( class _ZstdFileHandler(logging.Handler): - def __init__(self, path: Path, level: int | str = logging.NOTSET) -> None: + def __init__(self, path: Path | str, level: int | str = logging.NOTSET) -> None: super().__init__(level) self.file = zstandard.open( filename=path, diff --git a/src/gallia/utils.py b/src/gallia/utils.py index 336bee5bc..8c1b56aa9 100644 --- a/src/gallia/utils.py +++ b/src/gallia/utils.py @@ -23,7 +23,7 @@ import pydantic from pydantic.networks import IPvAnyAddress -from gallia.log import Loglevel, get_logger +from gallia.log import get_logger if TYPE_CHECKING: from gallia.db.handler import DBHandler @@ -261,27 +261,6 @@ def dump_args(args: Any) -> dict[str, str | int | float]: return settings -def get_log_level(args: Any) -> Loglevel: - level = Loglevel.INFO - if hasattr(args, "verbose"): - if args.verbose == 1: - level = Loglevel.DEBUG - elif args.verbose >= 2: - level = Loglevel.TRACE - return level - - -def get_file_log_level(args: Any) -> Loglevel: - level = Loglevel.DEBUG - if hasattr(args, "trace_log"): - if args.trace_log: - level = Loglevel.TRACE - elif hasattr(args, "verbose"): - if args.verbose >= 2: - level = Loglevel.TRACE - return level - - CONTEXT_SHARED_VARIABLE = "logger_name" context: contextvars.ContextVar[tuple[str, str | None]] = contextvars.ContextVar( CONTEXT_SHARED_VARIABLE diff --git a/tests/pytest/test_helpers.py b/tests/pytest/test_helpers.py index 57b762120..60c667618 100644 --- a/tests/pytest/test_helpers.py +++ b/tests/pytest/test_helpers.py @@ -11,7 +11,7 @@ ) from gallia.utils import split_host_port -setup_logging() +setup_logging("gallia") def test_split_host_port_v4() -> None: diff --git a/tests/pytest/test_transports.py b/tests/pytest/test_transports.py index 7c2f2eb5b..24ea6f0b4 100644 --- a/tests/pytest/test_transports.py +++ b/tests/pytest/test_transports.py @@ -15,7 +15,7 @@ test_data = [b"hello" b"tcp"] -setup_logging() +setup_logging("gallia") class TCPServer: