Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[docs] The recommended way to use early stopping #5196

Closed
c60evaporator opened this issue May 4, 2022 · 4 comments
Closed

[docs] The recommended way to use early stopping #5196

c60evaporator opened this issue May 4, 2022 · 4 comments
Labels

Comments

@c60evaporator
Copy link

c60evaporator commented May 4, 2022

Description

I suppose there are three ways to enable early stopping in Python Training API

  1. Setting early_stopping_rounds argument of train() function.
  2. Setting early_stopping_round in params argument of train() function.
  3. Passing early_stooping() callback via 'callbacks' argument of train() function.

I know that the first way is deprecated and will be removed according to the following update.

#4908

I'd like to know which way is the most recommended way to use early stopping.

Reproducible example

Python Quick Start page says that the third way (early_stopping() callback) is recommended.

Parameters Tuning page says that the second way (early_stopping_round in params) is recommended.

Motivation

I would appreciate it if we could unify the ways into one throughout the document to prevent confusion.

For example, if we choose the third way (early_stopping() callback) as the most recommended way, I'd like to modify Parameters Tuning page as follows.

Before

That "number of consecutive rounds" is controlled by the parameter ``early_stopping_round``. For example, ``early_stopping_round=1`` says "the first time accuracy on the validation set does not improve, stop training".

Set ``early_stopping_round`` and provide a validation set to possibly reduce training time.

After

That "number of consecutive rounds" is controlled by the parameter ``stopping_rounds`` in ``early_stopping`` callback constructor. For example, ``stopping_rounds=1`` says "the first time accuracy on the validation set does not improve, stop training". Also, early stopping requires at least one validation set in ``valid_sets``.

Use ``early_stopping`` callback and provide a validation set to possibly reduce training time.
@ddelange
Copy link

ddelange commented Sep 1, 2023

fyi: option 1 (early_stopping_rounds) has been removed from train(), cv(), and fit() as of v4.0.0: #4908 and #4846

ddelange added a commit to ddelange/vaex that referenced this issue Sep 1, 2023
ddelange added a commit to ddelange/vaex that referenced this issue Sep 1, 2023
@jameslamb
Copy link
Collaborator

Thanks for using LightGBM and for the thorough report. Sorry it took so long for someone to answer you here!

As of v4.0.0, you can use either approach 2 or 3 from your original post.

Consider the following example, with a metric that improves on each iteration and then starts getting worse after the 4th iteration.

import lightgbm as lgb
from sklearn.datasets import load_breast_cancer
from sklearn.model_selection import train_test_split

X, y = load_breast_cancer(return_X_y=True)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

dtrain = lgb.Dataset(X, label=y)
dvalid = lgb.Dataset(X_test, label=y_test)


class CustomMetric:
    """Metric that starts getting worse after the 4th iteration"""
    def __init__(self):
        self.scores = [1.0, 0.9, 0.8, 0.7, 0.8, 0.8, 0.8, 0.8, 0.8, 0.8]

    def __call__(self, y_true, y_pred):
        is_higher_better = False
        return "decreasing_metric", self.scores.pop(0), is_higher_better

params = {
    "verbosity": 1,
    "max_depth": 3,
    "num_leaves": 7,
    "min_data_in_leaf": 5,
}

# approach 1: use early_stopping() callback
bst_1 = lgb.train(
    params=params,
    num_boost_round=10,
    train_set=dtrain,
    valid_sets=[dvalid],
    feval=CustomMetric(),
    callbacks=[
        lgb.early_stopping(stopping_rounds=3),
        lgb.log_evaluation(1)
    ]
)

# approach 2: pass "early_stopping_round" through paramms
bst_2 = lgb.train(
    params={**params, "early_stopping_round": 3},
    num_boost_round=10,
    train_set=dtrain,
    valid_sets=[dvalid],
    feval=CustomMetric(),
    callbacks=[
        lgb.log_evaluation(1)
    ]
)

For both, you'll see logs indicating that 7 rounds of training happened (4 where the performance improved + the 3 it took to trigger early stopping):

[LightGBM] [Info] Total Bins 5676
[LightGBM] [Info] Number of data points in the train set: 569, number of used features: 30
[LightGBM] [Info] Start training from score 0.627417
Training until validation scores don't improve for 3 rounds
[1]	valid_0's l2: 0.194477	valid_0's decreasing_metric: 1
[2]	valid_0's l2: 0.161398	valid_0's decreasing_metric: 0.9
[3]	valid_0's l2: 0.134389	valid_0's decreasing_metric: 0.8
[4]	valid_0's l2: 0.112313	valid_0's decreasing_metric: 0.7
[5]	valid_0's l2: 0.0941232	valid_0's decreasing_metric: 0.8
[6]	valid_0's l2: 0.0800353	valid_0's decreasing_metric: 0.8
[7]	valid_0's l2: 0.0676462	valid_0's decreasing_metric: 0.8
Early stopping, best iteration is:
[4]	valid_0's l2: 0.112313	valid_0's decreasing_metric: 0.7

And that each Booster's best_iteration is set to 4.

bst_1.best_iteration
# 4

bst_2.best_iteration
# 4

I'd like to know which way is the most recommended way to use early stopping.

Both are equally valid, and there are no plans to remove support for either.

I personally tend to prefer the approach where early_stopping_round is passed via params, because then all params are in one place, but that's a minor personal preference. Since v4.0.0 made all callbacks serializable (#5080), I can't think of a strong reason that you might prefer one to the other.

I'd like to modify Parameters Tuning page as follows

I don't support those changes to the documentation.

Please note that LightGBM is not just a Python package... it is also a command-line interface (CLI), a redistributable C++ shared library with a C API, an R package, a Java jar, and more. Most of those other interfaces don't have a concept of "callbacks", and the Parameters Tuning doc you linked to is intended to be generically useful for all LightGBM users.

The Python Quickstart is intended to be just that... a quickstart. I also wouldn't support adding language to it documenting the multiple different ways to enable early stopping. The purpose of that document is to get users from "just installed LightGBM" to "trained a model" as fast as possible.

@c60evaporator
Copy link
Author

c60evaporator commented Sep 8, 2023

Thank you for the detailed explanation.

I understood that the following methods have no priority and there is no plan to remove either of them.
I also understood which one is used depends on personal preferences.

  1. Setting early_stopping_round in params argument of train() function.
  2. Passing early_stooping() callback via 'callbacks' argument of train() function.

Copy link

This issue has been automatically locked since there has not been any recent activity since it was closed. To start a new related discussion, open a new issue at https://github.com/microsoft/LightGBM/issues including a reference to this.

@github-actions github-actions bot locked as resolved and limited conversation to collaborators Dec 13, 2023
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
Projects
None yet
Development

No branches or pull requests

3 participants