-
Notifications
You must be signed in to change notification settings - Fork 3
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Override load_from_checkpoint
classmethod of TrainingPlan
#91
Conversation
from_checkpoint
classmethod to TrainingPlan
load_from_checkpoint
classmethod of TrainingPlan
cellarium/ml/core/saving.py
Outdated
if cls.CHECKPOINT_HYPER_PARAMS_KEY in checkpoint: | ||
cfg = checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY] | ||
else: | ||
cfg = parser.get_defaults() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this worth logging a warning?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, I removed reading defaults to be consistent with an upstream implementation and avoid confusion.
self._set_hparams(config) | ||
|
||
@classmethod | ||
@patch("lightning.pytorch.core.saving._load_state", new=_load_state) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I just learned something new! This mocking library is pretty awesome! It's funny that it lives in unittest. Is that a rare hack to use mocking for patching actual code, or a commonly used workaround?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah it is pretty neat, found it on stackexchange.
Is that a rare hack to use mocking for patching actual code, or a commonly used workaround?
Honestly, I don't know. But it seems to do exactly what I want.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Beautiful! Just a few small comments. Approved to merged either way.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for reviewing. I made couple of small changes, one related to your comment. I will go ahead an merge it.
self._set_hparams(config) | ||
|
||
@classmethod | ||
@patch("lightning.pytorch.core.saving._load_state", new=_load_state) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah it is pretty neat, found it on stackexchange.
Is that a rare hack to use mocking for patching actual code, or a commonly used workaround?
Honestly, I don't know. But it seems to do exactly what I want.
cellarium/ml/core/saving.py
Outdated
if cls.CHECKPOINT_HYPER_PARAMS_KEY in checkpoint: | ||
cfg = checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY] | ||
else: | ||
cfg = parser.get_defaults() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, I removed reading defaults to be consistent with an upstream implementation and avoid confusion.
obj = cfg_init.model | ||
|
||
# save the cfg to the :attr:`obj.hparams` to be able to load the model checkpoint | ||
obj._set_hparams(cfg) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added this to save the config back to the object so that hyperparams are saved if the loaded model is checkpointed again.
Addresses #86