Skip to content

Commit

Permalink
[python][sklearn] unify values of best_iteration for sklearn and st…
Browse files Browse the repository at this point in the history
…andard APIs (#4845)

* unify values of `best_iteration` for sklearn and standard APIs

* update Dask test
  • Loading branch information
StrikerRUS authored Dec 5, 2021
1 parent 7d90174 commit 12915d5
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 7 deletions.
8 changes: 2 additions & 6 deletions python-package/lightgbm/sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -785,11 +785,7 @@ def _get_meta_data(collection, name, i):
else: # reset after previous call to fit()
self._evals_result = None

if self._Booster.best_iteration != 0:
self._best_iteration = self._Booster.best_iteration
else: # reset after previous call to fit()
self._best_iteration = None

self._best_iteration = self._Booster.best_iteration
self._best_score = self._Booster.best_score

self.fitted_ = True
Expand Down Expand Up @@ -872,7 +868,7 @@ def best_score_(self):

@property
def best_iteration_(self):
""":obj:`int` or :obj:`None`: The best iteration of fitted model if ``early_stopping()`` callback has been specified."""
""":obj:`int`: The best iteration of fitted model if ``early_stopping()`` callback has been specified."""
if not self.__sklearn_is_fitted__():
raise LGBMNotFittedError('No best_iteration found. Need to call fit with early_stopping callback beforehand.')
return self._best_iteration
Expand Down
2 changes: 1 addition & 1 deletion tests/python_package_test/test_dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -925,7 +925,7 @@ def test_eval_set_no_early_stopping(task, output, eval_sizes, eval_names_prefix,

# check that early stopping was not applied.
assert dask_model.booster_.num_trees() == model_trees
assert dask_model.best_iteration_ is None
assert dask_model.best_iteration_ == 0

# checks that evals_result_ and best_score_ contain expected data and eval_set names.
evals_result = dask_model.evals_result_
Expand Down

0 comments on commit 12915d5

Please sign in to comment.