-
Notifications
You must be signed in to change notification settings - Fork 1.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
enh: enable loading model weights from training checkpoint #3969
Conversation
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, but why do we set this to False by default? Is there a reason why this shouldn't be the default behavior?
@Infernaught it's an interesting point– we want it False by default just because that preserves the existing user behavior. We can consider changing that if it is an overall better experience, but for now don't want to introduce confusion. |
…udwig into load_weights_from_checkpoint
:param from_checkpoint: (bool, default: `False`) if `True`, the model | ||
will be loaded from the latest checkpoint (training_checkpoints/) | ||
instead of the final model weights. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah I guess this is a fair thing to do, wondering though what the reason would be to load from_checkpoint if model/model_weights
is already present? Perhaps a no-op in that case and we can always make this True? Okay keeping it like this for now as well, just wanted to call it out
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah– I think it's okay to have it be explicit and not make assumptions on behalf of the user for now.
We can revisit this if people get confused, but because it's default False, I'm hopeful this won't interrupt anyone's experience (until of course someone really needs it, at which point we can direct them to this flag)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@geoffreyangus ✅ -- code LGTM (and it was illuminating -- thanks!) -- I was just wondering, how did/do we test something like this, where we called distributed
? Thank you!
I added a test for this @alexsherstinsky, if you want to take a look! |
@@ -32,6 +32,79 @@ | |||
) | |||
|
|||
|
|||
def test_model_load_from_checkpoint(tmpdir, csv_filename, tmp_path): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
❤️ @geoffreyangus Thank you -- so cool! For my edification, which checkpoint are we comparing? It looks like the loaded model from storage (ludwig_model_2
) is the result of training ludwig_model_1
all the way through the 1
epoch -- as opposed to some intermediate checkpoint (e.g., saved after a few steps). Is this correct? Thank you!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah that is correct– it's just the latest checkpoint deposited into training_checkpoints/ during the course of training. In the 1 epoch case, it is equivalent to the model weights at the end of training.
For some reason on my local, the models weren't equivalent after 2 epochs– my hunch is that this is because there was a difference between the "best" checkpoint (loaded at the end of training) and the "latest" checkpoint
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🙇
Title. Should enable users to use models given a pytorch training checkpoint instead of just finalized models. Useful if jobs error midway through training.