Skip to content

Commit

Permalink
[python-package] fix mypy errors in engine.py (#4839)
Browse files Browse the repository at this point in the history
* [python-package] fix mypy errors in engine.py

* Apply suggestions from code review

Co-authored-by: Nikita Titov <nekit94-08@mail.ru>

* allow for stdv

* whitespace

Co-authored-by: Nikita Titov <nekit94-08@mail.ru>
  • Loading branch information
jameslamb and StrikerRUS authored Dec 2, 2021
1 parent de23b56 commit 946817a
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 16 deletions.
15 changes: 10 additions & 5 deletions python-package/lightgbm/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,15 @@
"""Callbacks library."""
import collections
from functools import partial
from typing import Any, Callable, Dict, List, Union
from typing import Any, Callable, Dict, List, Tuple, Union

from .basic import _ConfigAliases, _log_info, _log_warning

_EvalResultTuple = Union[
List[Tuple[str, str, float, bool]],
List[Tuple[str, str, float, bool, float]]
]


def _gt_delta(curr_score: float, best_score: float, delta: float) -> bool:
return curr_score > best_score + delta
Expand All @@ -18,15 +23,15 @@ def _lt_delta(curr_score: float, best_score: float, delta: float) -> bool:
class EarlyStopException(Exception):
"""Exception of early stopping."""

def __init__(self, best_iteration: int, best_score: float) -> None:
def __init__(self, best_iteration: int, best_score: _EvalResultTuple) -> None:
"""Create early stopping exception.
Parameters
----------
best_iteration : int
The best iteration stopped.
best_score : float
The score of the best iteration.
best_score : list of (eval_name, metric_name, eval_result, is_higher_better) tuple or (eval_name, metric_name, eval_result, is_higher_better, stdv) tuple
Scores for each metric, on each validation set, as of the best iteration.
"""
super().__init__()
self.best_iteration = best_iteration
Expand All @@ -44,7 +49,7 @@ def __init__(self, best_iteration: int, best_score: float) -> None:
"evaluation_result_list"])


def _format_eval_result(value: list, show_stdv: bool = True) -> str:
def _format_eval_result(value: _EvalResultTuple, show_stdv: bool = True) -> str:
"""Format metric string."""
if len(value) == 4:
return f"{value[0]}'s {value[1]}: {value[2]:g}"
Expand Down
22 changes: 11 additions & 11 deletions python-package/lightgbm/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,38 +223,38 @@ def train(
name_valid_sets.append(f'valid_{i}')
# process callbacks
if callbacks is None:
callbacks = set()
callbacks_set = set()
else:
for i, cb in enumerate(callbacks):
cb.__dict__.setdefault('order', i - len(callbacks))
callbacks = set(callbacks)
callbacks_set = set(callbacks)

# Most of legacy advanced options becomes callbacks
if verbose_eval != "warn":
_log_warning("'verbose_eval' argument is deprecated and will be removed in a future release of LightGBM. "
"Pass 'log_evaluation()' callback via 'callbacks' argument instead.")
else:
if callbacks: # assume user has already specified log_evaluation callback
if callbacks_set: # assume user has already specified log_evaluation callback
verbose_eval = False
else:
verbose_eval = True
if verbose_eval is True:
callbacks.add(callback.log_evaluation())
callbacks_set.add(callback.log_evaluation())
elif isinstance(verbose_eval, int):
callbacks.add(callback.log_evaluation(verbose_eval))
callbacks_set.add(callback.log_evaluation(verbose_eval))

if early_stopping_rounds is not None and early_stopping_rounds > 0:
callbacks.add(callback.early_stopping(early_stopping_rounds, first_metric_only, verbose=bool(verbose_eval)))
callbacks_set.add(callback.early_stopping(early_stopping_rounds, first_metric_only, verbose=bool(verbose_eval)))

if evals_result is not None:
_log_warning("'evals_result' argument is deprecated and will be removed in a future release of LightGBM. "
"Pass 'record_evaluation()' callback via 'callbacks' argument instead.")
callbacks.add(callback.record_evaluation(evals_result))
callbacks_set.add(callback.record_evaluation(evals_result))

callbacks_before_iter = {cb for cb in callbacks if getattr(cb, 'before_iteration', False)}
callbacks_after_iter = callbacks - callbacks_before_iter
callbacks_before_iter = sorted(callbacks_before_iter, key=attrgetter('order'))
callbacks_after_iter = sorted(callbacks_after_iter, key=attrgetter('order'))
callbacks_before_iter_set = {cb for cb in callbacks_set if getattr(cb, 'before_iteration', False)}
callbacks_after_iter_set = callbacks_set - callbacks_before_iter_set
callbacks_before_iter = sorted(callbacks_before_iter_set, key=attrgetter('order'))
callbacks_after_iter = sorted(callbacks_after_iter_set, key=attrgetter('order'))

# construct booster
try:
Expand Down

0 comments on commit 946817a

Please sign in to comment.