-
Notifications
You must be signed in to change notification settings - Fork 3.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add resuming from specific checkpoint #516
Changes from 2 commits
d0a3e25
915b5c5
0db8785
bcebf60
70cbd8c
3477108
1d3dbf6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,6 +2,7 @@ | |
import re | ||
import signal | ||
import warnings | ||
from pathlib import Path | ||
from subprocess import call | ||
import logging | ||
|
||
|
@@ -46,7 +47,9 @@ def restore_weights(self, model): | |
|
||
if not did_restore_hpc_weights: | ||
# restore weights if same exp version | ||
self.restore_state_if_checkpoint_exists(model) | ||
did_restore_last_checkpoint = self.restore_state_if_checkpoint_exists(model) | ||
if not did_restore_last_checkpoint and self.resume_from_checkpoint is not None: | ||
self.restore_state_from_checkpoint(self.resume_from_checkpoint) | ||
|
||
# wait for all models to restore weights | ||
if self.use_ddp or self.use_ddp2: | ||
|
@@ -93,6 +96,18 @@ def restore_state_if_checkpoint_exists(self, model): | |
|
||
return did_restore | ||
|
||
def restore_state_from_checkpoint(self, checkpoint_path): | ||
did_restore = False | ||
|
||
checkpoint_path = Path(checkpoint_path) | ||
if not checkpoint_path.exists(): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. does not work for There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. checkpoint_path is what torch.load expects as input. It can be file-like object or str containing a file name. |
||
return did_restore | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You can simply return True/False no need for extra variable There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Okay. I'll change the code accordingly. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I simplified the code and this function no longer exists |
||
|
||
self.restore(checkpoint_path, self.on_gpu) | ||
did_restore = True | ||
|
||
return did_restore | ||
|
||
# -------------------- | ||
# HPC SIGNAL HANDLING | ||
# -------------------- | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add doc what data type the
checkpoint_path
isThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I removed this function. Please review the updated code :)