Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[python] add type hints to logging functions in basic.py #4527

Merged
merged 3 commits into from
Aug 19, 2021
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 12 additions & 12 deletions python-package/lightgbm/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from os.path import getsize
from pathlib import Path
from tempfile import NamedTemporaryFile
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union
from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, Union

import numpy as np
import scipy.sparse
Expand All @@ -34,17 +34,17 @@ def _get_sample_count(total_nrow: int, params: str):


class _DummyLogger:
def info(self, msg):
def info(self, msg: str) -> None:
print(msg)

def warning(self, msg):
def warning(self, msg: str) -> None:
warnings.warn(msg, stacklevel=3)


_LOGGER = _DummyLogger()
_LOGGER: Union[_DummyLogger, Logger] = _DummyLogger()


def register_logger(logger):
def register_logger(logger: Logger) -> None:
"""Register custom logger.

Parameters
Expand All @@ -58,12 +58,12 @@ def register_logger(logger):
_LOGGER = logger


def _normalize_native_string(func):
def _normalize_native_string(func: Callable[[str], None]) -> Callable[[str], None]:
"""Join log messages from native library which come by chunks."""
msg_normalized = []
msg_normalized: List[str] = []

@wraps(func)
def wrapper(msg):
def wrapper(msg: str) -> None:
nonlocal msg_normalized
if msg.strip() == '':
msg = ''.join(msg_normalized)
Expand All @@ -75,20 +75,20 @@ def wrapper(msg):
return wrapper


def _log_info(msg):
def _log_info(msg: str) -> None:
_LOGGER.info(msg)


def _log_warning(msg):
def _log_warning(msg: str) -> None:
_LOGGER.warning(msg)


@_normalize_native_string
def _log_native(msg):
def _log_native(msg: str) -> None:
_LOGGER.info(msg)


def _log_callback(msg):
def _log_callback(msg: bytes) -> None:
"""Redirect logs from native library into Python."""
_log_native(str(msg.decode('utf-8')))

Expand Down