Skip to content

Commit

Permalink
update typing and docs
Browse files Browse the repository at this point in the history
  • Loading branch information
martinkim0 committed Jun 17, 2024
1 parent c85f98d commit 61dfc04
Show file tree
Hide file tree
Showing 5 changed files with 89 additions and 98 deletions.
1 change: 1 addition & 0 deletions .github/workflows/test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ on:
push:
branches: [main]
pull_request:
branches: [main]
schedule:
- cron: "0 10 * * *" # runs at 10:00 UTC (03:00 PST) every day
workflow_dispatch:
Expand Down
24 changes: 11 additions & 13 deletions src/simple_scvi/_mymodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,36 +15,35 @@
from scvi.model.base import BaseModelClass, UnsupervisedTrainingMixin, VAEMixin
from scvi.utils import setup_anndata_dsp

from ._mymodule import MyModule
from simple_scvi._mymodule import MyModule

logger = logging.getLogger(__name__)


class MyModel(VAEMixin, UnsupervisedTrainingMixin, BaseModelClass):
"""
Skeleton for an scvi-tools model.
"""Skeleton for an scvi-tools model.
Please use this skeleton to create new models. This is a simple
implementation of the scVI model :cite:p:`Lopez18`.
Please use this skeleton to create new models. This is a simple implementation of the scVI
model :cite:p:`Lopez18`.
Parameters
----------
adata
AnnData object that has been registered via :meth:`~mypackage.MyModel.setup_anndata`.
AnnData object that has been registered via :meth:`~simple_scvi.MyModel.setup_anndata`.
n_hidden
Number of nodes per hidden layer.
n_latent
Dimensionality of the latent space.
n_layers
Number of hidden layers used for encoder and decoder NNs.
**model_kwargs
Keyword args for :class:`~mypackage.MyModule`
Keyword args for :class:`~simple_scvi.MyModule`
Examples
--------
>>> adata = anndata.read_h5ad(path_to_anndata)
>>> mypackage.MyModel.setup_anndata(adata, batch_key="batch")
>>> vae = mypackage.MyModel(adata)
>>> simple_scvi.MyModel.setup_anndata(adata, batch_key="batch")
>>> vae = simple_scvi.MyModel(adata)
>>> vae.train()
>>> adata.obsm["X_mymodel"] = vae.get_latent_representation()
"""
Expand All @@ -56,7 +55,7 @@ def __init__(
n_latent: int = 10,
n_layers: int = 1,
**model_kwargs,
):
) -> None:
super().__init__(adata)

library_log_means, library_log_vars = _init_library_size(
Expand Down Expand Up @@ -93,9 +92,8 @@ def setup_anndata(
categorical_covariate_keys: list[str] | None = None,
continuous_covariate_keys: list[str] | None = None,
**kwargs,
) -> AnnData | None:
"""
%(summary)s.
) -> None:
"""%(summary)s.
Parameters
----------
Expand Down
107 changes: 51 additions & 56 deletions src/simple_scvi/_mymodule.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,29 @@
from __future__ import annotations

import numpy as np
import numpy.typing as npt
import torch
import torch.nn.functional as F
from scvi import REGISTRY_KEYS
from scvi.distributions import ZeroInflatedNegativeBinomial
from scvi.module._constants import MODULE_KEYS
from scvi.module.base import BaseModuleClass, LossOutput, auto_move_data
from scvi.nn import DecoderSCVI, Encoder, one_hot
from torch import Tensor
from torch.distributions import Normal
from torch.distributions import kl_divergence as kl

TensorDict = dict[str, torch.Tensor]


class MyModule(BaseModuleClass):
"""
Skeleton Variational auto-encoder model.
"""Skeleton variational auto-encoder (VAE) model.
Here we implement a basic version of scVI's underlying VAE :cite:p:`Lopez18`.
This implementation is for instructional purposes only.
Here we implement a basic version of scVI's underlying VAE :cite:p:`Lopez18`. This
implementation is for instructional purposes only.
Parameters
----------
n_input
Number of input genes
Number of input genes.
library_log_means
1 x n_batch array of means of the log library sizes. Parameterizes prior on library size if
not using observed library size.
Expand All @@ -33,20 +33,20 @@ class MyModule(BaseModuleClass):
n_batch
Number of batches, if 0, no batch correction is performed.
n_hidden
Number of nodes per hidden layer
Number of nodes per hidden layer.
n_latent
Dimensionality of the latent space
Dimensionality of the latent space.
n_layers
Number of hidden layers used for encoder and decoder NNs
Number of hidden layers used for encoder and decoder NNs.
dropout_rate
Dropout rate for neural networks
Dropout rate for neural networks.
"""

