-
Notifications
You must be signed in to change notification settings - Fork 3
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
LogisticRegression
#100
LogisticRegression
#100
Conversation
cellarium/ml/module/base_module.py
Outdated
@@ -24,17 +30,38 @@ def _get_fn_args_from_batch(tensor_dict: dict[str, np.ndarray | torch.Tensor]) - | |||
Get forward method arguments from batch. | |||
""" | |||
|
|||
def __getattr__(self, name: str) -> Any: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems like you have brought in some parts of the PyroModule
code into BaseModule
and got rid of BasePyroModule
. What's the logic?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wanted to have a single base model both for regular and pyro models. The only feature needed from PyroModule
is the ability to have constrained PyroParam
s. PyroModule
also has a code that syncs parameters with ParamStoreDict
. This is the stripped down version of PyroModule
that only handles constrained params.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great! You have to rebase this and reorganize it according to the new codebase structure once you merge that PR.
Also, two remarks:
-
Could you include the callback for monitoring the distribution of W_gc? (here or in a separate PR)
-
(definitely for a separate PR) Could you comment on what would it take to run LogisticRegression on on-the-fly PCA-transformed data? Can we make a generic Transform that loads a model from a checkpoint and runs predict on a batch as the forward call, followed by another generic Transform that rectifies the predict output to look like a proper batch? One needs to cook up a dummy
feature_g
, likeEMBEDDING_DIM_0
,EMBEDDING_DIM_1
, ... and call the embeddingsx_ng
. Everything else in the batch (e.g.y_n
) should also pass through...
def guide(self, x_ng: torch.Tensor, y_n: torch.Tensor) -> None: | ||
pyro.sample("W", dist.Delta(self.W_gc).to_event(2)) | ||
|
||
def on_batch_end(self, trainer: pl.Trainer) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a callback to log the W_gc
histogram. Here I thought that it might be better to have a callback bundled together with the model unlike VarianceMonitor
which was separate (change that too in the future?). The reason is that these callbacks are specific to particular models and having them as a separate callback add more code and config.yaml boilerplate.
#99 should take care of it once it is ready |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great! You have to rebase this and reorganize it according to the new codebase structure once you merge that PR.
Also, two remarks:
-
Could you include the callback for monitoring the distribution of W_gc? (here or in a separate PR)
-
(definitely for a separate PR) Could you comment on what would it take to run LogisticRegression on on-the-fly PCA-transformed data? Can we make a generic Transform that loads a model from a checkpoint and runs predict on a batch as the forward call, followed by another generic Transform that rectifies the predict output to look like a proper batch? One needs to cook up a dummy
feature_g
, likeEMBEDDING_DIM_0
,EMBEDDING_DIM_1
, ... and call the embeddingsx_ng
. Everything else in the batch (e.g.y_n
) should also pass through...
No description provided.