diff --git a/python-package/lightgbm/basic.py b/python-package/lightgbm/basic.py index de408a4be4db..5d8efb950f00 100644 --- a/python-package/lightgbm/basic.py +++ b/python-package/lightgbm/basic.py @@ -12,7 +12,7 @@ from os.path import getsize from pathlib import Path from tempfile import NamedTemporaryFile -from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union +from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, Union import numpy as np import scipy.sparse @@ -34,17 +34,17 @@ def _get_sample_count(total_nrow: int, params: str): class _DummyLogger: - def info(self, msg): + def info(self, msg: str) -> None: print(msg) - def warning(self, msg): + def warning(self, msg: str) -> None: warnings.warn(msg, stacklevel=3) -_LOGGER = _DummyLogger() +_LOGGER: Union[_DummyLogger, Logger] = _DummyLogger() -def register_logger(logger): +def register_logger(logger: Logger) -> None: """Register custom logger. Parameters @@ -58,12 +58,12 @@ def register_logger(logger): _LOGGER = logger -def _normalize_native_string(func): +def _normalize_native_string(func: Callable[[str], None]) -> Callable[[str], None]: """Join log messages from native library which come by chunks.""" - msg_normalized = [] + msg_normalized: List[str] = [] @wraps(func) - def wrapper(msg): + def wrapper(msg: str) -> None: nonlocal msg_normalized if msg.strip() == '': msg = ''.join(msg_normalized) @@ -75,20 +75,20 @@ def wrapper(msg): return wrapper -def _log_info(msg): +def _log_info(msg: str) -> None: _LOGGER.info(msg) -def _log_warning(msg): +def _log_warning(msg: str) -> None: _LOGGER.warning(msg) @_normalize_native_string -def _log_native(msg): +def _log_native(msg: str) -> None: _LOGGER.info(msg) -def _log_callback(msg): +def _log_callback(msg: bytes) -> None: """Redirect logs from native library into Python.""" _log_native(str(msg.decode('utf-8')))