diff --git a/.gitignore b/.gitignore index 72c2207e5..21ea229dd 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,6 @@ # /archive/ +example_notebooks/* # .DS_Store diff --git a/.vscode/settings.json b/.vscode/settings.json index 7d5509522..b41b8ad69 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -56,6 +56,7 @@ "search.followSymlinks": false, "terminal.integrated.fontSize": 14, "terminal.integrated.scrollback": 100000, + "python.terminal.activateEnvironment": false, "workbench.colorTheme": "Catppuccin Mocha", "workbench.iconTheme": "vscode-icons", // Passing --no-cov to pytestArgs is required to respect breakpoints diff --git a/src/pyrovelocity/models/__init__.py b/src/pyrovelocity/models/__init__.py index 4dff5d0ed..20fdb7542 100644 --- a/src/pyrovelocity/models/__init__.py +++ b/src/pyrovelocity/models/__init__.py @@ -7,13 +7,14 @@ from pyrovelocity.models._deterministic_simulation import ( solve_transcription_splicing_model_analytical, ) -from pyrovelocity.models._transcription_dynamics import mrna_dynamics +from pyrovelocity.models._transcription_dynamics import mrna_dynamics, atac_mrna_dynamics from pyrovelocity.models._velocity import PyroVelocity __all__ = [ deterministic_transcription_splicing_probabilistic_model, mrna_dynamics, + atac_mrna_dynamics, PyroVelocity, solve_transcription_splicing_model, solve_transcription_splicing_model_analytical, diff --git a/src/pyrovelocity/models/_trainer.py b/src/pyrovelocity/models/_trainer.py index fe3505cf2..5bd377b2c 100644 --- a/src/pyrovelocity/models/_trainer.py +++ b/src/pyrovelocity/models/_trainer.py @@ -271,8 +271,8 @@ def train_faster( if scipy.sparse.issparse(self.adata.layers["raw_spliced"]) else self.adata.layers["raw_spliced"], dtype=torch.float32, - ).to(device) - + ).to(device) + epsilon = 1e-6 log_u_library_size = np.log( @@ -335,60 +335,127 @@ def train_faster( losses = [] patience = patient_init - for step in range(max_epochs): - if cell_state is None: - elbos = ( - svi.step( - u, - s, - u_library.reshape(-1, 1), - s_library.reshape(-1, 1), - u_library_mean.reshape(-1, 1), - s_library_mean.reshape(-1, 1), - u_library_scale.reshape(-1, 1), - s_library_scale.reshape(-1, 1), - None, - None, + + if not self.adata.uns['atac']: + + for step in range(max_epochs): + if cell_state is None: + elbos = ( + svi.step( + u, + s, + u_library.reshape(-1, 1), + s_library.reshape(-1, 1), + u_library_mean.reshape(-1, 1), + s_library_mean.reshape(-1, 1), + u_library_scale.reshape(-1, 1), + s_library_scale.reshape(-1, 1), + None, + None, + ) + / normalizer ) - / normalizer - ) - else: - elbos = ( - svi.step( - u, - s, - u_library.reshape(-1, 1), - s_library.reshape(-1, 1), - u_library_mean.reshape(-1, 1), - s_library_mean.reshape(-1, 1), - u_library_scale.reshape(-1, 1), - s_library_scale.reshape(-1, 1), - None, - cell_state.reshape(-1, 1), + else: + elbos = ( + svi.step( + u, + s, + u_library.reshape(-1, 1), + s_library.reshape(-1, 1), + u_library_mean.reshape(-1, 1), + s_library_mean.reshape(-1, 1), + u_library_scale.reshape(-1, 1), + s_library_scale.reshape(-1, 1), + None, + cell_state.reshape(-1, 1), + ) + / normalizer + ) + if (step == 0) or ( + ((step + 1) % log_every == 0) and ((step + 1) < max_epochs) + ): + mlflow.log_metric("-ELBO", -elbos, step=step + 1) + logger.info( + f"step {step + 1: >4d} loss = {elbos:0.6g} patience = {patience}" + ) + if step > log_every: + if (losses[-1] - elbos) < losses[-1] * patient_improve: + patience -= 1 + else: + patience = patient_init + if patience <= 0: + break + losses.append(elbos) + + else: + + c = torch.tensor( + np.array( + self.adata.layers["atac"].toarray(), dtype="float32" + ) + if scipy.sparse.issparse(self.adata.layers["atac"]) + else self.adata.layers["atac"], + dtype=torch.float32, + ).to(device) + + + for step in range(max_epochs): + if cell_state is None: + elbos = ( + svi.step( + c, + u, + s, + u_library.reshape(-1, 1), + s_library.reshape(-1, 1), + u_library_mean.reshape(-1, 1), + s_library_mean.reshape(-1, 1), + u_library_scale.reshape(-1, 1), + s_library_scale.reshape(-1, 1), + None, + None, + ) + / normalizer ) - / normalizer - ) - if (step == 0) or ( - ((step + 1) % log_every == 0) and ((step + 1) < max_epochs) - ): - mlflow.log_metric("-ELBO", -elbos, step=step + 1) - logger.info( - f"step {step + 1: >4d} loss = {elbos:0.6g} patience = {patience}" - ) - if step > log_every: - if (losses[-1] - elbos) < losses[-1] * patient_improve: - patience -= 1 else: - patience = patient_init - if patience <= 0: - break - losses.append(elbos) + elbos = ( + svi.step( + c, + u, + s, + u_library.reshape(-1, 1), + s_library.reshape(-1, 1), + u_library_mean.reshape(-1, 1), + s_library_mean.reshape(-1, 1), + u_library_scale.reshape(-1, 1), + s_library_scale.reshape(-1, 1), + None, + cell_state.reshape(-1, 1), + ) + / normalizer + ) + if (step == 0) or ( + ((step + 1) % log_every == 0) and ((step + 1) < max_epochs) + ): + mlflow.log_metric("-ELBO", -elbos, step=step + 1) + logger.info( + f"step {step + 1: >4d} loss = {elbos:0.6g} patience = {patience}" + ) + if step > log_every: + if (losses[-1] - elbos) < losses[-1] * patient_improve: + patience -= 1 + else: + patience = patient_init + if patience <= 0: + break + losses.append(elbos) + mlflow.log_metric("-ELBO", -elbos, step=step + 1) mlflow.log_metric("real_epochs", step + 1) logger.info( f"step {step + 1: >4d} loss = {elbos:0.6g} patience = {patience}" ) - return losses + return losses def train_faster_with_batch( self, diff --git a/src/pyrovelocity/models/_transcription_dynamics.py b/src/pyrovelocity/models/_transcription_dynamics.py index 03c531828..aba2a317f 100644 --- a/src/pyrovelocity/models/_transcription_dynamics.py +++ b/src/pyrovelocity/models/_transcription_dynamics.py @@ -60,6 +60,289 @@ def mrna_dynamics( return ut, st +@beartype +def atac_mrna_dynamics( + tau: Tensor, + c0: Tensor, + u0: Tensor, + s0: Tensor, + k_c: Tensor, + alpha_c: Tensor, + alpha: Tensor, + beta: Tensor, + gamma: Tensor, +) -> Tuple[Tensor, Tensor, Tensor]: + """ + Computes the ATAC and mRNA dynamics given temporal coordinate, parameter values, and + initial conditions. + + `st_gamma_equals_beta` for the case where the gamma parameter is equal + to the beta parameter is taken from Equation 2.12 of + + Args: + tau (Tensor): Time points starting at last change in RNA transcription rate. + c0 (Tensor): Initial value of c. + u0 (Tensor): Initial value of u. + s0 (Tensor): Initial value of s. + k_c (Tensor): Chromatin state. + alpha_c (Tensor): Rate of chromatin opening/closing. + alpha (Tensor): Alpha parameter. + beta (Tensor): Beta parameter. + gamma (Tensor): Gamma parameter. + + Returns: + Tuple[Tensor, Tensor]: Tuple containing the final values of c, u and s. + + Examples: + >>> import torch + >>> tau = torch.tensor(2.0) + >>> c0 = torch.tensor(1.0) + >>> u0 = torch.tensor(1.0) + >>> s0 = torch.tensor(0.5) + >>> alpha_c = torch.tensor(0.45) + >>> alpha = torch.tensor(0.5) + >>> beta = torch.tensor(0.4) + >>> gamma = torch.tensor(0.3) + >>> k_c = torch.tensor(1.0) + >>> atac_mrna_dynamics(tau_c, tau, c0, u0, s0, k_c, alpha_c, alpha, beta, gamma) + >>> import torch + >>> input = [torch.tensor([[ 0., 0.], + [ 5., 5.], + [35., 35.], + [15., 12.]]), torch.tensor([[0.0000, 0.0000], + [0.7769, 0.9502], + [0.7769, 0.9502], + [0.9985, 1.0000]]), torch.tensor([[0.0000, 0.0000], + [0.0000, 0.0000], + [0.0000, 0.0000], + [5.4451, 2.7188]]), torch.tensor([[0.0000, 0.0000], + [0.0000, 0.0000], + [0.0000, 0.0000], + [9.1791, 4.7921]]), torch.tensor([[0., 0.], + [1., 1.], + [1., 1.], + [0., 1.]]), torch.tensor([0.1000, 0.2000]), torch.tensor([[0.0000, 0.0000], + [0.5000, 0.3000], + [0.5000, 0.3000], + [0.5000, 0.0000]]), torch.tensor([0.0900, 0.1100]), torch.tensor([0.0500, 0.0600])] + >>> tau_vec = input[0] + >>> c0_vec = input[1] + >>> u0_vec = input[2] + >>> s0_vec = input[3] + >>> k_c_vec = input[4] + >>> alpha_c = input[5] + >>> alpha_vec = input[6] + >>> beta = input[7] + >>> gamma = input[8] + >>> atac_mrna_dynamics( + tau_vec, c0_vec, u0_vec, s0_vec, k_c_vec, alpha_c, alpha_vec, beta, gamma + ) + (tensor([[0.0000, 0.0000], + [0.8647, 0.9817], + [0.9933, 1.0000], + [0.2228, 1.0000]]), tensor([[0.0000, 0.0000], + [1.6662, 1.1191], + [5.1763, 2.6659], + [3.2144, 0.7263]]), tensor([[0.0000, 0.0000], + [2.2120, 1.2959], + [8.2623, 4.3877], + [4.3359, 2.3326]])) + """ + + A = torch.exp(-alpha_c * tau) + B = torch.exp(-beta * tau) + C = torch.exp(-gamma * tau) + + ct = c0 * A + k_c * (1 - A) + ut = ( + u0 * B + + alpha * k_c / beta * (1 - B) + + (k_c - c0) * alpha / (beta - alpha_c) * (B - A) + ) + st = s0 * C + alpha * k_c / gamma * (1 - C) + +beta / (gamma - beta) * ( + (alpha * k_c) / beta - u0 - (k_c - c0) * alpha / (beta - alpha_c) + ) * (C - B) + +beta / (gamma - alpha_c) * (k_c - c0) * alpha / (beta - alpha_c) * (C - A) + + return ct, ut, st + +@beartype +def get_initial_states( + t0_state: Tensor, + k_c_state: Tensor, + alpha_c: Tensor, + alpha_state: Tensor, + beta: Tensor, + gamma: Tensor, + state: Tensor +) -> Tuple[Tensor, Tensor, Tensor]: + """ + Computes initial conditions of chromatin and mRNA in each cell. + + Args: + t0_state (Tensor): The switch times of each gene (1 for each state). + k_c_state (Tensor): The chromatin state in each state. + alpha_c (Tensor): The chromatin opening and closing rate. + alpha_state (Tensor): The transcription rate of each gene in each state. + beta (Tensor): The splicing rate of each gene. + gamma (Tensor): The degradation rate of each gene. + state (Tensor): The state of each cell. + + Returns: + Tuple[Tensor, Tensor, Tensor]: Tuple containing the initial conditions of + c, u and s for each cell. + + Examples: + >>> import torch + >>> alpha_c = torch.tensor((0.1, 0.2)) + >>> beta = torch.tensor((0.09, 0.11)) + >>> gamma = torch.tensor((0.05, 0.06)) + >>> state = torch.tensor([[0, 0],[2, 2],[2, 2],[3, 3]]) + >>> k_c_state = torch.tensor([[0., 1., 1., 0., 0.], [0., 1., 1., 1., 0.]]) + >>> alpha_state = torch.tensor([[0.0000, 0.0000, 0.5000, 0.5000, 0.0000],[0.0000, 0.0000, 0.3000, 0.0000, 0.0000]]) + >>> t0_state = torch.tensor([[ 0., 10., 25., 75., 102.],[ 0., 10., 25., 78., 95.]]) + >>> get_initial_states( + t0_state, k_c_state, alpha_c, alpha_state, beta, gamma, state + ) + (torch.tensor([[0.0000, 0.0000], + [0.7769, 0.9502], + [0.7769, 0.9502], + [0.9985, 1.0000]]), + torch.tensor([[0.0000, 0.0000], + [0.0000, 0.0000], + [0.0000, 0.0000], + [5.4451, 2.7188]]), + torch.tensor([[0.0000, 0.0000], + [0.0000, 0.0000], + [0.0000, 0.0000], + [9.1791, 4.7921]])) + """ + + n_genes = t0_state.shape[0] + c0_state_list = [torch.zeros(n_genes),torch.zeros(n_genes)] + u0_state_list = [torch.zeros(n_genes),torch.zeros(n_genes)] + s0_state_list = [torch.zeros(n_genes),torch.zeros(n_genes)] + dt_state = t0_state - torch.stack([torch.zeros((2)), torch.zeros((2)), + t0_state[:,1], t0_state[:,2], t0_state[:,3]], dim = 1) # genes, states + for i in range(1, 4): + c0_i, u0_i, s0_i = atac_mrna_dynamics( + dt_state[:, i+1], c0_state_list[-1], u0_state_list[-1], s0_state_list[-1], k_c_state[:, i], + alpha_c, alpha_state[:, i], beta, gamma + ) + c0_state_list += [c0_i] + u0_state_list += [u0_i] + s0_state_list += [s0_i] + + c0_state = torch.stack(c0_state_list, dim = 1) + u0_state = torch.stack(u0_state_list, dim = 1) + s0_state = torch.stack(s0_state_list, dim = 1) + + c0_vec = c0_state[torch.arange(n_genes).unsqueeze(1), state.T].T # cells, genes + u0_vec = u0_state[torch.arange(n_genes).unsqueeze(1), state.T].T # cells, genes + s0_vec = s0_state[torch.arange(n_genes).unsqueeze(1), state.T].T # cells, genes + + return c0_vec, u0_vec, s0_vec + +@beartype +def get_cell_parameters( + t: Tensor, + t0_1: Tensor, + dt_1: Tensor, + dt_2: Tensor, + dt_3: Tensor, + alpha: Tensor, + alpha_off: Tensor, + k: Tensor, +) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]: + """ + Gets the ODE parameters for each cell, by first assign each gene in each cell to a state + based on state switch times of a gene and then computes the transcription rate, chromatin state + and time since last state switch(tau) for each gene in each cell. + + Args: + t (Tensor): The time of each cell. + t0_1 (Tensor): Start time for chromatin opening. + dt_1 (Tensor): Time gap since chromatin opening for transcription start for each gene. + dt_2 (Tensor): Time gap since transcription start for chromatin closing for each gene. + dt_3 (Tensor): Time gap since transcription start for transcription stopping for each gene. + alpha (Tensor): The transcription rate of each gene in the on state. + alpha_off (Tensor): The transcription rate of each gene in the off state. + k (Tensor): The activation state of each gene in each state. + + Returns: + Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]: Tuple containing the state of each cell (state), + the switch time of each state (t0_state), the chromatin opening state (k_c_state), the transcription rate in each cell + (alpha_state) and cell-specific parameters for the chromatin state (k_c_vec) transcription rate (alpha_vec) and + time (tau_vec) since last state switch. + + Examples: + >>> import torch + + >>> n_cells = 4 + >>> t = torch.arange(0, 120, 30).reshape(n_cells, 1) + >>> t0_1 = torch.tensor((10.0, 10.0)) + >>> dt_1 = torch.tensor((15.0, 15.0)) + >>> dt_2 = torch.tensor((77.0, 53.0)) + >>> dt_3 = torch.tensor((50.0, 70.0)) + >>> alpha = torch.tensor((0.5, 0.3)) + >>> alpha_off = torch.tensor(0.0) + >>> k = torch.tensor((1.0, 1.0),(1.0,1.0)) + >>> get_cell_parameters( + t, t0_1, dt_1, dt_2, dt_3, alpha, alpha_off,k + ) + (tensor([[0, 0], + [2, 2], + [2, 2], + [3, 3]]),tensor([[0., 1., 1., 0., 0.], + [0., 1., 1., 1., 0.]]), tensor([[0.0000, 0.0000, 0.5000, 0.5000, 0.0000], + [0.0000, 0.0000, 0.3000, 0.0000, 0.0000]]), tensor([[ 0., 10., 25., 75., 102.], + [ 0., 10., 25., 78., 95.]]), tensor([[0., 0.], + [1., 1.], + [1., 1.], + [0., 1.]]), tensor([[0.0000, 0.0000], + [0.5000, 0.3000], + [0.5000, 0.3000], + [0.5000, 0.0000]]), tensor([[ 0., 0.], + [ 5., 5.], + [35., 35.], + [15., 12.]])) + """ + + # Assign each gene in each cell to a state: + t0_2 = t0_1 + dt_1 + boolean = dt_2 >= dt_3 # True means chromatin starts closing, before transcription stops. + t0_3 = torch.where(boolean, t0_2 + dt_3, t0_2 + dt_2) + t0_4 = torch.where(~boolean, t0_2 + dt_3, t0_2 + dt_2) + state = ((t0_1 <= t).int() + (t0_2 <= t).int() + (t0_3 <= t).int() + (t0_4 <= t).int()) # cells, genes + n_genes = state.shape[1] + state = state * (1-1*k) + + t0_state = torch.stack([torch.zeros_like(t0_1), t0_1, t0_2, t0_3, t0_4], dim=1) # genes, states + t0_vec = t0_state[torch.arange(n_genes).unsqueeze(1), state.T].T # cells, genes + tau_vec = t - t0_vec # cells, genes + + alpha_state = torch.stack([ + torch.ones_like(t0_1) * alpha_off, + torch.ones_like(t0_1) * alpha_off, + torch.ones_like(t0_1) * alpha, + torch.where(boolean, torch.ones_like(t0_1) * alpha, torch.ones_like(t0_1) * alpha_off), + torch.ones_like(t0_1) * alpha_off + ], dim=1) # genes, states + + k_c_state = torch.stack([ + torch.zeros_like(t0_1), + torch.ones_like(t0_1), + torch.ones_like(t0_1), + torch.where(boolean, torch.zeros_like(t0_1), torch.ones_like(t0_1)), + torch.zeros_like(t0_1) + ], dim=1) # genes, states + + alpha_vec = alpha_state[torch.arange(n_genes).unsqueeze(1), state.T].T # cells, genes + k_c_vec = k_c_state[torch.arange(n_genes).unsqueeze(1), state.T].T # cells, genes + + return state, k_c_state, alpha_state, t0_state, k_c_vec, alpha_vec, tau_vec + @beartype def inv(x: Tensor) -> Tensor: diff --git a/src/pyrovelocity/models/_velocity.py b/src/pyrovelocity/models/_velocity.py index 4022bc246..53c70b03a 100644 --- a/src/pyrovelocity/models/_velocity.py +++ b/src/pyrovelocity/models/_velocity.py @@ -30,7 +30,7 @@ ) from pyrovelocity.logging import configure_logging from pyrovelocity.models._trainer import VelocityTrainingMixin -from pyrovelocity.models._velocity_module import VelocityModule +from pyrovelocity.models._velocity_module import VelocityModule, MultiVelocityModule __all__ = ["PyroVelocity"] @@ -99,6 +99,7 @@ class PyroVelocity(VelocityTrainingMixin, BaseModelClass): def __init__( self, adata: AnnData, + adata_atac: Optional[AnnData] = None, input_type: str = "raw", shared_time: bool = True, model_type: str = "auto", @@ -126,6 +127,7 @@ def __init__( Args: adata (AnnData): An AnnData object containing the gene expression data. + adata_atac (Optional[AnnData], optional): An AnnData object containing atac data. input_type (str, optional): Type of input data. Can be "raw", "knn", or "raw_cpm". Defaults to "raw". shared_time (bool, optional): Whether to use shared time. Defaults to True. model_type (str, optional): Type of model to use. Defaults to "auto". @@ -244,30 +246,56 @@ def __init__( # else: initial_values = {} logger.info(self.summary_stats) - self.module = VelocityModule( - self.summary_stats["n_cells"], - self.summary_stats["n_vars"], - model_type=model_type, - guide_type=guide_type, - likelihood=likelihood, - shared_time=shared_time, - t_scale_on=t_scale_on, - plate_size=plate_size, - latent_factor=latent_factor, - latent_factor_operation=latent_factor_operation, - latent_factor_size=latent_factor_size, - inducing_point_size=inducing_point_size, - include_prior=include_prior, - use_gpu=use_gpu, - num_aux_cells=num_aux_cells, - only_cell_times=only_cell_times, - decoder_on=decoder_on, - add_offset=add_offset, - correct_library_size=correct_library_size, - cell_specific_kinetics=cell_specific_kinetics, - kinetics_num=self.k, - **initial_values, - ) + if not adata_atac: + self.module = VelocityModule( + self.summary_stats["n_cells"], + self.summary_stats["n_vars"], + model_type=model_type, + guide_type=guide_type, + likelihood=likelihood, + shared_time=shared_time, + t_scale_on=t_scale_on, + plate_size=plate_size, + latent_factor=latent_factor, + latent_factor_operation=latent_factor_operation, + latent_factor_size=latent_factor_size, + inducing_point_size=inducing_point_size, + include_prior=include_prior, + use_gpu=use_gpu, + num_aux_cells=num_aux_cells, + only_cell_times=only_cell_times, + decoder_on=decoder_on, + add_offset=add_offset, + correct_library_size=correct_library_size, + cell_specific_kinetics=cell_specific_kinetics, + kinetics_num=self.k, + **initial_values, + ) + else: + self.module = MultiVelocityModule( + self.summary_stats["n_cells"], + self.summary_stats["n_vars"], + model_type=model_type, + guide_type=guide_type, + likelihood=likelihood, + shared_time=shared_time, + t_scale_on=t_scale_on, + plate_size=plate_size, + latent_factor=latent_factor, + latent_factor_operation=latent_factor_operation, + latent_factor_size=latent_factor_size, + inducing_point_size=inducing_point_size, + include_prior=include_prior, + use_gpu=use_gpu, + num_aux_cells=num_aux_cells, + only_cell_times=only_cell_times, + decoder_on=decoder_on, + add_offset=False, + correct_library_size=correct_library_size, + cell_specific_kinetics=cell_specific_kinetics, + kinetics_num=self.k, + **initial_values, + ) self.num_cells = self.module.num_cells self._model_summary_string = """ RNA velocity Pyro model with parameters: @@ -296,7 +324,7 @@ def enum_parallel_predict(self): return @classmethod - def setup_anndata(cls, adata: AnnData, *args, **kwargs): + def setup_anndata(cls, adata: AnnData, adata_atac = None, *args, **kwargs): """ Set up AnnData object for compatibility with the scvi-tools model training interface. @@ -330,9 +358,18 @@ def setup_anndata(cls, adata: AnnData, *args, **kwargs): NumericalObsField("s_lib_size_scale", "s_lib_size_scale"), NumericalObsField("ind_x", "ind_x"), ] + + if adata_atac: + adata.layers['atac'] = adata_atac.X + anndata_fields += [LayerField('atac', 'atac')] + adata.uns['atac'] = True + else: + adata.uns['atac'] = None + adata_manager = AnnDataManager( fields=anndata_fields, setup_method_args=setup_method_args ) + adata_manager.register_fields(adata, **kwargs) cls.register_manager(adata_manager) diff --git a/src/pyrovelocity/models/_velocity_model.py b/src/pyrovelocity/models/_velocity_model.py index 63ca3fd8a..3aaad3492 100644 --- a/src/pyrovelocity/models/_velocity_model.py +++ b/src/pyrovelocity/models/_velocity_model.py @@ -1,27 +1,19 @@ -from typing import Optional -from typing import Tuple -from typing import Union +from typing import Optional, Tuple, Union import pyro import torch from beartype import beartype -from jaxtyping import Float -from jaxtyping import jaxtyped +from jaxtyping import Float, jaxtyped from pyro import poutine -from pyro.distributions import Bernoulli -from pyro.distributions import LogNormal -from pyro.distributions import Normal -from pyro.distributions import Poisson -from pyro.nn import PyroModule -from pyro.nn import PyroSample +from pyro.distributions import Bernoulli, LogNormal, Normal, Poisson +from pyro.nn import PyroModule, PyroSample from pyro.primitives import plate from scvi.nn import Decoder -from torch.nn.functional import relu -from torch.nn.functional import softplus +from torch.nn.functional import relu, softplus +from torch import Tensor from pyrovelocity.logging import configure_logging -from pyrovelocity.models._transcription_dynamics import mrna_dynamics - +from pyrovelocity.models._transcription_dynamics import mrna_dynamics, atac_mrna_dynamics, get_initial_states, get_cell_parameters logger = configure_logging(__name__) @@ -36,6 +28,11 @@ Float[torch.Tensor, "samples num_cells num_genes"], ] +__all__ = [ + "LogNormalModel", + "VelocityModelAuto", + "MultiVelocityModelAuto", +] class LogNormalModel(PyroModule): """ @@ -154,6 +151,10 @@ def create_plates( gene_plate = pyro.plate("genes", self.num_genes, dim=-1) return cell_plate, gene_plate + @PyroSample + def alpha_c(self): + return self._pyrosample_helper(1.0) + @PyroSample def alpha(self): return self._pyrosample_helper(1.0) @@ -182,6 +183,18 @@ def u_inf(self): def s_inf(self): return self._pyrosample_helper(0.1) + @PyroSample + def dt_switching_c(self): + return self._pyrosample_helper(1.0) + + @PyroSample + def delay(self): + return self._pyrosample_helper(1.0) + + @PyroSample + def dt_switching_c(self): + return self._pyrosample_helper(1.0) + @PyroSample def dt_switching(self): return self._pyrosample_helper(1.0) @@ -305,8 +318,99 @@ def get_likelihood( u_dist = Poisson(ut) s_dist = Poisson(st) + return u_dist, s_dist + + @beartype + def get_likelihood_multiome( + self, + ct: torch.Tensor, + ut: torch.Tensor, + st: torch.Tensor, + sigma_c: torch.Tensor, + u_log_library: Optional[torch.Tensor] = None, + s_log_library: Optional[torch.Tensor] = None, + u_scale: Optional[torch.Tensor] = None, + s_scale: Optional[torch.Tensor] = None, + u_read_depth: Optional[torch.Tensor] = None, + s_read_depth: Optional[torch.Tensor] = None, + u_cell_size_coef: None = None, + ut_coef: None = None, + s_cell_size_coef: None = None, + st_coef: None = None, + ) -> Tuple[LogNormal, Poisson, Poisson]: + """ + Compute the likelihood of the given count data. + + Args: + ct (torch.Tensor): Tensor representing chromatin state. + ut (torch.Tensor): Tensor representing unspliced transcripts. + st (torch.Tensor): Tensor representing spliced transcripts. + sigma_c (torch.Tensor): Tensor representing standard deviation of chromatin state. + u_log_library (Optional[torch.Tensor], optional): Log library tensor for unspliced transcripts. Defaults to None. + s_log_library (Optional[torch.Tensor], optional): Log library tensor for spliced transcripts. Defaults to None. + u_scale (Optional[torch.Tensor], optional): Scale tensor for unspliced transcripts. Defaults to None. + s_scale (Optional[torch.Tensor], optional): Scale tensor for spliced transcripts. Defaults to None. + u_read_depth (Optional[torch.Tensor], optional): Read depth tensor for unspliced transcripts. Defaults to None. + s_read_depth (Optional[torch.Tensor], optional): Read depth tensor for spliced transcripts. Defaults to None. + u_cell_size_coef (Optional[Any], optional): Cell size coefficient for unspliced transcripts. Defaults to None. + ut_coef (Optional[Any], optional): Coefficient for unspliced transcripts. Defaults to None. + s_cell_size_coef (Optional[Any], optional): Cell size coefficient for spliced transcripts. Defaults to None. + st_coef (Optional[Any], optional): Coefficient for spliced transcripts. Defaults to None. + Returns: + Tuple[Poisson, Poisson]: A tuple of Poisson distributions for unspliced and spliced transcripts, respectively. + + Example: + >>> import torch + >>> from pyrovelocity.models._velocity_model import LogNormalModel + >>> num_cells = 10 + >>> num_genes = 20 + >>> likelihood = "Poisson" + >>> plate_size = 2 + >>> model = LogNormalModel(num_cells, num_genes, likelihood, plate_size) + >>> logger.info(model) + >>> ut = torch.rand(num_cells, num_genes) + >>> st = torch.rand(num_cells, num_genes) + >>> u_read_depth = torch.rand(num_cells, 1) + >>> s_read_depth = torch.rand(num_cells, 1) + >>> u_dist, s_dist = model.get_likelihood(ut, st, u_read_depth=u_read_depth, s_read_depth=s_read_depth) + >>> logger.info(f"u_dist: {u_dist}") + >>> logger.info(f"s_dist: {s_dist}") + >>> assert isinstance(u_dist, torch.distributions.Poisson) + >>> assert isinstance(s_dist, torch.distributions.Poisson) + """ + if self.likelihood != "Poisson": + likelihood_not_implemented_msg = ( + "In the future, the likelihood will be referred to via a " + "member of a sum type over supported distributions" + ) + raise NotImplementedError(likelihood_not_implemented_msg) + + if self.correct_library_size: + ut = relu(ut) + self.one * 1e-6 + st = relu(st) + self.one * 1e-6 + ut = pyro.deterministic("ut", ut, event_dim=0) + st = pyro.deterministic("st", st, event_dim=0) + ut = ut / torch.sum(ut, dim=-1, keepdim=True) + st = st / torch.sum(st, dim=-1, keepdim=True) + ut = pyro.deterministic("ut_norm", ut, event_dim=0) + st = pyro.deterministic("st_norm", st, event_dim=0) + ut = (ut + self.one * 1e-6) * u_read_depth + st = (st + self.one * 1e-6) * s_read_depth + else: + ut = relu(ut) + st = relu(st) + ut = pyro.deterministic("ut", ut, event_dim=0) + st = pyro.deterministic("st", st, event_dim=0) + ut = ut + self.one * 1e-6 + st = st + self.one * 1e-6 + + c_dist = LogNormal(ct, sigma_c) + u_dist = Poisson(ut) + s_dist = Poisson(st) + + return c_dist, u_dist, s_dist class VelocityModelAuto(LogNormalModel): """Automatically configured velocity model. @@ -698,3 +802,435 @@ def forward( u = pyro.sample("u", u_dist, obs=u_obs) s = pyro.sample("s", s_dist, obs=s_obs) return u, s + +class MultiVelocityModelAuto(LogNormalModel): + """Automatically configured MULTIOME velocity model. + + Args: + num_cells (int): _description_ + num_genes (int): _description_ + likelihood (str, optional): _description_. Defaults to "Poisson". + shared_time (bool, optional): _description_. Defaults to True. + t_scale_on (bool, optional): _description_. Defaults to False. + plate_size (int, optional): _description_. Defaults to 2. + latent_factor (str, optional): _description_. Defaults to "none". + latent_factor_size (int, optional): _description_. Defaults to 30. + latent_factor_operation (str, optional): _description_. Defaults to "selection". + include_prior (bool, optional): _description_. Defaults to False. + num_aux_cells (int, optional): _description_. Defaults to 100. + only_cell_times (bool, optional): _description_. Defaults to False. + decoder_on (bool, optional): _description_. Defaults to False. + add_offset (bool, optional): _description_. Defaults to False. + correct_library_size (Union[bool, str], optional): _description_. Defaults to True. + guide_type (str, optional): _description_. Defaults to "velocity". + cell_specific_kinetics (Optional[star], optional): _description_. Defaults to None. + kinetics_num (Optional[int], optional): _description_. Defaults to None. + + Examples: + >>> import torch + >>> from pyrovelocity.models._velocity_model import VelocityModelAuto + >>> model = VelocityModelAuto( + ... 3, + ... 4, + ... "Poisson", + ... True, + ... False, + ... 2, + ... "none", + ... latent_factor_operation="selection", + ... latent_factor_size=10, + ... include_prior=False, + ... num_aux_cells=0, + ... only_cell_times=True, + ... decoder_on=False, + ... add_offset=False, + ... correct_library_size=True, + ... guide_type="auto_t0_constraint", + ... cell_specific_kinetics=None, + ... **{} + ... ) + >>> logger.info(model) + """ + + @beartype + def __init__( + self, + num_cells: int, + num_genes: int, + likelihood: str = "Poisson", + shared_time: bool = True, + t_scale_on: bool = False, + plate_size: int = 2, + latent_factor: str = "none", + latent_factor_size: int = 30, + latent_factor_operation: str = "selection", + include_prior: bool = False, + num_aux_cells: int = 100, + only_cell_times: bool = False, + decoder_on: bool = False, + add_offset: bool = False, + correct_library_size: Union[bool, str] = True, + guide_type: str = "velocity", + cell_specific_kinetics: Optional[str] = None, + kinetics_num: Optional[int] = None, + **initial_values, + ) -> None: + assert num_cells > 0 and num_genes > 0 + super().__init__(num_cells, num_genes, likelihood, plate_size) + self.num_aux_cells = num_aux_cells + self.only_cell_times = only_cell_times + self.guide_type = guide_type + self.cell_specific_kinetics = cell_specific_kinetics + self.k = kinetics_num + + self.mask = initial_values.get( + "mask", torch.ones(self.num_cells, self.num_genes).bool() + ) + for key in initial_values: + self.register_buffer(f"{key}_init", initial_values[key]) + + self.shared_time = shared_time + self.t_scale_on = t_scale_on + self.add_offset = add_offset + self.plate_size = plate_size + + self.latent_factor = latent_factor + self.latent_factor_size = latent_factor_size + self.latent_factor_operation = latent_factor_operation + self.include_prior = include_prior + self.decoder_on = decoder_on + self.correct_library_size = correct_library_size + if self.decoder_on: + self.decoder = Decoder(1, self.num_genes, n_layers=2) + self.enumeration = "parallel" + + @beartype + def create_plates( + self, + c_obs: Optional[torch.Tensor] = None, + u_obs: Optional[torch.Tensor] = None, + s_obs: Optional[torch.Tensor] = None, + u_log_library: Optional[torch.Tensor] = None, + s_log_library: Optional[torch.Tensor] = None, + u_log_library_loc: Optional[torch.Tensor] = None, + s_log_library_loc: Optional[torch.Tensor] = None, + u_log_library_scale: Optional[torch.Tensor] = None, + s_log_library_scale: Optional[torch.Tensor] = None, + ind_x: Optional[torch.Tensor] = None, + cell_state: Optional[torch.Tensor] = None, + time_info: Optional[torch.Tensor] = None, + ) -> Tuple[plate, plate]: + # Call the parent class method + cell_plate, gene_plate = super().create_plates( + u_obs=u_obs, + s_obs=s_obs, + u_log_library=u_log_library, + s_log_library=s_log_library, + u_log_library_loc=u_log_library_loc, + s_log_library_loc=s_log_library_loc, + u_log_library_scale=u_log_library_scale, + s_log_library_scale=s_log_library_scale, + ind_x=ind_x, + cell_state=cell_state, + time_info=time_info, + ) + # You can add any additional logic here if needed + return cell_plate, gene_plate + + def sample_cell_gene_state(self, t, switching): + return ( + pyro.sample( + "cell_gene_state", + Bernoulli(logits=t - switching), + infer={"enumerate": self.enumeration}, + ) + == self.zero + ) + + @beartype + def __repr__(self) -> str: + return ( + f"\nVelocityModelAuto(\n" + f"\tnum_cells={self.num_cells}, \n" + f"\tnum_genes={self.num_genes}, \n" + f'\tlikelihood="{self.likelihood}", \n' + f"\tshared_time={self.shared_time}, \n" + f"\tt_scale_on={self.t_scale_on}, \n" + f"\tplate_size={self.plate_size}, \n" + f'\tlatent_factor="{self.latent_factor}", \n' + f"\tlatent_factor_size={self.latent_factor_size}, \n" + f'\tlatent_factor_operation="{self.latent_factor_operation}", \n' + f"\tinclude_prior={self.include_prior}, \n" + f"\tnum_aux_cells={self.num_aux_cells}, \n" + f"\tonly_cell_times={self.only_cell_times}, \n" + f"\tdecoder_on={self.decoder_on}, \n" + f"\tadd_offset={self.add_offset}, \n" + f"\tcorrect_library_size={self.correct_library_size}, \n" + f'\tguide_type="{self.guide_type}", \n' + f"\tcell_specific_kinetics={self.cell_specific_kinetics}, \n" + f"\tkinetics_num={self.k}\n" + f")\n" + ) + + @jaxtyped(typechecker=beartype) + def get_atac_rna( + self, + u_scale: RNAInputType, + s_scale: RNAInputType, + t: Tensor, # cells, 1 + t0_1: Tensor, + dt_1: Tensor, + dt_2: Tensor, + dt_3: Tensor, + alpha_c: Tensor, + alpha: Tensor, + alpha_off: Tensor, + beta: Tensor, + gamma: Tensor, + ) -> Tuple[RNAOutputType, RNAOutputType, RNAOutputType]: + """ + Computes the unspliced (u) and spliced (s) RNA expression levels and chromatin opening state (c) given + the model parameters. + + Args: + u_scale (torch.Tensor): Scaling factor for unspliced expression. + s_scale (torch.Tensor): Scaling factor for spliced expression. + t (Tensor): The time of each cell. + t0_1 (Tensor): Start time for chromatin opening. + dt_1 (Tensor): Time gap since chromatin opening for transcription start for each gene. + dt_2 (Tensor): Time gap since transcription start for chromatin closing for each gene. + dt_3 (Tensor): Time gap since transcription start for transcription stopping for each gene. + alpha_c (Tensor): The chromatin opening and closing rate. + alpha (Tensor): The transcription rate of each gene in the on state. + alpha_off (Tensor): The transcription rate of each gene in the off state. + beta (torch.Tensor): Splicing rate. + gamma (torch.Tensor): Degradation rate. + + Returns: + Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: The chromatin state (c), unspliced (u) and + spliced (s) RNA expression levels. + + + Examples: + >>> from pyrovelocity.models._velocity_model import MultiVelocityModelAuto + >>> import torch + >>> n_cells = 4 + >>> u_scale = torch.tensor(1.0) + >>> s_scale = torch.tensor(1.0) + >>> t = torch.arange(0, 120, 30).reshape(n_cells,1) # cells, 1 + >>> t0_1 = torch.tensor((10.0, 10.0)) + >>> dt_1 = torch.tensor((15.0, 15.0)) + >>> dt_2 = torch.tensor((77.0, 53.0)) + >>> dt_3 = torch.tensor((50.0, 70.0)) + >>> alpha_c = torch.tensor((0.1, 0.2)) + >>> alpha = torch.tensor((0.5, 0.3)) + >>> alpha_off = torch.tensor(0.0) + >>> beta = torch.tensor((0.09, 0.11)) + >>> gamma = torch.tensor((0.05, 0.06)) + >>> mod = MultiVelocityModelAuto(num_cells = n_cells, num_genes = 2) + >>> output = MultiVelocityModelAuto.get_atac_rna( + mod, u_scale, s_scale, t, t0_1, dt_1, dt_2, dt_3, alpha_c, alpha, alpha_off, beta, gamma + ) + (tensor([[0.0000, 0.0000], + [0.8647, 0.9817], + [0.9933, 1.0000], + [0.2228, 1.0000]]), + tensor([[0.0000, 0.0000], + [1.6662, 1.1191], + [5.1763, 2.6659], + [3.2144, 0.7263]]), + tensor([[0.0000, 0.0000], + [2.2120, 1.2959], + [8.2623, 4.3877], + [4.3359, 2.3326]])) + """ + + k = self.sample_cell_gene_state(t, t0_1) + + state, k_c_state, alpha_state, t0_state, k_c_vec, alpha_vec, tau_vec = get_cell_parameters( + t, t0_1, dt_1, dt_2, dt_3, alpha, alpha_off,k, + ) + + c0_vec, u0_vec, s0_vec = get_initial_states( + t0_state, k_c_state, alpha_c, alpha_state, beta, gamma, state + ) + + ct, ut, st = atac_mrna_dynamics( + tau_vec, c0_vec, u0_vec, s0_vec, k_c_vec, alpha_c, alpha_vec, beta, gamma + ) + + ut = ut * u_scale / s_scale + return ct, ut, st + + @beartype + def forward( + self, + c_obs: torch.Tensor, + u_obs: torch.Tensor, + s_obs: torch.Tensor, + u_log_library: Optional[torch.Tensor] = None, + s_log_library: Optional[torch.Tensor] = None, + u_log_library_loc: Optional[torch.Tensor] = None, + s_log_library_loc: Optional[torch.Tensor] = None, + u_log_library_scale: Optional[torch.Tensor] = None, + s_log_library_scale: Optional[torch.Tensor] = None, + ind_x: Optional[torch.Tensor] = None, + cell_state: Optional[torch.Tensor] = None, + time_info: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Defines the forward model, which computes the chromatin state (c), unspliced (u) and spliced + (s) RNA expression levels given the observations and model parameters. + + Args: + u_obs (Optional[torch.Tensor], optional): Observed unspliced RNA expression. Default is None. + s_obs (Optional[torch.Tensor], optional): Observed spliced RNA expression. Default is None. + c_obs (Optional[torch.Tensor], optional): Observed chromatin state. Default is None. + u_log_library (Optional[torch.Tensor], optional): Log-transformed library size for unspliced RNA. Default is None. + s_log_library (Optional[torch.Tensor], optional): Log-transformed library size for spliced RNA. Default is None. + u_log_library_loc (Optional[torch.Tensor], optional): Mean of log-transformed library size for unspliced RNA. Default is None. + s_log_library_loc (Optional[torch.Tensor], optional): Mean of log-transformed library size for spliced RNA. Default is None. + u_log_library_scale (Optional[torch.Tensor], optional): Scale of log-transformed library size for unspliced RNA. Default is None. + s_log_library_scale (Optional[torch.Tensor], optional): Scale of log-transformed library size for spliced RNA. Default is None. + ind_x (Optional[torch.Tensor], optional): Indices for the cells. Default is None. + cell_state (Optional[torch.Tensor], optional): Cell state information. Default is None. + time_info (Optional[torch.Tensor], optional): Time information for the cells. Default is None. + + Returns: + Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: The chromatin state (c), unspliced (u) and spliced (s) RNA expression levels. + + Examples: + >>> import torch + >>> from pyrovelocity.models._velocity_model import VelocityModelAuto + >>> u_obs=torch.tensor( + ... [[33., 1., 7., 1.], + ... [12., 30., 11., 3.], + ... [ 1., 1., 8., 5.]], + ... device="cpu", + >>> ) + >>> s_obs=torch.tensor( + ... [[32.0, 0.0, 6.0, 0.0], + ... [11.0, 29.0, 10.0, 2.0], + ... [0.0, 0.0, 7.0, 4.0]], + ... device="cpu", + >>> ) + >>> c_obs=torch.tensor( + ... [[1.0, 0.2, 0.4, 0.0], + ... [0.8, 0.2, 0.5, 0.3], + ... [0.0, 0.0, 0.1, 0.9]], + ... device="cpu", + >>> ) + >>> u_log_library=torch.tensor([[3.7377], [4.0254], [2.7081]], device="cpu") + >>> s_log_library=torch.tensor([[3.6376], [3.9512], [2.3979]], device="cpu") + >>> u_log_library_loc=torch.tensor([[3.4904], [3.4904], [3.4904]], device="cpu") + >>> s_log_library_loc=torch.tensor([[3.3289], [3.3289], [3.3289]], device="cpu") + >>> u_log_library_scale=torch.tensor([[0.6926], [0.6926], [0.6926]], device="cpu") + >>> s_log_library_scale=torch.tensor([[0.8214], [0.8214], [0.8214]], device="cpu") + >>> ind_x=torch.tensor([2, 0, 1], device="cpu") + >>> model = VelocityModelAuto(3,4) + >>> u, s = model.forward( + >>> u_obs, + >>> s_obs, + >>> u_log_library, + >>> s_log_library, + >>> u_log_library_loc, + >>> s_log_library_loc, + >>> u_log_library_scale, + >>> s_log_library_scale, + >>> ind_x, + >>> ) + >>> u, s + (tensor([[33., 1., 7., 1.], + [12., 30., 11., 3.], + [ 1., 1., 8., 5.]]), + tensor([[32., 0., 6., 0.], + [11., 29., 10., 2.], + [ 0., 0., 7., 4.]])) + """ + cell_plate, gene_plate = self.create_plates( + u_obs, + s_obs, + c_obs, + u_log_library, + s_log_library, + u_log_library_loc, + s_log_library_loc, + u_log_library_scale, + s_log_library_scale, + ind_x, + cell_state, + time_info, + ) + + with gene_plate, poutine.mask(mask=self.include_prior): + + alpha_c = pyro.sample("alpha_c", LogNormal(self.one, self.one)) + alpha = pyro.sample("alpha", LogNormal(self.one*20, self.one*10)) + gamma = pyro.sample("gamma", LogNormal(self.one*20, self.one*10)) + beta = pyro.sample("beta", LogNormal(self.one*20, self.one*10)) + alpha_off = self.zero + + t0_1 = pyro.sample("t0_1", Normal(self.zero, self.one*10)) + dt_1 = pyro.sample("dt_1", LogNormal(self.one*20, self.one*10)) + dt_2 = pyro.sample("dt_2", LogNormal(self.one*20, self.one*10)) + dt_3 = pyro.sample("dt_3", LogNormal(self.one*20, self.one*10)) + + u_scale = self.u_scale + s_scale = self.one + + with cell_plate: + t = pyro.sample( + "cell_time", + LogNormal(self.zero, self.one*50).mask(self.include_prior), + ) + + with cell_plate: + + u_cell_size_coef = ut_coef = s_cell_size_coef = st_coef = None + u_read_depth = pyro.sample( + "u_read_depth", LogNormal(u_log_library, u_log_library_scale) + ) + + s_read_depth = pyro.sample( + "s_read_depth", LogNormal(s_log_library, s_log_library_scale) + ) + + sigma_c = pyro.sample( + "sigma_c", LogNormal(0.2,0.2) + ) + + ct, ut, st = self.get_atac_rna( + u_scale, + s_scale, + t, # cells, 1 + t0_1, + dt_1, + dt_2, + dt_3, + alpha_c, + alpha, + alpha_off, + beta, + gamma) + + with gene_plate: + c_dist, u_dist, s_dist = self.get_likelihood_multiome( + ct, + ut, + st, + sigma_c, + u_log_library, + s_log_library, + u_scale, + s_scale, + u_read_depth=u_read_depth, + s_read_depth=s_read_depth, + u_cell_size_coef=u_cell_size_coef, + ut_coef=ut_coef, + s_cell_size_coef=s_cell_size_coef, + st_coef=st_coef, + ) + c = pyro.sample("c", c_dist, obs=c_obs) + u = pyro.sample("u", u_dist, obs=u_obs) + s = pyro.sample("s", s_dist, obs=s_obs) + return c, u, s diff --git a/src/pyrovelocity/models/_velocity_module.py b/src/pyrovelocity/models/_velocity_module.py index 16a75fb99..814fd1a0d 100644 --- a/src/pyrovelocity/models/_velocity_module.py +++ b/src/pyrovelocity/models/_velocity_module.py @@ -12,7 +12,7 @@ from scvi.module.base import PyroBaseModuleClass from pyrovelocity.logging import configure_logging -from pyrovelocity.models._velocity_model import VelocityModelAuto +from pyrovelocity.models._velocity_model import VelocityModelAuto, MultiVelocityModelAuto logger = configure_logging(__name__) @@ -242,3 +242,225 @@ def _get_fn_args_from_batch( cell_state, time_info, ), {} + +class MultiVelocityModule(PyroBaseModuleClass): + """ + VelocityModule is an scvi-tools pyro module that combines the VelocityModelAuto and pyro AutoGuideList classes. + + Args: + num_cells (int): Number of cells. + num_genes (int): Number of genes. + model_type (str, optional): Model type. Default is "auto". + guide_type (str, optional): Guide type. Default is "velocity_auto". + likelihood (str, optional): Likelihood type. Default is "Poisson". + shared_time (bool, optional): If True, a shared time parameter will be used. Default is True. + t_scale_on (bool, optional): If True, scale time parameter. Default is False. + plate_size (int, optional): Size of the plate set. Default is 2. + latent_factor (str, optional): Latent factor. Default is "none". + latent_factor_operation (str, optional): Latent factor operation mode. Default is "selection". + latent_factor_size (int, optional): Size of the latent factor. Default is 10. + inducing_point_size (int, optional): Inducing point size. Default is 0. + include_prior (bool, optional): If True, include prior in the model. Default is False. + use_gpu (str, optional): Accelerator type. Default is "auto". + num_aux_cells (int, optional): Number of auxiliary cells. Default is 0. + only_cell_times (bool, optional): If True, only model cell times. Default is True. + decoder_on (bool, optional): If True, use the decoder. Default is False. + add_offset (bool, optional): If True, add offset to the model. Default is True. + correct_library_size (Union[bool, str], optional): Library size correction method. Default is True. + cell_specific_kinetics (Optional[str], optional): Cell-specific kinetics method. Default is None. + kinetics_num (Optional[int], optional): Number of kinetics. Default is None. + **initial_values: Initial values for the model parameters. + + Examples: + >>> from scvi.module.base import PyroBaseModuleClass + >>> from pyrovelocity.models._velocity_module import VelocityModule + >>> num_cells = 10 + >>> num_genes = 20 + >>> velocity_module1 = VelocityModule( + ... num_cells, num_genes, model_type="auto", + ... guide_type="auto_t0_constraint", add_offset=False + ... ) + >>> type(velocity_module1.model) + + >>> type(velocity_module1.guide) + + >>> velocity_module2 = VelocityModule( + ... num_cells, num_genes, model_type="auto", + ... guide_type="auto", add_offset=True + ... ) + >>> type(velocity_module2.model) + + >>> type(velocity_module2.guide) + + """ + + def __init__( + self, + num_cells: int, + num_genes: int, + model_type: str = "auto", + guide_type: str = "velocity_auto", + likelihood: str = "Poisson", + shared_time: bool = True, + t_scale_on: bool = False, + plate_size: int = 2, + latent_factor: str = "none", + latent_factor_operation: str = "selection", + latent_factor_size: int = 10, + inducing_point_size: int = 0, + include_prior: bool = False, + use_gpu: str = "auto", + num_aux_cells: int = 0, + only_cell_times: bool = True, + decoder_on: bool = False, + add_offset: bool = True, + correct_library_size: Union[bool, str] = True, + cell_specific_kinetics: Optional[str] = None, + kinetics_num: Optional[int] = None, + **initial_values, + ) -> None: + super().__init__() + self.num_cells = num_cells + self.num_genes = num_genes + self.model_type = model_type + self.guide_type = guide_type + self._model = None + self.plate_size = plate_size + self.num_aux_cells = num_aux_cells + self.only_cell_times = only_cell_times + logger.info( + f"Model type: {self.model_type}, Guide type: {self.guide_type}" + ) + + self.cell_specific_kinetics = cell_specific_kinetics + + self._model = MultiVelocityModelAuto( + self.num_cells, + self.num_genes, + likelihood, + shared_time, + t_scale_on, + self.plate_size, + latent_factor, + latent_factor_operation=latent_factor_operation, + latent_factor_size=latent_factor_size, + include_prior=include_prior, + num_aux_cells=num_aux_cells, + only_cell_times=self.only_cell_times, + decoder_on=decoder_on, + add_offset=add_offset, + correct_library_size=correct_library_size, + guide_type=self.guide_type, + cell_specific_kinetics=self.cell_specific_kinetics, + **initial_values, + ) + + guide = AutoGuideList( + self._model, create_plates=self._model.create_plates + ) + guide.append( + AutoNormal( + poutine.block( + self._model, + expose=[ + "cell_time", + "u_read_depth", + "s_read_depth", + "kinetics_prob", + "kinetics_weights", + ], + ), + init_scale=0.1, + ) + ) + + if add_offset: + guide.append( + AutoLowRankMultivariateNormal( + poutine.block( + self._model, + expose=[ + "dt_switching", + "t0", + "u_scale", + "s_scale", + "u_offset", + "s_offset", + ], + ), + rank=10, + init_scale=0.1, + ) + ) + else: + guide.append( + AutoLowRankMultivariateNormal( + poutine.block( + self._model, + expose=[ + "dt_switching", + "t0", + "u_scale", + "s_scale", + ], + ), + rank=10, + init_scale=0.1, + ) + ) + self._guide = guide + + @property + def model(self) -> VelocityModelAuto: + return self._model + + @property + def guide(self) -> AutoGuideList: + return self._guide + + @staticmethod + def _get_fn_args_from_batch( + tensor_dict: Dict[str, torch.Tensor] + ) -> Tuple[ + Tuple[ + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + ], + Dict[Any, Any], + ]: + u_obs = tensor_dict["U"] + s_obs = tensor_dict["X"] + c_obs = tensor_dict['atac'] + u_log_library = tensor_dict["u_lib_size"] + s_log_library = tensor_dict["s_lib_size"] + u_log_library_mean = tensor_dict["u_lib_size_mean"] + s_log_library_mean = tensor_dict["s_lib_size_mean"] + u_log_library_scale = tensor_dict["u_lib_size_scale"] + s_log_library_scale = tensor_dict["s_lib_size_scale"] + ind_x = tensor_dict["ind_x"].long().squeeze() + cell_state = tensor_dict.get("pyro_cell_state") + time_info = tensor_dict.get("time_info") + return ( + c_obs, + u_obs, + s_obs, + u_log_library, + s_log_library, + u_log_library_mean, + s_log_library_mean, + u_log_library_scale, + s_log_library_scale, + ind_x, + cell_state, + time_info, + ), {} diff --git a/src/pyrovelocity/models/knn_model/__init__.py b/src/pyrovelocity/models/knn_model/__init__.py new file mode 100644 index 000000000..a8d28fa24 --- /dev/null +++ b/src/pyrovelocity/models/knn_model/__init__.py @@ -0,0 +1,5 @@ +from pyrovelocity.models._velocity import PyroVelocity + +__all__ = [ + PyroVelocity +] diff --git a/src/pyrovelocity/models/knn_model/_vector_fields.py b/src/pyrovelocity/models/knn_model/_vector_fields.py new file mode 100644 index 000000000..03969bd2e --- /dev/null +++ b/src/pyrovelocity/models/knn_model/_vector_fields.py @@ -0,0 +1,37 @@ +from beartype import beartype +from torch import Tensor +from typing import Tuple +from typing import List + +@beartype +def vector_field_1( t: float, + y: Tuple, + args: List): + """ + Vector field of mRNA dynamics of unspliced and spliced counts, based on a regulatory + function that that takes u,s as input and returns transcription (alpha), splicing (beta) + and degradation rates (gamma). + + Args: + t (Float): Integration time. Only used when vector field used in Diffrax library + and otherwise can be an arbitrary value. + y (Tuple): State of the system. Tuple of unspliced (u) and spliced counts (s). + args (List): List containing a regulatory function that takes u,s as input + and returns transcription (alpha), splicing (beta) and degradation + rates (gamma). + + Returns: + Tuple: Rates of change in y (= unspliced and spliced counts) + + Examples: + >>> + """ + + u, s = y + regulatory_function = args[0] + alpha, beta, gamma = regulatory_function(u,s) + du = alpha - beta*u + ds = beta*u - gamma*s + dy = du, ds + + return dy \ No newline at end of file diff --git a/src/pyrovelocity/models/knn_model/_velocity.py b/src/pyrovelocity/models/knn_model/_velocity.py new file mode 100644 index 000000000..99ca3155f --- /dev/null +++ b/src/pyrovelocity/models/knn_model/_velocity.py @@ -0,0 +1,749 @@ +import os +import pickle +import sys +from statistics import harmonic_mean +from typing import Dict, Optional, Sequence, Union +import mlflow +import numpy as np +import pyro +import torch +from anndata import AnnData +from beartype import beartype +from numpy import ndarray +from scvi.data import AnnDataManager +from scvi.data._constants import _SETUP_ARGS_KEY, _SETUP_METHOD_NAME +from scvi.data.fields import LayerField, NumericalObsField, CategoricalObsField, ObsmField +from scvi.model._utils import parse_device_args +from scvi.model.base import BaseModelClass, PyroSampleMixin +from scvi import REGISTRY_KEYS +from datetime import date +from scvi.model.base._utils import ( + _initialize_model, + _load_saved_files, + _validate_var_names, +) +from scvi.module.base import PyroBaseModuleClass + +from pyrovelocity.analysis.analyze import ( + compute_mean_vector_field, + compute_volcano_data, + vector_field_uncertainty, +) +from pyrovelocity.logging import configure_logging +from scvi.model.base import PyroSviTrainMixin +from pyrovelocity.models.knn_model._velocity_module import VelocityModule, MultiVelocityModule + +__all__ = ["PyroVelocity"] + +logger = configure_logging(__name__) + +class PyroVelocity(PyroSviTrainMixin, BaseModelClass, PyroSampleMixin): + """ + PyroVelocity is a class for constructing and training a Pyro model for + probabilistic RNA velocity estimation. This model leverages the + probabilistic programming language Pyro to estimate the parameters of models + for the dynamics of RNA transcription, splicing, and degradation, providing + the opportunity for insight into cellular states and associated state + transitions. It makes use of AnnData, scvi-tools, and other scverse + ecosystem libraries. + + Public methods include training the model with various configurations, + generating posterior samples for further analysis, and saving/loading the + model for reproducibility and further analysis. + + Attributes: + use_gpu (str): Whether and which GPU to use. + cell_specific_kinetics (Optional[str]): Type of cell-specific kinetics. + k (Optional[int]): Number of kinetics. + layers (List[str]): List of layers in the dataset. + input_type (str): Type of input data. + module (VelocityModule): + The Pyro module used for the velocity estimation model. + num_cells (int): Number of cells in the dataset. + num_samples (int): Number of posterior samples to generate. + _model_summary_string (str): Summary string for the model. + init_params_ (Dict[str, Any]): Initial parameters for the model. + + For usage examples, including training the model and generating posterior + samples, refer to the individual method docstrings. + """ + + """ + The `Methods` section is not supported by all documentation generators but + is provided detached from the class docstring for reference. Please + see the docstrings for each method for more details. This list may ignore + unused or private methods. + + Methods: + train: + Trains the PyroVelocity model using the provided data and configuration. + setup_anndata: + Set up AnnData object for compatibility with the scvi-tools + model training interface. + generate_posterior_samples: + Generates posterior samples for the given data using the trained + PyroVelocity model. + compute_statistics_from_posterior_samples: + Estimate statistics from posterior samples and add them to the + `posterior_samples` dictionary. + save_pyrovelocity_data: + Saves the PyroVelocity data to a pickle file. + save_model: + Saves the Pyro-Velocity model to a directory. + load_model: + Load the model from a directory with the same structure as that produced + by the save method. + """ + + def __init__( + self, + adata: AnnData, + adata_atac: Optional[AnnData] = None, + input_type: str = "raw", + shared_time: bool = True, + model_type: str = "auto", + guide_type: str = "auto", + likelihood: str = "Poisson", + t_scale_on: bool = False, + plate_size: int = 2, + latent_factor: str = "none", + latent_factor_operation: str = "selection", + inducing_point_size: int = 0, + latent_factor_size: int = 0, + include_prior: bool = False, + use_gpu: str = "auto", + init: bool = False, + num_aux_cells: int = 0, + only_cell_times: bool = True, + decoder_on: bool = False, + add_offset: bool = False, + correct_library_size: Union[bool, str] = True, + cell_specific_kinetics: Optional[str] = None, + kinetics_num: Optional[int] = None, + ) -> None: + """ + PyroVelocity class for estimating RNA velocity and related tasks. + + Args: + adata (AnnData): An AnnData object containing the gene expression data. + adata_atac (Optional[AnnData], optional): An AnnData object containing atac data. + input_type (str, optional): Type of input data. Can be "raw", "knn", or "raw_cpm". Defaults to "raw". + shared_time (bool, optional): Whether to use shared time. Defaults to True. + model_type (str, optional): Type of model to use. Defaults to "auto". + guide_type (str, optional): Type of guide to use. Defaults to "auto". + likelihood (str, optional): Type of likelihood to use. Defaults to "Poisson". + t_scale_on (bool, optional): Whether to use t_scale. Defaults to False. + plate_size (int, optional): Size of the plate. Defaults to 2. + latent_factor (str, optional): Type of latent factor. Defaults to "none". + latent_factor_operation (str, optional): Operation to perform on the latent factor. Defaults to "selection". + inducing_point_size (int, optional): Size of inducing points. Defaults to 0. + latent_factor_size (int, optional): Size of latent factors. Defaults to 0. + include_prior (bool, optional): Whether to include prior information. Defaults to False. + use_gpu (Union[bool, int], optional): Whether and which GPU to use. Defaults to 0. Can be False. + init (bool, optional): Whether to initialize the model. Defaults to False. + num_aux_cells (int, optional): Number of auxiliary cells. Defaults to 0. + only_cell_times (bool, optional): Whether to use only cell times. Defaults to True. + decoder_on (bool, optional): Whether to use decoder. Defaults to False. + add_offset (bool, optional): Whether to add offset. Defaults to False. + correct_library_size (Union[bool, str], optional): Whether to correct library size or method to correct. Defaults to True. + cell_specific_kinetics (Optional[str], optional): Type of cell-specific kinetics. Defaults to None. + kinetics_num (Optional[int], optional): Number of kinetics. Defaults to None. + + Examples: + >>> # import necessary libraries + >>> import numpy as np + >>> import anndata + >>> from pyrovelocity.utils import pretty_log_dict, print_anndata, generate_sample_data + >>> from pyrovelocity.tasks.preprocess import copy_raw_counts + >>> from pyrovelocity.models._velocity import PyroVelocity + ... + >>> # define fixtures + >>> try: + >>> tmp = getfixture("tmp_path") + >>> except NameError: + >>> import tempfile + >>> tmp = tempfile.TemporaryDirectory().name + >>> doctest_model_path = str(tmp) + "/save_pyrovelocity_doctest_model" + >>> print(doctest_model_path) + ... + >>> # setup sample data + >>> n_obs = 10 + >>> n_vars = 5 + >>> adata = generate_sample_data(n_obs=n_obs, n_vars=n_vars) + >>> copy_raw_counts(adata) + >>> print_anndata(adata) + >>> print(adata.X) + >>> print(adata.layers['spliced']) + >>> print(adata.layers['unspliced']) + >>> print(adata.obs['u_lib_size_raw']) + >>> print(adata.obs['s_lib_size_raw']) + >>> PyroVelocity.setup_anndata(adata) + ... + >>> # train model with macroscopic validation set + >>> model = PyroVelocity(adata) + >>> model.train(max_epochs=5, train_size=0.8, valid_size=0.2, use_gpu="auto") + >>> posterior_samples = model.generate_posterior_samples(model.adata, num_samples=30) + >>> print(posterior_samples.keys()) + >>> assert isinstance(posterior_samples, dict), f"Expected a dictionary, got {type(posterior_samples)}" + >>> posterior_samples_log = pretty_log_dict(posterior_samples) + >>> model.save_model(doctest_model_path, overwrite=True) + >>> model = PyroVelocity.load_model(doctest_model_path, adata, use_gpu="auto") + ... + >>> # train model with default parameters + >>> model = PyroVelocity(adata) + >>> model.train_faster(max_epochs=5, use_gpu="auto") + >>> model.save_model(doctest_model_path, overwrite=True) + >>> model = PyroVelocity.load_model(doctest_model_path, adata, use_gpu="auto") + >>> posterior_samples = model.generate_posterior_samples(model.adata, num_samples=30) + >>> posterior_samples_log = pretty_log_dict(posterior_samples) + >>> print(posterior_samples.keys()) + ... + >>> # train model with specified batch size + >>> model = PyroVelocity(adata) + >>> model.train_faster_with_batch(batch_size=24, max_epochs=5, use_gpu="auto") + >>> model.save_model(doctest_model_path, overwrite=True) + >>> model = PyroVelocity.load_model(doctest_model_path, adata, use_gpu="auto") + >>> posterior_samples = model.generate_posterior_samples(model.adata, num_samples=30) + >>> posterior_samples_log = pretty_log_dict(posterior_samples) + >>> print(posterior_samples.keys()) + ... + >>> # If running from an interactive session, the temporary directory + >>> # can be inspected to review the saved model files. When run as a + >>> # doctest it is automatically cleaned up after the test completes. + >>> print(f"Output located in {doctest_model_path}") + """ + self.use_gpu = use_gpu + self.cell_specific_kinetics = cell_specific_kinetics + self.k = kinetics_num + if input_type == "knn": + layers = ["Mu", "Ms"] + assert likelihood in {"Normal", "LogNormal"} + assert "Mu" in adata.layers + elif input_type == "raw_cpm": + layers = ["unspliced", "spliced"] + assert likelihood in {"Normal", "LogNormal"} + else: + layers = ["raw_unspliced", "raw_spliced"] + assert likelihood != "Normal" + + self.layers = layers + self.input_type = input_type + + super().__init__(adata) + # TODO: remove unused code + # from pyrovelocity.utils import init_with_all_cells + # if init: + # initial_values = init_with_all_cells( + # self.adata, + # input_type, + # shared_time, + # latent_factor, + # latent_factor_size, + # plate_size, + # num_aux_cells=num_aux_cells, + # ) + # else: + initial_values = {} + logger.info(self.summary_stats) + if not adata_atac: + self.module = VelocityModule( + self.summary_stats["n_cells"], + self.summary_stats["n_vars"], + self.summary_stats["n_batch"], + **initial_values, + ) + else: + self.module = MultiVelocityModule( + self.summary_stats["n_cells"], + self.summary_stats["n_vars"], + model_type=model_type, + guide_type=guide_type, + likelihood=likelihood, + shared_time=shared_time, + t_scale_on=t_scale_on, + plate_size=plate_size, + latent_factor=latent_factor, + latent_factor_operation=latent_factor_operation, + latent_factor_size=latent_factor_size, + inducing_point_size=inducing_point_size, + include_prior=include_prior, + use_gpu=use_gpu, + num_aux_cells=num_aux_cells, + only_cell_times=only_cell_times, + decoder_on=decoder_on, + add_offset=False, + correct_library_size=correct_library_size, + cell_specific_kinetics=cell_specific_kinetics, + kinetics_num=self.k, + **initial_values, + ) + self.num_cells = self.module.num_cells + self._model_summary_string = """ + RNA velocity Pyro model with parameters: + """ + self.init_params_ = self._get_init_params(locals()) + logger.info("Model initialized") + + def train( + self, + max_epochs: int = 500, + batch_size: int = 1000, + train_size: float = 1, + lr: float = 0.01, + **kwargs, + ): + """ + Training function for the model. + + Parameters + ---------- + max_epochs + Number of passes through the dataset. If `None`, defaults to + ``np.min([round((20000 / n_cells) * 400), 400])`` + train_size + Size of training set in the range [0.0, 1.0]. + batch_size + Minibatch size to use during training. If `None`, no minibatching occurs and all + data is copied to device (e.g., GPU). + lr + Optimiser learning rate (default optimiser is :class:`~pyro.optim.ClippedAdam`). + Specifying optimiser via plan_kwargs overrides this choice of lr. + kwargs + Other arguments to :py:meth:`scvi.model.base.PyroSviTrainMixin().train` method + """ + + self.max_epochs = max_epochs + kwargs["max_epochs"] = max_epochs + kwargs["batch_size"] = batch_size + kwargs["train_size"] = train_size + kwargs["lr"] = lr + + super().train(**kwargs) + + def enum_parallel_predict(self): + """work for parallel enumeration""" + return + + @classmethod + def setup_anndata(cls, adata: AnnData, adata_atac = None, batch_key = None, *args, **kwargs): + """ + Set up AnnData object for compatibility with the scvi-tools + model training interface. + + Args: + adata (AnnData): Anndata object to be used in model training. + """ + setup_method_args = cls._get_setup_method_args(**locals()) + + adata.obs["ind_x"] = np.arange(adata.n_obs).astype("int64") + + anndata_fields = [ + LayerField("U", "raw_unspliced", is_count_data=True), + LayerField("X", "raw_spliced", is_count_data=True), + CategoricalObsField(REGISTRY_KEYS.BATCH_KEY, batch_key), + NumericalObsField("ind_x", "ind_x"), + NumericalObsField("M_c", "n_cells") + ] + + if adata_atac: + adata.layers['atac'] = adata_atac.X + anndata_fields += [LayerField('atac', 'atac')] + adata.uns['atac'] = True + else: + adata.uns['atac'] = None + + if 'N_cn' in adata.obsm: + anndata_fields += [ObsmField('N_cn', 'N_cn')] + + adata_manager = AnnDataManager( + fields=anndata_fields, setup_method_args=setup_method_args + ) + + adata_manager.register_fields(adata, **kwargs) + cls.register_manager(adata_manager) + + def _export2adata(self, samples): + r""" + Export key model variables and samples + + Parameters + ---------- + samples + Dictionary with posterior mean, 5%/95% quantiles, SD, samples, generated by ``.sample_posterior()`` + + Returns + ------- + Dict + Updated dictionary with additional details is saved to ``adata.uns['mod']``. + """ + # add factor filter and samples of all parameters to unstructured data + results = { + "model_name": str(self.module.__class__.__name__), + "date": str(date.today()), + "var_names": self.adata.var_names.tolist(), + "obs_names": self.adata.obs_names.tolist(), + "post_sample_means": samples["post_sample_means"], + "post_sample_stds": samples["post_sample_stds"], + "post_sample_q05": samples["post_sample_q05"], + "post_sample_q95": samples["post_sample_q95"], + } + + return results + + def export_posterior( + self, + adata, + sample_kwargs = {"num_samples": 30, "batch_size" : None, + 'return_samples': True}, + export_slot: str = "mod", + full_velocity_posterior = False, + normalize = True): + """ + Summarises posterior distribution and exports results to anndata object. Also computes RNAvelocity (based on posterior of rates) and normalized counts (based on posterior of technical variables). + + - **adata.obs:** Latent time, sequencing depth constant. + + - **adata.var:** transcription/splicing/degredation rates, switch on and off times. + + - **adata.uns:** Posterior of all parameters ('mean', 'sd', 'q05', 'q95' and optionally all samples), model name, date. + + - **adata.layers:** ``velocity`` (expected gradient of spliced counts), ``velocity_sd`` (uncertainty in this gradient), ``spliced_norm``, ``unspliced_norm`` (normalized counts). + + - **adata.uns:** If ``return_samples: True`` and ``full_velocity_posterior = True`` full posterior distribution for velocity is saved in ``adata.uns['velocity_posterior']``. + + Parameters + ---------- + adata + AnnData object where results should be saved. + sample_kwargs + Optionally a dictionary of arguments for self.sample_posterior, namely: + + - **num_sample:s** Number of samples to use (Default = 1000). + - **batch_size:** Data batch size (keep low enough to fit on GPU, default 2048). + - **use_gpu:** Use gpu for generating samples. + - **return_samples:** Export all posterior samples (Otherwise just summary statistics). + export_slot + adata.uns slot where to export results. + full_velocity_posterior + Whether to save full posterior of velocity (only possible if "return_samples: True"). + normalize + Whether to compute normalized spliced and unspliced counts based on posterior of technical variables. + Returns + ------- + AnnData + AnnData object with posterior added in adata.obs, adata.var and adata.uns. + + """ + + if sample_kwargs['batch_size'] == None: + sample_kwargs['batch_size'] = adata.n_obs + + # generate samples from posterior distributions for all parameters + # and compute mean, 5%/95% quantiles and standard deviation + self.samples = self.sample_posterior(**sample_kwargs) + + # export posterior distribution summary for all parameters and + # annotation (model, date, var, obs and cell type names) to anndata object + adata.uns[export_slot] = self._export2adata(self.samples) + + if sample_kwargs['return_samples']: + print('Warning: Saving ALL posterior samples. Specify "return_samples: False" to save just summary statistics.') + adata.uns[export_slot]['post_samples'] = self.samples['posterior_samples'] + + adata.obs['Time (hours)'] = self.samples['post_sample_means']['T_c'].flatten() - np.min(self.samples['post_sample_means']['T_c'].flatten()) + adata.obs['Time Uncertainty (sd)'] = self.samples['post_sample_stds']['T_c'].flatten() + +# adata.layers['spliced mean'] = self.samples['post_sample_means']['mu_expression'][...,1] +# adata.layers['velocity'] = torch.tensor(self.samples['post_sample_means']['beta_g']) * \ +# self.samples['post_sample_means']['mu_expression'][...,0] - \ +# torch.tensor(self.samples['post_sample_means']['gamma_g']) * \ +# self.samples['post_sample_means']['mu_expression'][...,1] + + return adata + + def get_mlflow_logs(self): + return + + def compute_statistics_from_posterior_samples( + self, + adata: AnnData, + posterior_samples: Dict[str, ndarray], + vector_field_basis: str = "umap", + ncpus_use: int = 1, + ) -> Dict[str, ndarray]: + """ + Estimate statistics from posterior samples and add them to the + `posterior_samples` dictionary. The names of the statistics incorporated into + the dictionary are: + + - `gene_ranking` + - `original_spaces_embeds_magnitude` + - `genes` + - `vector_field_posterior_samples` + - `vector_field_posterior_mean` + - `fdri` + - `embeds_magnitude` + - `embeds_angle` + - `ut_mean` + - `st_mean` + - `pca_vector_field_posterior_samples` + - `pca_embeds_angle` + - `pca_fdri` + + The following data are removed from the `posterior_samples` dictionary: + + - `u` + - `s` + - `ut` + - `st` + + Each of these sets requires further documentation. + + Args: + adata (AnnData): Anndata object containing the data for which posterior samples + were computed. + posterior_samples (Dict[str, ndarray]): Dictionary containing the posterior samples + for each parameter. + vector_field_basis (str, optional): Basis for the vector field. Defaults to "umap". + ncpus_use (int, optional): Number of CPUs to use for computation. Defaults to 1. + + Returns: + Dict[str, ndarray]: Dictionary containing the posterior samples with added statistics. + """ + if ("u_scale" in posterior_samples) and ( + "s_scale" in posterior_samples + ): + scale = posterior_samples["u_scale"] / posterior_samples["s_scale"] + elif ("u_scale" in posterior_samples) and not ( + "s_scale" in posterior_samples + ): + scale = posterior_samples["u_scale"] + else: + scale = 1 + original_spaces_velocity_samples = ( + posterior_samples["beta"] * posterior_samples["ut"] / scale + - posterior_samples["gamma"] * posterior_samples["st"] + ) + original_spaces_embeds_magnitude = np.sqrt( + (original_spaces_velocity_samples**2).sum(axis=-1) + ) + + ( + vector_field_posterior_samples, + embeds_radian, + fdri, + ) = vector_field_uncertainty( + adata, + posterior_samples, + basis=vector_field_basis, + n_jobs=ncpus_use, + ) + embeds_magnitude = np.sqrt( + (vector_field_posterior_samples**2).sum(axis=-1) + ) + + mlflow.log_metric( + "FDR_sig_frac", round((fdri < 0.05).sum() / fdri.shape[0], 3) + ) + mlflow.log_metric("FDR_HMP", harmonic_mean(fdri)) + + compute_mean_vector_field( + posterior_samples=posterior_samples, + adata=adata, + basis=vector_field_basis, + n_jobs=ncpus_use, + ) + + vector_field_posterior_mean = adata.obsm[ + f"velocity_pyro_{vector_field_basis}" + ] + + gene_ranking, genes = compute_volcano_data( + [posterior_samples], [adata], time_correlation_with="st" + ) + gene_ranking = ( + gene_ranking.sort_values("mean_mae", ascending=False) + .head(300) + .sort_values("time_correlation", ascending=False) + ) + posterior_samples["gene_ranking"] = gene_ranking + posterior_samples[ + "original_spaces_embeds_magnitude" + ] = original_spaces_embeds_magnitude + posterior_samples["genes"] = genes + posterior_samples[ + "vector_field_posterior_samples" + ] = vector_field_posterior_samples + posterior_samples[ + "vector_field_posterior_mean" + ] = vector_field_posterior_mean + posterior_samples["fdri"] = fdri + posterior_samples["embeds_magnitude"] = embeds_magnitude + posterior_samples["embeds_angle"] = embeds_radian + posterior_samples["ut_mean"] = posterior_samples["ut"].mean(0).squeeze() + posterior_samples["st_mean"] = posterior_samples["st"].mean(0).squeeze() + + ( + pca_vector_field_posterior_samples, + pca_embeds_radian, + pca_fdri, + ) = vector_field_uncertainty( + adata, + posterior_samples, + basis="pca", + n_jobs=ncpus_use, + ) + posterior_samples[ + "pca_vector_field_posterior_samples" + ] = pca_vector_field_posterior_samples + posterior_samples["pca_embeds_angle"] = pca_embeds_radian + posterior_samples["pca_fdri"] = pca_fdri + + del posterior_samples["u"] + del posterior_samples["s"] + del posterior_samples["ut"] + del posterior_samples["st"] + return posterior_samples + + @beartype + def save_pyrovelocity_data( + self, + posterior_samples: Dict[str, ndarray], + pyrovelocity_data_path: os.PathLike | str, + ): + """ + Save the PyroVelocity data to a pickle file. + + Args: + posterior_samples (Dict[str, ndarray]): Dictionary containing the posterior samples + pyrovelocity_data_path (os.PathLike | str): Path to save the PyroVelocity data. + """ + with open(pyrovelocity_data_path, "wb") as f: + pickle.dump(posterior_samples, f) + for k in posterior_samples: + logger.debug(k, "after", sys.getsizeof(posterior_samples[k])) + + def save_model( + self, + dir_path: str, + prefix: Optional[str] = None, + overwrite: bool = True, + save_anndata: bool = False, + **anndata_write_kwargs, + ) -> None: + """ + Save the Pyro-Velocity model to a directory. + + Dispatches to the `save` method of the inherited `BaseModelClass` which + calls `torch.save` on a model state dictionary, variable names, and user + attributes. + + Args: + dir_path (str): Path to the directory where the model will be saved. + prefix (Optional[str], optional): Prefix to add to the saved files. Defaults to None. + overwrite (bool, optional): Whether to overwrite existing files. Defaults to True. + save_anndata (bool, optional): Whether to save the AnnData object. Defaults to False. + """ + super().save( + dir_path, prefix, overwrite, save_anndata, **anndata_write_kwargs + ) + pyro.get_param_store().save( + os.path.join(dir_path, "param_store_test.pt") + ) + + @classmethod + def load_model( + cls, + dir_path: str, + adata: Optional[AnnData] = None, + use_gpu: str = "auto", + prefix: Optional[str] = None, + backup_url: Optional[str] = None, + ) -> BaseModelClass: + """ + Load the model from a directory with the same structure as that produced + by the save method. + + Args: + dir_path (str): Path to the directory where the model is saved. + adata (Optional[AnnData], optional): Anndata object to load into the model. Defaults to None. + use_gpu (str, optional): Whether and which GPU to use. Defaults to "auto". + prefix (Optional[str], optional): Prefix to add to the saved files. Defaults to None. + backup_url (Optional[str], optional): URL to download the model from. Defaults to None. + + Raises: + RuntimeError: If the model is not an instance of PyroBaseModuleClass. + + Returns: + PyroVelocity: The loaded PyroVelocity model. + """ + load_adata = adata is None + _accelerator, _devices, device = parse_device_args( + accelerator=use_gpu, return_device="torch" + ) + logger.info( + f"\nLoading model with:\n" + f"\taccelerator: {_accelerator}\n" + f"\tdevices: {_devices}\n" + f"\tdevice: {device}\n" + ) + + ( + attr_dict, + var_names, + model_state_dict, + new_adata, + ) = _load_saved_files( + dir_path, + load_adata, + map_location=device, + prefix=prefix, + backup_url=backup_url, + ) + + adata = new_adata if new_adata is not None else adata + + _validate_var_names(adata, var_names) + + registry = attr_dict.pop("registry_") + method_name = registry.get(_SETUP_METHOD_NAME, "setup_anndata") + getattr(cls, method_name)( + adata, source_registry=registry, **registry[_SETUP_ARGS_KEY] + ) + + model = _initialize_model(cls, adata, attr_dict) + + for attr, val in attr_dict.items(): + setattr(model, attr, val) + + pyro.clear_param_store() + old_history = model.history_ + try: + model.module.load_state_dict(model_state_dict) + except RuntimeError as err: + if not isinstance(model.module, PyroBaseModuleClass): + raise err + logger.info( + "Preparing underlying `PyroBaseModuleClass` module for load" + ) + try: + model.train(max_epochs=1, max_steps=1) + except Exception: + model.train( + max_epochs=1, + max_steps=1, + batch_size=adata.shape[0], + train_size=0.8, + valid_size=0.2, + ) + model.module.load_state_dict(model_state_dict) + + model.history_ = old_history + model.to_device(device) + model.module.eval() + model._validate_anndata(adata) + pyro.get_param_store().load( + os.path.join(dir_path, "param_store_test.pt"), + map_location=device, + ) + return model diff --git a/src/pyrovelocity/models/knn_model/_velocity_model.py b/src/pyrovelocity/models/knn_model/_velocity_model.py new file mode 100644 index 000000000..df95ce651 --- /dev/null +++ b/src/pyrovelocity/models/knn_model/_velocity_model.py @@ -0,0 +1,485 @@ +from typing import Optional, Tuple, Union + +import pyro +from pyro.nn import PyroModule +import torch +from beartype import beartype +from jaxtyping import Float, jaxtyped +from pyro import poutine +from pyro.distributions import Bernoulli, LogNormal, Normal, Poisson +from pyro.nn import PyroModule, PyroSample +from pyro.primitives import plate +import pyro.distributions as dist +from scvi.nn import Decoder +from scvi.nn import one_hot +from torch.nn.functional import relu, softplus +from torch import Tensor +import torch.nn.functional as F + +from pyrovelocity.logging import configure_logging +from pyrovelocity.models._transcription_dynamics import mrna_dynamics, atac_mrna_dynamics, get_initial_states, get_cell_parameters + +from pyrovelocity.models.knn_model.regulatory_functions_torch import regulatory_function_1 +from pyrovelocity.models.knn_model._vector_fields import vector_field_1 + +logger = configure_logging(__name__) + +RNAInputType = Union[ + Float[torch.Tensor, ""], + Float[torch.Tensor, "num_genes"], + Float[torch.Tensor, "samples num_genes"], +] + +RNAOutputType = Union[ + Float[torch.Tensor, "num_cells num_genes"], + Float[torch.Tensor, "samples num_cells num_genes"], +] + +__all__ = [ + "VelocityModelAuto", + "MultiVelocityModelAuto", +] + +def G_a(mu, sd): + """ + Converts mean and standard deviation for a Gamma distribution into the shape parameter. + + Parameters + ---------- + mu + The mean of the Gamma distribution. + sd + The standard deviation of the Gamma distribution. + + Returns + ------- + Float + The shape parameter of the Gamma distribution. + """ + return mu**2/sd**2 + +def G_b(mu, sd): + """ + Converts mean and standard deviation for a Gamma distribution into the scale parameter. + + Parameters + ---------- + mu + The mean of the Gamma distribution. + sd + The standard deviation of the Gamma distribution. + + Returns + ------- + Float + The scale parameter of the Gamma distribution. + """ + + return mu/sd**2 + +class VelocityModelAuto(PyroModule): + """Automatically configured velocity model. + + Args: + num_cells (int): _description_ + num_genes (int): _description_ + likelihood (str, optional): _description_. Defaults to "Poisson". + shared_time (bool, optional): _description_. Defaults to True. + t_scale_on (bool, optional): _description_. Defaults to False. + plate_size (int, optional): _description_. Defaults to 2. + latent_factor (str, optional): _description_. Defaults to "none". + latent_factor_size (int, optional): _description_. Defaults to 30. + latent_factor_operation (str, optional): _description_. Defaults to "selection". + include_prior (bool, optional): _description_. Defaults to False. + num_aux_cells (int, optional): _description_. Defaults to 100. + only_cell_times (bool, optional): _description_. Defaults to False. + decoder_on (bool, optional): _description_. Defaults to False. + add_offset (bool, optional): _description_. Defaults to False. + correct_library_size (Union[bool, str], optional): _description_. Defaults to True. + guide_type (str, optional): _description_. Defaults to "velocity". + cell_specific_kinetics (Optional[str], optional): _description_. Defaults to None. + kinetics_num (Optional[int], optional): _description_. Defaults to None. + + Examples: + >>> import torch + >>> from pyrovelocity.models._velocity_model import VelocityModelAuto + >>> model = VelocityModelAuto( + ... 3, + ... 4, + ... "Poisson", + ... True, + ... False, + ... 2, + ... "none", + ... latent_factor_operation="selection", + ... latent_factor_size=10, + ... include_prior=False, + ... num_aux_cells=0, + ... only_cell_times=True, + ... decoder_on=False, + ... add_offset=False, + ... correct_library_size=True, + ... guide_type="auto_t0_constraint", + ... cell_specific_kinetics=None, + ... **{} + ... ) + >>> logger.info(model) + """ + + @beartype + def __init__( + self, + num_cells: int, + num_genes: int, + n_batch: int, + likelihood: str = "Poisson", + shared_time: bool = True, + t_scale_on: bool = False, + plate_size: int = 2, + latent_factor: str = "none", + latent_factor_size: int = 30, + latent_factor_operation: str = "selection", + include_prior: bool = False, + num_aux_cells: int = 100, + only_cell_times: bool = False, + decoder_on: bool = False, + add_offset: bool = False, + correct_library_size: Union[bool, str] = True, + guide_type: str = "velocity", + cell_specific_kinetics: Optional[str] = None, + kinetics_num: Optional[int] = None, + stochastic_v_ag_hyp_prior={"alpha": 6.0, "beta": 3.0}, + s_overdispersion_factor_hyp_prior={'alpha_mean': 100., 'beta_mean': 1., + 'alpha_sd': 1., 'beta_sd': 0.1}, + detection_hyp_prior={"alpha": 10.0, "mean_alpha": 1.0, "mean_beta": 1.0}, + detection_i_prior={"mean": 1, "alpha": 100}, + detection_gi_prior={"mean": 1, "alpha": 200}, + gene_add_alpha_hyp_prior={"alpha": 9.0, "beta": 3.0}, + gene_add_mean_hyp_prior={"alpha": 1.0, "beta": 100.0}, + Tmax_prior={"mean": 50., "sd": 20.}, + **initial_values, + ) -> None: + + super().__init__() + + assert num_cells > 0 and num_genes > 0 + self.num_aux_cells = num_aux_cells + self.only_cell_times = only_cell_times + self.guide_type = guide_type + self.cell_specific_kinetics = cell_specific_kinetics + self.k = kinetics_num + self.num_cells = num_cells + self.num_genes = num_genes + self.n_genes = num_genes + + self.mask = initial_values.get( + "mask", torch.ones(self.num_cells, self.num_genes).bool() + ) + for key in initial_values: + self.register_buffer(f"{key}_init", initial_values[key]) + + self.shared_time = shared_time + self.t_scale_on = t_scale_on + self.add_offset = add_offset + self.plate_size = plate_size + + self.latent_factor = latent_factor + self.latent_factor_size = latent_factor_size + self.latent_factor_operation = latent_factor_operation + self.include_prior = include_prior + self.decoder_on = decoder_on + self.correct_library_size = correct_library_size + if self.decoder_on: + self.decoder = Decoder(1, self.num_genes, n_layers=2) + + self.enumeration = "parallel" + # self.set_enumeration_strategy() + + self.n_obs = num_cells + self.n_vars = num_genes + self.n_batch = n_batch + + self.stochastic_v_ag_hyp_prior = stochastic_v_ag_hyp_prior + self.gene_add_alpha_hyp_prior = gene_add_alpha_hyp_prior + self.gene_add_mean_hyp_prior = gene_add_mean_hyp_prior + self.detection_hyp_prior = detection_hyp_prior + self.s_overdispersion_factor_hyp_prior = s_overdispersion_factor_hyp_prior + self.detection_gi_prior = detection_gi_prior + self.detection_i_prior = detection_i_prior + + h1 = 10 + h2 = 10 + self.l1 = [] + self.l2 = [] + self.l3 = [] + for i in range(self.n_vars): + self.l1 += [PyroModule[torch.nn.Linear](2, h1)] + self.l2 += [PyroModule[torch.nn.Linear](h1, h2)] + self.l3 += [PyroModule[torch.nn.Linear](h2, 3)] + + self.dropout = torch.nn.Dropout(p=0.1) + + + self.register_buffer( + "s_overdispersion_factor_alpha_mean", + torch.tensor(self.s_overdispersion_factor_hyp_prior["alpha_mean"]), + ) + self.register_buffer( + "s_overdispersion_factor_beta_mean", + torch.tensor(self.s_overdispersion_factor_hyp_prior["beta_mean"]), + ) + self.register_buffer( + "s_overdispersion_factor_alpha_sd", + torch.tensor(self.s_overdispersion_factor_hyp_prior["alpha_sd"]), + ) + self.register_buffer( + "s_overdispersion_factor_beta_sd", + torch.tensor(self.s_overdispersion_factor_hyp_prior["beta_sd"]), + ) + + self.register_buffer( + "detection_gi_prior_alpha", + torch.tensor(self.detection_gi_prior["alpha"]), + ) + self.register_buffer( + "detection_gi_prior_beta", + torch.tensor(self.detection_gi_prior["alpha"] / self.detection_gi_prior["mean"]), + ) + + self.register_buffer( + "detection_i_prior_alpha", + torch.tensor(self.detection_i_prior["alpha"]), + ) + self.register_buffer( + "detection_i_prior_beta", + torch.tensor(self.detection_i_prior["alpha"] / self.detection_i_prior["mean"]), + ) + + self.register_buffer( + "Tmax_mean", + torch.tensor(Tmax_prior["mean"]), + ) + + self.register_buffer( + "Tmax_sd", + torch.tensor(Tmax_prior["sd"]), + ) + + self.register_buffer( + "detection_mean_hyp_prior_alpha", + torch.tensor(self.detection_hyp_prior["mean_alpha"]), + ) + self.register_buffer( + "detection_mean_hyp_prior_beta", + torch.tensor(self.detection_hyp_prior["mean_beta"]), + ) + + self.register_buffer( + "stochastic_v_ag_hyp_prior_alpha", + torch.tensor(self.stochastic_v_ag_hyp_prior["alpha"]), + ) + self.register_buffer( + "stochastic_v_ag_hyp_prior_beta", + torch.tensor(self.stochastic_v_ag_hyp_prior["beta"]), + ) + self.register_buffer( + "gene_add_alpha_hyp_prior_alpha", + torch.tensor(self.gene_add_alpha_hyp_prior["alpha"]), + ) + self.register_buffer( + "gene_add_alpha_hyp_prior_beta", + torch.tensor(self.gene_add_alpha_hyp_prior["beta"]), + ) + self.register_buffer( + "gene_add_mean_hyp_prior_alpha", + torch.tensor(self.gene_add_mean_hyp_prior["alpha"]), + ) + self.register_buffer( + "gene_add_mean_hyp_prior_beta", + torch.tensor(self.gene_add_mean_hyp_prior["beta"]), + ) + + self.register_buffer( + "detection_hyp_prior_alpha", + torch.tensor(self.detection_hyp_prior["alpha"]), + ) + + self.register_buffer("one", torch.tensor(1.)) + self.register_buffer("ones", torch.ones((1, 1))) + + @beartype + def create_plates(self, + u_obs: torch.Tensor, + s_obs: torch.Tensor, + N_cn: torch.Tensor, + M_c: torch.Tensor, + ind_x: torch.Tensor, + batch_index: torch.Tensor): + """ + Creates a Pyro plate for observations. + + Parameters + ---------- + u_obs + Unspliced count data. + s_obs + Spliced count data. + ind_x + Index tensor to subsample. + batch_index + Index tensor indicating batch assignments. + + Returns + ------- + Pyro.plate + A Pyro plate representing the observations in the dataset. + """ + + return pyro.plate("obs_plate", size=self.n_obs, dim=-3, subsample=ind_x) + + @beartype + def __repr__(self) -> str: + return ( + f"\nKnnModel(\n" + f"\tnum_cells={self.num_cells}, \n" + f"\tnum_genes={self.num_genes}, \n" + f")\n" + ) + + @beartype + def forward(self, + u_obs: torch.Tensor, + s_obs: torch.Tensor, + N_cn: torch.Tensor, + M_c: torch.Tensor, + ind_x: torch.Tensor, + batch_index: torch.Tensor): + """ + Defines the forward model, which computes the unspliced (u) and spliced + (s) RNA expression levels given the observations and model parameters. + + Args: + u_obs (Optional[torch.Tensor], optional): Observed unspliced RNA expression. Default is None. + s_obs (Optional[torch.Tensor], optional): Observed spliced RNA expression. Default is None. + ind_x (Optional[torch.Tensor], optional): Indices for the cells. + batch_index (Optional[torch.Tensor], optional): Experimental batch index of cells. + + Returns: + + Examples: + >>> . + """ + + batch_size = len(ind_x) + obs2sample = one_hot(batch_index, self.n_batch) + k = N_cn.shape[1] + N_cn = N_cn.long() + M_c = M_c.long().unsqueeze(-1) + obs_plate = self.create_plates(u_obs, s_obs, N_cn, M_c, ind_x, batch_index) + + # ============= Expression Model =============== # + T_max = pyro.sample('Tmax', dist.Gamma(G_a(self.Tmax_mean, self.Tmax_sd), G_b(self.Tmax_mean, self.Tmax_sd))) + t_c_loc = pyro.sample('t_c_loc', dist.Gamma(self.one, self.one/0.5)) + t_c_scale = pyro.sample('t_c_scale', dist.Gamma(self.one, self.one/0.25)) + with obs_plate: + t_c = pyro.sample('t_c', dist.Normal(t_c_loc, t_c_scale).expand([batch_size, 1, 1])) + T_c = pyro.deterministic('T_c', t_c*T_max) + + # Time difference between neighbors (previously: T_c.unsqueeze(-1) - T_c[N_cn, :]): + with obs_plate: + delta_cn = pyro.sample('delta_cn', dist.Gamma(self.one, self.one).expand([batch_size, k, 1])) + + # Counts in each cell: + # with obs_plate: + # mu0_cg = pyro.sample('mu0_cg', dist.Gamma(self.one*5.0, self.one*1.0).expand([batch_size, self.n_genes, 2])) + + mu0_cg = pyro.deterministic('mu0_cg', torch.stack([u_obs, s_obs], axis = 2)/M_c) + + # ============= Measurement Model =============== # + # Cell specific relative detection efficiency with hierarchical prior across batches: + detection_mean_y_e = pyro.sample( + "detection_mean_y_e", + dist.Beta( + self.ones * self.detection_mean_hyp_prior_alpha, + self.ones * self.detection_mean_hyp_prior_beta, + ) + .expand([self.n_batch, 1]) + .to_event(2), + ) + detection_hyp_prior_alpha = pyro.deterministic( + "detection_hyp_prior_alpha", + self.detection_hyp_prior_alpha, + ) + + beta = detection_hyp_prior_alpha / (obs2sample @ detection_mean_y_e) + with obs_plate: + detection_y_c = pyro.sample( + "detection_y_c", + dist.Gamma(detection_hyp_prior_alpha.unsqueeze(dim=-1), beta.unsqueeze(dim=-1)), + ) # (self.n_obs, 1) + + # =====================Expected observed expression ======================= # + with obs_plate: + mu = pyro.deterministic('mu', (self.one*10**(-5) + mu0_cg * detection_y_c)) + + # Weight of each nearest neighbor: + wdash0_nc = pyro.sample('wdash0_nc', dist.Gamma(self.one*0.000001, self.one*1000000.0).expand([1,batch_size]).to_event(2)) + wdash5_nc = pyro.sample('wdash5_nc', dist.Gamma(self.one*0.1, self.one*0.1).expand([k-1,batch_size]).to_event(2)) + wdash_nc = pyro.deterministic('wdash_nc',torch.concat([wdash0_nc, wdash5_nc], axis = 0)) + w_nc = pyro.deterministic('w_nc', wdash_nc/torch.sum(wdash_nc, axis = 0)) + + # Vector field: + + # x_u = self.l1(torch.log(mu[...,0])) + # x_u = F.leaky_relu(x_u) + + alpha0_g = pyro.sample('alpha0_g', dist.Gamma(self.one, self.one).expand([1,self.n_vars]).to_event(2)) + beta0_g = pyro.sample('beta0_g', dist.Gamma(self.one, self.one/2.0).expand([1,self.n_vars]).to_event(2)) + gamma0_g = pyro.sample('gamma0_g', dist.Gamma(self.one, self.one).expand([1,self.n_vars]).to_event(2)) + + print('mu', mu.shape) + betas = [] + alphas = [] + gammas = [] + for i in range(self.n_vars): + x = self.l1[i](mu[:,i,...]) + x = F.leaky_relu(x) + x = self.l2[i](x) + x = F.leaky_relu(x) + x = self.l3[i](x) + output = torch.sigmoid(x) + betas += [output[:,0]] + gammas += [output[:,1]] + alphas += [output[:,2]] + + alphas = torch.concat(alphas, axis = -1) * alpha0_g + beta = torch.concat(betas, axis = -1) * beta0_g + gamma = torch.concat(gammas, axis = -1) * gamma0_g + + pyro.deterministic('alpha', alpha) + pyro.deterministic('beta', beta) + pyro.deterministic('gamma', gamma) + + # print('alpha', alpha.shape) + # print('beta', beta.shape) + # print('gamma', gamma.shape) + + du = alpha - beta*mu[...,0] + ds = beta*mu[...,0] - gamma*mu[...,1] + dy = du, ds + + # Predicted counts from each neighbor: + y = (mu[...,0], mu[...,1]) + # dy = vector_field_1(0.0,y,[regulatory_function_1]) + velocity = pyro.deterministic('velocity', dy[1]) + dy_cn = pyro.deterministic('dy_cn', torch.stack(dy, axis = -1)[N_cn,...]) + muhat_cg = pyro.deterministic('muhat_cg', (torch.stack(y, axis = -1) + torch.sum((w_nc.T.unsqueeze(-1).unsqueeze(-1) * delta_cn.unsqueeze(-1) * dy_cn), axis = 1)/k)) + + # =====================DATA likelihood ======================= # + with obs_plate: + # pyro.sample("data_target", dist.Poisson(rate = mu), + # obs=torch.stack([u_obs, s_obs], axis = 2)) + pyro.sample("constrain", dist.Normal(muhat_cg, 0.01), obs=mu) + + # print('MAE', torch.sum((torch.abs(muhat_cg - mu)))) + # print('1', torch.sum((mu - torch.stack([u_obs, s_obs], axis = 2))**2)) \ No newline at end of file diff --git a/src/pyrovelocity/models/knn_model/_velocity_module.py b/src/pyrovelocity/models/knn_model/_velocity_module.py new file mode 100644 index 000000000..bd53796bf --- /dev/null +++ b/src/pyrovelocity/models/knn_model/_velocity_module.py @@ -0,0 +1,402 @@ +from typing import Any +from typing import Dict +from typing import Optional +from typing import Tuple +from typing import Union + +import torch +from pyro import poutine +from pyro.infer.autoguide import AutoLowRankMultivariateNormal +from pyro.infer.autoguide import AutoNormal +from pyro.infer.autoguide.guides import AutoGuideList +from scvi.module.base import PyroBaseModuleClass +from scvi import REGISTRY_KEYS + +from pyrovelocity.logging import configure_logging +from pyrovelocity.models.knn_model._velocity_model import VelocityModelAuto + + +logger = configure_logging(__name__) + + +class VelocityModule(PyroBaseModuleClass): + """ + VelocityModule is an scvi-tools pyro module that combines the VelocityModelAuto and pyro AutoGuideList classes. + + Args: + num_cells (int): Number of cells. + num_genes (int): Number of genes. + model_type (str, optional): Model type. Default is "auto". + guide_type (str, optional): Guide type. Default is "velocity_auto". + likelihood (str, optional): Likelihood type. Default is "Poisson". + shared_time (bool, optional): If True, a shared time parameter will be used. Default is True. + t_scale_on (bool, optional): If True, scale time parameter. Default is False. + plate_size (int, optional): Size of the plate set. Default is 2. + latent_factor (str, optional): Latent factor. Default is "none". + latent_factor_operation (str, optional): Latent factor operation mode. Default is "selection". + latent_factor_size (int, optional): Size of the latent factor. Default is 10. + inducing_point_size (int, optional): Inducing point size. Default is 0. + include_prior (bool, optional): If True, include prior in the model. Default is False. + use_gpu (str, optional): Accelerator type. Default is "auto". + num_aux_cells (int, optional): Number of auxiliary cells. Default is 0. + only_cell_times (bool, optional): If True, only model cell times. Default is True. + decoder_on (bool, optional): If True, use the decoder. Default is False. + add_offset (bool, optional): If True, add offset to the model. Default is True. + correct_library_size (Union[bool, str], optional): Library size correction method. Default is True. + cell_specific_kinetics (Optional[str], optional): Cell-specific kinetics method. Default is None. + kinetics_num (Optional[int], optional): Number of kinetics. Default is None. + **initial_values: Initial values for the model parameters. + + Examples: + >>> from scvi.module.base import PyroBaseModuleClass + >>> from pyrovelocity.models._velocity_module import VelocityModule + >>> num_cells = 10 + >>> num_genes = 20 + >>> velocity_module1 = VelocityModule( + ... num_cells, num_genes, model_type="auto", + ... guide_type="auto_t0_constraint", add_offset=False + ... ) + >>> type(velocity_module1.model) + + >>> type(velocity_module1.guide) + + >>> velocity_module2 = VelocityModule( + ... num_cells, num_genes, model_type="auto", + ... guide_type="auto", add_offset=True + ... ) + >>> type(velocity_module2.model) + + >>> type(velocity_module2.guide) + + """ + + def __init__( + self, + num_cells: int, + num_genes: int, + n_batch: int, + model_type: str = "auto", + guide_type: str = "velocity_auto", + likelihood: str = "Poisson", + shared_time: bool = True, + t_scale_on: bool = False, + plate_size: int = 2, + latent_factor: str = "none", + latent_factor_operation: str = "selection", + latent_factor_size: int = 10, + inducing_point_size: int = 0, + include_prior: bool = False, + use_gpu: str = "auto", + num_aux_cells: int = 0, + only_cell_times: bool = True, + decoder_on: bool = False, + add_offset: bool = True, + correct_library_size: Union[bool, str] = True, + cell_specific_kinetics: Optional[str] = None, + kinetics_num: Optional[int] = None, + **initial_values, + ) -> None: + super().__init__() + self.num_cells = num_cells + self.num_genes = num_genes + self.n_batch = n_batch + self.model_type = model_type + self.guide_type = guide_type + self._model = None + self.plate_size = plate_size + self.num_aux_cells = num_aux_cells + self.only_cell_times = only_cell_times + logger.info( + f"Model type: {self.model_type}, Guide type: {self.guide_type}" + ) + + self.cell_specific_kinetics = cell_specific_kinetics + + self._model = VelocityModelAuto( + self.num_cells, + self.num_genes, + self.n_batch, + **initial_values, + ) + + guide = AutoGuideList( + self._model, create_plates=self._model.create_plates + ) + guide.append( + AutoNormal( + poutine.block( + self._model, + expose=[ + "Tmax", + "u_read_depth", + "s_read_depth", + "kinetics_prob", + "kinetics_weights", + ], + ), + init_scale=0.1, + ) + ) + + if add_offset: + guide.append( + AutoLowRankMultivariateNormal( + poutine.block( + self._model, + expose=["detection_y_c", + ], + ), + rank=10, + init_scale=0.1, + ) + ) + else: + guide.append( + AutoLowRankMultivariateNormal( + poutine.block( + self._model, + expose=[ + "beta", + "gamma", + "dt_switching", + "t0", + "u_scale", + "s_scale", + ], + ), + rank=10, + init_scale=0.1, + ) + ) + self._guide = guide + + @property + def model(self) -> VelocityModelAuto: + return self._model + + @property + def guide(self) -> AutoGuideList: + return self._guide + + @staticmethod + def _get_fn_args_from_batch( + tensor_dict: Dict[str, torch.Tensor] + ) -> Tuple[ + Tuple[ + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + ], + Dict[Any, Any], + ]: + u_obs = tensor_dict["U"] + s_obs = tensor_dict["X"] + N_cn = tensor_dict["N_cn"] + M_c = tensor_dict["M_c"] + ind_x = tensor_dict["ind_x"].long().squeeze() + batch_index = tensor_dict[REGISTRY_KEYS.BATCH_KEY] + return ( + u_obs, + s_obs, + N_cn, + M_c, + ind_x, + batch_index + ), {} + +class MultiVelocityModule(PyroBaseModuleClass): + """ + VelocityModule is an scvi-tools pyro module that combines the VelocityModelAuto and pyro AutoGuideList classes. + + Args: + num_cells (int): Number of cells. + num_genes (int): Number of genes. + model_type (str, optional): Model type. Default is "auto". + guide_type (str, optional): Guide type. Default is "velocity_auto". + likelihood (str, optional): Likelihood type. Default is "Poisson". + shared_time (bool, optional): If True, a shared time parameter will be used. Default is True. + t_scale_on (bool, optional): If True, scale time parameter. Default is False. + plate_size (int, optional): Size of the plate set. Default is 2. + latent_factor (str, optional): Latent factor. Default is "none". + latent_factor_operation (str, optional): Latent factor operation mode. Default is "selection". + latent_factor_size (int, optional): Size of the latent factor. Default is 10. + inducing_point_size (int, optional): Inducing point size. Default is 0. + include_prior (bool, optional): If True, include prior in the model. Default is False. + use_gpu (str, optional): Accelerator type. Default is "auto". + num_aux_cells (int, optional): Number of auxiliary cells. Default is 0. + only_cell_times (bool, optional): If True, only model cell times. Default is True. + decoder_on (bool, optional): If True, use the decoder. Default is False. + add_offset (bool, optional): If True, add offset to the model. Default is True. + correct_library_size (Union[bool, str], optional): Library size correction method. Default is True. + cell_specific_kinetics (Optional[str], optional): Cell-specific kinetics method. Default is None. + kinetics_num (Optional[int], optional): Number of kinetics. Default is None. + **initial_values: Initial values for the model parameters. + + Examples: + >>> from scvi.module.base import PyroBaseModuleClass + >>> from pyrovelocity.models._velocity_module import VelocityModule + >>> num_cells = 10 + >>> num_genes = 20 + >>> velocity_module1 = VelocityModule( + ... num_cells, num_genes, model_type="auto", + ... guide_type="auto_t0_constraint", add_offset=False + ... ) + >>> type(velocity_module1.model) + + >>> type(velocity_module1.guide) + + >>> velocity_module2 = VelocityModule( + ... num_cells, num_genes, model_type="auto", + ... guide_type="auto", add_offset=True + ... ) + >>> type(velocity_module2.model) + + >>> type(velocity_module2.guide) + + """ + + def __init__( + self, + num_cells: int, + num_genes: int, + model_type: str = "auto", + guide_type: str = "velocity_auto", + likelihood: str = "Poisson", + shared_time: bool = True, + t_scale_on: bool = False, + plate_size: int = 2, + latent_factor: str = "none", + latent_factor_operation: str = "selection", + latent_factor_size: int = 10, + inducing_point_size: int = 0, + include_prior: bool = False, + use_gpu: str = "auto", + num_aux_cells: int = 0, + only_cell_times: bool = True, + decoder_on: bool = False, + add_offset: bool = True, + correct_library_size: Union[bool, str] = True, + cell_specific_kinetics: Optional[str] = None, + kinetics_num: Optional[int] = None, + **initial_values, + ) -> None: + super().__init__() + self.num_cells = num_cells + self.num_genes = num_genes + self.n_genes = num_genes + self.model_type = model_type + self.guide_type = guide_type + self._model = None + self.plate_size = plate_size + self.num_aux_cells = num_aux_cells + self.only_cell_times = only_cell_times + logger.info( + f"Model type: {self.model_type}, Guide type: {self.guide_type}" + ) + + self.cell_specific_kinetics = cell_specific_kinetics + + self._model = VelocityModelAuto( + self.num_cells, + self.num_genes, + likelihood, + shared_time, + t_scale_on, + self.plate_size, + latent_factor, + latent_factor_operation=latent_factor_operation, + latent_factor_size=latent_factor_size, + include_prior=include_prior, + num_aux_cells=num_aux_cells, + only_cell_times=self.only_cell_times, + decoder_on=decoder_on, + add_offset=add_offset, + correct_library_size=correct_library_size, + guide_type=self.guide_type, + cell_specific_kinetics=self.cell_specific_kinetics, + **initial_values, + ) + + guide = AutoGuideList( + self._model, create_plates=self._model.create_plates + ) + guide.append( + AutoNormal( + poutine.block( + self._model, + expose=[ + "Tmax", + 't_c_loc' + ], + ), + init_scale=0.1, + ) + ) + + if add_offset: + guide.append( + AutoLowRankMultivariateNormal( + poutine.block( + self._model, + expose=[ + "Tmax", + + ], + ), + rank=10, + init_scale=0.1, + ) + ) + else: + guide.append( + AutoLowRankMultivariateNormal( + poutine.block( + self._model, + expose=[ + "Tmax", + ], + ), + rank=10, + init_scale=0.1, + ) + ) + self._guide = guide + + @property + def model(self) -> VelocityModelAuto: + return self._model + + @property + def guide(self) -> AutoGuideList: + return self._guide + + @staticmethod + def _get_fn_args_from_batch( + tensor_dict: Dict[str, torch.Tensor] + ) -> Tuple[ + Tuple[ + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + ], + Dict[Any, Any], + ]: + u_obs = tensor_dict["U"] + s_obs = tensor_dict["X"] + N_cn = tensor_dict["N_cn"] + ind_x = tensor_dict["ind_x"].long().squeeze() + return ( + u_obs, + s_obs, + N_cn + ), {} \ No newline at end of file diff --git a/src/pyrovelocity/models/knn_model/regulatory_functions_torch.py b/src/pyrovelocity/models/knn_model/regulatory_functions_torch.py new file mode 100644 index 000000000..1f2695b9e --- /dev/null +++ b/src/pyrovelocity/models/knn_model/regulatory_functions_torch.py @@ -0,0 +1,46 @@ +import torch +import torch.nn.functional as F +from beartype import beartype +from torch import Tensor +import numpy as np + +@beartype +def regulatory_function_1(u: Tensor, + s: Tensor, + h1: int = 100, + h2: int = 100): + """ + Regulatory function that contains a neural net with two hidden layers that takes unspliced and spliced + counts as input and returns transcription (alpha), splicing (beta) and degradation rates (gamma). + + Args: + u (Tensor): Unspliced counts + s (Tensor): Spliced counts + h1 (int): Nodes in hidden layer 1 + h2 (int): Nodes in hidden layer 2 + + Returns: + Tuple: transcription (alpha), splicing (beta) and degradation rate (gamma). + + Examples: + >>> + """ + + input = torch.tensor(np.array([np.array(u), np.array(s)]).T) + + l1 = torch.nn.Linear(2, h1) + l2 = torch.nn.Linear(h1, h2) + l3 = torch.nn.Linear(h2, 3) + + x = l1(input) + x = F.leaky_relu(x) + x = l2(x) + x = F.leaky_relu(x) + x = l3(x) + + output = torch.sigmoid(x) + beta = output[...,0].T + gamma = output[...,1].T + alphas = output[...,2].T + + return alphas, beta, gamma \ No newline at end of file diff --git a/src/pyrovelocity/tasks/preprocess.py b/src/pyrovelocity/tasks/preprocess.py index 45a11762b..d80b9c4c5 100644 --- a/src/pyrovelocity/tasks/preprocess.py +++ b/src/pyrovelocity/tasks/preprocess.py @@ -22,6 +22,10 @@ from pyrovelocity.tasks.data import load_anndata_from_path from pyrovelocity.utils import ensure_numpy_array, print_anndata +from scipy.sparse import csr_matrix +from scipy.sparse import issparse +import warnings + __all__ = [ "assign_colors", "compute_and_plot_qc", @@ -29,7 +33,7 @@ "get_high_us_genes", "get_thresh_histogram_title_from_path", "plot_high_us_genes", - "preprocess_dataset", + "preprocess_dataset" ] logger = configure_logging(__name__) @@ -634,6 +638,234 @@ def get_high_us_genes( logger.info(f"adata.shape after filtering: {adata.shape}") return adata +@beartype +def compute_metacells( + adata_rna: AnnData, + adata_atac: AnnData, + latent_key: str, + celltype_key: Optional[str] = None, + n_neighbors: int = 10, + n_neighbors_metacell: int = 5, + resolution: int = 50, + verbose: bool = True, + merge_knn_graph: bool = True, + merge_umap: bool = True, + umap_key: Optional[str] = None +) -> Tuple[AnnData, AnnData]: + """ + Computes metacells, using low-level clustering in a given a latent-space. By default, includes + summing up RNA counts, ATAC counts and optionally includes averaging UMAP coordinates and computing + a new knn-graph for meta-cells. If a celltype key is provided, metacells are named by their most + frequent celltype label. + + Args: + adata_rna (AnnData): AnnData object with RNA counts. + adata_atac (AnnData): AnnData object with ATAC counts. + latent_key (str): Name of latent space key in .obsm slot in adata_rna, e.g. X_pca. + celltype_key (Optional[str] optional): Name of cell type key in .obs column in adata_rna. + n_neighbors (int, optional): Number of nearest neighbors to use in initial knn-graph used + for metacell construction. Defaults to 10. + n_neighbors_metacell (int, optional): Number of nearest neighbors to use in new knn-graph + for metacells. Defaults to 5. + resolution (int, optional): Resolution at which to do leiden clustering for metacells. + Defaults to 50. + verbose (bool, optional): Whether to print out progress and make diagnostic plots + of metacell computations. Defaults to True. + merge_knn_graph (bool, optional): Whether to produce a new knn graph for metacells. + merge_umap (bool, optional): Whether to produce a umap embedding for metacells from the previous embedding of cells. + If no previous embedding is present it will be recomputed for the original anndata object. + umap_key (str, optional): Key of UMAP embedding in adata_rna.obsm + Returns: + Tuple[AnnData, AnnData]: Tuple containing two anndata objects containing RNA and ATAC counts for metacells. + + Examples: + >>> from pyrovelocity.tests.synthetic_AnnData import synthetic_AnnData + >>> adata_rna = synthetic_AnnData(seed = 1) + >>> adata_atac = synthetic_AnnData(seed = 2) + >>> sc.tl.pca(adata_rna, n_comps=3) + >>> compute_metacells(adata_rna, adata_atac, + latent_key = 'X_pca', + resolution = 1, + celltype_key = 'cell_type') + """ + + # Check input makes sense: + if len(adata_rna.obs_names) != len(adata_atac.obs_names): + raise ValueError("RNA and ATAC data do not contain the same cells number of cells.") + if (adata_rna.obs_names != adata_atac.obs_names).all(): + raise ValueError("RNA and ATAC data do not contain the same cells in obs_names.") + + if not isinstance(adata_rna.X, csr_matrix): + adata_rna.X = csr_matrix(adata_rna.X) + if not isinstance(adata_atac.X, csr_matrix): + adata_atac.X = csr_matrix(adata_atac.X) + + # Define functions for all processing steps: + + @beartype + def merge_RNA( + adata_rna: AnnData, + cluster_key: str, + celltype_key: Optional[str] = None, + verbose: bool = True, + )-> AnnData: + + if verbose: + print('merging RNA counts') + + X = np.concatenate([np.sum(adata_rna.X[adata_rna.obs[cluster_key] == c,:], axis = 0) for c in np.unique(adata_rna.obs[cluster_key])], axis = 0) + print(X.shape) + adata_meta = sc.AnnData(X = np.array(X)) + adata_meta.var = adata_rna.var + adata_meta.obs['n_cells'] = [np.sum(adata_rna.obs[cluster_key] == c) for c in np.unique(adata_rna.obs[cluster_key])] + if celltype_key: + adata_meta.obs[celltype_key] = [adata_rna[adata_rna.obs[cluster_key] == c,:].obs[celltype_key].mode()[0] for c in np.unique(adata_rna.obs[cluster_key])] + adata_meta.obs['RNA counts'] = np.sum(adata_meta.X, axis = 1) + + if 'unspliced' in adata_rna.layers: + adata_meta.layers['unspliced'] = np.concatenate([np.sum(adata_rna.layers['unspliced'][adata_rna.obs[cluster_key] == c,:], axis = 0) for c in np.unique(adata_rna.obs[cluster_key])], axis = 0) + adata_meta.layers['unspliced'] = csr_matrix(adata_meta.layers['unspliced'], dtype=np.uint16) + if 'spliced' in adata_rna.layers: + adata_meta.layers['spliced'] = np.concatenate([np.sum(adata_rna.layers['spliced'][adata_rna.obs[cluster_key] == c,:], axis = 0) for c in np.unique(adata_rna.obs[cluster_key])], axis = 0) + adata_meta.layers['spliced'] = csr_matrix(adata_meta.layers['spliced'], dtype=np.uint16) + + if verbose: + print('Mean RNA counts per cell before: ', np.mean(adata_rna.obs['RNA counts'])) + print('Mean RNA counts per cell after: ', np.mean(adata_meta.obs['RNA counts'])) + plt.hist(adata_rna.obs['RNA counts'], bins = 10, label = 'single cells', alpha = 0.5) + plt.hist(adata_meta.obs['RNA counts'],bins = 20, label = 'meta cells', alpha = 0.5) + plt.xlabel('Total Counts') + plt.ylabel('Occurences') + plt.legend() + plt.show() + + return adata_meta + + @beartype + def merge_UMAP( + adata_rna: AnnData, + adata_meta: AnnData, + cluster_key: str, + umap_key: str = 'X_umap', + verbose: bool = True, + )-> AnnData: + + if verbose: + print('merging UMAP') + + adata_meta.obsm[umap_key] = np.concatenate([np.expand_dims(np.mean(adata_rna.obsm[umap_key][adata_rna.obs[cluster_key] == c,:], axis = 0), axis = -1) for c in np.unique(adata_rna.obs[cluster_key])], axis = 1).T + + return adata_meta + + @beartype + def merge_ATAC( + adata_atac: AnnData, + cluster_key: str, + celltype_key: Optional[str] = None, + verbose: bool = True + )-> AnnData: + + if verbose: + print('merging ATAC counts') + + X = np.concatenate([np.sum(adata_atac.X[adata_rna.obs[cluster_key] == c,:], axis = 0) for c in np.unique(adata_atac.obs[cluster_key])], axis = 0) + adata_atac_meta = sc.AnnData(X = np.array(X)) + adata_atac_meta.var = adata_atac.var + adata_atac_meta.obs['n_cells'] = [np.sum(adata_atac.obs[cluster_key] == c) for c in np.unique(adata_atac.obs[cluster_key])] + if celltype_key: + adata_atac_meta.obs[celltype_key] = [adata_atac[adata_atac.obs[cluster_key] == c,:].obs[celltype_key].mode()[0] for c in np.unique(adata_atac.obs[cluster_key])] + adata_atac_meta.obs['ATAC counts'] = np.sum(adata_atac_meta.X, axis = 1) + + if verbose: + print('Mean ATAC counts per cell before: ', np.mean(adata_atac.obs['ATAC counts'])) + print('Mean ATAC counts per cell after: ', np.mean(adata_atac_meta.obs['ATAC counts'])) + plt.hist(adata_atac.obs['ATAC counts'], bins = 10, label = 'single cells', alpha = 0.5) + plt.hist(adata_atac_meta.obs['ATAC counts'],bins = 20, label = 'meta cells', alpha = 0.5) + plt.xlabel('Total Counts') + plt.ylabel('Occurences') + plt.legend() + plt.show() + + return adata_atac_meta + + @beartype + def merge_knn( + adata_rna: AnnData, + adata_meta: AnnData, + cluster_key: str, + n_neighbors: int = 6, + verbose: bool = True + ) -> AnnData: + + if verbose: + print('merging knn graph') + + distance_matrix = (adata_rna.obsp['distances'].toarray() != 0)*1 + clusters = adata_rna.obs[cluster_key] + for c in np.unique(clusters): + subset = np.array(clusters == c) + distance_matrix = np.concatenate([distance_matrix[~subset,:], np.expand_dims(np.sum(distance_matrix[subset,:], axis = 0), axis = 0)], axis = 0) + distance_matrix = np.concatenate([distance_matrix[:,~subset], np.expand_dims(np.sum(distance_matrix[:,subset], axis = 1), axis = 1)], axis = 1) + clusters = np.concatenate([clusters[~subset], np.expand_dims(np.array((c)), axis = 0)]) + adata_meta.obsm['N_cn'] = np.stack([np.argsort(-1*distance_matrix[i,:])[:n_neighbors+1] for i in range(len(distance_matrix[:,0]))], axis = 0) + + return adata_meta + + # Run through the metacell construction: + adata_atac = adata_atac[adata_rna.obs_names,:] + if celltype_key: + adata_atac.obs[celltype_key] = adata_rna.obs[celltype_key] + adata_rna.obs['RNA counts'] = np.sum(adata_rna.X, axis = 1) + adata_atac.obs['ATAC counts'] = np.sum(adata_atac.X, axis = 1) + + if verbose: + print('low resolution clustering cells') + sc.pp.neighbors(adata_rna, use_rep=latent_key, n_neighbors=n_neighbors) + cluster_key = "leiden" + sc.tl.leiden(adata_rna, key_added=cluster_key, resolution=resolution) + adata_atac.obs[cluster_key] = adata_rna.obs[cluster_key] + + if verbose: + print('total number of cells', len(adata_rna.obs_names)) + print('total number of meta-cells', len(np.unique(adata_rna.obs[cluster_key]))) + print('minimum cells/meta-cell', np.min(adata_rna.obs[cluster_key].value_counts())) + print('average cells/meta-cell', np.round(np.mean(adata_rna.obs[cluster_key].value_counts()),1)) + print('maximum cells/meta_cell', np.max(adata_rna.obs[cluster_key].value_counts())) + plt.hist(adata_rna.obs[cluster_key].value_counts(), bins = 10) + plt.xlabel('cells per meta-cell') + plt.ylabel('number of meta-cells') + + adata_meta = merge_RNA(adata_rna, verbose = verbose, + cluster_key = cluster_key, celltype_key = celltype_key) + + if merge_umap: + if not umap_key: + sc.tl.umap(adata_rna) + umap_key = 'X_umap' + elif umap_key not in adata_rna.obsm: + warnings.warn("Umap_key not found in AnnData. Computing with sc.tl.umap()...", category=UserWarning) + sc.tl.umap(adata_rna) + umap_key = 'X_umap' + adata_meta = merge_UMAP(adata_rna, adata_meta, umap_key = umap_key,verbose = verbose, cluster_key = cluster_key) + + adata_atac_meta = merge_ATAC(adata_atac, verbose = verbose, + cluster_key = cluster_key, celltype_key = celltype_key) + + adata_meta.obs['ATAC counts'] = adata_atac_meta.obs['ATAC counts'] + + if merge_knn_graph: + + adata_meta = merge_knn( + adata_rna, + adata_meta, + cluster_key = cluster_key, + n_neighbors = n_neighbors_metacell, + verbose = True) + + if verbose: + print('Done.') + + return adata_meta, adata_atac_meta # ------------------------------------------------------------------------------ diff --git a/src/pyrovelocity/tasks/train.py b/src/pyrovelocity/tasks/train.py index 0e69e694e..3d3b834b4 100644 --- a/src/pyrovelocity/tasks/train.py +++ b/src/pyrovelocity/tasks/train.py @@ -284,6 +284,7 @@ def check_shared_time(posterior_samples, adata): @beartype def train_model( adata: str | Path | AnnData, + adata_atac: Optional[AnnData] = None, guide_type: str = "auto", model_type: str = "auto", batch_size: int = -1, @@ -311,6 +312,7 @@ def train_model( Args: adata (str | AnnData): Path to a file that can be read to an AnnData object or an AnnData object. + adata_atac (Optional[AnnData], optional): An anndata object with atac data, matching the default adata input with RNA data. guide_type (str, optional): The type of guide function for the Pyro model. Default is "auto". model_type (str, optional): The type of Pyro model. Default is "auto". batch_size (int, optional): Batch size for training. Default is -1, which indicates using the full dataset. @@ -353,16 +355,18 @@ def train_model( >>> copy_raw_counts(adata) >>> _, model, posterior_samples = train_model(adata, use_gpu="auto", seed=99, max_epochs=200, loss_plot_path=loss_plot_path) """ + if isinstance(adata, str | Path): adata = load_anndata_from_path(adata) logger.info(f"AnnData object prior to model training") print_anndata(adata) - PyroVelocity.setup_anndata(adata) + PyroVelocity.setup_anndata(adata, adata_atac = adata_atac) model = PyroVelocity( adata, + adata_atac = adata_atac, likelihood=likelihood, model_type=model_type, guide_type=guide_type, diff --git a/src/pyrovelocity/tests/models/test_transcription_dynamics.py b/src/pyrovelocity/tests/models/test_transcription_dynamics.py index ec9c170ed..7493205a1 100644 --- a/src/pyrovelocity/tests/models/test_transcription_dynamics.py +++ b/src/pyrovelocity/tests/models/test_transcription_dynamics.py @@ -67,3 +67,129 @@ def test_mRNA_extreme_parameter_values(value): assert u is not None assert s is not None assert s is not None + +from pyrovelocity.models._transcription_dynamics import ( + atac_mrna_dynamics, + get_cell_parameters, + get_initial_states +) + +def test_get_cell_parameters(): + import torch + + n_cells = 4 + t = torch.arange(0, 120, 30).reshape(n_cells, 1) + t0_1 = torch.tensor((10.0, 10.0)) + dt_1 = torch.tensor((15.0, 15.0)) + dt_2 = torch.tensor((77.0, 53.0)) + dt_3 = torch.tensor((50.0, 70.0)) + alpha = torch.tensor((0.5, 0.3)) + alpha_off = torch.tensor(0.0) + k = torch.tensor((1.0, 1.0),(1.0,1.0)) + + output = get_cell_parameters( + t, t0_1, dt_1, dt_2, dt_3, alpha, alpha_off,k + ) + + correct_output = (torch.tensor([[0, 0], + [2, 2], + [2, 2], + [3, 3]]),torch.tensor([[0., 1., 1., 0., 0.], + [0., 1., 1., 1., 0.]]), torch.tensor([[0.0000, 0.0000, 0.5000, 0.5000, 0.0000], + [0.0000, 0.0000, 0.3000, 0.0000, 0.0000]]), torch.tensor([[ 0., 10., 25., 75., 102.], + [ 0., 10., 25., 78., 95.]]), torch.tensor([[0., 0.], + [1., 1.], + [1., 1.], + [0., 1.]]), torch.tensor([[0.0000, 0.0000], + [0.5000, 0.3000], + [0.5000, 0.3000], + [0.5000, 0.0000]]), torch.tensor([[ 0., 0.], + [ 5., 5.], + [35., 35.], + [15., 12.]])) + + for i in range(len(output)): + assert torch.allclose(output[i], correct_output[i], atol=1e-3), f"Output at index {i} is incorrect" + +def test_get_initial_states(): + import torch + + alpha_c = torch.tensor((0.1, 0.2)) + beta = torch.tensor((0.09, 0.11)) + gamma = torch.tensor((0.05, 0.06)) + state = torch.tensor([[0, 0],[2, 2],[2, 2],[3, 3]]) + k_c_state = torch.tensor([[0., 1., 1., 0., 0.], [0., 1., 1., 1., 0.]]) + alpha_state = torch.tensor([[0.0000, 0.0000, 0.5000, 0.5000, 0.0000],[0.0000, 0.0000, 0.3000, 0.0000, 0.0000]]) + t0_state = torch.tensor([[ 0., 10., 25., 75., 102.],[ 0., 10., 25., 78., 95.]]) + + output = get_initial_states( + t0_state, k_c_state, alpha_c, alpha_state, beta, gamma, state + ) + + correct_output = (torch.tensor([[0.0000, 0.0000], + [0.7769, 0.9502], + [0.7769, 0.9502], + [0.9985, 1.0000]]), + torch.tensor([[0.0000, 0.0000], + [0.0000, 0.0000], + [0.0000, 0.0000], + [5.4451, 2.7188]]), + torch.tensor([[0.0000, 0.0000], + [0.0000, 0.0000], + [0.0000, 0.0000], + [9.1791, 4.7921]])) + + for i in range(len(output)): + assert torch.allclose(output[i], correct_output[i], atol=1e-3), f"Output at index {i} is incorrect" + + +def test_atac_mrna_dynamics(): + import torch + + input = [torch.tensor([[ 0., 0.], + [ 5., 5.], + [35., 35.], + [15., 12.]]), torch.tensor([[0.0000, 0.0000], + [0.7769, 0.9502], + [0.7769, 0.9502], + [0.9985, 1.0000]]), torch.tensor([[0.0000, 0.0000], + [0.0000, 0.0000], + [0.0000, 0.0000], + [5.4451, 2.7188]]), torch.tensor([[0.0000, 0.0000], + [0.0000, 0.0000], + [0.0000, 0.0000], + [9.1791, 4.7921]]), torch.tensor([[0., 0.], + [1., 1.], + [1., 1.], + [0., 1.]]), torch.tensor([0.1000, 0.2000]), torch.tensor([[0.0000, 0.0000], + [0.5000, 0.3000], + [0.5000, 0.3000], + [0.5000, 0.0000]]), torch.tensor([0.0900, 0.1100]), torch.tensor([0.0500, 0.0600])] + + tau_vec = input[0] + c0_vec = input[1] + u0_vec = input[2] + s0_vec = input[3] + k_c_vec = input[4] + alpha_c = input[5] + alpha_vec = input[6] + beta = input[7] + gamma = input[8] + + output = atac_mrna_dynamics( + tau_vec, c0_vec, u0_vec, s0_vec, k_c_vec, alpha_c, alpha_vec, beta, gamma + ) + + correct_output = (torch.tensor([[0.0000, 0.0000], + [0.8647, 0.9817], + [0.9933, 1.0000], + [0.2228, 1.0000]]), torch.tensor([[0.0000, 0.0000], + [1.6662, 1.1191], + [5.1763, 2.6659], + [3.2144, 0.7263]]), torch.tensor([[0.0000, 0.0000], + [2.2120, 1.2959], + [8.2623, 4.3877], + [4.3359, 2.3326]])) + + for i in range(len(output)): + assert torch.allclose(output[i], correct_output[i], atol=1e-3), f"Output at index {i} is incorrect" diff --git a/src/pyrovelocity/tests/models/test_velocity_model.py b/src/pyrovelocity/tests/models/test_velocity_model.py index 1804fb2e8..7f1e0db87 100644 --- a/src/pyrovelocity/tests/models/test_velocity_model.py +++ b/src/pyrovelocity/tests/models/test_velocity_model.py @@ -122,3 +122,45 @@ def test_forward_method(self, velocity_model_auto): assert u.shape == (3, 4) assert s.shape == (3, 4) + +def test_MultiVelocityModelAuto(): + from pyrovelocity.models._velocity_model import MultiVelocityModelAuto + +def test_MultiVelocityModelAuto_get_atac_rna(): + from pyrovelocity.models._velocity_model import MultiVelocityModelAuto + import torch + + n_cells = 4 + u_scale = torch.tensor(1.0) + s_scale = torch.tensor(1.0) + t = torch.arange(0, 120, 30).reshape(n_cells,1) # cells, 1 + t0_1 = torch.tensor((10.0, 10.0)) + dt_1 = torch.tensor((15.0, 15.0)) + dt_2 = torch.tensor((77.0, 53.0)) + dt_3 = torch.tensor((50.0, 70.0)) + alpha_c = torch.tensor((0.1, 0.2)) + alpha = torch.tensor((0.5, 0.3)) + alpha_off = torch.tensor(0.0) + beta = torch.tensor((0.09, 0.11)) + gamma = torch.tensor((0.05, 0.06)) + + mod = MultiVelocityModelAuto(num_cells = n_cells, num_genes = 2) + output = MultiVelocityModelAuto.get_atac_rna( + mod, u_scale, s_scale, t, t0_1, dt_1, dt_2, dt_3, alpha_c, alpha, alpha_off, beta, gamma + ) + + correct_output = ((torch.tensor([[0.0000, 0.0000], + [0.8647, 0.9817], + [0.9933, 1.0000], + [0.2228, 1.0000]]), + torch.tensor([[0.0000, 0.0000], + [1.6662, 1.1191], + [5.1763, 2.6659], + [3.2144, 0.7263]]), + torch.tensor([[0.0000, 0.0000], + [2.2120, 1.2959], + [8.2623, 4.3877], + [4.3359, 2.3326]]))) + + for i in range(len(output)): + assert torch.allclose(output[i], correct_output[i], atol=1e-3), f"Output at index {i} is incorrect" diff --git a/src/pyrovelocity/tests/synthetic_AnnData.py b/src/pyrovelocity/tests/synthetic_AnnData.py new file mode 100644 index 000000000..a9bfc4916 --- /dev/null +++ b/src/pyrovelocity/tests/synthetic_AnnData.py @@ -0,0 +1,62 @@ +"""Producing synthetic AnnData for tests.""" + +import numpy as np +import pandas as pd +import anndata as ad + +def synthetic_AnnData( + n_cell_types: int = 3, + n_genes: int = 10, + cells_per_type: int = 20, + seed: int = 42 + ): + + """ + Produces a simple synthetic AnnData object. + + Args: + n_cell_types (int): Number of cell types. + n_genes (int): Number of genes. + cells_per_type (int): Number of cells per cell type. + seed (int): Random seed. + Returns: + AnnData: Synthetic AnnData object. + + Examples: + >>> synthetic_AnnData() + """ + + # Number of genes, cells, and cell types + n_genes = 10 + n_cells = cells_per_type * n_cell_types + + # Create synthetic gene expression data + # Each cell type will have slightly different expression profiles + np.random.seed(seed) # For reproducibility + + cells_per_type = int(n_cells/n_cell_types) + # Generate data with different means for different cell types + expression_data = np.vstack([ + np.random.normal(loc=i, scale=0.5, size=(cells_per_type, n_genes)) + for i in range(n_cell_types) + ]) + + # Create an AnnData object + adata = ad.AnnData(X=expression_data) + + # Add cell type annotations + cell_types = [] + for i in range(n_cell_types): + cell_types += ['Type ' + str(i)] * cells_per_type + adata.obs['cell_type'] = pd.Categorical(cell_types) + + # Add gene names (e.g., Gene1, Gene2, ..., Gene20) + gene_names = [f'Gene{i+1}' for i in range(n_genes)] + adata.var['gene_names'] = gene_names + + # Add cell names (e.g., Cell1, Cell2, ..., Cell30) + cell_names = [f'Cell{i+1}' for i in range(n_cells)] + adata.obs_names = cell_names + adata.var_names = gene_names + + return adata diff --git a/src/pyrovelocity/tests/tasks/__init__.py b/src/pyrovelocity/tests/tasks/__init__.py index 9148a2e22..69f0ed42b 100644 --- a/src/pyrovelocity/tests/tasks/__init__.py +++ b/src/pyrovelocity/tests/tasks/__init__.py @@ -1 +1 @@ -"""Unit test package for pyrovelocity.tasks""" +"""Unit test package for pyrovelocity.tasks""" \ No newline at end of file diff --git a/src/pyrovelocity/tests/tasks/test_train_model.py b/src/pyrovelocity/tests/tasks/test_train_model.py new file mode 100644 index 000000000..7f516485a --- /dev/null +++ b/src/pyrovelocity/tests/tasks/test_train_model.py @@ -0,0 +1,20 @@ +"""Tests for `pyrovelocity._train_model` task.""" + +from pyrovelocity.tasks.preprocess import copy_raw_counts +from pyrovelocity.tasks.train import train_model +from pyrovelocity.utils import generate_sample_data + + +def test_train_model(tmp_path): + loss_plot_path = str(tmp_path) + "/loss_plot_docs.png" + print(loss_plot_path) + adata = generate_sample_data(random_seed=99) + copy_raw_counts(adata) + _, model, posterior_samples = train_model( + adata, + adata_atac=None, + use_gpu="auto", + seed=99, + max_epochs=200, + loss_plot_path=loss_plot_path, + ) diff --git a/src/pyrovelocity/tests/test_preprocess.py b/src/pyrovelocity/tests/test_preprocess.py new file mode 100644 index 000000000..9f706a7aa --- /dev/null +++ b/src/pyrovelocity/tests/test_preprocess.py @@ -0,0 +1,20 @@ +"""Tests for `pyrovelocity.tasks.preprocess` module.""" + + +def test_load_preprocess(): + from pyrovelocity.tasks import preprocess + + print(preprocess.__file__) + +def test_compute_metacells(): + from pyrovelocity.tasks.preprocess import compute_metacells + from pyrovelocity.tests.synthetic_AnnData import synthetic_AnnData + import scanpy as sc + adata_rna = synthetic_AnnData(seed = 1) + adata_atac = synthetic_AnnData(seed = 2) + sc.tl.pca(adata_rna, n_comps=3) + compute_metacells(adata_rna, adata_atac, + latent_key = 'X_pca', + resolution = 1, + celltype_key = 'cell_type') + diff --git a/src/pyrovelocity/tests/test_synthetic_AnnData.py b/src/pyrovelocity/tests/test_synthetic_AnnData.py new file mode 100644 index 000000000..ae0f838db --- /dev/null +++ b/src/pyrovelocity/tests/test_synthetic_AnnData.py @@ -0,0 +1,5 @@ +"""Test synthetic_AnnData function.""" + +def test_synthetic_AnnData(): + from pyrovelocity.tests.synthetic_AnnData import synthetic_AnnData + synthetic_AnnData() \ No newline at end of file