Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LogisticRegression #100

Merged
merged 21 commits into from
Nov 21, 2023
Merged

LogisticRegression #100

merged 21 commits into from
Nov 21, 2023

Conversation

ordabayevy
Copy link
Contributor

No description provided.

@ordabayevy ordabayevy changed the base branch from main to anndata-field November 2, 2023 18:05
@ordabayevy ordabayevy requested a review from mbabadi November 3, 2023 18:21
@ordabayevy ordabayevy added enhancement New feature or request awaiting review labels Nov 3, 2023
@ordabayevy ordabayevy linked an issue Nov 3, 2023 that may be closed by this pull request
@@ -24,17 +30,38 @@ def _get_fn_args_from_batch(tensor_dict: dict[str, np.ndarray | torch.Tensor]) -
Get forward method arguments from batch.
"""

def __getattr__(self, name: str) -> Any:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems like you have brought in some parts of the PyroModule code into BaseModule and got rid of BasePyroModule. What's the logic?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wanted to have a single base model both for regular and pyro models. The only feature needed from PyroModule is the ability to have constrained PyroParams. PyroModule also has a code that syncs parameters with ParamStoreDict. This is the stripped down version of PyroModule that only handles constrained params.

Base automatically changed from anndata-field to main November 21, 2023 02:12
Copy link
Member

@mbabadi mbabadi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great! You have to rebase this and reorganize it according to the new codebase structure once you merge that PR.

Also, two remarks:

  1. Could you include the callback for monitoring the distribution of W_gc? (here or in a separate PR)

  2. (definitely for a separate PR) Could you comment on what would it take to run LogisticRegression on on-the-fly PCA-transformed data? Can we make a generic Transform that loads a model from a checkpoint and runs predict on a batch as the forward call, followed by another generic Transform that rectifies the predict output to look like a proper batch? One needs to cook up a dummy feature_g, like EMBEDDING_DIM_0, EMBEDDING_DIM_1, ... and call the embeddings x_ng. Everything else in the batch (e.g. y_n) should also pass through...

def guide(self, x_ng: torch.Tensor, y_n: torch.Tensor) -> None:
pyro.sample("W", dist.Delta(self.W_gc).to_event(2))

def on_batch_end(self, trainer: pl.Trainer) -> None:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a callback to log the W_gc histogram. Here I thought that it might be better to have a callback bundled together with the model unlike VarianceMonitor which was separate (change that too in the future?). The reason is that these callbacks are specific to particular models and having them as a separate callback add more code and config.yaml boilerplate.

@ordabayevy ordabayevy requested a review from mbabadi November 21, 2023 15:48
@ordabayevy
Copy link
Contributor Author

Could you comment on what would it take to run LogisticRegression on on-the-fly PCA-transformed data?

#99 should take care of it once it is ready

Copy link
Member

@mbabadi mbabadi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great! You have to rebase this and reorganize it according to the new codebase structure once you merge that PR.

Also, two remarks:

  1. Could you include the callback for monitoring the distribution of W_gc? (here or in a separate PR)

  2. (definitely for a separate PR) Could you comment on what would it take to run LogisticRegression on on-the-fly PCA-transformed data? Can we make a generic Transform that loads a model from a checkpoint and runs predict on a batch as the forward call, followed by another generic Transform that rectifies the predict output to look like a proper batch? One needs to cook up a dummy feature_g, like EMBEDDING_DIM_0, EMBEDDING_DIM_1, ... and call the embeddings x_ng. Everything else in the batch (e.g. y_n) should also pass through...

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
awaiting review enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Add LogisticRegression model
2 participants