-
Notifications
You must be signed in to change notification settings - Fork 152
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[DataLoader2] Adding support for naive checkpointing #1119
Conversation
[ghstack-poisoned]
ghstack-source-id: 64cc77bc98f0f0045b9126e05b3a23e731832904 Pull Request resolved: #1119
[ghstack-poisoned]
ghstack-source-id: f9091dd7497871623e349018ffab472ddbeefdf4 Pull Request resolved: #1119
@NivekT has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
Differential Revision: [D44712802](https://our.internmc.facebook.com/intern/diff/D44712802) [ghstack-poisoned]
ghstack-source-id: 9b238819db4308c3ed51a67fe415cce83eda5a06 Pull Request resolved: #1119
Differential Revision: [D44712802](https://our.internmc.facebook.com/intern/diff/D44712802) [ghstack-poisoned]
ghstack-source-id: 2b1dcb174810ee6d622d7bf6a9cab0327341d432 Pull Request resolved: #1119
Differential Revision: [D44712802](https://our.internmc.facebook.com/intern/diff/D44712802) [ghstack-poisoned]
ghstack-source-id: 42e9901684b0d9d3b756d3a14f4bd29782d1dcb0 Pull Request resolved: #1119
@NivekT has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
num_of_previously_yielded_batches = state_dict[NUM_PREV_YIELDED_BATCH_KEY_NAME] | ||
self._num_prev_yielded_batches = num_of_previously_yielded_batches |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we combine them into a single line?
@@ -365,6 +381,28 @@ def _restore_checkpoint_beginning_of_epoch(self) -> None: | |||
""" | |||
self._seed_generator = self._initial_seed_generator | |||
|
|||
def _restore_naive_checkpoint(self, num_prev_yielded_batches: Optional[int] = None) -> DataLoader2Iterator[T_co]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When do we need to specify num_prev_yielded_batches
?
I have a noob question: Why don't we make it automatically choosing the restoring option between naive or advanced based on if NUM_PREV_YIELDED_BATCH_KEY_NAME
in the state_dict
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When do we need to specify
num_prev_yielded_batches
?
I see it as an option to override or to restore even if self._num_prev_yielded_batches
has not been set. We can definitely take it out if you don't find it useful (and add it later if we see a need).
automatically choose
I can imagine there are situations where users only want to restore the randomness state:
- the model is only saved once per epoch (so you want DataLoader2 to be in sync)
- maybe "naive" restoration is too slow and users just want to restore the randomness state, then do something custom
In these cases, a separate API to only restore randomness state would be good. Let me know if that is not what you are asking.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see it as an option to override or to restore even if
self._num_prev_yielded_batches
has not been set.
If needed, users can always skipping number of iterations after doing _restore_checkpoint_beginning_of_epoch
, right? I would like to keep the API minimum at first until a solid use case is required.
- the model is only saved once per epoch (so you want DataLoader2 to be in sync)
In this case, we only need to save dataloader state once per epoch as well, right? If we want to support fault-tolerant, we need to make sure model has been stored at the same time.
- maybe "naive" restoration is too slow and users just want to restore the randomness state, then do something custom
Can we list all the expected scenarios and corresponding API calls in the summary of the PR?
@@ -222,7 +227,8 @@ def __iter__(self) -> DataLoader2Iterator[T_co]: | |||
self._reset_iter = False | |||
|
|||
self.valid_iterator_id = 0 if self.valid_iterator_id is None else self.valid_iterator_id + 1 | |||
return DataLoader2Iterator(self, self.valid_iterator_id) | |||
self._iterator = DataLoader2Iterator(self, self.valid_iterator_id) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we try our best to prevent circular referencing?
Hi @NivekT! Thank you for your pull request. We require contributors to sign our Contributor License Agreement, and yours needs attention. You currently have a record in our system, but the CLA is no longer valid, and will need to be resubmitted. ProcessIn order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA. Once the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with If you have received this in error or have any questions, please contact us at cla@meta.com. Thanks! |
Closing. Feel free to re-open if someone else would like to work on this. |
Stack from ghstack:
Differential Revision: D44712802