-
Notifications
You must be signed in to change notification settings - Fork 3
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Update CellariumAnnDataDataModule
and CellariumModule
#125
Conversation
CellariumAnnDataDataModule
and CellariumModule
605f261
to
32cca79
Compare
95a2bca
to
11c767f
Compare
cellarium/ml/core/module.py
Outdated
|
||
_transforms, _model = transforms, model | ||
|
||
transforms = [uninitialize_object(transform) for transform in _transforms] if _transforms is not None else None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you explain this to me?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perhaps also write two lines of comment about why you're uninitializing and reinitializing?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added a comment above explaining the logic behind un-initializing and initializing.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
model and transforms are not attributes of self (yet)-- does uninitializing them have any effect on the behavior of self.save_hyperparameters
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good work, nothing strikes my eyes :) Small request for comment on a magical-looking part of the code.
# In order to achieve this, we temporarily re-assign `dadc` to its un-initialized state | ||
# and then call `save_hyperparameters` which will save these values as hparams. | ||
# Then, we re-assign `dadc` back to its initialized state. | ||
# `initialize_object` handles the case when the object was passed as a dictionary of class path and init args. | ||
_dadc = dadc | ||
dadc = uninitialize_object(_dadc) | ||
self.save_hyperparameters(logger=False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can this be done more neatly with a context manager construct? also, dadc
not not an attrib of self yet-- how would uninitializing it on line 93 have an effect on self.save_hyperparameters
call on line 94?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Having a context manager is good idea, let me think about it more.
save_hyperparameters
looks up the variable from the locals()
in the current frame using an argument name from the function signature.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Another approach is to have model
and transforms
args of CellariumModule
as dicts of class name and init args and delegate their initialization to CellariumModule
instead of LightningCLI
. In fact, there is a configure_model
method of LightningModule
that is used to initialize large models with FSDP and DeepSpeed strategies (https://lightning.ai/docs/pytorch/stable/advanced/model_init.html#model-parallel-training-fsdp-and-deepspeed). We can just always use it to initialize all our models. WDYT?
6f0fbf9
to
7cba8f2
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good stuff all around! I like initialization in meta device context + configure_model.
save_hyperparameters
save_hyperparameters
pulls init args from thelocals()
. Therefore, we can uninitializetorch.nn.Module
s before callingsave_hyperparameters
.torch.nn.Module
s in the__init__
method if they were provided asclass_path
andinit_args
.config
arg from theCellariumModule
that was used to store hyper-parametersCellariumAnnDataDataModule
where bothDistributedAnnDataCollection
andAnnData
can be passedIncrementalPCA
orOnePassMeanVarStd
optimizer is not needed.