def __init__(
self,
n_input: int,
library_log_means: np.ndarray,
library_log_vars: np.ndarray,
library_log_means: npt.NDArray,
library_log_vars: npt.NDArray,
n_batch: int = 0,
n_hidden: int = 128,
n_latent: int = 10,
Expand Down Expand Up @@ -89,48 +89,44 @@ def __init__(
n_hidden=n_hidden,
)

def _get_inference_input(self, tensors):
"""Parse the dictionary to get appropriate args"""
x = tensors[REGISTRY_KEYS.X_KEY]

input_dict = {"x": x}
return input_dict

def _get_generative_input(self, tensors, inference_outputs):
z = inference_outputs["z"]
library = inference_outputs["library"]
def _get_inference_input(self, tensors: dict[str, Tensor]) -> dict[str, Tensor]:
"""Parse the dictionary to get appropriate args."""
return {MODULE_KEYS.X_KEY: tensors[REGISTRY_KEYS.X_KEY]}

input_dict = {
"z": z,
"library": library,
def _get_generative_input(
self,
tensors: dict[str, Tensor],
inference_outputs: dict[str, Tensor],
) -> dict[str, Tensor]:
return {
MODULE_KEYS.Z_KEY: inference_outputs["z"],
MODULE_KEYS.LIBRARY_KEY: inference_outputs["library"],
}
return input_dict

@auto_move_data
def inference(self, x):
def inference(self, x: Tensor) -> dict[str, Tensor]:
"""
High level inference method.
Runs the inference (encoder) model.
"""
# log the input to the variational distribution for numerical stability
x_ = torch.log(1 + x)
x_ = torch.log1p(x)
# get variational parameters via the encoder networks
qz_m, qz_v, z = self.z_encoder(x_)
ql_m, ql_v, library = self.l_encoder(x_)

outputs = {
"z": z,
"qz_m": qz_m,
"qz_v": qz_v,
return {
MODULE_KEYS.Z_KEY: z,
MODULE_KEYS.QZM_KEY: qz_m,
MODULE_KEYS.QZV_KEY: qz_v,
"ql_m": ql_m,
"ql_v": ql_v,
"library": library,
MODULE_KEYS.LIBRARY_KEY: library,
}
return outputs

@auto_move_data
def generative(self, z, library):
def generative(self, z: Tensor, library: Tensor) -> dict[str, Tensor]:
"""Runs the generative model."""
# form the parameters of the ZINB likelihood
px_scale, _, px_rate, px_dropout = self.decoder("gene", z, library)
Expand All @@ -145,15 +141,16 @@ def generative(self, z, library):

def loss(
self,
tensors,
inference_outputs,
generative_outputs,
tensors: dict[str, Tensor],
inference_outputs: dict[str, Tensor],
generative_outputs: dict[str, Tensor],
kl_weight: float = 1.0,
):
) -> LossOutput:
"""Loss function."""
x = tensors[REGISTRY_KEYS.X_KEY]
qz_m = inference_outputs["qz_m"]
qz_v = inference_outputs["qz_v"]
batch_index = tensors[REGISTRY_KEYS.BATCH_KEY]
qz_m = inference_outputs[MODULE_KEYS.QZM_KEY]
qz_v = inference_outputs[MODULE_KEYS.QZV_KEY]
ql_m = inference_outputs["ql_m"]
ql_v = inference_outputs["ql_v"]
px_rate = generative_outputs["px_rate"]
Expand All @@ -165,7 +162,6 @@ def loss(

kl_divergence_z = kl(Normal(qz_m, torch.sqrt(qz_v)), Normal(mean, scale)).sum(dim=1)

batch_index = tensors[REGISTRY_KEYS.BATCH_KEY]
n_batch = self.library_log_means.shape[1]
local_library_log_means = F.linear(one_hot(batch_index, n_batch), self.library_log_means)
local_library_log_vars = F.linear(one_hot(batch_index, n_batch), self.library_log_vars)
Expand All @@ -189,20 +185,19 @@ def loss(
loss = torch.mean(reconst_loss + weighted_kl_local)

kl_local = {
"kl_divergence_l": kl_divergence_l,
"kl_divergence_z": kl_divergence_z,
MODULE_KEYS.KL_L_KEY: kl_divergence_l,
MODULE_KEYS.KL_Z_KEY: kl_divergence_z,
}
return LossOutput(loss=loss, reconstruction_loss=reconst_loss, kl_local=kl_local)

@torch.no_grad()
def sample(
self,
tensors,
n_samples=1,
library_size=1,
) -> torch.Tensor:
r"""
Generate observation samples from the posterior predictive distribution.
tensors: dict[str, Tensor],
n_samples: int = 1,
library_size: int = 1,
) -> Tensor:
r"""Generate observation samples from the posterior predictive distribution.
The posterior predictive distribution is written as :math:`p(\hat{x} \mid x)`.
Expand Down Expand Up @@ -245,7 +240,7 @@ def sample(

@torch.no_grad()
@auto_move_data
def marginal_ll(self, tensors: TensorDict, n_mc_samples: int):
def marginal_ll(self, tensors: dict[str, Tensor], n_mc_samples: int) -> float:
"""Marginal ll."""
sample_batch = tensors[REGISTRY_KEYS.X_KEY]
batch_index = tensors[REGISTRY_KEYS.BATCH_KEY]
Expand All @@ -255,12 +250,12 @@ def marginal_ll(self, tensors: TensorDict, n_mc_samples: int):
for i in range(n_mc_samples):
# Distribution parameters and sampled variables
inference_outputs, _, losses = self.forward(tensors)
qz_m = inference_outputs["qz_m"]
qz_v = inference_outputs["qz_v"]
z = inference_outputs["z"]
qz_m = inference_outputs[MODULE_KEYS.QZM_KEY]
qz_v = inference_outputs[MODULE_KEYS.QZV_KEY]
z = inference_outputs[MODULE_KEYS.Z_KEY]
ql_m = inference_outputs["ql_m"]
ql_v = inference_outputs["ql_v"]
library = inference_outputs["library"]
library = inference_outputs[MODULE_KEYS.LIBRARY_KEY]

# Reconstruction Loss
reconst_loss = losses.dict_sum(losses.reconstruction_loss)
Expand Down
25 changes: 11 additions & 14 deletions src/simple_scvi/_mypyromodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from collections.abc import Sequence

import numpy as np
import numpy.typing as npt
import torch
from anndata import AnnData
from scvi import REGISTRY_KEYS
Expand All @@ -26,23 +27,22 @@


class MyPyroModel(BaseModelClass):
"""
Skeleton for a pyro version of a scvi-tools model.
"""Skeleton for a Pyro version of a scvi-tools model.
Please use this skeleton to create new models.
Parameters
----------
adata
AnnData object that has been registered via :meth:`~mypackage.MyPyroModel.setup_anndata`.
AnnData object that has been registered via :meth:`~simple_scvi.MyPyroModel.setup_anndata`.
n_hidden
Number of nodes per hidden layer.
n_latent
Dimensionality of the latent space.
n_layers
Number of hidden layers used for encoder and decoder NNs.
**model_kwargs
Keyword args for :class:`~mypackage.MyModule`
Keyword args for :class:`~simple_scvi.MyModule`
Examples
--------
Expand All @@ -60,7 +60,7 @@ def __init__(
n_latent: int = 10,
n_layers: int = 1,
**model_kwargs,
):
) -> None:
super().__init__(adata)

# self.summary_stats provides information about anndata dimensions and other tensor info
Expand All @@ -85,9 +85,8 @@ def get_latent(
adata: AnnData | None = None,
indices: Sequence[int] | None = None,
batch_size: int | None = None,
):
"""
Return the latent representation for each cell.
) -> npt.NDArray:
"""Return the latent representation for each cell.
This is denoted as :math:`z_n` in our manuscripts.
Expand Down Expand Up @@ -125,9 +124,8 @@ def train(
batch_size: int = 128,
plan_kwargs: dict | None = None,
**trainer_kwargs,
):
"""
Train the model.
) -> None:
"""Train the model.
Parameters
----------
Expand Down Expand Up @@ -184,9 +182,8 @@ def setup_anndata(
categorical_covariate_keys: list[str] | None = None,
continuous_covariate_keys: list[str] | None = None,
**kwargs,
) -> AnnData | None:
"""
%(summary)s.
) -> None:
"""%(summary)s.
Parameters
----------
Expand Down
Loading

0 comments on commit 61dfc04

Please sign in to comment.