From 12915d58139c8d25ccc8c4edb0bb2e2d38144291 Mon Sep 17 00:00:00 2001 From: Nikita Titov Date: Sun, 5 Dec 2021 08:10:28 +0300 Subject: [PATCH] [python][sklearn] unify values of `best_iteration` for sklearn and standard APIs (#4845) * unify values of `best_iteration` for sklearn and standard APIs * update Dask test --- python-package/lightgbm/sklearn.py | 8 ++------ tests/python_package_test/test_dask.py | 2 +- 2 files changed, 3 insertions(+), 7 deletions(-) diff --git a/python-package/lightgbm/sklearn.py b/python-package/lightgbm/sklearn.py index 504917601b8f..35cb059de024 100644 --- a/python-package/lightgbm/sklearn.py +++ b/python-package/lightgbm/sklearn.py @@ -785,11 +785,7 @@ def _get_meta_data(collection, name, i): else: # reset after previous call to fit() self._evals_result = None - if self._Booster.best_iteration != 0: - self._best_iteration = self._Booster.best_iteration - else: # reset after previous call to fit() - self._best_iteration = None - + self._best_iteration = self._Booster.best_iteration self._best_score = self._Booster.best_score self.fitted_ = True @@ -872,7 +868,7 @@ def best_score_(self): @property def best_iteration_(self): - """:obj:`int` or :obj:`None`: The best iteration of fitted model if ``early_stopping()`` callback has been specified.""" + """:obj:`int`: The best iteration of fitted model if ``early_stopping()`` callback has been specified.""" if not self.__sklearn_is_fitted__(): raise LGBMNotFittedError('No best_iteration found. Need to call fit with early_stopping callback beforehand.') return self._best_iteration diff --git a/tests/python_package_test/test_dask.py b/tests/python_package_test/test_dask.py index 6b2327ec70a6..6a1e4afe567a 100644 --- a/tests/python_package_test/test_dask.py +++ b/tests/python_package_test/test_dask.py @@ -925,7 +925,7 @@ def test_eval_set_no_early_stopping(task, output, eval_sizes, eval_names_prefix, # check that early stopping was not applied. assert dask_model.booster_.num_trees() == model_trees - assert dask_model.best_iteration_ is None + assert dask_model.best_iteration_ == 0 # checks that evals_result_ and best_score_ contain expected data and eval_set names. evals_result = dask_model.evals_result_