Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[python-package] Early Stopping does not work as expected #5354

Closed
Tracked by #5153
ZhiZhongWan opened this issue Jul 6, 2022 · 3 comments · Fixed by #5412
Closed
Tracked by #5153

[python-package] Early Stopping does not work as expected #5354

ZhiZhongWan opened this issue Jul 6, 2022 · 3 comments · Fixed by #5412
Labels

Comments

@ZhiZhongWan
Copy link

Description

I found that when early stopping is enabled and there're multiple validation sets in valid_sets, the LightGBM might won't save the best model as we expected.

Reproducible example

np.random.seed(2022)
toy_pd = pd.DataFrame({'feature1':np.random.randn(1000),'feature2':np.random.randn(1000),'label':np.concatenate((np.ones(500), np.zeros(500)))})
feats = [f for f in toy_pd if f not in ['label']]
print(feats)


params = {
    'learning_rate': 0.07,
    'boosting_type': 'gbdt',
    'objective': 'binary',
    'metric': ['auc'],
    'is_unbalance': False,
    'lambda_l1':0.5,
    'lambda_l2':10,
    'num_leaves':32,
    'max_depth':-1,
    'min_child_samples': 15, 
    'verbose': -1,
    'seed': 42,
    'n_jobs': 8,
}

fold_num = 5
kf = KFold(n_splits=fold_num, shuffle=True, random_state=42)
for fold, (train_idx, val_idx) in enumerate(kf.split(toy_pd[feats], toy_pd['label'])):
    print('-----------', fold, '-----------')
    train = lgb.Dataset(toy_pd.loc[train_idx, feats],
                        toy_pd.loc[train_idx,'label'])
    val = lgb.Dataset(toy_pd.loc[val_idx, feats],
                      toy_pd.loc[val_idx, 'label'])
    model = lgb.train(params, train, valid_sets=[train, val], num_boost_round=4,
                      callbacks=[lgb.early_stopping(2), lgb.log_evaluation(1)])


output:

----------- 0 -----------
[1]	training's auc: 0.63893	valid_1's auc: 0.582732
Training until validation scores don't improve for 2 rounds
[2]	training's auc: 0.639539	valid_1's auc: 0.580629
[3]	training's auc: 0.652984	valid_1's auc: 0.588742
[4]	training's auc: 0.678193	valid_1's auc: 0.570563
Did not meet early stopping. Best iteration is:
[4]	training's auc: 0.678193	valid_1's auc: 0.570563
...
----------- 3 -----------
[1]	training's auc: 0.649097	valid_1's auc: 0.575294
Training until validation scores don't improve for 2 rounds
[2]	training's auc: 0.657224	valid_1's auc: 0.566957
[3]	training's auc: 0.666932	valid_1's auc: 0.565524
Early stopping, best iteration is:
[1]	training's auc: 0.649097	valid_1's auc: 0.575294

As you can see, I clearly hope lgb could save the model that has the best auc on val because val is the last one in valid_sets list. However, in this example, it sometimes returns model which has the best performance on valid_sets[0], sometimes valid_sets[-1].

It seems that when it does not meet early stopping, something would go wrong.

I'm very confused about this. I fixed all random seeds so you can easily reproduce it.

Environment info

LightGBM version or commit hash:

'3.3.2'

Command(s) you used to install LightGBM

pip install lightgbm

Additional Comments

@jameslamb jameslamb changed the title Early Stopping does not work as expected [python-package] Early Stopping does not work as expected Jul 8, 2022
@jmoralez
Copy link
Collaborator

Hi @ZhiZhongWan, thanks for raising this and for the excellent example. I confirm this isn't working as expected. The problem is that even though the documentation says that the training set is ignored it actually isn't. The scores are saved here:

if self.best_score_list[i] is None or self.cmp_op[i](score, self.best_score[i]):
self.best_score[i] = score
self.best_iter[i] = env.iteration
self.best_score_list[i] = env.evaluation_result_list

And the training set is always the first one in the validation sets:
if valid_sets is not None:
if is_valid_contain_train:
evaluation_result_list.extend(booster.eval_train(feval))
evaluation_result_list.extend(booster.eval_valid(feval))

So on the final iteration check the training set is checked first:

if ((env.evaluation_result_list[i][0] == "cv_agg" and eval_name_splitted[0] == "train"
or env.evaluation_result_list[i][0] == env.model._train_data_name)):
self._final_iteration_check(env, eval_name_splitted, i)

This means that if the training score improved on the last iteration it will be saved as the best one, even if the validation score didn't.

I'll try to come up with a fix for this. One possible fix I see is actually ignoring the training set by removing this line:

self._final_iteration_check(env, eval_name_splitted, i)

@jmoralez jmoralez added bug and removed question labels Jul 12, 2022
@jmoralez jmoralez mentioned this issue Jul 12, 2022
60 tasks
@StrikerRUS
Copy link
Collaborator

Maybe this old discussion can help: #2371 (comment).

StrikerRUS pushed a commit that referenced this issue Aug 28, 2022
…#5354) (#5412)

* ignore training set on early stopping callback

* fixes

* lint

* Apply suggestions from code review

Co-authored-by: James Lamb <jaylamb20@gmail.com>

* trigger ci

Co-authored-by: James Lamb <jaylamb20@gmail.com>
@github-actions
Copy link

This issue has been automatically locked since there has not been any recent activity since it was closed. To start a new related discussion, open a new issue at https://github.com/microsoft/LightGBM/issues including a reference to this.

@github-actions github-actions bot locked as resolved and limited conversation to collaborators Aug 19, 2023
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants