Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

enh: enable loading model weights from training checkpoint #3969

Merged
merged 5 commits into from
Mar 20, 2024

Conversation

geoffreyangus
Copy link
Contributor

Title. Should enable users to use models given a pytorch training checkpoint instead of just finalized models. Useful if jobs error midway through training.

Copy link
Contributor

@Infernaught Infernaught left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, but why do we set this to False by default? Is there a reason why this shouldn't be the default behavior?

Copy link

github-actions bot commented Mar 18, 2024

Unit Test Results

  6 files  ±0    6 suites  ±0   14m 24s ⏱️ - 2m 39s
12 tests ±0    9 ✔️ ±0    3 💤 ±0  0 ±0 
60 runs  ±0  42 ✔️ ±0  18 💤 ±0  0 ±0 

Results for commit d242b5e. ± Comparison against base commit c09d5dc.

♻️ This comment has been updated with latest results.

@geoffreyangus
Copy link
Contributor Author

@Infernaught it's an interesting point– we want it False by default just because that preserves the existing user behavior. We can consider changing that if it is an overall better experience, but for now don't want to introduce confusion.

Comment on lines +1796 to +1798
:param from_checkpoint: (bool, default: `False`) if `True`, the model
will be loaded from the latest checkpoint (training_checkpoints/)
instead of the final model weights.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I guess this is a fair thing to do, wondering though what the reason would be to load from_checkpoint if model/model_weights is already present? Perhaps a no-op in that case and we can always make this True? Okay keeping it like this for now as well, just wanted to call it out

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah– I think it's okay to have it be explicit and not make assumptions on behalf of the user for now.

We can revisit this if people get confused, but because it's default False, I'm hopeful this won't interrupt anyone's experience (until of course someone really needs it, at which point we can direct them to this flag)

Copy link
Collaborator

@alexsherstinsky alexsherstinsky left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@geoffreyangus ✅ -- code LGTM (and it was illuminating -- thanks!) -- I was just wondering, how did/do we test something like this, where we called distributed? Thank you!

@geoffreyangus
Copy link
Contributor Author

I added a test for this @alexsherstinsky, if you want to take a look!

@geoffreyangus geoffreyangus merged commit 25e4ac1 into master Mar 20, 2024
18 checks passed
@geoffreyangus geoffreyangus deleted the load_weights_from_checkpoint branch March 20, 2024 00:15
@@ -32,6 +32,79 @@
)


def test_model_load_from_checkpoint(tmpdir, csv_filename, tmp_path):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

❤️ @geoffreyangus Thank you -- so cool! For my edification, which checkpoint are we comparing? It looks like the loaded model from storage (ludwig_model_2) is the result of training ludwig_model_1 all the way through the 1 epoch -- as opposed to some intermediate checkpoint (e.g., saved after a few steps). Is this correct? Thank you!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah that is correct– it's just the latest checkpoint deposited into training_checkpoints/ during the course of training. In the 1 epoch case, it is equivalent to the model weights at the end of training.

For some reason on my local, the models weren't equivalent after 2 epochs– my hunch is that this is because there was a difference between the "best" checkpoint (loaded at the end of training) and the "latest" checkpoint

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🙇

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants