Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Override load_from_checkpoint classmethod of TrainingPlan #91

Merged
merged 15 commits into from
Oct 17, 2023
Merged

Conversation

ordabayevy
Copy link
Contributor

@ordabayevy ordabayevy commented Oct 15, 2023

Addresses #86

@ordabayevy ordabayevy changed the title Add from_checkpoint classmethod to TrainingPlan Override load_from_checkpoint classmethod of TrainingPlan Oct 15, 2023
@ordabayevy ordabayevy linked an issue Oct 16, 2023 that may be closed by this pull request
@ordabayevy ordabayevy added the WIP label Oct 16, 2023
@ordabayevy ordabayevy requested a review from mbabadi October 17, 2023 16:47
if cls.CHECKPOINT_HYPER_PARAMS_KEY in checkpoint:
cfg = checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY]
else:
cfg = parser.get_defaults()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this worth logging a warning?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I removed reading defaults to be consistent with an upstream implementation and avoid confusion.

self._set_hparams(config)

@classmethod
@patch("lightning.pytorch.core.saving._load_state", new=_load_state)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just learned something new! This mocking library is pretty awesome! It's funny that it lives in unittest. Is that a rare hack to use mocking for patching actual code, or a commonly used workaround?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah it is pretty neat, found it on stackexchange.

Is that a rare hack to use mocking for patching actual code, or a commonly used workaround?
Honestly, I don't know. But it seems to do exactly what I want.

Copy link
Member

@mbabadi mbabadi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Beautiful! Just a few small comments. Approved to merged either way.

Copy link
Contributor Author

@ordabayevy ordabayevy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for reviewing. I made couple of small changes, one related to your comment. I will go ahead an merge it.

self._set_hparams(config)

@classmethod
@patch("lightning.pytorch.core.saving._load_state", new=_load_state)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah it is pretty neat, found it on stackexchange.

Is that a rare hack to use mocking for patching actual code, or a commonly used workaround?
Honestly, I don't know. But it seems to do exactly what I want.

if cls.CHECKPOINT_HYPER_PARAMS_KEY in checkpoint:
cfg = checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY]
else:
cfg = parser.get_defaults()
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I removed reading defaults to be consistent with an upstream implementation and avoid confusion.

obj = cfg_init.model

# save the cfg to the :attr:`obj.hparams` to be able to load the model checkpoint
obj._set_hparams(cfg)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added this to save the config back to the object so that hyperparams are saved if the loaded model is checkpointed again.

@ordabayevy ordabayevy merged commit df89ebf into main Oct 17, 2023
@ordabayevy ordabayevy deleted the from-ckpt branch October 17, 2023 20:46
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Load model checkpoints from the Trainer checkpoint
2 participants