-
Notifications
You must be signed in to change notification settings - Fork 166
[DataLoader2] Adding support for naive checkpointing #1119
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
[ghstack-poisoned]
[ghstack-poisoned]
@NivekT has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
Differential Revision: [D44712802](https://our.internmc.facebook.com/intern/diff/D44712802) [ghstack-poisoned]
Differential Revision: [D44712802](https://our.internmc.facebook.com/intern/diff/D44712802) [ghstack-poisoned]
Differential Revision: [D44712802](https://our.internmc.facebook.com/intern/diff/D44712802) [ghstack-poisoned]
@NivekT has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
num_of_previously_yielded_batches = state_dict[NUM_PREV_YIELDED_BATCH_KEY_NAME] | ||
self._num_prev_yielded_batches = num_of_previously_yielded_batches |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we combine them into a single line?
@@ -365,6 +381,28 @@ def _restore_checkpoint_beginning_of_epoch(self) -> None: | |||
""" | |||
self._seed_generator = self._initial_seed_generator | |||
|
|||
def _restore_naive_checkpoint(self, num_prev_yielded_batches: Optional[int] = None) -> DataLoader2Iterator[T_co]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When do we need to specify num_prev_yielded_batches
?
I have a noob question: Why don't we make it automatically choosing the restoring option between naive or advanced based on if NUM_PREV_YIELDED_BATCH_KEY_NAME
in the state_dict
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When do we need to specify
num_prev_yielded_batches
?
I see it as an option to override or to restore even if self._num_prev_yielded_batches
has not been set. We can definitely take it out if you don't find it useful (and add it later if we see a need).
automatically choose
I can imagine there are situations where users only want to restore the randomness state:
- the model is only saved once per epoch (so you want DataLoader2 to be in sync)
- maybe "naive" restoration is too slow and users just want to restore the randomness state, then do something custom
In these cases, a separate API to only restore randomness state would be good. Let me know if that is not what you are asking.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see it as an option to override or to restore even if
self._num_prev_yielded_batches
has not been set.
If needed, users can always skipping number of iterations after doing _restore_checkpoint_beginning_of_epoch
, right? I would like to keep the API minimum at first until a solid use case is required.
- the model is only saved once per epoch (so you want DataLoader2 to be in sync)
In this case, we only need to save dataloader state once per epoch as well, right? If we want to support fault-tolerant, we need to make sure model has been stored at the same time.
- maybe "naive" restoration is too slow and users just want to restore the randomness state, then do something custom
Can we list all the expected scenarios and corresponding API calls in the summary of the PR?
@@ -222,7 +227,8 @@ def __iter__(self) -> DataLoader2Iterator[T_co]: | |||
self._reset_iter = False | |||
|
|||
self.valid_iterator_id = 0 if self.valid_iterator_id is None else self.valid_iterator_id + 1 | |||
return DataLoader2Iterator(self, self.valid_iterator_id) | |||
self._iterator = DataLoader2Iterator(self, self.valid_iterator_id) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we try our best to prevent circular referencing?
Hi @NivekT! Thank you for your pull request. We require contributors to sign our Contributor License Agreement, and yours needs attention. You currently have a record in our system, but the CLA is no longer valid, and will need to be resubmitted. ProcessIn order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA. Once the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with If you have received this in error or have any questions, please contact us at cla@meta.com. Thanks! |
Closing. Feel free to re-open if someone else would like to work on this. |
Stack from ghstack:
Differential Revision: D44712802