Skip to content

Commit

Permalink
remove early_stopping_rounds argument of train() and cv() funct…
Browse files Browse the repository at this point in the history
…ions
  • Loading branch information
StrikerRUS committed Dec 23, 2021
1 parent cace5bb commit b133600
Show file tree
Hide file tree
Showing 6 changed files with 114 additions and 62 deletions.
4 changes: 2 additions & 2 deletions docs/Parameters-Tuning.rst
Original file line number Diff line number Diff line change
Expand Up @@ -108,9 +108,9 @@ Use Early Stopping

If early stopping is enabled, after each boosting round the model's training accuracy is evaluated against a validation set that contains data not available to the training process. That accuracy is then compared to the accuracy as of the previous boosting round. If the model's accuracy fails to improve for some number of consecutive rounds, LightGBM stops the training process.

That "number of consecutive rounds" is controlled by the parameter ``early_stopping_rounds``. For example, ``early_stopping_rounds=1`` says "the first time accuracy on the validation set does not improve, stop training".
That "number of consecutive rounds" is controlled by the parameter ``early_stopping_round``. For example, ``early_stopping_round=1`` says "the first time accuracy on the validation set does not improve, stop training".

Set ``early_stopping_rounds`` and provide a validation set to possibly reduce training time.
Set ``early_stopping_round`` and provide a validation set to possibly reduce training time.

Consider Fewer Splits
'''''''''''''''''''''
Expand Down
8 changes: 4 additions & 4 deletions docs/Python-Intro.rst
Original file line number Diff line number Diff line change
Expand Up @@ -228,18 +228,18 @@ Early stopping requires at least one set in ``valid_sets``. If there is more tha

.. code:: python
bst = lgb.train(param, train_data, num_round, valid_sets=valid_sets, early_stopping_rounds=5)
bst = lgb.train(param, train_data, num_round, valid_sets=valid_sets, callbacks=[lgb.early_stopping(stopping_rounds=5)])
bst.save_model('model.txt', num_iteration=bst.best_iteration)
The model will train until the validation score stops improving.
Validation score needs to improve at least every ``early_stopping_rounds`` to continue training.
Validation score needs to improve at least every ``stopping_rounds`` to continue training.

The index of iteration that has the best performance will be saved in the ``best_iteration`` field if early stopping logic is enabled by setting ``early_stopping_rounds``.
The index of iteration that has the best performance will be saved in the ``best_iteration`` field if early stopping logic is enabled by setting ``early_stopping`` callback.
Note that ``train()`` will return a model from the best iteration.

This works with both metrics to minimize (L2, log loss, etc.) and to maximize (NDCG, AUC, etc.).
Note that if you specify more than one evaluation metric, all of them will be used for early stopping.
However, you can change this behavior and make LightGBM check only the first metric for early stopping by passing ``first_metric_only=True`` in ``param`` or ``early_stopping`` callback constructor.
However, you can change this behavior and make LightGBM check only the first metric for early stopping by passing ``first_metric_only=True`` in ``early_stopping`` callback constructor.

Prediction
----------
Expand Down
2 changes: 1 addition & 1 deletion examples/python-guide/simple_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
lgb_train,
num_boost_round=20,
valid_sets=lgb_eval,
early_stopping_rounds=5)
callbacks=[lgb.early_stopping(stopping_rounds=5)])

print('Saving model...')
# save model to file
Expand Down
3 changes: 3 additions & 0 deletions python-package/lightgbm/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,9 @@ def _init(env: CallbackEnv) -> None:
raise ValueError('For early stopping, '
'at least one dataset and eval metric is required for evaluation')

if stopping_rounds <= 0:
raise ValueError("stopping_rounds should be greater than zero.")

if verbose:
_log_info(f"Training until validation scores don't improve for {stopping_rounds} rounds")

Expand Down
84 changes: 44 additions & 40 deletions python-package/lightgbm/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
import numpy as np

from . import callback
from .basic import Booster, Dataset, LightGBMError, _ArrayLike, _ConfigAliases, _InnerPredictor, _log_warning
from .basic import (Booster, Dataset, LightGBMError, _ArrayLike, _ConfigAliases,
_InnerPredictor, _choose_param_value, _log_warning)
from .compat import SKLEARN_INSTALLED, _LGBMGroupKFold, _LGBMStratifiedKFold

_LGBM_CustomObjectiveFunction = Callable[
Expand All @@ -33,7 +34,6 @@ def train(
init_model: Optional[Union[str, Path, Booster]] = None,
feature_name: Union[List[str], str] = 'auto',
categorical_feature: Union[List[str], List[int], str] = 'auto',
early_stopping_rounds: Optional[int] = None,
keep_training_booster: bool = False,
callbacks: Optional[List[Callable]] = None
) -> Booster:
Expand Down Expand Up @@ -109,15 +109,6 @@ def train(
Large values could be memory consuming. Consider using consecutive integers starting from zero.
All negative values in categorical features will be treated as missing values.
The output cannot be monotonically constrained with respect to a categorical feature.
early_stopping_rounds : int or None, optional (default=None)
Activates early stopping. The model will train until the validation score stops improving.
Validation score needs to improve at least every ``early_stopping_rounds`` round(s)
to continue training.
Requires at least one validation data and one metric.
If there's more than one, will check all of them. But the training data is ignored anyway.
To check only the first metric, set the ``first_metric_only`` parameter to ``True`` in ``params``.
The index of iteration that has the best performance will be saved in the ``best_iteration`` field
if early stopping logic is enabled by setting ``early_stopping_rounds``.
keep_training_booster : bool, optional (default=False)
Whether the returned Booster will be used to keep training.
If False, the returned value will be converted into _InnerPredictor before returning.
Expand Down Expand Up @@ -145,14 +136,14 @@ def train(
num_boost_round = params.pop(alias)
_log_warning(f"Found `{alias}` in params. Will use it instead of argument")
params["num_iterations"] = num_boost_round
# show deprecation warning only for early stop argument, setting early stop via global params should still be possible
if early_stopping_rounds is not None and early_stopping_rounds > 0:
_log_warning("'early_stopping_rounds' argument is deprecated and will be removed in a future release of LightGBM. "
"Pass 'early_stopping()' callback via 'callbacks' argument instead.")
for alias in _ConfigAliases.get("early_stopping_round"):
if alias in params:
early_stopping_rounds = params.pop(alias)
params["early_stopping_round"] = early_stopping_rounds
# setting early stopping via global params should be possible
params = _choose_param_value(
main_param_name="early_stopping_round",
params=params,
default_value=None
)
if params["early_stopping_round"] is None:
params.pop("early_stopping_round")
first_metric_only = params.get('first_metric_only', False)

if num_boost_round <= 0:
Expand Down Expand Up @@ -203,9 +194,18 @@ def train(
cb.__dict__.setdefault('order', i - len(callbacks))
callbacks_set = set(callbacks)

# Most of legacy advanced options becomes callbacks
if early_stopping_rounds is not None and early_stopping_rounds > 0:
callbacks_set.add(callback.early_stopping(early_stopping_rounds, first_metric_only))
if "early_stopping_round" in params:
callbacks_set.add(
callback.early_stopping(
stopping_rounds=params["early_stopping_round"],
first_metric_only=first_metric_only,
verbose=_choose_param_value(
main_param_name="verbosity",
params=params,
default_value=1
).pop("verbosity") > 0
)
)

callbacks_before_iter_set = {cb for cb in callbacks_set if getattr(cb, 'before_iteration', False)}
callbacks_after_iter_set = callbacks_set - callbacks_before_iter_set
Expand Down Expand Up @@ -381,8 +381,7 @@ def cv(params, train_set, num_boost_round=100,
folds=None, nfold=5, stratified=True, shuffle=True,
metrics=None, fobj=None, feval=None, init_model=None,
feature_name='auto', categorical_feature='auto',
early_stopping_rounds=None, fpreproc=None,
seed=0, callbacks=None, eval_train_metric=False,
fpreproc=None, seed=0, callbacks=None, eval_train_metric=False,
return_cvbooster=False):
"""Perform the cross-validation with given parameters.
Expand Down Expand Up @@ -467,13 +466,6 @@ def cv(params, train_set, num_boost_round=100,
Large values could be memory consuming. Consider using consecutive integers starting from zero.
All negative values in categorical features will be treated as missing values.
The output cannot be monotonically constrained with respect to a categorical feature.
early_stopping_rounds : int or None, optional (default=None)
Activates early stopping.
CV score needs to improve at least every ``early_stopping_rounds`` round(s)
to continue.
Requires at least one metric. If there's more than one, will check all of them.
To check only the first metric, set the ``first_metric_only`` parameter to ``True`` in ``params``.
Last entry in evaluation history is the one from the best iteration.
fpreproc : callable or None, optional (default=None)
Preprocessing function that takes (dtrain, dtest, params)
and returns transformed versions of those.
Expand Down Expand Up @@ -511,13 +503,14 @@ def cv(params, train_set, num_boost_round=100,
_log_warning(f"Found '{alias}' in params. Will use it instead of 'num_boost_round' argument")
num_boost_round = params.pop(alias)
params["num_iterations"] = num_boost_round
if early_stopping_rounds is not None and early_stopping_rounds > 0:
_log_warning("'early_stopping_rounds' argument is deprecated and will be removed in a future release of LightGBM. "
"Pass 'early_stopping()' callback via 'callbacks' argument instead.")
for alias in _ConfigAliases.get("early_stopping_round"):
if alias in params:
early_stopping_rounds = params.pop(alias)
params["early_stopping_round"] = early_stopping_rounds
# setting early stopping via global params should be possible
params = _choose_param_value(
main_param_name="early_stopping_round",
params=params,
default_value=None
)
if params["early_stopping_round"] is None:
params.pop("early_stopping_round")
first_metric_only = params.get('first_metric_only', False)

if num_boost_round <= 0:
Expand Down Expand Up @@ -552,8 +545,19 @@ def cv(params, train_set, num_boost_round=100,
for i, cb in enumerate(callbacks):
cb.__dict__.setdefault('order', i - len(callbacks))
callbacks = set(callbacks)
if early_stopping_rounds is not None and early_stopping_rounds > 0:
callbacks.add(callback.early_stopping(early_stopping_rounds, first_metric_only, verbose=False))

if "early_stopping_round" in params:
callbacks.add(
callback.early_stopping(
stopping_rounds=params["early_stopping_round"],
first_metric_only=first_metric_only,
verbose=_choose_param_value(
main_param_name="verbosity",
params=params,
default_value=1
).pop("verbosity") > 0
)
)

callbacks_before_iter = {cb for cb in callbacks if getattr(cb, 'before_iteration', False)}
callbacks_after_iter = callbacks - callbacks_before_iter
Expand Down
75 changes: 60 additions & 15 deletions tests/python_package_test/test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -741,7 +741,7 @@ def test_early_stopping():
num_boost_round=10,
valid_sets=lgb_eval,
valid_names=valid_set_name,
early_stopping_rounds=5)
callbacks=[lgb.early_stopping(stopping_rounds=5)])
assert gbm.best_iteration == 10
assert valid_set_name in gbm.best_score
assert 'binary_logloss' in gbm.best_score[valid_set_name]
Expand All @@ -750,12 +750,42 @@ def test_early_stopping():
num_boost_round=40,
valid_sets=lgb_eval,
valid_names=valid_set_name,
early_stopping_rounds=5)
callbacks=[lgb.early_stopping(stopping_rounds=5)])
assert gbm.best_iteration <= 39
assert valid_set_name in gbm.best_score
assert 'binary_logloss' in gbm.best_score[valid_set_name]


@pytest.mark.parametrize('first_metric_only', [True, False])
def test_early_stopping_via_global_params(first_metric_only):
X, y = load_breast_cancer(return_X_y=True)
num_trees = 5
params = {
'num_trees': num_trees,
'objective': 'binary',
'metric': 'None',
'verbose': -1,
'early_stopping_round': 2,
'first_metric_only': first_metric_only
}
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1, random_state=42)
lgb_train = lgb.Dataset(X_train, y_train)
lgb_eval = lgb.Dataset(X_test, y_test, reference=lgb_train)
valid_set_name = 'valid_set'
gbm = lgb.train(params,
lgb_train,
feval=[decreasing_metric, constant_metric],
valid_sets=lgb_eval,
valid_names=valid_set_name)
if first_metric_only:
assert gbm.best_iteration == num_trees
else:
assert gbm.best_iteration == 1
assert valid_set_name in gbm.best_score
assert 'decreasing_metric' in gbm.best_score[valid_set_name]
assert 'error' in gbm.best_score[valid_set_name]


@pytest.mark.parametrize('first_only', [True, False])
@pytest.mark.parametrize('single_metric', [True, False])
@pytest.mark.parametrize('greater_is_better', [True, False])
Expand Down Expand Up @@ -808,7 +838,7 @@ def test_early_stopping_min_delta(first_only, single_metric, greater_is_better):
# regular early stopping
evals_result = {}
train_kwargs['callbacks'] = [
lgb.callback.early_stopping(10, first_only, verbose=0),
lgb.callback.early_stopping(10, first_only, verbose=False),
lgb.record_evaluation(evals_result)
]
bst = lgb.train(**train_kwargs)
Expand All @@ -817,7 +847,7 @@ def test_early_stopping_min_delta(first_only, single_metric, greater_is_better):
# positive min_delta
delta_result = {}
train_kwargs['callbacks'] = [
lgb.callback.early_stopping(10, first_only, verbose=0, min_delta=min_delta),
lgb.callback.early_stopping(10, first_only, verbose=False, min_delta=min_delta),
lgb.record_evaluation(delta_result)
]
delta_bst = lgb.train(**train_kwargs)
Expand Down Expand Up @@ -998,8 +1028,8 @@ def test_cvbooster():
# with early stopping
cv_res = lgb.cv(params, lgb_train,
num_boost_round=25,
early_stopping_rounds=5,
nfold=3,
callbacks=[lgb.early_stopping(stopping_rounds=5)],
return_cvbooster=True)
assert 'cvbooster' in cv_res
cvb = cv_res['cvbooster']
Expand Down Expand Up @@ -2371,9 +2401,14 @@ def metrics_combination_train_regression(valid_sets, metric_list, assumed_iterat
'verbose': -1,
'seed': 123
}
gbm = lgb.train(dict(params, first_metric_only=first_metric_only), lgb_train,
num_boost_round=25, valid_sets=valid_sets, feval=feval,
early_stopping_rounds=5)
gbm = lgb.train(
params,
lgb_train,
num_boost_round=25,
valid_sets=valid_sets,
feval=feval,
callbacks=[lgb.early_stopping(stopping_rounds=5, first_metric_only=first_metric_only)]
)
assert assumed_iteration == gbm.best_iteration

def metrics_combination_cv_regression(metric_list, assumed_iteration,
Expand All @@ -2387,11 +2422,15 @@ def metrics_combination_cv_regression(metric_list, assumed_iteration,
'seed': 123,
'gpu_use_dp': True
}
ret = lgb.cv(dict(params, first_metric_only=first_metric_only),
train_set=lgb_train, num_boost_round=25,
stratified=False, feval=feval,
early_stopping_rounds=5,
eval_train_metric=eval_train_metric)
ret = lgb.cv(
params,
train_set=lgb_train,
num_boost_round=25,
stratified=False,
feval=feval,
callbacks=[lgb.early_stopping(stopping_rounds=5, first_metric_only=first_metric_only)],
eval_train_metric=eval_train_metric
)
assert assumed_iteration == len(ret[list(ret.keys())[0]])

X, y = load_boston(return_X_y=True)
Expand Down Expand Up @@ -2956,8 +2995,14 @@ def inner_test(X, y, params, early_stopping_rounds):
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1, random_state=42)
train_data = lgb.Dataset(X_train, label=y_train)
valid_data = lgb.Dataset(X_test, label=y_test)
booster = lgb.train(params, train_data, num_boost_round=50, early_stopping_rounds=early_stopping_rounds,
valid_sets=[valid_data])
callbacks = [lgb.early_stopping(early_stopping_rounds)] if early_stopping_rounds is not None else []
booster = lgb.train(
params,
train_data,
num_boost_round=50,
valid_sets=[valid_data],
callbacks=callbacks
)

# test that the predict once with all iterations equals summed results with start_iteration and num_iteration
all_pred = booster.predict(X, raw_score=True)
Expand Down

0 comments on commit b133600

Please sign in to comment.