-
Notifications
You must be signed in to change notification settings - Fork 27.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add early stopping callback to pytorch trainer #8581
Add early stopping callback to pytorch trainer #8581
Conversation
… to prevent early stopping to pytorch trainer
Hi there. Thanks your PR! When I was designing the callbacks, it was to be them small independent pieces of code. I would prefer if early stopping had its own callback that the user would then choose to add or not. Do you think you could amend your PR in that direction? |
Hello, thank you for your feedback! I will amend the PR in that direction. Could you clarify which pieces of early stopping should be in class EarlyStoppingCallback(TrainerCallback):
best_metric: Optional[float] = None # maybe not this
best_model_checkpoint: Optional[str] = None # maybe not this either
early_stopping_patience: int = None
early_stopping_patience_counter: int = None
def on_evaluate(self, args, state, control, **kwargs):
# Keep track of patience
# End training via early stopping
if (
self.early_stopping_patience is not None
and self.early_sotpping_patience_counter >= self.early_stopping_patience
):
control.should_training_stop = True |
Or do you mean I just move the if statement I added to its own callback and keep |
The |
That makes sense. I think this block of code (to line 933) could be a callback because it's all about the best metric. Then users could customize the best model calculations. Is that desirable? If you think that's out of scope I'll keep the early stopping callback simple and separate from the best metric calculation. |
I had put it in |
Sounds good, you know best! I keep |
…est metric for early stopping callback to trigger on.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A few mote things to change, but we're close to get this in good state. Thanks a lot for your work on this!
metric_value = metrics.get(metric_to_check) | ||
|
||
if metric_value is None: | ||
logger.warning( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good warning!
self.early_stopping_patience_counter += 1 | ||
|
||
def on_train_begin(self, args, state, control, **kwargs): | ||
assert args.load_best_model_at_end, "EarlyStoppingCallback requires load_best_model_at_end = True" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I still don't understand why this line is necessary? I feel we should be able to use this callback without the option load_best_model_at_end
? The other sanity checks are perfectly ok.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is necessary because we require control.should_save=True
for _save_checkpoint to update the best metric. Should I move the best metric calculation into its own function and place it in the should_evaluate
block?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree that it's not fully intuitive to need load_best_model_at_end
, but it makes sense to me because if we don't load the best model early stopping will stop us, but the model we receive back from training will not be the model early stopping deemed best.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok let's leave it as is for now then, and we will re-evaluate if some users complain!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Saw this issue while debugging something. It doesn't seem intuitive how these two are related, so can we please do what @cbrochtrup suggested above?
Thanks for your thorough and affable review! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great addition, LGTM!
Summary
Address PyTorch half of #4894 by adding early stopping patience and a minimum threshold metrics must improve to prevent early stopping. I piggybacked heavily off of #7431 since the two functions are very similar.
Since #4186 seems to be abandoned and behind master, I figured I'd take a crack at this.
Who can review?
Anyone! But @julien-c and @sgugger seem the most appropriate.