-
Notifications
You must be signed in to change notification settings - Fork 3.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Model checkpointing for sub-epoch frequency #1758
Comments
Agree, when training on very large datasets (like you do with large transformer models) it would be very nice to be able to save checkpoints on validation! |
i thought this was already the case? |
This is a bit of a hack but it seems to work: |
I think @artidoro has struck on the issue. The second condition is always Perhaps the easiest fix is to set |
Currently period is an integer value and the behaviour @Anjum48 described is expected: If we check validation multiple times during one epoch but period=1, we expect only one time to save a checkpoint. So I don't see any bugs here (but perhaps we could document better the edge cases). In PL currently it only makes sense to checkpoint at epoch intervals, because the training can anyway not be restored mid-epoch currently. So the edge case where we have val_check_interval < 1 is unfortunate. In my opinion this case should not at all save checkpoints mid epoch and only on Maybe @jeremyjordan has also comments about this. |
I'm confused by this terminology. As I understand it, epoch end refers to when we have made a full pass through the training dataset. Validation end would refer to when we have made a full pass through the validation set, which depending on the trainer setting, could happen multiple times per training epoch. I think I understand what you're saying though that we should only be saving one checkpoint at the end of an epoch? I'm also confused about mid-epoch checkpoints. The docs reference:
which implies that we can restart training mid-epoch. However, if you look at the training loop it does appear that we always start at the beginning of an epoch and don't respect the global step loaded from a checkpoint. |
I used the terminology wrong. What I meant to say is simply that currenly we save checkpoints on validation_end (for a good reason) but it might be undesireable to do so in case validation_end happens mid-epoch, because restoring from such a checkpoint leads to incorrectly restored Trainer state (as you pointed out with global_step). So we need to solve this problem first. |
This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions. |
What is the status on this? We do support sub-epoch checkpointing, but as @awaelchli mentioned, it doesn't restore to the global checkpoint? |
yeah... i don't think we can actually restore sub epoch state though? know of a way to do that? we'd have to pull the shuffle state out of the loaders |
Idk we could save the whole dataloader but I don't think that makes sense |
This issue has been automatically marked as stale because it hasn't had any recent activity. This issue will be closed in 7 days if no further activity occurs. Thank you for your contributions, Pytorch Lightning Team! |
❓ Questions and Help
What is your question?
In my
Trainer
I haveval_check_interval=0.25
, which is great and I can see that the validation loop is run 4 times per epoch as expected.Is there a way for
ModelCheckpoint
to use the validation checks to trigger model saving at a sub-epoch frequency? I can see (by settingverbose=True
) that the check is only done on epoch endWhat's your environment?
The text was updated successfully, but these errors were encountered: