LearningRateFinder
defining max validation batches for entire training loop
#17412
Labels
Milestone
LearningRateFinder
defining max validation batches for entire training loop
#17412
Bug description
When the LearningRateFinder callback is used, the num_training_steps parameter that is passed on init (default: 100) ends up defining how many validation batches to run during the entire length of training. Meaning that if num_training_steps in the learning rate finder is less than the total number of batches in your validation set, then all validation loops while training will only see a subset of the validation data.
What version are you seeing the problem on?
2.0+
How to reproduce the bug
This code will fail because
trainer.num_val_batches[0] = 5
.Using the same base code from above but removing the
LearningRateFinder
, this code passes.However,
num_val_batches
does get updated once.validate()
is called. Putting theLearningRateFinder
back in but moving the assert statement, this code passes:Environment
Current environment
More info
I did some more digging into why this might be happening and it looks like the problem is likely coming from the fact that
trainer.fit_loop.epoch_loop.val_loop.setup_data()
is getting called for the first time while the learning rate finder is running, sotrainer.fit_loop.epoch_loop.val_loop._max_batches
gets set according to the parameters that the learning rate finder has passed in.Even though the learning rate finder restores the parameters that the trainer initially set once it is done, the
setup_data()
method never runs a full setup again, so the_max_batches
attribute never gets updated again.One solution to fix this might be to redo the data setup once the learning rate finder has completed, like how setup is redone when
.validate()
is calledThe text was updated successfully, but these errors were encountered: