-
Notifications
You must be signed in to change notification settings - Fork 3.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[docs] The recommended way to use early stopping #5196
Comments
Thanks for using LightGBM and for the thorough report. Sorry it took so long for someone to answer you here! As of v4.0.0, you can use either approach 2 or 3 from your original post. Consider the following example, with a metric that improves on each iteration and then starts getting worse after the 4th iteration. import lightgbm as lgb
from sklearn.datasets import load_breast_cancer
from sklearn.model_selection import train_test_split
X, y = load_breast_cancer(return_X_y=True)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
dtrain = lgb.Dataset(X, label=y)
dvalid = lgb.Dataset(X_test, label=y_test)
class CustomMetric:
"""Metric that starts getting worse after the 4th iteration"""
def __init__(self):
self.scores = [1.0, 0.9, 0.8, 0.7, 0.8, 0.8, 0.8, 0.8, 0.8, 0.8]
def __call__(self, y_true, y_pred):
is_higher_better = False
return "decreasing_metric", self.scores.pop(0), is_higher_better
params = {
"verbosity": 1,
"max_depth": 3,
"num_leaves": 7,
"min_data_in_leaf": 5,
}
# approach 1: use early_stopping() callback
bst_1 = lgb.train(
params=params,
num_boost_round=10,
train_set=dtrain,
valid_sets=[dvalid],
feval=CustomMetric(),
callbacks=[
lgb.early_stopping(stopping_rounds=3),
lgb.log_evaluation(1)
]
)
# approach 2: pass "early_stopping_round" through paramms
bst_2 = lgb.train(
params={**params, "early_stopping_round": 3},
num_boost_round=10,
train_set=dtrain,
valid_sets=[dvalid],
feval=CustomMetric(),
callbacks=[
lgb.log_evaluation(1)
]
) For both, you'll see logs indicating that 7 rounds of training happened (4 where the performance improved + the 3 it took to trigger early stopping):
And that each Booster's bst_1.best_iteration
# 4
bst_2.best_iteration
# 4
Both are equally valid, and there are no plans to remove support for either. I personally tend to prefer the approach where
I don't support those changes to the documentation. Please note that LightGBM is not just a Python package... it is also a command-line interface (CLI), a redistributable C++ shared library with a C API, an R package, a Java jar, and more. Most of those other interfaces don't have a concept of "callbacks", and the The |
Thank you for the detailed explanation. I understood that the following methods have no priority and there is no plan to remove either of them.
|
This issue has been automatically locked since there has not been any recent activity since it was closed. To start a new related discussion, open a new issue at https://github.com/microsoft/LightGBM/issues including a reference to this. |
Description
I suppose there are three ways to enable early stopping in Python Training API
early_stopping_rounds
argument oftrain()
function.early_stopping_round
inparams
argument oftrain()
function.early_stooping()
callback via 'callbacks' argument oftrain()
function.I know that the first way is deprecated and will be removed according to the following update.
#4908
I'd like to know which way is the most recommended way to use early stopping.
Reproducible example
Python Quick Start page says that the third way (
early_stopping()
callback) is recommended.Parameters Tuning page says that the second way (
early_stopping_round
inparams
) is recommended.Motivation
I would appreciate it if we could unify the ways into one throughout the document to prevent confusion.
For example, if we choose the third way (
early_stopping()
callback) as the most recommended way, I'd like to modify Parameters Tuning page as follows.Before
After
The text was updated successfully, but these errors were encountered: