From b4c100f41b4a16085dd264cec21974e40b33400e Mon Sep 17 00:00:00 2001 From: Alexander Aivazidis Date: Mon, 8 Jul 2024 21:22:18 +0000 Subject: [PATCH 01/47] added atac_layer argument to train_model tasks and made tests for it --- src/pyrovelocity/tasks/train.py | 13 ++++++++++++ src/pyrovelocity/tests/tasks/__init__.py | 2 +- .../tests/tasks/test_train_model.py | 20 +++++++++++++++++++ 3 files changed, 34 insertions(+), 1 deletion(-) create mode 100644 src/pyrovelocity/tests/tasks/test_train_model.py diff --git a/src/pyrovelocity/tasks/train.py b/src/pyrovelocity/tasks/train.py index 0e69e694e..15fca8135 100644 --- a/src/pyrovelocity/tasks/train.py +++ b/src/pyrovelocity/tasks/train.py @@ -284,6 +284,7 @@ def check_shared_time(posterior_samples, adata): @beartype def train_model( adata: str | Path | AnnData, + atac_layer: Optional[str] = None, guide_type: str = "auto", model_type: str = "auto", batch_size: int = -1, @@ -311,6 +312,7 @@ def train_model( Args: adata (str | AnnData): Path to a file that can be read to an AnnData object or an AnnData object. + atac_layer (Optional[str], optional): Name of AnnData layer that contains atac data, if present. guide_type (str, optional): The type of guide function for the Pyro model. Default is "auto". model_type (str, optional): The type of Pyro model. Default is "auto". batch_size (int, optional): Batch size for training. Default is -1, which indicates using the full dataset. @@ -353,7 +355,18 @@ def train_model( >>> copy_raw_counts(adata) >>> _, model, posterior_samples = train_model(adata, use_gpu="auto", seed=99, max_epochs=200, loss_plot_path=loss_plot_path) """ +<<<<<<< HEAD if isinstance(adata, str | Path): +||||||| parent of 9d2f67d (added atac_layer argument to train_model tasks and made tests for it) + if isinstance(adata, str): +======= + + if atac_layer: + logger.info( + "Multiome model not yet implemented. Proceeding without atac data." + ) + if isinstance(adata, str): +>>>>>>> 9d2f67d (added atac_layer argument to train_model tasks and made tests for it) adata = load_anndata_from_path(adata) logger.info(f"AnnData object prior to model training") diff --git a/src/pyrovelocity/tests/tasks/__init__.py b/src/pyrovelocity/tests/tasks/__init__.py index 9148a2e22..69f0ed42b 100644 --- a/src/pyrovelocity/tests/tasks/__init__.py +++ b/src/pyrovelocity/tests/tasks/__init__.py @@ -1 +1 @@ -"""Unit test package for pyrovelocity.tasks""" +"""Unit test package for pyrovelocity.tasks""" \ No newline at end of file diff --git a/src/pyrovelocity/tests/tasks/test_train_model.py b/src/pyrovelocity/tests/tasks/test_train_model.py new file mode 100644 index 000000000..568a64ab7 --- /dev/null +++ b/src/pyrovelocity/tests/tasks/test_train_model.py @@ -0,0 +1,20 @@ +"""Tests for `pyrovelocity._train_model` task.""" + +from pyrovelocity.tasks.train import train_model +from pyrovelocity.utils import generate_sample_data +from pyrovelocity.tasks.preprocess import copy_raw_counts + + +def test_train_model(tmp_path): + loss_plot_path = str(tmp_path) + "/loss_plot_docs.png" + print(loss_plot_path) + adata = generate_sample_data(random_seed=99) + copy_raw_counts(adata) + _, model, posterior_samples = train_model( + adata, + atac_layer="atac", + use_gpu="auto", + seed=99, + max_epochs=200, + loss_plot_path=loss_plot_path, + ) From c1ee5cb3031850d75fae80c68e43dba19d57afd0 Mon Sep 17 00:00:00 2001 From: Alexander Aivazidis Date: Mon, 8 Jul 2024 21:47:59 +0000 Subject: [PATCH 02/47] fix(vscode): disable automatic python env activation --- .vscode/settings.json | 1 + 1 file changed, 1 insertion(+) diff --git a/.vscode/settings.json b/.vscode/settings.json index 7d5509522..b41b8ad69 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -56,6 +56,7 @@ "search.followSymlinks": false, "terminal.integrated.fontSize": 14, "terminal.integrated.scrollback": 100000, + "python.terminal.activateEnvironment": false, "workbench.colorTheme": "Catppuccin Mocha", "workbench.iconTheme": "vscode-icons", // Passing --no-cov to pytestArgs is required to respect breakpoints From f8d04f3e8ede02df0102e1725512ae90456353d0 Mon Sep 17 00:00:00 2001 From: Alexander Aivazidis Date: Wed, 10 Jul 2024 15:55:48 +0000 Subject: [PATCH 03/47] corrected ATAC argument type from layer name to anndata object --- src/pyrovelocity/models/_velocity.py | 2 ++ src/pyrovelocity/tasks/train.py | 6 +++--- src/pyrovelocity/tests/tasks/test_train_model.py | 4 ++-- 3 files changed, 7 insertions(+), 5 deletions(-) diff --git a/src/pyrovelocity/models/_velocity.py b/src/pyrovelocity/models/_velocity.py index 4022bc246..73e5f31ea 100644 --- a/src/pyrovelocity/models/_velocity.py +++ b/src/pyrovelocity/models/_velocity.py @@ -99,6 +99,7 @@ class PyroVelocity(VelocityTrainingMixin, BaseModelClass): def __init__( self, adata: AnnData, + adata_atac: Optional[AnnData] = None, input_type: str = "raw", shared_time: bool = True, model_type: str = "auto", @@ -126,6 +127,7 @@ def __init__( Args: adata (AnnData): An AnnData object containing the gene expression data. + adata_atac (Optional[AnnData], optional) An AnnData object containing atac data. input_type (str, optional): Type of input data. Can be "raw", "knn", or "raw_cpm". Defaults to "raw". shared_time (bool, optional): Whether to use shared time. Defaults to True. model_type (str, optional): Type of model to use. Defaults to "auto". diff --git a/src/pyrovelocity/tasks/train.py b/src/pyrovelocity/tasks/train.py index 15fca8135..eb3483189 100644 --- a/src/pyrovelocity/tasks/train.py +++ b/src/pyrovelocity/tasks/train.py @@ -284,7 +284,7 @@ def check_shared_time(posterior_samples, adata): @beartype def train_model( adata: str | Path | AnnData, - atac_layer: Optional[str] = None, + adata_atac: Optional[AnnData] = None, guide_type: str = "auto", model_type: str = "auto", batch_size: int = -1, @@ -312,7 +312,7 @@ def train_model( Args: adata (str | AnnData): Path to a file that can be read to an AnnData object or an AnnData object. - atac_layer (Optional[str], optional): Name of AnnData layer that contains atac data, if present. + adata_atac (Optional[AnnData], optional): An anndata object with atac data, matching the default adata input with RNA data. guide_type (str, optional): The type of guide function for the Pyro model. Default is "auto". model_type (str, optional): The type of Pyro model. Default is "auto". batch_size (int, optional): Batch size for training. Default is -1, which indicates using the full dataset. @@ -361,7 +361,7 @@ def train_model( if isinstance(adata, str): ======= - if atac_layer: + if adata_atac: logger.info( "Multiome model not yet implemented. Proceeding without atac data." ) diff --git a/src/pyrovelocity/tests/tasks/test_train_model.py b/src/pyrovelocity/tests/tasks/test_train_model.py index 568a64ab7..7f516485a 100644 --- a/src/pyrovelocity/tests/tasks/test_train_model.py +++ b/src/pyrovelocity/tests/tasks/test_train_model.py @@ -1,8 +1,8 @@ """Tests for `pyrovelocity._train_model` task.""" +from pyrovelocity.tasks.preprocess import copy_raw_counts from pyrovelocity.tasks.train import train_model from pyrovelocity.utils import generate_sample_data -from pyrovelocity.tasks.preprocess import copy_raw_counts def test_train_model(tmp_path): @@ -12,7 +12,7 @@ def test_train_model(tmp_path): copy_raw_counts(adata) _, model, posterior_samples = train_model( adata, - atac_layer="atac", + adata_atac=None, use_gpu="auto", seed=99, max_epochs=200, From c83e0650c5e6bb2d58040a5eb35fcc9b34de0ac7 Mon Sep 17 00:00:00 2001 From: Alexander Aivazidis Date: Wed, 10 Jul 2024 18:43:20 +0000 Subject: [PATCH 04/47] feat[PyroVelocity]: added atac_data to setup_anndata method --- src/pyrovelocity/models/_velocity.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/pyrovelocity/models/_velocity.py b/src/pyrovelocity/models/_velocity.py index 73e5f31ea..fcb85fe37 100644 --- a/src/pyrovelocity/models/_velocity.py +++ b/src/pyrovelocity/models/_velocity.py @@ -332,9 +332,15 @@ def setup_anndata(cls, adata: AnnData, *args, **kwargs): NumericalObsField("s_lib_size_scale", "s_lib_size_scale"), NumericalObsField("ind_x", "ind_x"), ] + + if adata_atac: + adata.layers['atac'] = adata_atac.X + anndata_fields += [LayerField('atac', 'atac')] + adata_manager = AnnDataManager( fields=anndata_fields, setup_method_args=setup_method_args ) + adata_manager.register_fields(adata, **kwargs) cls.register_manager(adata_manager) From 8661a36d24e4ca7e7b72592b60631a7ad4840f9d Mon Sep 17 00:00:00 2001 From: Alexander Aivazidis Date: Wed, 10 Jul 2024 18:54:20 +0000 Subject: [PATCH 05/47] fix[PyroVelocity]: missing colon in Args description Signed-off-by: Alexander Aivazidis --- src/pyrovelocity/models/_velocity.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/pyrovelocity/models/_velocity.py b/src/pyrovelocity/models/_velocity.py index fcb85fe37..293cccdb6 100644 --- a/src/pyrovelocity/models/_velocity.py +++ b/src/pyrovelocity/models/_velocity.py @@ -127,7 +127,7 @@ def __init__( Args: adata (AnnData): An AnnData object containing the gene expression data. - adata_atac (Optional[AnnData], optional) An AnnData object containing atac data. + adata_atac (Optional[AnnData], optional): An AnnData object containing atac data. input_type (str, optional): Type of input data. Can be "raw", "knn", or "raw_cpm". Defaults to "raw". shared_time (bool, optional): Whether to use shared time. Defaults to True. model_type (str, optional): Type of model to use. Defaults to "auto". From c64346554045eed39a172c2d235826015378f8e5 Mon Sep 17 00:00:00 2001 From: Alexander Aivazidis Date: Wed, 10 Jul 2024 19:44:49 +0000 Subject: [PATCH 06/47] feat(VelocityTrainingMixin): Added atac data to train_faster method In the long run we should refactor code, so that training uses scvi-tools modules. Signed-off-by: Alexander Aivazidis --- src/pyrovelocity/models/_trainer.py | 163 ++++++++++++++++++++-------- 1 file changed, 115 insertions(+), 48 deletions(-) diff --git a/src/pyrovelocity/models/_trainer.py b/src/pyrovelocity/models/_trainer.py index fe3505cf2..1fcd13b7e 100644 --- a/src/pyrovelocity/models/_trainer.py +++ b/src/pyrovelocity/models/_trainer.py @@ -271,8 +271,8 @@ def train_faster( if scipy.sparse.issparse(self.adata.layers["raw_spliced"]) else self.adata.layers["raw_spliced"], dtype=torch.float32, - ).to(device) - + ).to(device) + epsilon = 1e-6 log_u_library_size = np.log( @@ -335,60 +335,127 @@ def train_faster( losses = [] patience = patient_init - for step in range(max_epochs): - if cell_state is None: - elbos = ( - svi.step( - u, - s, - u_library.reshape(-1, 1), - s_library.reshape(-1, 1), - u_library_mean.reshape(-1, 1), - s_library_mean.reshape(-1, 1), - u_library_scale.reshape(-1, 1), - s_library_scale.reshape(-1, 1), - None, - None, + + if not self.adata.layers['atac']: + + for step in range(max_epochs): + if cell_state is None: + elbos = ( + svi.step( + u, + s, + u_library.reshape(-1, 1), + s_library.reshape(-1, 1), + u_library_mean.reshape(-1, 1), + s_library_mean.reshape(-1, 1), + u_library_scale.reshape(-1, 1), + s_library_scale.reshape(-1, 1), + None, + None, + ) + / normalizer ) - / normalizer - ) - else: - elbos = ( - svi.step( - u, - s, - u_library.reshape(-1, 1), - s_library.reshape(-1, 1), - u_library_mean.reshape(-1, 1), - s_library_mean.reshape(-1, 1), - u_library_scale.reshape(-1, 1), - s_library_scale.reshape(-1, 1), - None, - cell_state.reshape(-1, 1), + else: + elbos = ( + svi.step( + u, + s, + u_library.reshape(-1, 1), + s_library.reshape(-1, 1), + u_library_mean.reshape(-1, 1), + s_library_mean.reshape(-1, 1), + u_library_scale.reshape(-1, 1), + s_library_scale.reshape(-1, 1), + None, + cell_state.reshape(-1, 1), + ) + / normalizer + ) + if (step == 0) or ( + ((step + 1) % log_every == 0) and ((step + 1) < max_epochs) + ): + mlflow.log_metric("-ELBO", -elbos, step=step + 1) + logger.info( + f"step {step + 1: >4d} loss = {elbos:0.6g} patience = {patience}" + ) + if step > log_every: + if (losses[-1] - elbos) < losses[-1] * patient_improve: + patience -= 1 + else: + patience = patient_init + if patience <= 0: + break + losses.append(elbos) + + else: + + atac = torch.tensor( + np.array( + self.adata.layers["atac"].toarray(), dtype="float32" + ) + if scipy.sparse.issparse(self.adata.layers["atac"]) + else self.adata.layers["atac"], + dtype=torch.float32, + ).to(device) + + + for step in range(max_epochs): + if cell_state is None: + elbos = ( + svi.step( + u, + s, + atac, + u_library.reshape(-1, 1), + s_library.reshape(-1, 1), + u_library_mean.reshape(-1, 1), + s_library_mean.reshape(-1, 1), + u_library_scale.reshape(-1, 1), + s_library_scale.reshape(-1, 1), + None, + None, + ) + / normalizer ) - / normalizer - ) - if (step == 0) or ( - ((step + 1) % log_every == 0) and ((step + 1) < max_epochs) - ): - mlflow.log_metric("-ELBO", -elbos, step=step + 1) - logger.info( - f"step {step + 1: >4d} loss = {elbos:0.6g} patience = {patience}" - ) - if step > log_every: - if (losses[-1] - elbos) < losses[-1] * patient_improve: - patience -= 1 else: - patience = patient_init - if patience <= 0: - break - losses.append(elbos) + elbos = ( + svi.step( + u, + s, + atac, + u_library.reshape(-1, 1), + s_library.reshape(-1, 1), + u_library_mean.reshape(-1, 1), + s_library_mean.reshape(-1, 1), + u_library_scale.reshape(-1, 1), + s_library_scale.reshape(-1, 1), + None, + cell_state.reshape(-1, 1), + ) + / normalizer + ) + if (step == 0) or ( + ((step + 1) % log_every == 0) and ((step + 1) < max_epochs) + ): + mlflow.log_metric("-ELBO", -elbos, step=step + 1) + logger.info( + f"step {step + 1: >4d} loss = {elbos:0.6g} patience = {patience}" + ) + if step > log_every: + if (losses[-1] - elbos) < losses[-1] * patient_improve: + patience -= 1 + else: + patience = patient_init + if patience <= 0: + break + losses.append(elbos) + mlflow.log_metric("-ELBO", -elbos, step=step + 1) mlflow.log_metric("real_epochs", step + 1) logger.info( f"step {step + 1: >4d} loss = {elbos:0.6g} patience = {patience}" ) - return losses + return losses def train_faster_with_batch( self, From b0ce0f9a3b1677da9fcfda54462f4bc6fde1b165 Mon Sep 17 00:00:00 2001 From: Alexander Aivazidis Date: Wed, 10 Jul 2024 20:20:17 +0000 Subject: [PATCH 07/47] feat(_velocity_module): Added a MultiVelocityModule for multiome data. Signed-off-by: Alexander Aivazidis --- src/pyrovelocity/models/_velocity.py | 76 ++++--- src/pyrovelocity/models/_velocity_module.py | 228 ++++++++++++++++++++ 2 files changed, 279 insertions(+), 25 deletions(-) diff --git a/src/pyrovelocity/models/_velocity.py b/src/pyrovelocity/models/_velocity.py index 293cccdb6..7bd36ee47 100644 --- a/src/pyrovelocity/models/_velocity.py +++ b/src/pyrovelocity/models/_velocity.py @@ -30,7 +30,7 @@ ) from pyrovelocity.logging import configure_logging from pyrovelocity.models._trainer import VelocityTrainingMixin -from pyrovelocity.models._velocity_module import VelocityModule +from pyrovelocity.models._velocity_module import VelocityModule, MultiVelocityModule __all__ = ["PyroVelocity"] @@ -246,30 +246,56 @@ def __init__( # else: initial_values = {} logger.info(self.summary_stats) - self.module = VelocityModule( - self.summary_stats["n_cells"], - self.summary_stats["n_vars"], - model_type=model_type, - guide_type=guide_type, - likelihood=likelihood, - shared_time=shared_time, - t_scale_on=t_scale_on, - plate_size=plate_size, - latent_factor=latent_factor, - latent_factor_operation=latent_factor_operation, - latent_factor_size=latent_factor_size, - inducing_point_size=inducing_point_size, - include_prior=include_prior, - use_gpu=use_gpu, - num_aux_cells=num_aux_cells, - only_cell_times=only_cell_times, - decoder_on=decoder_on, - add_offset=add_offset, - correct_library_size=correct_library_size, - cell_specific_kinetics=cell_specific_kinetics, - kinetics_num=self.k, - **initial_values, - ) + if not adata_atac: + self.module = VelocityModule( + self.summary_stats["n_cells"], + self.summary_stats["n_vars"], + model_type=model_type, + guide_type=guide_type, + likelihood=likelihood, + shared_time=shared_time, + t_scale_on=t_scale_on, + plate_size=plate_size, + latent_factor=latent_factor, + latent_factor_operation=latent_factor_operation, + latent_factor_size=latent_factor_size, + inducing_point_size=inducing_point_size, + include_prior=include_prior, + use_gpu=use_gpu, + num_aux_cells=num_aux_cells, + only_cell_times=only_cell_times, + decoder_on=decoder_on, + add_offset=add_offset, + correct_library_size=correct_library_size, + cell_specific_kinetics=cell_specific_kinetics, + kinetics_num=self.k, + **initial_values, + ) + else: + self.module = MultiVelocityModule( + self.summary_stats["n_cells"], + self.summary_stats["n_vars"], + model_type=model_type, + guide_type=guide_type, + likelihood=likelihood, + shared_time=shared_time, + t_scale_on=t_scale_on, + plate_size=plate_size, + latent_factor=latent_factor, + latent_factor_operation=latent_factor_operation, + latent_factor_size=latent_factor_size, + inducing_point_size=inducing_point_size, + include_prior=include_prior, + use_gpu=use_gpu, + num_aux_cells=num_aux_cells, + only_cell_times=only_cell_times, + decoder_on=decoder_on, + add_offset=add_offset, + correct_library_size=correct_library_size, + cell_specific_kinetics=cell_specific_kinetics, + kinetics_num=self.k, + **initial_values, + ) self.num_cells = self.module.num_cells self._model_summary_string = """ RNA velocity Pyro model with parameters: diff --git a/src/pyrovelocity/models/_velocity_module.py b/src/pyrovelocity/models/_velocity_module.py index 16a75fb99..e931a9f67 100644 --- a/src/pyrovelocity/models/_velocity_module.py +++ b/src/pyrovelocity/models/_velocity_module.py @@ -242,3 +242,231 @@ def _get_fn_args_from_batch( cell_state, time_info, ), {} + +class MultiVelocityModule(PyroBaseModuleClass): + """ + VelocityModule is an scvi-tools pyro module that combines the VelocityModelAuto and pyro AutoGuideList classes. + + Args: + num_cells (int): Number of cells. + num_genes (int): Number of genes. + model_type (str, optional): Model type. Default is "auto". + guide_type (str, optional): Guide type. Default is "velocity_auto". + likelihood (str, optional): Likelihood type. Default is "Poisson". + shared_time (bool, optional): If True, a shared time parameter will be used. Default is True. + t_scale_on (bool, optional): If True, scale time parameter. Default is False. + plate_size (int, optional): Size of the plate set. Default is 2. + latent_factor (str, optional): Latent factor. Default is "none". + latent_factor_operation (str, optional): Latent factor operation mode. Default is "selection". + latent_factor_size (int, optional): Size of the latent factor. Default is 10. + inducing_point_size (int, optional): Inducing point size. Default is 0. + include_prior (bool, optional): If True, include prior in the model. Default is False. + use_gpu (str, optional): Accelerator type. Default is "auto". + num_aux_cells (int, optional): Number of auxiliary cells. Default is 0. + only_cell_times (bool, optional): If True, only model cell times. Default is True. + decoder_on (bool, optional): If True, use the decoder. Default is False. + add_offset (bool, optional): If True, add offset to the model. Default is True. + correct_library_size (Union[bool, str], optional): Library size correction method. Default is True. + cell_specific_kinetics (Optional[str], optional): Cell-specific kinetics method. Default is None. + kinetics_num (Optional[int], optional): Number of kinetics. Default is None. + **initial_values: Initial values for the model parameters. + + Examples: + >>> from scvi.module.base import PyroBaseModuleClass + >>> from pyrovelocity.models._velocity_module import VelocityModule + >>> num_cells = 10 + >>> num_genes = 20 + >>> velocity_module1 = VelocityModule( + ... num_cells, num_genes, model_type="auto", + ... guide_type="auto_t0_constraint", add_offset=False + ... ) + >>> type(velocity_module1.model) + + >>> type(velocity_module1.guide) + + >>> velocity_module2 = VelocityModule( + ... num_cells, num_genes, model_type="auto", + ... guide_type="auto", add_offset=True + ... ) + >>> type(velocity_module2.model) + + >>> type(velocity_module2.guide) + + """ + + def __init__( + self, + num_cells: int, + num_genes: int, + model_type: str = "auto", + guide_type: str = "velocity_auto", + likelihood: str = "Poisson", + shared_time: bool = True, + t_scale_on: bool = False, + plate_size: int = 2, + latent_factor: str = "none", + latent_factor_operation: str = "selection", + latent_factor_size: int = 10, + inducing_point_size: int = 0, + include_prior: bool = False, + use_gpu: str = "auto", + num_aux_cells: int = 0, + only_cell_times: bool = True, + decoder_on: bool = False, + add_offset: bool = True, + correct_library_size: Union[bool, str] = True, + cell_specific_kinetics: Optional[str] = None, + kinetics_num: Optional[int] = None, + **initial_values, + ) -> None: + super().__init__() + self.num_cells = num_cells + self.num_genes = num_genes + self.model_type = model_type + self.guide_type = guide_type + self._model = None + self.plate_size = plate_size + self.num_aux_cells = num_aux_cells + self.only_cell_times = only_cell_times + logger.info( + f"Model type: {self.model_type}, Guide type: {self.guide_type}" + ) + + self.cell_specific_kinetics = cell_specific_kinetics + + self._model = VelocityModelAuto( + self.num_cells, + self.num_genes, + likelihood, + shared_time, + t_scale_on, + self.plate_size, + latent_factor, + latent_factor_operation=latent_factor_operation, + latent_factor_size=latent_factor_size, + include_prior=include_prior, + num_aux_cells=num_aux_cells, + only_cell_times=self.only_cell_times, + decoder_on=decoder_on, + add_offset=add_offset, + correct_library_size=correct_library_size, + guide_type=self.guide_type, + cell_specific_kinetics=self.cell_specific_kinetics, + **initial_values, + ) + + guide = AutoGuideList( + self._model, create_plates=self._model.create_plates + ) + guide.append( + AutoNormal( + poutine.block( + self._model, + expose=[ + "cell_time", + "u_read_depth", + "s_read_depth", + "kinetics_prob", + "kinetics_weights", + ], + ), + init_scale=0.1, + ) + ) + + if add_offset: + guide.append( + AutoLowRankMultivariateNormal( + poutine.block( + self._model, + expose=[ + "alpha", + "beta", + "gamma", + "dt_switching", + "t0", + "u_scale", + "s_scale", + "u_offset", + "s_offset", + ], + ), + rank=10, + init_scale=0.1, + ) + ) + else: + guide.append( + AutoLowRankMultivariateNormal( + poutine.block( + self._model, + expose=[ + "alpha", + "beta", + "gamma", + "dt_switching", + "t0", + "u_scale", + "s_scale", + ], + ), + rank=10, + init_scale=0.1, + ) + ) + self._guide = guide + + @property + def model(self) -> VelocityModelAuto: + return self._model + + @property + def guide(self) -> AutoGuideList: + return self._guide + + @staticmethod + def _get_fn_args_from_batch( + tensor_dict: Dict[str, torch.Tensor] + ) -> Tuple[ + Tuple[ + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + ], + Dict[Any, Any], + ]: + u_obs = tensor_dict["U"] + s_obs = tensor_dict["X"] + c_obs = tensor_dict['atac'] + u_log_library = tensor_dict["u_lib_size"] + s_log_library = tensor_dict["s_lib_size"] + u_log_library_mean = tensor_dict["u_lib_size_mean"] + s_log_library_mean = tensor_dict["s_lib_size_mean"] + u_log_library_scale = tensor_dict["u_lib_size_scale"] + s_log_library_scale = tensor_dict["s_lib_size_scale"] + ind_x = tensor_dict["ind_x"].long().squeeze() + cell_state = tensor_dict.get("pyro_cell_state") + time_info = tensor_dict.get("time_info") + return ( + c_obs, + u_obs, + s_obs, + u_log_library, + s_log_library, + u_log_library_mean, + s_log_library_mean, + u_log_library_scale, + s_log_library_scale, + ind_x, + cell_state, + time_info, + ), {} From 2abf5ecfd886931f1a52c84376f4d1d643f2bd3f Mon Sep 17 00:00:00 2001 From: Alexander Aivazidis Date: Thu, 11 Jul 2024 20:28:54 +0000 Subject: [PATCH 08/47] feat(_velocity_model): Added a MultiVelocityModelAuto class for multiome data Signed-off-by: Alexander Aivazidis --- src/pyrovelocity/models/_velocity_model.py | 573 +++++++++++++++++++- src/pyrovelocity/models/_velocity_module.py | 2 +- 2 files changed, 558 insertions(+), 17 deletions(-) diff --git a/src/pyrovelocity/models/_velocity_model.py b/src/pyrovelocity/models/_velocity_model.py index 63ca3fd8a..639321610 100644 --- a/src/pyrovelocity/models/_velocity_model.py +++ b/src/pyrovelocity/models/_velocity_model.py @@ -1,27 +1,19 @@ -from typing import Optional -from typing import Tuple -from typing import Union +from typing import Optional, Tuple, Union import pyro import torch from beartype import beartype -from jaxtyping import Float -from jaxtyping import jaxtyped +from jaxtyping import Float, jaxtyped from pyro import poutine -from pyro.distributions import Bernoulli -from pyro.distributions import LogNormal -from pyro.distributions import Normal -from pyro.distributions import Poisson -from pyro.nn import PyroModule -from pyro.nn import PyroSample +from pyro.distributions import Bernoulli, LogNormal, Normal, Poisson +from pyro.nn import PyroModule, PyroSample from pyro.primitives import plate from scvi.nn import Decoder -from torch.nn.functional import relu -from torch.nn.functional import softplus +from torch.nn.functional import relu, softplus +from torch import Tensor from pyrovelocity.logging import configure_logging -from pyrovelocity.models._transcription_dynamics import mrna_dynamics - +from pyrovelocity.models._transcription_dynamics import mrna_dynamics, atac_mrna_dynamics, get_initial_states, get_cell_parameters logger = configure_logging(__name__) @@ -36,6 +28,11 @@ Float[torch.Tensor, "samples num_cells num_genes"], ] +__all__ = [ + "LogNormalModel", + "VelocityModelAuto", + "MultiVelocityModelAuto", +] class LogNormalModel(PyroModule): """ @@ -154,6 +151,10 @@ def create_plates( gene_plate = pyro.plate("genes", self.num_genes, dim=-1) return cell_plate, gene_plate + @PyroSample + def alpha_c(self): + return self._pyrosample_helper(1.0) + @PyroSample def alpha(self): return self._pyrosample_helper(1.0) @@ -166,6 +167,10 @@ def beta(self): def gamma(self): return self._pyrosample_helper(1.0) + @PyroSample + def sigma_c(self): + return self._pyrosample_helper(0.1) + @PyroSample def u_scale(self): return self._pyrosample_helper(0.1) @@ -182,6 +187,18 @@ def u_inf(self): def s_inf(self): return self._pyrosample_helper(0.1) + @PyroSample + def dt_switching_c(self): + return self._pyrosample_helper(1.0) + + @PyroSample + def delay(self): + return self._pyrosample_helper(1.0) + + @PyroSample + def dt_switching_c(self): + return self._pyrosample_helper(1.0) + @PyroSample def dt_switching(self): return self._pyrosample_helper(1.0) @@ -305,7 +322,101 @@ def get_likelihood( u_dist = Poisson(ut) s_dist = Poisson(st) + return u_dist, s_dist + + @beartype + def get_likelihood_multiome( + self, + ct: torch.Tensor, + ut: torch.Tensor, + st: torch.Tensor, + sigma_c: torch.Tensor, + u_log_library: Optional[torch.Tensor] = None, + s_log_library: Optional[torch.Tensor] = None, + u_scale: Optional[torch.Tensor] = None, + s_scale: Optional[torch.Tensor] = None, + u_read_depth: Optional[torch.Tensor] = None, + s_read_depth: Optional[torch.Tensor] = None, + u_cell_size_coef: None = None, + ut_coef: None = None, + s_cell_size_coef: None = None, + st_coef: None = None, + ) -> Tuple[Poisson, Poisson]: + """ + Compute the likelihood of the given count data. + + Args: + ct (torch.Tensor): Tensor representing chromatin state. + ut (torch.Tensor): Tensor representing unspliced transcripts. + st (torch.Tensor): Tensor representing spliced transcripts. + sigma_c (torch.Tensor): Tensor representing standard deviation of chromatin state. + u_log_library (Optional[torch.Tensor], optional): Log library tensor for unspliced transcripts. Defaults to None. + s_log_library (Optional[torch.Tensor], optional): Log library tensor for spliced transcripts. Defaults to None. + u_scale (Optional[torch.Tensor], optional): Scale tensor for unspliced transcripts. Defaults to None. + s_scale (Optional[torch.Tensor], optional): Scale tensor for spliced transcripts. Defaults to None. + u_read_depth (Optional[torch.Tensor], optional): Read depth tensor for unspliced transcripts. Defaults to None. + s_read_depth (Optional[torch.Tensor], optional): Read depth tensor for spliced transcripts. Defaults to None. + u_cell_size_coef (Optional[Any], optional): Cell size coefficient for unspliced transcripts. Defaults to None. + ut_coef (Optional[Any], optional): Coefficient for unspliced transcripts. Defaults to None. + s_cell_size_coef (Optional[Any], optional): Cell size coefficient for spliced transcripts. Defaults to None. + st_coef (Optional[Any], optional): Coefficient for spliced transcripts. Defaults to None. + + Returns: + Tuple[Poisson, Poisson]: A tuple of Poisson distributions for unspliced and spliced transcripts, respectively. + + Example: + >>> import torch + >>> from pyrovelocity.models._velocity_model import LogNormalModel + >>> num_cells = 10 + >>> num_genes = 20 + >>> likelihood = "Poisson" + >>> plate_size = 2 + >>> model = LogNormalModel(num_cells, num_genes, likelihood, plate_size) + >>> logger.info(model) + >>> ut = torch.rand(num_cells, num_genes) + >>> st = torch.rand(num_cells, num_genes) + >>> u_read_depth = torch.rand(num_cells, 1) + >>> s_read_depth = torch.rand(num_cells, 1) + >>> u_dist, s_dist = model.get_likelihood(ut, st, u_read_depth=u_read_depth, s_read_depth=s_read_depth) + >>> logger.info(f"u_dist: {u_dist}") + >>> logger.info(f"s_dist: {s_dist}") + >>> assert isinstance(u_dist, torch.distributions.Poisson) + >>> assert isinstance(s_dist, torch.distributions.Poisson) + """ + if self.likelihood != "Poisson": + likelihood_not_implemented_msg = ( + "In the future, the likelihood will be referred to via a " + "member of a sum type over supported distributions" + ) + raise NotImplementedError(likelihood_not_implemented_msg) + + if self.correct_library_size: + ut = relu(ut) + self.one * 1e-6 + st = relu(st) + self.one * 1e-6 + ut = pyro.deterministic("ut", ut, event_dim=0) + st = pyro.deterministic("st", st, event_dim=0) + ut = ut / torch.sum(ut, dim=-1, keepdim=True) + st = st / torch.sum(st, dim=-1, keepdim=True) + ut = pyro.deterministic("ut_norm", ut, event_dim=0) + st = pyro.deterministic("st_norm", st, event_dim=0) + ut = (ut + self.one * 1e-6) * u_read_depth + st = (st + self.one * 1e-6) * s_read_depth + else: + ut = relu(ut) + st = relu(st) + ut = pyro.deterministic("ut", ut, event_dim=0) + st = pyro.deterministic("st", st, event_dim=0) + ut = ut + self.one * 1e-6 + st = st + self.one * 1e-6 + + c_dist = Normal(ct, sigma=sigma_c) + u_dist = Poisson(ut) + s_dist = Poisson(st) + + return c_dist, u_dist, s_dist + + class VelocityModelAuto(LogNormalModel): @@ -613,7 +724,7 @@ def forward( [ 0., 0., 7., 4.]])) """ cell_plate, gene_plate = self.create_plates( - u_obs, + _obs, s_obs, u_log_library, s_log_library, @@ -698,3 +809,433 @@ def forward( u = pyro.sample("u", u_dist, obs=u_obs) s = pyro.sample("s", s_dist, obs=s_obs) return u, s + + +class MultiVelocityModelAuto(LogNormalModel): + """Automatically configured MULTIOME velocity model. + + Args: + num_cells (int): _description_ + num_genes (int): _description_ + likelihood (str, optional): _description_. Defaults to "Poisson". + shared_time (bool, optional): _description_. Defaults to True. + t_scale_on (bool, optional): _description_. Defaults to False. + plate_size (int, optional): _description_. Defaults to 2. + latent_factor (str, optional): _description_. Defaults to "none". + latent_factor_size (int, optional): _description_. Defaults to 30. + latent_factor_operation (str, optional): _description_. Defaults to "selection". + include_prior (bool, optional): _description_. Defaults to False. + num_aux_cells (int, optional): _description_. Defaults to 100. + only_cell_times (bool, optional): _description_. Defaults to False. + decoder_on (bool, optional): _description_. Defaults to False. + add_offset (bool, optional): _description_. Defaults to False. + correct_library_size (Union[bool, str], optional): _description_. Defaults to True. + guide_type (str, optional): _description_. Defaults to "velocity". + cell_specific_kinetics (Optional[star], optional): _description_. Defaults to None. + kinetics_num (Optional[int], optional): _description_. Defaults to None. + + Examples: + >>> import torch + >>> from pyrovelocity.models._velocity_model import VelocityModelAuto + >>> model = VelocityModelAuto( + ... 3, + ... 4, + ... "Poisson", + ... True, + ... False, + ... 2, + ... "none", + ... latent_factor_operation="selection", + ... latent_factor_size=10, + ... include_prior=False, + ... num_aux_cells=0, + ... only_cell_times=True, + ... decoder_on=False, + ... add_offset=False, + ... correct_library_size=True, + ... guide_type="auto_t0_constraint", + ... cell_specific_kinetics=None, + ... **{} + ... ) + >>> logger.info(model) + """ + + @beartype + def __init__( + self, + num_cells: int, + num_genes: int, + likelihood: str = "Poisson", + shared_time: bool = True, + t_scale_on: bool = False, + plate_size: int = 2, + latent_factor: str = "none", + latent_factor_size: int = 30, + latent_factor_operation: str = "selection", + include_prior: bool = False, + num_aux_cells: int = 100, + only_cell_times: bool = False, + decoder_on: bool = False, + add_offset: bool = False, + correct_library_size: Union[bool, str] = True, + guide_type: str = "velocity", + cell_specific_kinetics: Optional[str] = None, + kinetics_num: Optional[int] = None, + **initial_values, + ) -> None: + assert num_cells > 0 and num_genes > 0 + super().__init__(num_cells, num_genes, likelihood, plate_size) + self.num_aux_cells = num_aux_cells + self.only_cell_times = only_cell_times + self.guide_type = guide_type + self.cell_specific_kinetics = cell_specific_kinetics + self.k = kinetics_num + + self.mask = initial_values.get( + "mask", torch.ones(self.num_cells, self.num_genes).bool() + ) + for key in initial_values: + self.register_buffer(f"{key}_init", initial_values[key]) + + self.shared_time = shared_time + self.t_scale_on = t_scale_on + self.add_offset = add_offset + self.plate_size = plate_size + + self.latent_factor = latent_factor + self.latent_factor_size = latent_factor_size + self.latent_factor_operation = latent_factor_operation + self.include_prior = include_prior + self.decoder_on = decoder_on + self.correct_library_size = correct_library_size + if self.decoder_on: + self.decoder = Decoder(1, self.num_genes, n_layers=2) + + self.enumeration = "parallel" + # self.set_enumeration_strategy() + + def sample_cell_gene_chromatin_state(self, t, switching): + return ( + pyro.sample( + "cell_gene_chromatin_state", + Bernoulli(logits=t - switching), + infer={"enumerate": self.enumeration}, + ) + == self.zero + ) + + def sample_cell_gene_state(self, t, switching, state_c): + return ( + pyro.sample( + "cell_gene_state", + Bernoulli(logits=t - switching), + infer={"enumerate": self.enumeration}, + ) + == self.zero + ) + + @beartype + def __repr__(self) -> str: + return ( + f"\nVelocityModelAuto(\n" + f"\tnum_cells={self.num_cells}, \n" + f"\tnum_genes={self.num_genes}, \n" + f'\tlikelihood="{self.likelihood}", \n' + f"\tshared_time={self.shared_time}, \n" + f"\tt_scale_on={self.t_scale_on}, \n" + f"\tplate_size={self.plate_size}, \n" + f'\tlatent_factor="{self.latent_factor}", \n' + f"\tlatent_factor_size={self.latent_factor_size}, \n" + f'\tlatent_factor_operation="{self.latent_factor_operation}", \n' + f"\tinclude_prior={self.include_prior}, \n" + f"\tnum_aux_cells={self.num_aux_cells}, \n" + f"\tonly_cell_times={self.only_cell_times}, \n" + f"\tdecoder_on={self.decoder_on}, \n" + f"\tadd_offset={self.add_offset}, \n" + f"\tcorrect_library_size={self.correct_library_size}, \n" + f'\tguide_type="{self.guide_type}", \n' + f"\tcell_specific_kinetics={self.cell_specific_kinetics}, \n" + f"\tkinetics_num={self.k}\n" + f")\n" + ) + + @jaxtyped(typechecker=beartype) + def get_atac_rna( + self, + u_scale: RNAInputType, + s_scale: RNAInputType, + t: Tensor, # cells, 1 + t0_1: Tensor, + dt_1: Tensor, + dt_2: Tensor, + dt_3: Tensor, + alpha_c: Tensor, + alpha: Tensor, + alpha_off: Tensor, + beta: Tensor, + gamma: Tensor, + ) -> Tuple[RNAOutputType, RNAOutputType, RNAOutputType]: + """ + Computes the unspliced (u) and spliced (s) RNA expression levels and chromatin opening state (c) given + the model parameters. + + Args: + u_scale (torch.Tensor): Scaling factor for unspliced expression. + s_scale (torch.Tensor): Scaling factor for spliced expression. + t (Tensor): The time of each cell. + t0_1 (Tensor): Start time for chromatin opening. + dt_1 (Tensor): Time gap since chromatin opening for transcription start for each gene. + dt_2 (Tensor): Time gap since transcription start for chromatin closing for each gene. + dt_3 (Tensor): Time gap since transcription start for transcription stopping for each gene. + alpha_c (Tensor): The chromatin opening and closing rate. + alpha (Tensor): The transcription rate of each gene in the on state. + alpha_off (Tensor): The transcription rate of each gene in the off state. + beta (torch.Tensor): Splicing rate. + gamma (torch.Tensor): Degradation rate. + + Returns: + Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: The chromatin state (c), unspliced (u) and + spliced (s) RNA expression levels. + + + Examples: + >>> from pyrovelocity.models._velocity_model import MultiVelocityModelAuto + >>> import torch + >>> n_cells = 4 + >>> u_scale = torch.tensor(1.0) + >>> s_scale = torch.tensor(1.0) + >>> t = torch.arange(0, 120, 30).reshape(n_cells,1) # cells, 1 + >>> t0_1 = torch.tensor((10.0, 10.0)) + >>> dt_1 = torch.tensor((15.0, 15.0)) + >>> dt_2 = torch.tensor((77.0, 53.0)) + >>> dt_3 = torch.tensor((50.0, 70.0)) + >>> alpha_c = torch.tensor((0.1, 0.2)) + >>> alpha = torch.tensor((0.5, 0.3)) + >>> alpha_off = torch.tensor(0.0) + >>> beta = torch.tensor((0.09, 0.11)) + >>> gamma = torch.tensor((0.05, 0.06)) + >>> mod = MultiVelocityModelAuto(num_cells = n_cells, num_genes = 2) + >>> output = MultiVelocityModelAuto.get_atac_rna( + mod, u_scale, s_scale, t, t0_1, dt_1, dt_2, dt_3, alpha_c, alpha, alpha_off, beta, gamma + ) + (tensor([[0.0000, 0.0000], + [0.8647, 0.9817], + [0.9933, 1.0000], + [0.2228, 1.0000]]), + tensor([[0.0000, 0.0000], + [1.6662, 1.1191], + [5.1763, 2.6659], + [3.2144, 0.7263]]), + tensor([[0.0000, 0.0000], + [2.2120, 1.2959], + [8.2623, 4.3877], + [4.3359, 2.3326]])) + """ + + state, k_c_state, alpha_state, t0_state, k_c_vec, alpha_vec, tau_vec = get_cell_parameters( + t, t0_1, dt_1, dt_2, dt_3, alpha, alpha_off + ) + + c0_vec, u0_vec, s0_vec = get_initial_states( + t0_state, k_c_state, alpha_c, alpha_state, beta, gamma, state + ) + + ct, ut, st = atac_mrna_dynamics( + tau_vec, c0_vec, u0_vec, s0_vec, k_c_vec, alpha_c, alpha_vec, beta, gamma + ) + + ut = ut * u_scale / s_scale + return ct, ut, st + + @beartype + def forward( + self, + u_obs: torch.Tensor, + s_obs: torch.Tensor, + c_obs: torch.Tensor, + u_log_library: Optional[torch.Tensor] = None, + s_log_library: Optional[torch.Tensor] = None, + u_log_library_loc: Optional[torch.Tensor] = None, + s_log_library_loc: Optional[torch.Tensor] = None, + u_log_library_scale: Optional[torch.Tensor] = None, + s_log_library_scale: Optional[torch.Tensor] = None, + ind_x: Optional[torch.Tensor] = None, + cell_state: Optional[torch.Tensor] = None, + time_info: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Defines the forward model, which computes the unspliced (u) and spliced + (s) RNA expression levels given the observations and model parameters. + + Args: + u_obs (Optional[torch.Tensor], optional): Observed unspliced RNA expression. Default is None. + s_obs (Optional[torch.Tensor], optional): Observed spliced RNA expression. Default is None. + c_obs (Optional[torch.Tensor], optional): Observed chromatin state. Default is None. + u_log_library (Optional[torch.Tensor], optional): Log-transformed library size for unspliced RNA. Default is None. + s_log_library (Optional[torch.Tensor], optional): Log-transformed library size for spliced RNA. Default is None. + u_log_library_loc (Optional[torch.Tensor], optional): Mean of log-transformed library size for unspliced RNA. Default is None. + s_log_library_loc (Optional[torch.Tensor], optional): Mean of log-transformed library size for spliced RNA. Default is None. + u_log_library_scale (Optional[torch.Tensor], optional): Scale of log-transformed library size for unspliced RNA. Default is None. + s_log_library_scale (Optional[torch.Tensor], optional): Scale of log-transformed library size for spliced RNA. Default is None. + ind_x (Optional[torch.Tensor], optional): Indices for the cells. Default is None. + cell_state (Optional[torch.Tensor], optional): Cell state information. Default is None. + time_info (Optional[torch.Tensor], optional): Time information for the cells. Default is None. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: The unspliced (u) and spliced (s) RNA expression levels. + + Examples: + >>> import torch + >>> from pyrovelocity.models._velocity_model import VelocityModelAuto + >>> u_obs=torch.tensor( + ... [[33., 1., 7., 1.], + ... [12., 30., 11., 3.], + ... [ 1., 1., 8., 5.]], + ... device="cpu", + >>> ) + >>> s_obs=torch.tensor( + ... [[32.0, 0.0, 6.0, 0.0], + ... [11.0, 29.0, 10.0, 2.0], + ... [0.0, 0.0, 7.0, 4.0]], + ... device="cpu", + >>> ) + >>> c_obs=torch.tensor( + ... [[1.0, 0.2, 0.4, 0.0], + ... [0.8, 0.2, 0.5, 0.3], + ... [0.0, 0.0, 0.1, 0.9]], + ... device="cpu", + >>> ) + >>> u_log_library=torch.tensor([[3.7377], [4.0254], [2.7081]], device="cpu") + >>> s_log_library=torch.tensor([[3.6376], [3.9512], [2.3979]], device="cpu") + >>> u_log_library_loc=torch.tensor([[3.4904], [3.4904], [3.4904]], device="cpu") + >>> s_log_library_loc=torch.tensor([[3.3289], [3.3289], [3.3289]], device="cpu") + >>> u_log_library_scale=torch.tensor([[0.6926], [0.6926], [0.6926]], device="cpu") + >>> s_log_library_scale=torch.tensor([[0.8214], [0.8214], [0.8214]], device="cpu") + >>> ind_x=torch.tensor([2, 0, 1], device="cpu") + >>> model = VelocityModelAuto(3,4) + >>> u, s = model.forward( + >>> u_obs, + >>> s_obs, + >>> u_log_library, + >>> s_log_library, + >>> u_log_library_loc, + >>> s_log_library_loc, + >>> u_log_library_scale, + >>> s_log_library_scale, + >>> ind_x, + >>> ) + >>> u, s + (tensor([[33., 1., 7., 1.], + [12., 30., 11., 3.], + [ 1., 1., 8., 5.]]), + tensor([[32., 0., 6., 0.], + [11., 29., 10., 2.], + [ 0., 0., 7., 4.]])) + """ + cell_plate, gene_plate = self.create_plates( + u_obs, + s_obs, + c_obs, + u_log_library, + s_log_library, + u_log_library_loc, + s_log_library_loc, + u_log_library_scale, + s_log_library_scale, + ind_x, + cell_state, + time_info, + ) + + with gene_plate, poutine.mask(mask=self.include_prior): + + alpha_c = self.alpha_c + alpha = self.alpha + gamma = self.gamma + beta = self.beta + sigma_c = self.sigma_c + + if self.add_offset: + u0 = pyro.sample("u_offset", LogNormal(self.zero, self.one)) + s0 = pyro.sample("s_offset", LogNormal(self.zero, self.one)) + c0 = pyro.sample("c_offset", LogNormal(self.zero, self.one)) + else: + s0 = u0 = c0 = self.zero + + t0_c = pyro.sample("t0_c", Normal(self.zero, self.one)) + t0 = t0_c + self.delay + + u_scale = self.u_scale + s_scale = self.one + + dt_switching_c = self.dt_switching_c + dt_switching = self.dt_switching + + c_inf, u_inf, s_inf = atac_mrna_dynamics( + dt_switching_c, dt_switching, c0, u0, s0, alpha_c, alpha, beta, gamma + ) + c_inf = pyro.deterministic("c_inf", c_inf, event_dim=0) + u_inf = pyro.deterministic("u_inf", u_inf, event_dim=0) + s_inf = pyro.deterministic("s_inf", s_inf, event_dim=0) + + switching_c = pyro.deterministic( + "switching_c", dt_switching_c + t0, event_dim=0 + ) + + switching = pyro.deterministic( + "switching", dt_switching + t0, event_dim=0 + ) + + with cell_plate: + t = pyro.sample( + "cell_time", + LogNormal(self.zero, self.one).mask(self.include_prior), + ) + + with cell_plate: + u_cell_size_coef = ut_coef = s_cell_size_coef = st_coef = None + u_read_depth = pyro.sample( + "u_read_depth", LogNormal(u_log_library, u_log_library_scale) + ) + s_read_depth = pyro.sample( + "s_read_depth", LogNormal(s_log_library, s_log_library_scale) + ) + with gene_plate: + ct, ut, st = self.get_atac_rna( + u_scale, + s_scale, + alpha_c, + alpha, + beta, + gamma, + t, + c0, + u0, + s0, + t0, + switching_c, + switching, + c_inf, + u_inf, + s_inf, + ) + c_dist, u_dist, s_dist = self.get_likelihood_multiome( + ct, + ut, + st, + sigma_c, + u_log_library, + s_log_library, + u_scale, + s_scale, + u_read_depth=u_read_depth, + s_read_depth=s_read_depth, + u_cell_size_coef=u_cell_size_coef, + ut_coef=ut_coef, + s_cell_size_coef=s_cell_size_coef, + st_coef=st_coef, + ) + c = pyro.sample("c", c_dist, obs=c_obs) + u = pyro.sample("u", u_dist, obs=u_obs) + s = pyro.sample("s", s_dist, obs=s_obs) + return c, u, s diff --git a/src/pyrovelocity/models/_velocity_module.py b/src/pyrovelocity/models/_velocity_module.py index e931a9f67..440aebb87 100644 --- a/src/pyrovelocity/models/_velocity_module.py +++ b/src/pyrovelocity/models/_velocity_module.py @@ -334,7 +334,7 @@ def __init__( self.cell_specific_kinetics = cell_specific_kinetics - self._model = VelocityModelAuto( + self._model = MultiVelocityModelAuto( self.num_cells, self.num_genes, likelihood, From 7b9c355a3d4a977ad48f09be3d9e64e49f45cff5 Mon Sep 17 00:00:00 2001 From: Alexander Aivazidis Date: Thu, 11 Jul 2024 20:27:04 +0000 Subject: [PATCH 09/47] feat(_transcription_dynamics): Added function for multiome dynamics. Signed-off-by: Alexander Aivazidis --- src/pyrovelocity/models/__init__.py | 3 +- .../models/_transcription_dynamics.py | 266 ++++++++++++++++++ 2 files changed, 268 insertions(+), 1 deletion(-) diff --git a/src/pyrovelocity/models/__init__.py b/src/pyrovelocity/models/__init__.py index 4dff5d0ed..20fdb7542 100644 --- a/src/pyrovelocity/models/__init__.py +++ b/src/pyrovelocity/models/__init__.py @@ -7,13 +7,14 @@ from pyrovelocity.models._deterministic_simulation import ( solve_transcription_splicing_model_analytical, ) -from pyrovelocity.models._transcription_dynamics import mrna_dynamics +from pyrovelocity.models._transcription_dynamics import mrna_dynamics, atac_mrna_dynamics from pyrovelocity.models._velocity import PyroVelocity __all__ = [ deterministic_transcription_splicing_probabilistic_model, mrna_dynamics, + atac_mrna_dynamics, PyroVelocity, solve_transcription_splicing_model, solve_transcription_splicing_model_analytical, diff --git a/src/pyrovelocity/models/_transcription_dynamics.py b/src/pyrovelocity/models/_transcription_dynamics.py index 03c531828..00e8c359f 100644 --- a/src/pyrovelocity/models/_transcription_dynamics.py +++ b/src/pyrovelocity/models/_transcription_dynamics.py @@ -60,6 +60,272 @@ def mrna_dynamics( return ut, st +@beartype +def atac_mrna_dynamics( + tau: Tensor, + c0: Tensor, + u0: Tensor, + s0: Tensor, + k_c: Tensor, + alpha_c: Tensor, + alpha: Tensor, + beta: Tensor, + gamma: Tensor, +) -> Tuple[Tensor, Tensor, Tensor]: + """ + Computes the ATAC and mRNA dynamics given temporal coordinate, parameter values, and + initial conditions. + + `st_gamma_equals_beta` for the case where the gamma parameter is equal + to the beta parameter is taken from Equation 2.12 of + + Args: + tau (Tensor): Time points starting at last change in RNA transcription rate. + c0 (Tensor): Initial value of c. + u0 (Tensor): Initial value of u. + s0 (Tensor): Initial value of s. + k_c (Tensor): Chromatin state. + alpha_c (Tensor): Rate of chromatin opening/closing. + alpha (Tensor): Alpha parameter. + beta (Tensor): Beta parameter. + gamma (Tensor): Gamma parameter. + + Returns: + Tuple[Tensor, Tensor]: Tuple containing the final values of c, u and s. + + Examples: + >>> import torch + >>> tau = torch.tensor(2.0) + >>> c0 = torch.tensor(1.0) + >>> u0 = torch.tensor(1.0) + >>> s0 = torch.tensor(0.5) + >>> alpha_c = torch.tensor(0.45) + >>> alpha = torch.tensor(0.5) + >>> beta = torch.tensor(0.4) + >>> gamma = torch.tensor(0.3) + >>> k_c = torch.tensor(1.0) + >>> atac_mrna_dynamics(tau_c, tau, c0, u0, s0, k_c, alpha_c, alpha, beta, gamma) + >>> import torch + >>> input = [torch.tensor([[ 0., 0.], + [ 5., 5.], + [35., 35.], + [15., 12.]]), torch.tensor([[0.0000, 0.0000], + [0.7769, 0.9502], + [0.7769, 0.9502], + [0.9985, 1.0000]]), torch.tensor([[0.0000, 0.0000], + [0.0000, 0.0000], + [0.0000, 0.0000], + [5.4451, 2.7188]]), torch.tensor([[0.0000, 0.0000], + [0.0000, 0.0000], + [0.0000, 0.0000], + [9.1791, 4.7921]]), torch.tensor([[0., 0.], + [1., 1.], + [1., 1.], + [0., 1.]]), torch.tensor([0.1000, 0.2000]), torch.tensor([[0.0000, 0.0000], + [0.5000, 0.3000], + [0.5000, 0.3000], + [0.5000, 0.0000]]), torch.tensor([0.0900, 0.1100]), torch.tensor([0.0500, 0.0600])] + >>> tau_vec = input[0] + >>> c0_vec = input[1] + >>> u0_vec = input[2] + >>> s0_vec = input[3] + >>> k_c_vec = input[4] + >>> alpha_c = input[5] + >>> alpha_vec = input[6] + >>> beta = input[7] + >>> gamma = input[8] + >>> atac_mrna_dynamics( + tau_vec, c0_vec, u0_vec, s0_vec, k_c_vec, alpha_c, alpha_vec, beta, gamma + ) + (tensor([[0.0000, 0.0000], + [0.8647, 0.9817], + [0.9933, 1.0000], + [0.2228, 1.0000]]), tensor([[0.0000, 0.0000], + [1.6662, 1.1191], + [5.1763, 2.6659], + [3.2144, 0.7263]]), tensor([[0.0000, 0.0000], + [2.2120, 1.2959], + [8.2623, 4.3877], + [4.3359, 2.3326]])) + """ + + A = torch.exp(-alpha_c * tau) + B = torch.exp(-beta * tau) + C = torch.exp(-gamma * tau) + + ct = c0 * A + k_c * (1 - A) + ut = ( + u0 * B + + alpha * k_c / beta * (1 - B) + + (k_c - c0) * alpha / (beta - alpha_c) * (B - A) + ) + st = s0 * C + alpha * k_c / gamma * (1 - C) + +beta / (gamma - beta) * ( + (alpha * k_c) / beta - u0 - (k_c - c0) * alpha / (beta - alpha_c) + ) * (C - B) + +beta / (gamma - alpha_c) * (k_c - c0) * alpha / (beta - alpha_c) * (C - A) + + return ct, ut, st + +@beartype +def get_initial_states( + t0_state: Tensor, + k_c_state: Tensor, + alpha_c: Tensor, + alpha_state: Tensor, + beta: Tensor, + gamma: Tensor, + state: Tensor +) -> Tuple[Tensor, Tensor, Tensor]: + """ + Computes initial conditions of chromatin and mRNA in each cell. + + Args: + t0_state (Tensor): The switch times of each gene (1 for each state). + k_c_state (Tensor): The chromatin state in each state. + alpha_c (Tensor): The chromatin opening and closing rate. + alpha_state (Tensor): The transcription rate of each gene in each state. + beta (Tensor): The splicing rate of each gene. + gamma (Tensor): The degradation rate of each gene. + state (Tensor): The state of each cell. + + Returns: + Tuple[Tensor, Tensor, Tensor]: Tuple containing the initial conditions of + c, u and s for each cell. + + Examples: + >>> import torch + >>> alpha_c = torch.tensor((0.1, 0.2)) + >>> beta = torch.tensor((0.09, 0.11)) + >>> gamma = torch.tensor((0.05, 0.06)) + >>> state = torch.tensor([[0, 0],[2, 2],[2, 2],[3, 3]]) + >>> k_c_state = torch.tensor([[0., 1., 1., 0., 0.], [0., 1., 1., 1., 0.]]) + >>> alpha_state = torch.tensor([[0.0000, 0.0000, 0.5000, 0.5000, 0.0000],[0.0000, 0.0000, 0.3000, 0.0000, 0.0000]]) + >>> t0_state = torch.tensor([[ 0., 10., 25., 75., 102.],[ 0., 10., 25., 78., 95.]]) + >>> get_initial_states( + t0_state, k_c_state, alpha_c, alpha_state, beta, gamma, state + ) + (torch.tensor([[0.0000, 0.0000], + [0.7769, 0.9502], + [0.7769, 0.9502], + [0.9985, 1.0000]]), + torch.tensor([[0.0000, 0.0000], + [0.0000, 0.0000], + [0.0000, 0.0000], + [5.4451, 2.7188]]), + torch.tensor([[0.0000, 0.0000], + [0.0000, 0.0000], + [0.0000, 0.0000], + [9.1791, 4.7921]])) + """ + + n_genes = t0_state.shape[0] + c0_state = torch.zeros(n_genes, 5) + u0_state = torch.zeros(n_genes, 5) + s0_state = torch.zeros(n_genes, 5) + dt_state = t0_state - torch.stack([torch.zeros((2)), torch.zeros((2)), + t0_state[:,1], t0_state[:,2], t0_state[:,3]], dim = 1) # genes, states + for i in range(1,4): + c0_state[:,i+1], u0_state[:,i+1], s0_state[:,i+1] = atac_mrna_dynamics( + dt_state[:,i+1], c0_state[:,i], u0_state[:,i], s0_state[:,i], k_c_state[:,i], + alpha_c, alpha_state[:,i], beta, gamma) + c0_vec = c0_state[torch.arange(n_genes).unsqueeze(1), state.T].T # cells, genes + u0_vec = u0_state[torch.arange(n_genes).unsqueeze(1), state.T].T # cells, genes + s0_vec = s0_state[torch.arange(n_genes).unsqueeze(1), state.T].T # cells, genes + + return c0_vec, u0_vec, s0_vec + +@beartype +def get_cell_parameters( + t: Tensor, + t0_1: Tensor, + dt_1: Tensor, + dt_2: Tensor, + dt_3: Tensor, + alpha: Tensor, + alpha_off: Tensor, +) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]: + """ + Gets the ODE parameters for each cell, by first assign each gene in each cell to a state + based on state switch times of a gene and then computes the transcription rate, chromatin state + and time since last state switch(tau) for each gene in each cell. + + Args: + t (Tensor): The time of each cell. + t0_1 (Tensor): Start time for chromatin opening. + dt_1 (Tensor): Time gap since chromatin opening for transcription start for each gene. + dt_2 (Tensor): Time gap since transcription start for chromatin closing for each gene. + dt_3 (Tensor): Time gap since transcription start for transcription stopping for each gene. + alpha (Tensor): The transcription rate of each gene in the on state. + alpha_off (Tensor): The transcription rate of each gene in the off state. + + Returns: + Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]: Tuple containing the state of each cell (state), + the switch time of each state (t0_state), the chromatin opening state (k_c_state), the transcription rate in each cell + (alpha_state) and cell-specific parameters for the chromatin state (k_c_vec) transcription rate (alpha_vec) and + time (tau_vec) since last state switch. + + Examples: + >>> import torch + + >>> n_cells = 4 + >>> t = torch.arange(0, 120, 30).reshape(n_cells, 1) + >>> t0_1 = torch.tensor((10.0, 10.0)) + >>> dt_1 = torch.tensor((15.0, 15.0)) + >>> dt_2 = torch.tensor((77.0, 53.0)) + >>> dt_3 = torch.tensor((50.0, 70.0)) + >>> alpha = torch.tensor((0.5, 0.3)) + >>> alpha_off = torch.tensor(0.0) + >>> get_cell_parameters( + t, t0_1, dt_1, dt_2, dt_3, alpha, alpha_off + ) + (tensor([[0, 0], + [2, 2], + [2, 2], + [3, 3]]),tensor([[0., 1., 1., 0., 0.], + [0., 1., 1., 1., 0.]]), tensor([[0.0000, 0.0000, 0.5000, 0.5000, 0.0000], + [0.0000, 0.0000, 0.3000, 0.0000, 0.0000]]), tensor([[ 0., 10., 25., 75., 102.], + [ 0., 10., 25., 78., 95.]]), tensor([[0., 0.], + [1., 1.], + [1., 1.], + [0., 1.]]), tensor([[0.0000, 0.0000], + [0.5000, 0.3000], + [0.5000, 0.3000], + [0.5000, 0.0000]]), tensor([[ 0., 0.], + [ 5., 5.], + [35., 35.], + [15., 12.]])) + """ + + # Assign each gene in each cell to a state: + t0_2 = t0_1 + dt_1 + boolean = dt_2 >= dt_3 # True means chromatin starts closing, before transcription stops. + t0_3 = torch.where(boolean, t0_2 + dt_3, t0_2 + dt_2) + t0_4 = torch.where(~boolean, t0_2 + dt_3, t0_2 + dt_2) + state = ((t0_1 <= t)*1 + (t0_2 <= t)*1 + (t0_3 <= t)*1 + (t0_4 <= t)*1) # cells, genes + n_genes = state.shape[1] + + # Calculate time for each gene in each cell: + t0_state = torch.stack([torch.zeros((2)), t0_1, t0_2, t0_3, t0_4], dim = 1) # genes, states + t0_vec = t0_state[torch.arange(n_genes).unsqueeze(1), state.T].T # cells, genes + tau_vec = t - t0_vec # cells, genes + + # Calculate the transcription rate and chromatin state for each gene in each cell. + alpha_state = torch.stack([torch.ones((2))*alpha_off, + torch.ones((2))*alpha_off, + alpha, + torch.where(boolean, alpha, alpha_off), + torch.ones((2))*alpha_off], dim = 1) # genes, states + k_c_state = torch.stack([torch.zeros((2)), + torch.ones((2)), + torch.ones((2)), + torch.where(boolean, 0,1), + torch.zeros((2))], dim = 1) # genes, states + alpha_vec = alpha_state[torch.arange(n_genes).unsqueeze(1), state.T].T # cells, genes + k_c_vec = k_c_state[torch.arange(n_genes).unsqueeze(1), state.T].T # cells, genes + + return state, k_c_state, alpha_state, t0_state, k_c_vec, alpha_vec, tau_vec + @beartype def inv(x: Tensor) -> Tensor: From 6955c46a162ba9b9fcc072da064e229e3d7af55e Mon Sep 17 00:00:00 2001 From: Alexander Aivazidis Date: Mon, 22 Jul 2024 19:06:03 +0000 Subject: [PATCH 10/47] feat(_test_transcription_dynamics): Unit tests for transcription dynamics functions. Signed-off-by: Alexander Aivazidis --- .../models/test_transcription_dynamics.py | 125 ++++++++++++++++++ 1 file changed, 125 insertions(+) diff --git a/src/pyrovelocity/tests/models/test_transcription_dynamics.py b/src/pyrovelocity/tests/models/test_transcription_dynamics.py index ec9c170ed..cc8e0c1f1 100644 --- a/src/pyrovelocity/tests/models/test_transcription_dynamics.py +++ b/src/pyrovelocity/tests/models/test_transcription_dynamics.py @@ -67,3 +67,128 @@ def test_mRNA_extreme_parameter_values(value): assert u is not None assert s is not None assert s is not None + +from pyrovelocity.models._transcription_dynamics import ( + atac_mrna_dynamics, + get_cell_parameters, + get_initial_states +) + +def test_get_cell_parameters(): + import torch + + n_cells = 4 + t = torch.arange(0, 120, 30).reshape(n_cells, 1) + t0_1 = torch.tensor((10.0, 10.0)) + dt_1 = torch.tensor((15.0, 15.0)) + dt_2 = torch.tensor((77.0, 53.0)) + dt_3 = torch.tensor((50.0, 70.0)) + alpha = torch.tensor((0.5, 0.3)) + alpha_off = torch.tensor(0.0) + + output = get_cell_parameters( + t, t0_1, dt_1, dt_2, dt_3, alpha, alpha_off + ) + + correct_output = (torch.tensor([[0, 0], + [2, 2], + [2, 2], + [3, 3]]),torch.tensor([[0., 1., 1., 0., 0.], + [0., 1., 1., 1., 0.]]), torch.tensor([[0.0000, 0.0000, 0.5000, 0.5000, 0.0000], + [0.0000, 0.0000, 0.3000, 0.0000, 0.0000]]), torch.tensor([[ 0., 10., 25., 75., 102.], + [ 0., 10., 25., 78., 95.]]), torch.tensor([[0., 0.], + [1., 1.], + [1., 1.], + [0., 1.]]), torch.tensor([[0.0000, 0.0000], + [0.5000, 0.3000], + [0.5000, 0.3000], + [0.5000, 0.0000]]), torch.tensor([[ 0., 0.], + [ 5., 5.], + [35., 35.], + [15., 12.]])) + + for i in range(len(output)): + assert torch.allclose(output[i], correct_output[i], atol=1e-3), f"Output at index {i} is incorrect" + +def test_get_initial_states(): + import torch + + alpha_c = torch.tensor((0.1, 0.2)) + beta = torch.tensor((0.09, 0.11)) + gamma = torch.tensor((0.05, 0.06)) + state = torch.tensor([[0, 0],[2, 2],[2, 2],[3, 3]]) + k_c_state = torch.tensor([[0., 1., 1., 0., 0.], [0., 1., 1., 1., 0.]]) + alpha_state = torch.tensor([[0.0000, 0.0000, 0.5000, 0.5000, 0.0000],[0.0000, 0.0000, 0.3000, 0.0000, 0.0000]]) + t0_state = torch.tensor([[ 0., 10., 25., 75., 102.],[ 0., 10., 25., 78., 95.]]) + + output = get_initial_states( + t0_state, k_c_state, alpha_c, alpha_state, beta, gamma, state + ) + + correct_output = (torch.tensor([[0.0000, 0.0000], + [0.7769, 0.9502], + [0.7769, 0.9502], + [0.9985, 1.0000]]), + torch.tensor([[0.0000, 0.0000], + [0.0000, 0.0000], + [0.0000, 0.0000], + [5.4451, 2.7188]]), + torch.tensor([[0.0000, 0.0000], + [0.0000, 0.0000], + [0.0000, 0.0000], + [9.1791, 4.7921]])) + + for i in range(len(output)): + assert torch.allclose(output[i], correct_output[i], atol=1e-3), f"Output at index {i} is incorrect" + + +def test_atac_mrna_dynamics(): + import torch + + input = [torch.tensor([[ 0., 0.], + [ 5., 5.], + [35., 35.], + [15., 12.]]), torch.tensor([[0.0000, 0.0000], + [0.7769, 0.9502], + [0.7769, 0.9502], + [0.9985, 1.0000]]), torch.tensor([[0.0000, 0.0000], + [0.0000, 0.0000], + [0.0000, 0.0000], + [5.4451, 2.7188]]), torch.tensor([[0.0000, 0.0000], + [0.0000, 0.0000], + [0.0000, 0.0000], + [9.1791, 4.7921]]), torch.tensor([[0., 0.], + [1., 1.], + [1., 1.], + [0., 1.]]), torch.tensor([0.1000, 0.2000]), torch.tensor([[0.0000, 0.0000], + [0.5000, 0.3000], + [0.5000, 0.3000], + [0.5000, 0.0000]]), torch.tensor([0.0900, 0.1100]), torch.tensor([0.0500, 0.0600])] + + tau_vec = input[0] + c0_vec = input[1] + u0_vec = input[2] + s0_vec = input[3] + k_c_vec = input[4] + alpha_c = input[5] + alpha_vec = input[6] + beta = input[7] + gamma = input[8] + + output = atac_mrna_dynamics( + tau_vec, c0_vec, u0_vec, s0_vec, k_c_vec, alpha_c, alpha_vec, beta, gamma + ) + + correct_output = (torch.tensor([[0.0000, 0.0000], + [0.8647, 0.9817], + [0.9933, 1.0000], + [0.2228, 1.0000]]), torch.tensor([[0.0000, 0.0000], + [1.6662, 1.1191], + [5.1763, 2.6659], + [3.2144, 0.7263]]), torch.tensor([[0.0000, 0.0000], + [2.2120, 1.2959], + [8.2623, 4.3877], + [4.3359, 2.3326]])) + + for i in range(len(output)): + assert torch.allclose(output[i], correct_output[i], atol=1e-3), f"Output at index {i} is incorrect" From db8cbb1cd1e9fb868914e96e44737cfbad0a9d27 Mon Sep 17 00:00:00 2001 From: Alexander Aivazidis Date: Mon, 22 Jul 2024 19:06:21 +0000 Subject: [PATCH 11/47] feat(_test_velocity_model): Unit tests for velocity model Signed-off-by: Alexander Aivazidis --- .../tests/models/test_velocity_model.py | 42 +++++++++++++++++++ 1 file changed, 42 insertions(+) diff --git a/src/pyrovelocity/tests/models/test_velocity_model.py b/src/pyrovelocity/tests/models/test_velocity_model.py index 1804fb2e8..7f1e0db87 100644 --- a/src/pyrovelocity/tests/models/test_velocity_model.py +++ b/src/pyrovelocity/tests/models/test_velocity_model.py @@ -122,3 +122,45 @@ def test_forward_method(self, velocity_model_auto): assert u.shape == (3, 4) assert s.shape == (3, 4) + +def test_MultiVelocityModelAuto(): + from pyrovelocity.models._velocity_model import MultiVelocityModelAuto + +def test_MultiVelocityModelAuto_get_atac_rna(): + from pyrovelocity.models._velocity_model import MultiVelocityModelAuto + import torch + + n_cells = 4 + u_scale = torch.tensor(1.0) + s_scale = torch.tensor(1.0) + t = torch.arange(0, 120, 30).reshape(n_cells,1) # cells, 1 + t0_1 = torch.tensor((10.0, 10.0)) + dt_1 = torch.tensor((15.0, 15.0)) + dt_2 = torch.tensor((77.0, 53.0)) + dt_3 = torch.tensor((50.0, 70.0)) + alpha_c = torch.tensor((0.1, 0.2)) + alpha = torch.tensor((0.5, 0.3)) + alpha_off = torch.tensor(0.0) + beta = torch.tensor((0.09, 0.11)) + gamma = torch.tensor((0.05, 0.06)) + + mod = MultiVelocityModelAuto(num_cells = n_cells, num_genes = 2) + output = MultiVelocityModelAuto.get_atac_rna( + mod, u_scale, s_scale, t, t0_1, dt_1, dt_2, dt_3, alpha_c, alpha, alpha_off, beta, gamma + ) + + correct_output = ((torch.tensor([[0.0000, 0.0000], + [0.8647, 0.9817], + [0.9933, 1.0000], + [0.2228, 1.0000]]), + torch.tensor([[0.0000, 0.0000], + [1.6662, 1.1191], + [5.1763, 2.6659], + [3.2144, 0.7263]]), + torch.tensor([[0.0000, 0.0000], + [2.2120, 1.2959], + [8.2623, 4.3877], + [4.3359, 2.3326]]))) + + for i in range(len(output)): + assert torch.allclose(output[i], correct_output[i], atol=1e-3), f"Output at index {i} is incorrect" From b708935f9074263f56368320ccf8e15910549ce7 Mon Sep 17 00:00:00 2001 From: Alexander Aivazidis Date: Mon, 22 Jul 2024 19:11:41 +0000 Subject: [PATCH 12/47] feat(.gitignore): Added example_notebooks directory to .gitignore file. Signed-off-by: Alexander Aivazidis --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index 72c2207e5..21ea229dd 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,6 @@ # /archive/ +example_notebooks/* # .DS_Store From 7c7c73b0758b17954b1fb827879ea3f4b3f43511 Mon Sep 17 00:00:00 2001 From: Alexander Aivazidis Date: Mon, 22 Jul 2024 21:49:54 +0000 Subject: [PATCH 13/47] fix[_trainer_]: checking for existence of atac data Signed-off-by: Alexander Aivazidis --- src/pyrovelocity/models/_trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/pyrovelocity/models/_trainer.py b/src/pyrovelocity/models/_trainer.py index 1fcd13b7e..5d1e4989d 100644 --- a/src/pyrovelocity/models/_trainer.py +++ b/src/pyrovelocity/models/_trainer.py @@ -336,7 +336,7 @@ def train_faster( losses = [] patience = patient_init - if not self.adata.layers['atac']: + if not self.adata.uns['atac']: for step in range(max_epochs): if cell_state is None: From e38338818384011d04d5faf1b10e4e693a9dd875 Mon Sep 17 00:00:00 2001 From: Alexander Aivazidis Date: Mon, 22 Jul 2024 21:50:51 +0000 Subject: [PATCH 14/47] fix(_velocity): Save existence of atac data in adata Signed-off-by: Alexander Aivazidis --- src/pyrovelocity/models/_velocity.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/pyrovelocity/models/_velocity.py b/src/pyrovelocity/models/_velocity.py index 7bd36ee47..2ef140d53 100644 --- a/src/pyrovelocity/models/_velocity.py +++ b/src/pyrovelocity/models/_velocity.py @@ -324,7 +324,7 @@ def enum_parallel_predict(self): return @classmethod - def setup_anndata(cls, adata: AnnData, *args, **kwargs): + def setup_anndata(cls, adata: AnnData, adata_atac = None, *args, **kwargs): """ Set up AnnData object for compatibility with the scvi-tools model training interface. @@ -362,6 +362,8 @@ def setup_anndata(cls, adata: AnnData, *args, **kwargs): if adata_atac: adata.layers['atac'] = adata_atac.X anndata_fields += [LayerField('atac', 'atac')] + else: + adata.uns['atac'] = None adata_manager = AnnDataManager( fields=anndata_fields, setup_method_args=setup_method_args From fc9be1ac2567c15ea4c073f170a577ff8ad90480 Mon Sep 17 00:00:00 2001 From: Alexander Aivazidis Date: Mon, 22 Jul 2024 21:51:49 +0000 Subject: [PATCH 15/47] fix(_velocity_model): LogNormal instead of Normal likelihood for atac data Signed-off-by: Alexander Aivazidis --- src/pyrovelocity/models/_velocity_model.py | 113 ++++++--------------- 1 file changed, 32 insertions(+), 81 deletions(-) diff --git a/src/pyrovelocity/models/_velocity_model.py b/src/pyrovelocity/models/_velocity_model.py index 639321610..44451847f 100644 --- a/src/pyrovelocity/models/_velocity_model.py +++ b/src/pyrovelocity/models/_velocity_model.py @@ -167,10 +167,6 @@ def beta(self): def gamma(self): return self._pyrosample_helper(1.0) - @PyroSample - def sigma_c(self): - return self._pyrosample_helper(0.1) - @PyroSample def u_scale(self): return self._pyrosample_helper(0.1) @@ -342,7 +338,7 @@ def get_likelihood_multiome( ut_coef: None = None, s_cell_size_coef: None = None, st_coef: None = None, - ) -> Tuple[Poisson, Poisson]: + ) -> Tuple[LogNormal, Poisson, Poisson]: """ Compute the likelihood of the given count data. @@ -410,7 +406,7 @@ def get_likelihood_multiome( ut = ut + self.one * 1e-6 st = st + self.one * 1e-6 - c_dist = Normal(ct, sigma=sigma_c) + c_dist = LogNormal(ct, sigma_c) u_dist = Poisson(ut) s_dist = Poisson(st) @@ -724,7 +720,7 @@ def forward( [ 0., 0., 7., 4.]])) """ cell_plate, gene_plate = self.create_plates( - _obs, + u_obs, s_obs, u_log_library, s_log_library, @@ -810,7 +806,6 @@ def forward( s = pyro.sample("s", s_dist, obs=s_obs) return u, s - class MultiVelocityModelAuto(LogNormalModel): """Automatically configured MULTIOME velocity model. @@ -911,29 +906,6 @@ def __init__( if self.decoder_on: self.decoder = Decoder(1, self.num_genes, n_layers=2) - self.enumeration = "parallel" - # self.set_enumeration_strategy() - - def sample_cell_gene_chromatin_state(self, t, switching): - return ( - pyro.sample( - "cell_gene_chromatin_state", - Bernoulli(logits=t - switching), - infer={"enumerate": self.enumeration}, - ) - == self.zero - ) - - def sample_cell_gene_state(self, t, switching, state_c): - return ( - pyro.sample( - "cell_gene_state", - Bernoulli(logits=t - switching), - infer={"enumerate": self.enumeration}, - ) - == self.zero - ) - @beartype def __repr__(self) -> str: return ( @@ -1062,9 +1034,9 @@ def forward( ind_x: Optional[torch.Tensor] = None, cell_state: Optional[torch.Tensor] = None, time_info: Optional[torch.Tensor] = None, - ) -> Tuple[torch.Tensor, torch.Tensor]: + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ - Defines the forward model, which computes the unspliced (u) and spliced + Defines the forward model, which computes the chromatin state (c), unspliced (u) and spliced (s) RNA expression levels given the observations and model parameters. Args: @@ -1082,7 +1054,7 @@ def forward( time_info (Optional[torch.Tensor], optional): Time information for the cells. Default is None. Returns: - Tuple[torch.Tensor, torch.Tensor]: The unspliced (u) and spliced (s) RNA expression levels. + Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: The chromatin state (c), unspliced (u) and spliced (s) RNA expression levels. Examples: >>> import torch @@ -1153,46 +1125,24 @@ def forward( alpha = self.alpha gamma = self.gamma beta = self.beta - sigma_c = self.sigma_c - - if self.add_offset: - u0 = pyro.sample("u_offset", LogNormal(self.zero, self.one)) - s0 = pyro.sample("s_offset", LogNormal(self.zero, self.one)) - c0 = pyro.sample("c_offset", LogNormal(self.zero, self.one)) - else: - s0 = u0 = c0 = self.zero + alpha_off = self.zero - t0_c = pyro.sample("t0_c", Normal(self.zero, self.one)) - t0 = t0_c + self.delay + t0_1 = pyro.sample("t0_1", Normal(self.zero, self.one*10)) + dt_1 = pyro.sample("dt_1", LogNormal(self.one*20, self.one*10)) + dt_2 = pyro.sample("dt_2", LogNormal(self.one*20, self.one*10)) + dt_3 = pyro.sample("dt_3", LogNormal(self.one*20, self.one*10)) u_scale = self.u_scale s_scale = self.one - dt_switching_c = self.dt_switching_c - dt_switching = self.dt_switching - - c_inf, u_inf, s_inf = atac_mrna_dynamics( - dt_switching_c, dt_switching, c0, u0, s0, alpha_c, alpha, beta, gamma - ) - c_inf = pyro.deterministic("c_inf", c_inf, event_dim=0) - u_inf = pyro.deterministic("u_inf", u_inf, event_dim=0) - s_inf = pyro.deterministic("s_inf", s_inf, event_dim=0) - - switching_c = pyro.deterministic( - "switching_c", dt_switching_c + t0, event_dim=0 - ) - - switching = pyro.deterministic( - "switching", dt_switching + t0, event_dim=0 - ) - with cell_plate: t = pyro.sample( "cell_time", - LogNormal(self.zero, self.one).mask(self.include_prior), + LogNormal(self.zero, self.one*50).mask(self.include_prior), ) with cell_plate: + u_cell_size_coef = ut_coef = s_cell_size_coef = st_coef = None u_read_depth = pyro.sample( "u_read_depth", LogNormal(u_log_library, u_log_library_scale) @@ -1200,25 +1150,26 @@ def forward( s_read_depth = pyro.sample( "s_read_depth", LogNormal(s_log_library, s_log_library_scale) ) + + sigma_c = pyro.sample( + "sigma_c", LogNormal(0.2,0.2) + ) + + ct, ut, st = self.get_atac_rna( + u_scale, + s_scale, + t, # cells, 1 + t0_1, + dt_1, + dt_2, + dt_3, + alpha_c, + alpha, + alpha_off, + beta, + gamma) + with gene_plate: - ct, ut, st = self.get_atac_rna( - u_scale, - s_scale, - alpha_c, - alpha, - beta, - gamma, - t, - c0, - u0, - s0, - t0, - switching_c, - switching, - c_inf, - u_inf, - s_inf, - ) c_dist, u_dist, s_dist = self.get_likelihood_multiome( ct, ut, From 94569f8d485d000af65748c027d7c6fbe1d5e062 Mon Sep 17 00:00:00 2001 From: Alexander Aivazidis Date: Fri, 26 Jul 2024 14:20:43 +0000 Subject: [PATCH 16/47] fix(_trainer.py): Properly processing atac data --- src/pyrovelocity/models/_trainer.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/pyrovelocity/models/_trainer.py b/src/pyrovelocity/models/_trainer.py index 5d1e4989d..5bd377b2c 100644 --- a/src/pyrovelocity/models/_trainer.py +++ b/src/pyrovelocity/models/_trainer.py @@ -389,7 +389,7 @@ def train_faster( else: - atac = torch.tensor( + c = torch.tensor( np.array( self.adata.layers["atac"].toarray(), dtype="float32" ) @@ -403,9 +403,9 @@ def train_faster( if cell_state is None: elbos = ( svi.step( + c, u, s, - atac, u_library.reshape(-1, 1), s_library.reshape(-1, 1), u_library_mean.reshape(-1, 1), @@ -420,9 +420,9 @@ def train_faster( else: elbos = ( svi.step( + c, u, s, - atac, u_library.reshape(-1, 1), s_library.reshape(-1, 1), u_library_mean.reshape(-1, 1), From 750be4af955128513d50dd1dca0a2e8695c9e705 Mon Sep 17 00:00:00 2001 From: Alexander Aivazidis Date: Fri, 26 Jul 2024 14:22:16 +0000 Subject: [PATCH 17/47] fix(_transcription_dynamics): Ensuring no inplace tensor operations in functions. Signed-off-by: Alexander Aivazidis --- .../models/_transcription_dynamics.py | 65 +++++++++++-------- 1 file changed, 39 insertions(+), 26 deletions(-) diff --git a/src/pyrovelocity/models/_transcription_dynamics.py b/src/pyrovelocity/models/_transcription_dynamics.py index 00e8c359f..26a683085 100644 --- a/src/pyrovelocity/models/_transcription_dynamics.py +++ b/src/pyrovelocity/models/_transcription_dynamics.py @@ -218,17 +218,26 @@ def get_initial_states( [0.0000, 0.0000], [9.1791, 4.7921]])) """ - + n_genes = t0_state.shape[0] - c0_state = torch.zeros(n_genes, 5) - u0_state = torch.zeros(n_genes, 5) - s0_state = torch.zeros(n_genes, 5) + c0_state_list = [torch.zeros(n_genes),torch.zeros(n_genes)] + u0_state_list = [torch.zeros(n_genes),torch.zeros(n_genes)] + s0_state_list = [torch.zeros(n_genes),torch.zeros(n_genes)] dt_state = t0_state - torch.stack([torch.zeros((2)), torch.zeros((2)), t0_state[:,1], t0_state[:,2], t0_state[:,3]], dim = 1) # genes, states - for i in range(1,4): - c0_state[:,i+1], u0_state[:,i+1], s0_state[:,i+1] = atac_mrna_dynamics( - dt_state[:,i+1], c0_state[:,i], u0_state[:,i], s0_state[:,i], k_c_state[:,i], - alpha_c, alpha_state[:,i], beta, gamma) + for i in range(1, 4): + c0_i, u0_i, s0_i = atac_mrna_dynamics( + dt_state[:, i+1], c0_state_list[-1], u0_state_list[-1], s0_state_list[-1], k_c_state[:, i], + alpha_c, alpha_state[:, i], beta, gamma + ) + c0_state_list += [c0_i] + u0_state_list += [u0_i] + s0_state_list += [s0_i] + + c0_state = torch.stack(c0_state_list, dim = 1) + u0_state = torch.stack(u0_state_list, dim = 1) + s0_state = torch.stack(s0_state_list, dim = 1) + c0_vec = c0_state[torch.arange(n_genes).unsqueeze(1), state.T].T # cells, genes u0_vec = u0_state[torch.arange(n_genes).unsqueeze(1), state.T].T # cells, genes s0_vec = s0_state[torch.arange(n_genes).unsqueeze(1), state.T].T # cells, genes @@ -302,27 +311,31 @@ def get_cell_parameters( boolean = dt_2 >= dt_3 # True means chromatin starts closing, before transcription stops. t0_3 = torch.where(boolean, t0_2 + dt_3, t0_2 + dt_2) t0_4 = torch.where(~boolean, t0_2 + dt_3, t0_2 + dt_2) - state = ((t0_1 <= t)*1 + (t0_2 <= t)*1 + (t0_3 <= t)*1 + (t0_4 <= t)*1) # cells, genes + state = ((t0_1 <= t).int() + (t0_2 <= t).int() + (t0_3 <= t).int() + (t0_4 <= t).int()) # cells, genes n_genes = state.shape[1] - # Calculate time for each gene in each cell: - t0_state = torch.stack([torch.zeros((2)), t0_1, t0_2, t0_3, t0_4], dim = 1) # genes, states - t0_vec = t0_state[torch.arange(n_genes).unsqueeze(1), state.T].T # cells, genes - tau_vec = t - t0_vec # cells, genes + t0_state = torch.stack([torch.zeros_like(t0_1), t0_1, t0_2, t0_3, t0_4], dim=1) # genes, states + t0_vec = t0_state[torch.arange(n_genes).unsqueeze(1), state.T].T # cells, genes + tau_vec = t - t0_vec # cells, genes - # Calculate the transcription rate and chromatin state for each gene in each cell. - alpha_state = torch.stack([torch.ones((2))*alpha_off, - torch.ones((2))*alpha_off, - alpha, - torch.where(boolean, alpha, alpha_off), - torch.ones((2))*alpha_off], dim = 1) # genes, states - k_c_state = torch.stack([torch.zeros((2)), - torch.ones((2)), - torch.ones((2)), - torch.where(boolean, 0,1), - torch.zeros((2))], dim = 1) # genes, states - alpha_vec = alpha_state[torch.arange(n_genes).unsqueeze(1), state.T].T # cells, genes - k_c_vec = k_c_state[torch.arange(n_genes).unsqueeze(1), state.T].T # cells, genes + alpha_state = torch.stack([ + torch.ones_like(t0_1) * alpha_off, + torch.ones_like(t0_1) * alpha_off, + torch.ones_like(t0_1) * alpha, + torch.where(boolean, torch.ones_like(t0_1) * alpha, torch.ones_like(t0_1) * alpha_off), + torch.ones_like(t0_1) * alpha_off + ], dim=1) # genes, states + + k_c_state = torch.stack([ + torch.zeros_like(t0_1), + torch.ones_like(t0_1), + torch.ones_like(t0_1), + torch.where(boolean, torch.zeros_like(t0_1), torch.ones_like(t0_1)), + torch.zeros_like(t0_1) + ], dim=1) # genes, states + + alpha_vec = alpha_state[torch.arange(n_genes).unsqueeze(1), state.T].T # cells, genes + k_c_vec = k_c_state[torch.arange(n_genes).unsqueeze(1), state.T].T # cells, genes return state, k_c_state, alpha_state, t0_state, k_c_vec, alpha_vec, tau_vec From a2628d4058d7cf022551ca2e15276c73f44b8bd0 Mon Sep 17 00:00:00 2001 From: Alexander Aivazidis Date: Fri, 26 Jul 2024 14:22:52 +0000 Subject: [PATCH 18/47] fix(_velocity): Handling atac data. Signed-off-by: Alexander Aivazidis --- src/pyrovelocity/models/_velocity.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/pyrovelocity/models/_velocity.py b/src/pyrovelocity/models/_velocity.py index 2ef140d53..53c70b03a 100644 --- a/src/pyrovelocity/models/_velocity.py +++ b/src/pyrovelocity/models/_velocity.py @@ -290,7 +290,7 @@ def __init__( num_aux_cells=num_aux_cells, only_cell_times=only_cell_times, decoder_on=decoder_on, - add_offset=add_offset, + add_offset=False, correct_library_size=correct_library_size, cell_specific_kinetics=cell_specific_kinetics, kinetics_num=self.k, @@ -362,6 +362,7 @@ def setup_anndata(cls, adata: AnnData, adata_atac = None, *args, **kwargs): if adata_atac: adata.layers['atac'] = adata_atac.X anndata_fields += [LayerField('atac', 'atac')] + adata.uns['atac'] = True else: adata.uns['atac'] = None From ec34568b420a2325396a57c7d4bb964c9f32099a Mon Sep 17 00:00:00 2001 From: Alexander Aivazidis Date: Fri, 26 Jul 2024 14:23:51 +0000 Subject: [PATCH 19/47] fix(_velocity_module): Removed rates from multivariateNormalGuide, because they caused dimension errors during posterior sampling. Signed-off-by: Alexander Aivazidis --- src/pyrovelocity/models/_velocity_module.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/src/pyrovelocity/models/_velocity_module.py b/src/pyrovelocity/models/_velocity_module.py index 440aebb87..814fd1a0d 100644 --- a/src/pyrovelocity/models/_velocity_module.py +++ b/src/pyrovelocity/models/_velocity_module.py @@ -12,7 +12,7 @@ from scvi.module.base import PyroBaseModuleClass from pyrovelocity.logging import configure_logging -from pyrovelocity.models._velocity_model import VelocityModelAuto +from pyrovelocity.models._velocity_model import VelocityModelAuto, MultiVelocityModelAuto logger = configure_logging(__name__) @@ -331,7 +331,7 @@ def __init__( logger.info( f"Model type: {self.model_type}, Guide type: {self.guide_type}" ) - + self.cell_specific_kinetics = cell_specific_kinetics self._model = MultiVelocityModelAuto( @@ -380,9 +380,6 @@ def __init__( poutine.block( self._model, expose=[ - "alpha", - "beta", - "gamma", "dt_switching", "t0", "u_scale", @@ -401,9 +398,6 @@ def __init__( poutine.block( self._model, expose=[ - "alpha", - "beta", - "gamma", "dt_switching", "t0", "u_scale", From 085d7ac7ff76e4a9c5d2fb93bc4ece7ba54c21b2 Mon Sep 17 00:00:00 2001 From: Alexander Aivazidis Date: Fri, 26 Jul 2024 14:24:43 +0000 Subject: [PATCH 20/47] fix(train): Ensure atac data is handled properly. Signed-off-by: Alexander Aivazidis --- src/pyrovelocity/tasks/train.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/src/pyrovelocity/tasks/train.py b/src/pyrovelocity/tasks/train.py index eb3483189..3a1a45e48 100644 --- a/src/pyrovelocity/tasks/train.py +++ b/src/pyrovelocity/tasks/train.py @@ -361,10 +361,6 @@ def train_model( if isinstance(adata, str): ======= - if adata_atac: - logger.info( - "Multiome model not yet implemented. Proceeding without atac data." - ) if isinstance(adata, str): >>>>>>> 9d2f67d (added atac_layer argument to train_model tasks and made tests for it) adata = load_anndata_from_path(adata) @@ -372,10 +368,11 @@ def train_model( logger.info(f"AnnData object prior to model training") print_anndata(adata) - PyroVelocity.setup_anndata(adata) + PyroVelocity.setup_anndata(adata, adata_atac = adata_atac) model = PyroVelocity( adata, + adata_atac = adata_atac, likelihood=likelihood, model_type=model_type, guide_type=guide_type, From fbd06a8811e95ab0f33b56f78bf319edf12ddcf9 Mon Sep 17 00:00:00 2001 From: Alexander Aivazidis Date: Fri, 26 Jul 2024 17:41:12 +0000 Subject: [PATCH 21/47] feat(_transcription_dynamics): Added latent discrete parameter for modelling stochastic gene activation. Signed-off-by: Alexander Aivazidis --- src/pyrovelocity/models/_transcription_dynamics.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/pyrovelocity/models/_transcription_dynamics.py b/src/pyrovelocity/models/_transcription_dynamics.py index 26a683085..aba2a317f 100644 --- a/src/pyrovelocity/models/_transcription_dynamics.py +++ b/src/pyrovelocity/models/_transcription_dynamics.py @@ -253,6 +253,7 @@ def get_cell_parameters( dt_3: Tensor, alpha: Tensor, alpha_off: Tensor, + k: Tensor, ) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]: """ Gets the ODE parameters for each cell, by first assign each gene in each cell to a state @@ -267,6 +268,7 @@ def get_cell_parameters( dt_3 (Tensor): Time gap since transcription start for transcription stopping for each gene. alpha (Tensor): The transcription rate of each gene in the on state. alpha_off (Tensor): The transcription rate of each gene in the off state. + k (Tensor): The activation state of each gene in each state. Returns: Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]: Tuple containing the state of each cell (state), @@ -285,8 +287,9 @@ def get_cell_parameters( >>> dt_3 = torch.tensor((50.0, 70.0)) >>> alpha = torch.tensor((0.5, 0.3)) >>> alpha_off = torch.tensor(0.0) + >>> k = torch.tensor((1.0, 1.0),(1.0,1.0)) >>> get_cell_parameters( - t, t0_1, dt_1, dt_2, dt_3, alpha, alpha_off + t, t0_1, dt_1, dt_2, dt_3, alpha, alpha_off,k ) (tensor([[0, 0], [2, 2], @@ -313,6 +316,7 @@ def get_cell_parameters( t0_4 = torch.where(~boolean, t0_2 + dt_3, t0_2 + dt_2) state = ((t0_1 <= t).int() + (t0_2 <= t).int() + (t0_3 <= t).int() + (t0_4 <= t).int()) # cells, genes n_genes = state.shape[1] + state = state * (1-1*k) t0_state = torch.stack([torch.zeros_like(t0_1), t0_1, t0_2, t0_3, t0_4], dim=1) # genes, states t0_vec = t0_state[torch.arange(n_genes).unsqueeze(1), state.T].T # cells, genes From 113c3a4297b8077647c3c735fc5d6ecb5df57568 Mon Sep 17 00:00:00 2001 From: Alexander Aivazidis Date: Fri, 26 Jul 2024 17:41:48 +0000 Subject: [PATCH 22/47] feat(_velocity_model): Sampling latent discrete parameter for modelling stochastic gene activation. Signed-off-by: Alexander Aivazidis --- src/pyrovelocity/models/_velocity_model.py | 68 ++++++++++++++++++---- 1 file changed, 56 insertions(+), 12 deletions(-) diff --git a/src/pyrovelocity/models/_velocity_model.py b/src/pyrovelocity/models/_velocity_model.py index 44451847f..3aaad3492 100644 --- a/src/pyrovelocity/models/_velocity_model.py +++ b/src/pyrovelocity/models/_velocity_model.py @@ -412,9 +412,6 @@ def get_likelihood_multiome( return c_dist, u_dist, s_dist - - - class VelocityModelAuto(LogNormalModel): """Automatically configured velocity model. @@ -905,6 +902,50 @@ def __init__( self.correct_library_size = correct_library_size if self.decoder_on: self.decoder = Decoder(1, self.num_genes, n_layers=2) + self.enumeration = "parallel" + + @beartype + def create_plates( + self, + c_obs: Optional[torch.Tensor] = None, + u_obs: Optional[torch.Tensor] = None, + s_obs: Optional[torch.Tensor] = None, + u_log_library: Optional[torch.Tensor] = None, + s_log_library: Optional[torch.Tensor] = None, + u_log_library_loc: Optional[torch.Tensor] = None, + s_log_library_loc: Optional[torch.Tensor] = None, + u_log_library_scale: Optional[torch.Tensor] = None, + s_log_library_scale: Optional[torch.Tensor] = None, + ind_x: Optional[torch.Tensor] = None, + cell_state: Optional[torch.Tensor] = None, + time_info: Optional[torch.Tensor] = None, + ) -> Tuple[plate, plate]: + # Call the parent class method + cell_plate, gene_plate = super().create_plates( + u_obs=u_obs, + s_obs=s_obs, + u_log_library=u_log_library, + s_log_library=s_log_library, + u_log_library_loc=u_log_library_loc, + s_log_library_loc=s_log_library_loc, + u_log_library_scale=u_log_library_scale, + s_log_library_scale=s_log_library_scale, + ind_x=ind_x, + cell_state=cell_state, + time_info=time_info, + ) + # You can add any additional logic here if needed + return cell_plate, gene_plate + + def sample_cell_gene_state(self, t, switching): + return ( + pyro.sample( + "cell_gene_state", + Bernoulli(logits=t - switching), + infer={"enumerate": self.enumeration}, + ) + == self.zero + ) @beartype def __repr__(self) -> str: @@ -1004,14 +1045,16 @@ def get_atac_rna( [4.3359, 2.3326]])) """ + k = self.sample_cell_gene_state(t, t0_1) + state, k_c_state, alpha_state, t0_state, k_c_vec, alpha_vec, tau_vec = get_cell_parameters( - t, t0_1, dt_1, dt_2, dt_3, alpha, alpha_off - ) - + t, t0_1, dt_1, dt_2, dt_3, alpha, alpha_off,k, + ) + c0_vec, u0_vec, s0_vec = get_initial_states( t0_state, k_c_state, alpha_c, alpha_state, beta, gamma, state ) - + ct, ut, st = atac_mrna_dynamics( tau_vec, c0_vec, u0_vec, s0_vec, k_c_vec, alpha_c, alpha_vec, beta, gamma ) @@ -1022,9 +1065,9 @@ def get_atac_rna( @beartype def forward( self, + c_obs: torch.Tensor, u_obs: torch.Tensor, s_obs: torch.Tensor, - c_obs: torch.Tensor, u_log_library: Optional[torch.Tensor] = None, s_log_library: Optional[torch.Tensor] = None, u_log_library_loc: Optional[torch.Tensor] = None, @@ -1121,10 +1164,10 @@ def forward( with gene_plate, poutine.mask(mask=self.include_prior): - alpha_c = self.alpha_c - alpha = self.alpha - gamma = self.gamma - beta = self.beta + alpha_c = pyro.sample("alpha_c", LogNormal(self.one, self.one)) + alpha = pyro.sample("alpha", LogNormal(self.one*20, self.one*10)) + gamma = pyro.sample("gamma", LogNormal(self.one*20, self.one*10)) + beta = pyro.sample("beta", LogNormal(self.one*20, self.one*10)) alpha_off = self.zero t0_1 = pyro.sample("t0_1", Normal(self.zero, self.one*10)) @@ -1147,6 +1190,7 @@ def forward( u_read_depth = pyro.sample( "u_read_depth", LogNormal(u_log_library, u_log_library_scale) ) + s_read_depth = pyro.sample( "s_read_depth", LogNormal(s_log_library, s_log_library_scale) ) From 40581dc12d81cf6dd504ac1b54b9ed8ba416615a Mon Sep 17 00:00:00 2001 From: Alexander Aivazidis Date: Fri, 26 Jul 2024 17:42:12 +0000 Subject: [PATCH 23/47] feat(_test_transcription_dynamics): Adapted test to include latent discrete parameter for gene state. Signed-off-by: Alexander Aivazidis --- src/pyrovelocity/tests/models/test_transcription_dynamics.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/pyrovelocity/tests/models/test_transcription_dynamics.py b/src/pyrovelocity/tests/models/test_transcription_dynamics.py index cc8e0c1f1..7493205a1 100644 --- a/src/pyrovelocity/tests/models/test_transcription_dynamics.py +++ b/src/pyrovelocity/tests/models/test_transcription_dynamics.py @@ -85,9 +85,10 @@ def test_get_cell_parameters(): dt_3 = torch.tensor((50.0, 70.0)) alpha = torch.tensor((0.5, 0.3)) alpha_off = torch.tensor(0.0) + k = torch.tensor((1.0, 1.0),(1.0,1.0)) output = get_cell_parameters( - t, t0_1, dt_1, dt_2, dt_3, alpha, alpha_off + t, t0_1, dt_1, dt_2, dt_3, alpha, alpha_off,k ) correct_output = (torch.tensor([[0, 0], From c90af64d4d8141c5a89549283335019b51bc6c41 Mon Sep 17 00:00:00 2001 From: Alexander Aivazidis Date: Fri, 23 Aug 2024 20:03:51 +0000 Subject: [PATCH 24/47] feat(_models): Initial commit for new knn_model. Signed-off-by: Alexander Aivazidis --- src/pyrovelocity/models/knn_model.py | 1 + 1 file changed, 1 insertion(+) create mode 100644 src/pyrovelocity/models/knn_model.py diff --git a/src/pyrovelocity/models/knn_model.py b/src/pyrovelocity/models/knn_model.py new file mode 100644 index 000000000..8b1378917 --- /dev/null +++ b/src/pyrovelocity/models/knn_model.py @@ -0,0 +1 @@ + From 211d9b501548d6f7e6939b5618c4001b27dc19ab Mon Sep 17 00:00:00 2001 From: Alexander Aivazidis Date: Fri, 23 Aug 2024 20:04:48 +0000 Subject: [PATCH 25/47] feat(preprocess.py): Added function for computation of metacells. Signed-off-by: Alexander Aivazidis --- src/pyrovelocity/tasks/preprocess.py | 225 ++++++++++++++++++++++++++- 1 file changed, 224 insertions(+), 1 deletion(-) diff --git a/src/pyrovelocity/tasks/preprocess.py b/src/pyrovelocity/tasks/preprocess.py index 45a11762b..efd41a2b9 100644 --- a/src/pyrovelocity/tasks/preprocess.py +++ b/src/pyrovelocity/tasks/preprocess.py @@ -22,6 +22,10 @@ from pyrovelocity.tasks.data import load_anndata_from_path from pyrovelocity.utils import ensure_numpy_array, print_anndata +from scipy.sparse import csr_matrix +from scipy.sparse import issparse +import warnings + __all__ = [ "assign_colors", "compute_and_plot_qc", @@ -29,7 +33,7 @@ "get_high_us_genes", "get_thresh_histogram_title_from_path", "plot_high_us_genes", - "preprocess_dataset", + "preprocess_dataset" ] logger = configure_logging(__name__) @@ -634,6 +638,225 @@ def get_high_us_genes( logger.info(f"adata.shape after filtering: {adata.shape}") return adata +@beartype +def compute_metacells( + adata_rna: AnnData, + adata_atac: AnnData, + latent_key: str, + celltype_key: Optional[str] = None, + n_neighbors: int = 10, + n_neighbors_metacell: int = 5, + resolution: int = 50, + verbose: bool = True, + merge_knn_graph: bool = True, + merge_umap: bool = True, + umap_key: Optional[str] = None +) -> Tuple[AnnData, AnnData]: + """ + Computes metacells, using low-level clustering in a given a latent-space. By default, includes + summing up RNA counts, ATAC counts and optionally includes averaging UMAP coordinates and computing + a new knn-graph for meta-cells. If a celltype key is provided, metacells are named by their most + frequent celltype label. + + Args: + adata_rna (AnnData): AnnData object with RNA counts. + adata_atac (AnnData): AnnData object with ATAC counts. + latent_key (str): Name of latent space key in .obsm slot in adata_rna, e.g. X_pca. + celltype_key (Optional[str] optional): Name of cell type key in .obs column in adata_rna. + n_neighbors (int, optional): Number of nearest neighbors to use in initial knn-graph used + for metacell construction. Defaults to 10. + n_neighbors_metacell (int, optional): Number of nearest neighbors to use in new knn-graph + for metacells. Defaults to 5. + resolution (int, optional): Resolution at which to do leiden clustering for metacells. + Defaults to 50. + verbose (bool, optional): Whether to print out progress and make diagnostic plots + of metacell computations. Defaults to True. + merge_knn_graph (bool, optional): Whether to produce a new knn graph for metacells. + merge_umap (bool, optional): Whether to produce a umap embedding for metacells from the previous embedding of cells. + If no previous embedding is present it will be recomputed for the original anndata object. + umap_key (str, optional): Key of UMAP embedding in adata_rna.obsm + Returns: + Tuple[AnnData, AnnData]: Tuple containing two anndata objects containing RNA and ATAC counts for metacells. + + Examples: + >>> from pyrovelocity.tests.synthetic_AnnData import synthetic_AnnData + >>> adata_rna = synthetic_AnnData(seed = 1) + >>> adata_atac = synthetic_AnnData(seed = 2) + >>> sc.tl.pca(adata_rna, n_comps=3) + >>> compute_metacells(adata_rna, adata_atac, + latent_key = 'X_pca', + resolution = 1, + celltype_key = 'cell_type') + """ + + # Check input makes sense: + if len(adata_rna.obs_names) != len(adata_atac.obs_names): + raise ValueError("RNA and ATAC data do not contain the same cells number of cells.") + if (adata_rna.obs_names != adata_atac.obs_names).all(): + raise ValueError("RNA and ATAC data do not contain the same cells in obs_names.") + + if not isinstance(adata_rna.X, csr_matrix): + adata_rna.X = csr_matrix(adata_rna.X) + if not isinstance(adata_atac.X, csr_matrix): + adata_atac.X = csr_matrix(adata_atac.X) + + # Define functions for all processing steps: + + @beartype + def merge_RNA( + adata_rna: AnnData, + cluster_key: str, + celltype_key: Optional[str] = None, + verbose: bool = True, + )-> AnnData: + + if verbose: + print('merging RNA counts') + + X = np.concatenate([np.sum(adata_rna.X[adata_rna.obs[cluster_key] == c,:], axis = 0) for c in np.unique(adata_rna.obs[cluster_key])], axis = 0) + print(X.shape) + adata_meta = sc.AnnData(X = np.array(X)) + adata_meta.obs['n_cells'] = [np.sum(adata_rna.obs[cluster_key] == c) for c in np.unique(adata_rna.obs[cluster_key])] + if celltype_key: + adata_meta.obs[celltype_key] = [adata_rna[adata_rna.obs[cluster_key] == c,:].obs[celltype_key].mode()[0] for c in np.unique(adata_rna.obs[cluster_key])] + adata_meta.obs['RNA counts'] = np.sum(adata_meta.X, axis = 1) + + if verbose: + print('Mean RNA counts per cell before: ', np.mean(adata_rna.obs['RNA counts'])) + print('Mean RNA counts per cell after: ', np.mean(adata_meta.obs['RNA counts'])) + plt.hist(adata_rna.obs['RNA counts'], bins = 10, label = 'single cells', alpha = 0.5) + plt.hist(adata_meta.obs['RNA counts'],bins = 20, label = 'meta cells', alpha = 0.5) + plt.xlabel('Total Counts') + plt.ylabel('Occurences') + plt.legend() + plt.show() + + return adata_meta + + @beartype + def merge_UMAP( + adata_rna: AnnData, + adata_meta: AnnData, + cluster_key: str, + umap_key: str = 'X_umap', + verbose: bool = True, + )-> AnnData: + + if verbose: + print('merging UMAP') + + adata_meta.obsm[umap_key] = np.concatenate([np.expand_dims(np.mean(adata_rna.obsm[umap_key][adata_rna.obs[cluster_key] == c,:], axis = 0), axis = -1) for c in np.unique(adata_rna.obs[cluster_key])], axis = 1).T + + return adata_meta + + @beartype + def merge_ATAC( + adata_atac: AnnData, + cluster_key: str, + celltype_key: Optional[str] = None, + verbose: bool = True + )-> AnnData: + + if verbose: + print('merging ATAC counts') + + X = np.concatenate([np.sum(adata_atac.X[adata_rna.obs[cluster_key] == c,:], axis = 0) for c in np.unique(adata_atac.obs[cluster_key])], axis = 0) + adata_atac_meta = sc.AnnData(X = np.array(X)) + adata_atac_meta.obs['n_cells'] = [np.sum(adata_atac.obs[cluster_key] == c) for c in np.unique(adata_atac.obs[cluster_key])] + if celltype_key: + adata_atac_meta.obs[celltype_key] = [adata_atac[adata_atac.obs[cluster_key] == c,:].obs[celltype_key].mode()[0] for c in np.unique(adata_atac.obs[cluster_key])] + adata_atac_meta.obs['ATAC counts'] = np.sum(adata_atac_meta.X, axis = 1) + + if verbose: + print('Mean ATAC counts per cell before: ', np.mean(adata_atac.obs['ATAC counts'])) + print('Mean ATAC counts per cell after: ', np.mean(adata_atac_meta.obs['ATAC counts'])) + plt.hist(adata_atac.obs['ATAC counts'], bins = 10, label = 'single cells', alpha = 0.5) + plt.hist(adata_atac_meta.obs['ATAC counts'],bins = 20, label = 'meta cells', alpha = 0.5) + plt.xlabel('Total Counts') + plt.ylabel('Occurences') + plt.legend() + plt.show() + + return adata_atac_meta + + @beartype + def merge_knn( + adata_rna: AnnData, + adata_meta: AnnData, + cluster_key: str, + n_neighbors: int = 6, + verbose: bool = True + ) -> AnnData: + + if verbose: + print('merging knn graph') + + distance_matrix = (adata_rna.obsp['distances'].toarray() != 0)*1 + clusters = adata_rna.obs[cluster_key] + for c in np.unique(clusters): + subset = np.array(clusters == c) + distance_matrix = np.concatenate([distance_matrix[~subset,:], np.expand_dims(np.sum(distance_matrix[subset,:], axis = 0), axis = 0)], axis = 0) + distance_matrix = np.concatenate([distance_matrix[:,~subset], np.expand_dims(np.sum(distance_matrix[:,subset], axis = 1), axis = 1)], axis = 1) + clusters = np.concatenate([clusters[~subset], np.expand_dims(np.array((c)), axis = 0)]) + adata_meta.obsm['N_cn'] = np.stack([np.argsort(-1*distance_matrix[i,:])[:n_neighbors+1] for i in range(len(distance_matrix[:,0]))], axis = 0) + + return adata_meta + + # Run through the metacell construction: + adata_atac = adata_atac[adata_rna.obs_names,:] + if celltype_key: + adata_atac.obs[celltype_key] = adata_rna.obs[celltype_key] + adata_rna.obs['RNA counts'] = np.sum(adata_rna.X, axis = 1) + adata_atac.obs['ATAC counts'] = np.sum(adata_atac.X, axis = 1) + + if verbose: + print('low resolution clustering cells') + sc.pp.neighbors(adata_rna, use_rep=latent_key, n_neighbors=n_neighbors) + cluster_key = "leiden" + sc.tl.leiden(adata_rna, key_added=cluster_key, resolution=resolution) + adata_atac.obs[cluster_key] = adata_rna.obs[cluster_key] + + if verbose: + print('total number of cells', len(adata_rna.obs_names)) + print('total number of meta-cells', len(np.unique(adata_rna.obs[cluster_key]))) + print('minimum cells/meta-cell', np.min(adata_rna.obs[cluster_key].value_counts())) + print('average cells/meta-cell', np.round(np.mean(adata_rna.obs[cluster_key].value_counts()),1)) + print('maximum cells/meta_cell', np.max(adata_rna.obs[cluster_key].value_counts())) + plt.hist(adata_rna.obs[cluster_key].value_counts(), bins = 10) + plt.xlabel('cells per meta-cell') + plt.ylabel('number of meta-cells') + + adata_meta = merge_RNA(adata_rna, verbose = verbose, + cluster_key = cluster_key, celltype_key = celltype_key) + + if merge_umap: + if not umap_key: + sc.tl.umap(adata_rna) + umap_key = 'X_umap' + elif umap_key not in adata_rna.obsm: + warnings.warn("Umap_key not found in AnnData. Computing with sc.tl.umap()...", category=UserWarning) + sc.tl.umap(adata_rna) + umap_key = 'X_umap' + adata_meta = merge_UMAP(adata_rna, adata_meta, umap_key = umap_key,verbose = verbose, cluster_key = cluster_key) + + adata_atac_meta = merge_ATAC(adata_atac, verbose = verbose, + cluster_key = cluster_key, celltype_key = celltype_key) + + adata_meta.obs['ATAC counts'] = adata_atac_meta.obs['ATAC counts'] + + if merge_knn_graph: + + adata_meta = merge_knn( + adata_rna, + adata_meta, + cluster_key = cluster_key, + n_neighbors = n_neighbors_metacell, + verbose = True) + + if verbose: + print('Done.') + + return adata_meta, adata_atac_meta # ------------------------------------------------------------------------------ From 73b1b87319ce79ca650ac5d208825b6c9fb4f198 Mon Sep 17 00:00:00 2001 From: Alexander Aivazidis Date: Fri, 23 Aug 2024 20:05:12 +0000 Subject: [PATCH 26/47] feat(tests): Added function to produce synthetic AnnData. Signed-off-by: Alexander Aivazidis --- src/pyrovelocity/tests/synthetic_AnnData.py | 62 +++++++++++++++++++++ 1 file changed, 62 insertions(+) create mode 100644 src/pyrovelocity/tests/synthetic_AnnData.py diff --git a/src/pyrovelocity/tests/synthetic_AnnData.py b/src/pyrovelocity/tests/synthetic_AnnData.py new file mode 100644 index 000000000..a9bfc4916 --- /dev/null +++ b/src/pyrovelocity/tests/synthetic_AnnData.py @@ -0,0 +1,62 @@ +"""Producing synthetic AnnData for tests.""" + +import numpy as np +import pandas as pd +import anndata as ad + +def synthetic_AnnData( + n_cell_types: int = 3, + n_genes: int = 10, + cells_per_type: int = 20, + seed: int = 42 + ): + + """ + Produces a simple synthetic AnnData object. + + Args: + n_cell_types (int): Number of cell types. + n_genes (int): Number of genes. + cells_per_type (int): Number of cells per cell type. + seed (int): Random seed. + Returns: + AnnData: Synthetic AnnData object. + + Examples: + >>> synthetic_AnnData() + """ + + # Number of genes, cells, and cell types + n_genes = 10 + n_cells = cells_per_type * n_cell_types + + # Create synthetic gene expression data + # Each cell type will have slightly different expression profiles + np.random.seed(seed) # For reproducibility + + cells_per_type = int(n_cells/n_cell_types) + # Generate data with different means for different cell types + expression_data = np.vstack([ + np.random.normal(loc=i, scale=0.5, size=(cells_per_type, n_genes)) + for i in range(n_cell_types) + ]) + + # Create an AnnData object + adata = ad.AnnData(X=expression_data) + + # Add cell type annotations + cell_types = [] + for i in range(n_cell_types): + cell_types += ['Type ' + str(i)] * cells_per_type + adata.obs['cell_type'] = pd.Categorical(cell_types) + + # Add gene names (e.g., Gene1, Gene2, ..., Gene20) + gene_names = [f'Gene{i+1}' for i in range(n_genes)] + adata.var['gene_names'] = gene_names + + # Add cell names (e.g., Cell1, Cell2, ..., Cell30) + cell_names = [f'Cell{i+1}' for i in range(n_cells)] + adata.obs_names = cell_names + adata.var_names = gene_names + + return adata From 725a0bdc3b6fb50a46e12f342588c66826b7d65f Mon Sep 17 00:00:00 2001 From: Alexander Aivazidis Date: Fri, 23 Aug 2024 20:05:33 +0000 Subject: [PATCH 27/47] feat(tests): Added test for compute_metacell function. Signed-off-by: Alexander Aivazidis --- src/pyrovelocity/tests/test_preprocess.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) create mode 100644 src/pyrovelocity/tests/test_preprocess.py diff --git a/src/pyrovelocity/tests/test_preprocess.py b/src/pyrovelocity/tests/test_preprocess.py new file mode 100644 index 000000000..9f706a7aa --- /dev/null +++ b/src/pyrovelocity/tests/test_preprocess.py @@ -0,0 +1,20 @@ +"""Tests for `pyrovelocity.tasks.preprocess` module.""" + + +def test_load_preprocess(): + from pyrovelocity.tasks import preprocess + + print(preprocess.__file__) + +def test_compute_metacells(): + from pyrovelocity.tasks.preprocess import compute_metacells + from pyrovelocity.tests.synthetic_AnnData import synthetic_AnnData + import scanpy as sc + adata_rna = synthetic_AnnData(seed = 1) + adata_atac = synthetic_AnnData(seed = 2) + sc.tl.pca(adata_rna, n_comps=3) + compute_metacells(adata_rna, adata_atac, + latent_key = 'X_pca', + resolution = 1, + celltype_key = 'cell_type') + From a5f31dfe2ddc3dadcc4e8c48df9c4c84314d11ad Mon Sep 17 00:00:00 2001 From: Alexander Aivazidis Date: Fri, 23 Aug 2024 20:05:51 +0000 Subject: [PATCH 28/47] feat(tests): Added test for synthetic_AnnData function. Signed-off-by: Alexander Aivazidis --- src/pyrovelocity/tests/test_synthetic_AnnData.py | 5 +++++ 1 file changed, 5 insertions(+) create mode 100644 src/pyrovelocity/tests/test_synthetic_AnnData.py diff --git a/src/pyrovelocity/tests/test_synthetic_AnnData.py b/src/pyrovelocity/tests/test_synthetic_AnnData.py new file mode 100644 index 000000000..ae0f838db --- /dev/null +++ b/src/pyrovelocity/tests/test_synthetic_AnnData.py @@ -0,0 +1,5 @@ +"""Test synthetic_AnnData function.""" + +def test_synthetic_AnnData(): + from pyrovelocity.tests.synthetic_AnnData import synthetic_AnnData + synthetic_AnnData() \ No newline at end of file From 58a48e979768273a59acd16fed7b5c6d4fe5f260 Mon Sep 17 00:00:00 2001 From: Alexander Aivazidis Date: Mon, 26 Aug 2024 21:32:08 +0000 Subject: [PATCH 29/47] feat(models): Started knn model in a new folder. Signed-off-by: Alexander Aivazidis --- src/pyrovelocity/models/knn_model.py | 1 - .../models/knn_model/_velocity.py | 714 ++++++++++++++++++ .../models/knn_model/_velocity_model.py | 342 +++++++++ .../models/knn_model/_velocity_module.py | 236 ++++++ 4 files changed, 1292 insertions(+), 1 deletion(-) delete mode 100644 src/pyrovelocity/models/knn_model.py create mode 100644 src/pyrovelocity/models/knn_model/_velocity.py create mode 100644 src/pyrovelocity/models/knn_model/_velocity_model.py create mode 100644 src/pyrovelocity/models/knn_model/_velocity_module.py diff --git a/src/pyrovelocity/models/knn_model.py b/src/pyrovelocity/models/knn_model.py deleted file mode 100644 index 8b1378917..000000000 --- a/src/pyrovelocity/models/knn_model.py +++ /dev/null @@ -1 +0,0 @@ - diff --git a/src/pyrovelocity/models/knn_model/_velocity.py b/src/pyrovelocity/models/knn_model/_velocity.py new file mode 100644 index 000000000..918f4568a --- /dev/null +++ b/src/pyrovelocity/models/knn_model/_velocity.py @@ -0,0 +1,714 @@ +import os +import pickle +import sys +from statistics import harmonic_mean +from typing import Dict, Optional, Sequence, Union +import mlflow +import numpy as np +import pyro +import torch +from anndata import AnnData +from beartype import beartype +from numpy import ndarray +from scvi.data import AnnDataManager +from scvi.data._constants import _SETUP_ARGS_KEY, _SETUP_METHOD_NAME +from scvi.data.fields import LayerField, NumericalObsField +from scvi.model._utils import parse_device_args +from scvi.model.base import BaseModelClass +from scvi.model.base._utils import ( + _initialize_model, + _load_saved_files, + _validate_var_names, +) +from scvi.module.base import PyroBaseModuleClass + +from pyrovelocity.analysis.analyze import ( + compute_mean_vector_field, + compute_volcano_data, + vector_field_uncertainty, +) +from pyrovelocity.logging import configure_logging +from pyrovelocity.models._trainer import VelocityTrainingMixin +from pyrovelocity.models.knn_model._velocity_module import VelocityModule, MultiVelocityModule + +__all__ = ["PyroVelocity"] + +logger = configure_logging(__name__) + +class PyroVelocity(VelocityTrainingMixin, BaseModelClass): + """ + PyroVelocity is a class for constructing and training a Pyro model for + probabilistic RNA velocity estimation. This model leverages the + probabilistic programming language Pyro to estimate the parameters of models + for the dynamics of RNA transcription, splicing, and degradation, providing + the opportunity for insight into cellular states and associated state + transitions. It makes use of AnnData, scvi-tools, and other scverse + ecosystem libraries. + + Public methods include training the model with various configurations, + generating posterior samples for further analysis, and saving/loading the + model for reproducibility and further analysis. + + Attributes: + use_gpu (str): Whether and which GPU to use. + cell_specific_kinetics (Optional[str]): Type of cell-specific kinetics. + k (Optional[int]): Number of kinetics. + layers (List[str]): List of layers in the dataset. + input_type (str): Type of input data. + module (VelocityModule): + The Pyro module used for the velocity estimation model. + num_cells (int): Number of cells in the dataset. + num_samples (int): Number of posterior samples to generate. + _model_summary_string (str): Summary string for the model. + init_params_ (Dict[str, Any]): Initial parameters for the model. + + For usage examples, including training the model and generating posterior + samples, refer to the individual method docstrings. + """ + + """ + The `Methods` section is not supported by all documentation generators but + is provided detached from the class docstring for reference. Please + see the docstrings for each method for more details. This list may ignore + unused or private methods. + + Methods: + train: + Trains the PyroVelocity model using the provided data and configuration. + setup_anndata: + Set up AnnData object for compatibility with the scvi-tools + model training interface. + generate_posterior_samples: + Generates posterior samples for the given data using the trained + PyroVelocity model. + compute_statistics_from_posterior_samples: + Estimate statistics from posterior samples and add them to the + `posterior_samples` dictionary. + save_pyrovelocity_data: + Saves the PyroVelocity data to a pickle file. + save_model: + Saves the Pyro-Velocity model to a directory. + load_model: + Load the model from a directory with the same structure as that produced + by the save method. + """ + + def __init__( + self, + adata: AnnData, + adata_atac: Optional[AnnData] = None, + input_type: str = "raw", + shared_time: bool = True, + model_type: str = "auto", + guide_type: str = "auto", + likelihood: str = "Poisson", + t_scale_on: bool = False, + plate_size: int = 2, + latent_factor: str = "none", + latent_factor_operation: str = "selection", + inducing_point_size: int = 0, + latent_factor_size: int = 0, + include_prior: bool = False, + use_gpu: str = "auto", + init: bool = False, + num_aux_cells: int = 0, + only_cell_times: bool = True, + decoder_on: bool = False, + add_offset: bool = False, + correct_library_size: Union[bool, str] = True, + cell_specific_kinetics: Optional[str] = None, + kinetics_num: Optional[int] = None, + ) -> None: + """ + PyroVelocity class for estimating RNA velocity and related tasks. + + Args: + adata (AnnData): An AnnData object containing the gene expression data. + adata_atac (Optional[AnnData], optional): An AnnData object containing atac data. + input_type (str, optional): Type of input data. Can be "raw", "knn", or "raw_cpm". Defaults to "raw". + shared_time (bool, optional): Whether to use shared time. Defaults to True. + model_type (str, optional): Type of model to use. Defaults to "auto". + guide_type (str, optional): Type of guide to use. Defaults to "auto". + likelihood (str, optional): Type of likelihood to use. Defaults to "Poisson". + t_scale_on (bool, optional): Whether to use t_scale. Defaults to False. + plate_size (int, optional): Size of the plate. Defaults to 2. + latent_factor (str, optional): Type of latent factor. Defaults to "none". + latent_factor_operation (str, optional): Operation to perform on the latent factor. Defaults to "selection". + inducing_point_size (int, optional): Size of inducing points. Defaults to 0. + latent_factor_size (int, optional): Size of latent factors. Defaults to 0. + include_prior (bool, optional): Whether to include prior information. Defaults to False. + use_gpu (Union[bool, int], optional): Whether and which GPU to use. Defaults to 0. Can be False. + init (bool, optional): Whether to initialize the model. Defaults to False. + num_aux_cells (int, optional): Number of auxiliary cells. Defaults to 0. + only_cell_times (bool, optional): Whether to use only cell times. Defaults to True. + decoder_on (bool, optional): Whether to use decoder. Defaults to False. + add_offset (bool, optional): Whether to add offset. Defaults to False. + correct_library_size (Union[bool, str], optional): Whether to correct library size or method to correct. Defaults to True. + cell_specific_kinetics (Optional[str], optional): Type of cell-specific kinetics. Defaults to None. + kinetics_num (Optional[int], optional): Number of kinetics. Defaults to None. + + Examples: + >>> # import necessary libraries + >>> import numpy as np + >>> import anndata + >>> from pyrovelocity.utils import pretty_log_dict, print_anndata, generate_sample_data + >>> from pyrovelocity.tasks.preprocess import copy_raw_counts + >>> from pyrovelocity.models._velocity import PyroVelocity + ... + >>> # define fixtures + >>> try: + >>> tmp = getfixture("tmp_path") + >>> except NameError: + >>> import tempfile + >>> tmp = tempfile.TemporaryDirectory().name + >>> doctest_model_path = str(tmp) + "/save_pyrovelocity_doctest_model" + >>> print(doctest_model_path) + ... + >>> # setup sample data + >>> n_obs = 10 + >>> n_vars = 5 + >>> adata = generate_sample_data(n_obs=n_obs, n_vars=n_vars) + >>> copy_raw_counts(adata) + >>> print_anndata(adata) + >>> print(adata.X) + >>> print(adata.layers['spliced']) + >>> print(adata.layers['unspliced']) + >>> print(adata.obs['u_lib_size_raw']) + >>> print(adata.obs['s_lib_size_raw']) + >>> PyroVelocity.setup_anndata(adata) + ... + >>> # train model with macroscopic validation set + >>> model = PyroVelocity(adata) + >>> model.train(max_epochs=5, train_size=0.8, valid_size=0.2, use_gpu="auto") + >>> posterior_samples = model.generate_posterior_samples(model.adata, num_samples=30) + >>> print(posterior_samples.keys()) + >>> assert isinstance(posterior_samples, dict), f"Expected a dictionary, got {type(posterior_samples)}" + >>> posterior_samples_log = pretty_log_dict(posterior_samples) + >>> model.save_model(doctest_model_path, overwrite=True) + >>> model = PyroVelocity.load_model(doctest_model_path, adata, use_gpu="auto") + ... + >>> # train model with default parameters + >>> model = PyroVelocity(adata) + >>> model.train_faster(max_epochs=5, use_gpu="auto") + >>> model.save_model(doctest_model_path, overwrite=True) + >>> model = PyroVelocity.load_model(doctest_model_path, adata, use_gpu="auto") + >>> posterior_samples = model.generate_posterior_samples(model.adata, num_samples=30) + >>> posterior_samples_log = pretty_log_dict(posterior_samples) + >>> print(posterior_samples.keys()) + ... + >>> # train model with specified batch size + >>> model = PyroVelocity(adata) + >>> model.train_faster_with_batch(batch_size=24, max_epochs=5, use_gpu="auto") + >>> model.save_model(doctest_model_path, overwrite=True) + >>> model = PyroVelocity.load_model(doctest_model_path, adata, use_gpu="auto") + >>> posterior_samples = model.generate_posterior_samples(model.adata, num_samples=30) + >>> posterior_samples_log = pretty_log_dict(posterior_samples) + >>> print(posterior_samples.keys()) + ... + >>> # If running from an interactive session, the temporary directory + >>> # can be inspected to review the saved model files. When run as a + >>> # doctest it is automatically cleaned up after the test completes. + >>> print(f"Output located in {doctest_model_path}") + """ + self.use_gpu = use_gpu + self.cell_specific_kinetics = cell_specific_kinetics + self.k = kinetics_num + if input_type == "knn": + layers = ["Mu", "Ms"] + assert likelihood in {"Normal", "LogNormal"} + assert "Mu" in adata.layers + elif input_type == "raw_cpm": + layers = ["unspliced", "spliced"] + assert likelihood in {"Normal", "LogNormal"} + else: + layers = ["raw_unspliced", "raw_spliced"] + assert likelihood != "Normal" + + self.layers = layers + self.input_type = input_type + + super().__init__(adata) + # TODO: remove unused code + # from pyrovelocity.utils import init_with_all_cells + # if init: + # initial_values = init_with_all_cells( + # self.adata, + # input_type, + # shared_time, + # latent_factor, + # latent_factor_size, + # plate_size, + # num_aux_cells=num_aux_cells, + # ) + # else: + initial_values = {} + logger.info(self.summary_stats) + if not adata_atac: + self.module = VelocityModule( + self.summary_stats["n_cells"], + self.summary_stats["n_vars"], + model_type=model_type, + guide_type=guide_type, + likelihood=likelihood, + shared_time=shared_time, + t_scale_on=t_scale_on, + plate_size=plate_size, + latent_factor=latent_factor, + latent_factor_operation=latent_factor_operation, + latent_factor_size=latent_factor_size, + inducing_point_size=inducing_point_size, + include_prior=include_prior, + use_gpu=use_gpu, + num_aux_cells=num_aux_cells, + only_cell_times=only_cell_times, + decoder_on=decoder_on, + add_offset=add_offset, + correct_library_size=correct_library_size, + cell_specific_kinetics=cell_specific_kinetics, + kinetics_num=self.k, + **initial_values, + ) + else: + self.module = MultiVelocityModule( + self.summary_stats["n_cells"], + self.summary_stats["n_vars"], + model_type=model_type, + guide_type=guide_type, + likelihood=likelihood, + shared_time=shared_time, + t_scale_on=t_scale_on, + plate_size=plate_size, + latent_factor=latent_factor, + latent_factor_operation=latent_factor_operation, + latent_factor_size=latent_factor_size, + inducing_point_size=inducing_point_size, + include_prior=include_prior, + use_gpu=use_gpu, + num_aux_cells=num_aux_cells, + only_cell_times=only_cell_times, + decoder_on=decoder_on, + add_offset=False, + correct_library_size=correct_library_size, + cell_specific_kinetics=cell_specific_kinetics, + kinetics_num=self.k, + **initial_values, + ) + self.num_cells = self.module.num_cells + self._model_summary_string = """ + RNA velocity Pyro model with parameters: + """ + self.init_params_ = self._get_init_params(locals()) + logger.info("Model initialized") + + def train(self, **kwargs): + """ + Trains the PyroVelocity model using the provided data and configuration. + + The method leverages the Pyro library to train the model using the underlying + data. It relies on the `VelocityTrainingMixin` to define the training logic. + + Args: + + **kwargs : dict, optional + Additional keyword arguments to be passed to the underlying train method + provided by the `VelocityTrainingMixin`. + """ + pyro.enable_validation(True) + super().train(**kwargs) + + def enum_parallel_predict(self): + """work for parallel enumeration""" + return + + @classmethod + def setup_anndata(cls, adata: AnnData, adata_atac = None, batch_key = None, *args, **kwargs): + """ + Set up AnnData object for compatibility with the scvi-tools + model training interface. + + Args: + adata (AnnData): Anndata object to be used in model training. + """ + setup_method_args = cls._get_setup_method_args(**locals()) + + adata.obs["ind_x"] = np.arange(adata.n_obs).astype("int64") + + anndata_fields = [ + LayerField("U", "raw_unspliced", is_count_data=True), + LayerField("X", "raw_spliced", is_count_data=True), + CategoricalObsField(REGISTRY_KEYS.BATCH_KEY, batch_key) + NumericalObsField("ind_x", "ind_x"), + ] + + if adata_atac: + adata.layers['atac'] = adata_atac.X + anndata_fields += [LayerField('atac', 'atac')] + adata.uns['atac'] = True + else: + adata.uns['atac'] = None + + if 'N_cn' in adata.obsm: + anndata_fields += [LayerField('N_cn', 'N_cn')] + + adata_manager = AnnDataManager( + fields=anndata_fields, setup_method_args=setup_method_args + ) + + adata_manager.register_fields(adata, **kwargs) + cls.register_manager(adata_manager) + + def generate_posterior_samples( + self, + adata: Optional[AnnData] = None, + indices: Optional[Sequence[int]] = None, + batch_size: Optional[int] = None, + num_samples: int = 100, + ) -> Dict[str, ndarray]: + """ + Generates posterior samples for the given data using the trained + PyroVelocity model. + + The method generates posterior samples by running the trained model on the + provided data and returns a dictionary containing samples for each parameter. + + Args: + adata (AnnData, optional): Anndata object containing the data for which posterior samples + are to be computed. If not provided, the anndata used to initialize the model will be used. + indices (Sequence[int], optional): Indices of cells in `adata` for which the posterior + samples are to be computed. + batch_size (int, optional): The size of the mini-batches used during computation. + If not provided, the entire dataset will be used. + num_samples (int, default: 100): The number of posterior samples to compute for each parameter. + + Returns: + Dict[str, ndarray]: A dictionary containing the posterior samples for each parameter. + """ + self.module.eval() + predictive = self.module.create_predictive( + model=pyro.poutine.uncondition(self.module.model), + num_samples=num_samples, + ) + + scdl = self._make_data_loader( + adata=adata, indices=indices, batch_size=batch_size + ) + + with torch.no_grad(), pyro.poutine.mask(mask=False): + posterior_samples = [] + for tensor in scdl: + args, kwargs = self.module._get_fn_args_from_batch(tensor) + posterior_sample = { + k: v.cpu().numpy() + for k, v in predictive(*args, **kwargs).items() + } + posterior_samples.append(posterior_sample) + samples = {} + for k in posterior_samples[0].keys(): + if k in [ + "ut_norm", + "st_norm", + "time_constraint", + ]: + continue + + if posterior_samples[0][k].shape[-2] == 1: + samples[k] = posterior_samples[0][k] + else: + samples[k] = np.concatenate( + [ + posterior_samples[j][k] + for j in range(len(posterior_samples)) + ], + axis=-2, + ) + + logger.debug(k, "before", sys.getsizeof(samples[k])) + self.num_samples = num_samples + return samples + + def get_mlflow_logs(self): + return + + def compute_statistics_from_posterior_samples( + self, + adata: AnnData, + posterior_samples: Dict[str, ndarray], + vector_field_basis: str = "umap", + ncpus_use: int = 1, + ) -> Dict[str, ndarray]: + """ + Estimate statistics from posterior samples and add them to the + `posterior_samples` dictionary. The names of the statistics incorporated into + the dictionary are: + + - `gene_ranking` + - `original_spaces_embeds_magnitude` + - `genes` + - `vector_field_posterior_samples` + - `vector_field_posterior_mean` + - `fdri` + - `embeds_magnitude` + - `embeds_angle` + - `ut_mean` + - `st_mean` + - `pca_vector_field_posterior_samples` + - `pca_embeds_angle` + - `pca_fdri` + + The following data are removed from the `posterior_samples` dictionary: + + - `u` + - `s` + - `ut` + - `st` + + Each of these sets requires further documentation. + + Args: + adata (AnnData): Anndata object containing the data for which posterior samples + were computed. + posterior_samples (Dict[str, ndarray]): Dictionary containing the posterior samples + for each parameter. + vector_field_basis (str, optional): Basis for the vector field. Defaults to "umap". + ncpus_use (int, optional): Number of CPUs to use for computation. Defaults to 1. + + Returns: + Dict[str, ndarray]: Dictionary containing the posterior samples with added statistics. + """ + if ("u_scale" in posterior_samples) and ( + "s_scale" in posterior_samples + ): + scale = posterior_samples["u_scale"] / posterior_samples["s_scale"] + elif ("u_scale" in posterior_samples) and not ( + "s_scale" in posterior_samples + ): + scale = posterior_samples["u_scale"] + else: + scale = 1 + original_spaces_velocity_samples = ( + posterior_samples["beta"] * posterior_samples["ut"] / scale + - posterior_samples["gamma"] * posterior_samples["st"] + ) + original_spaces_embeds_magnitude = np.sqrt( + (original_spaces_velocity_samples**2).sum(axis=-1) + ) + + ( + vector_field_posterior_samples, + embeds_radian, + fdri, + ) = vector_field_uncertainty( + adata, + posterior_samples, + basis=vector_field_basis, + n_jobs=ncpus_use, + ) + embeds_magnitude = np.sqrt( + (vector_field_posterior_samples**2).sum(axis=-1) + ) + + mlflow.log_metric( + "FDR_sig_frac", round((fdri < 0.05).sum() / fdri.shape[0], 3) + ) + mlflow.log_metric("FDR_HMP", harmonic_mean(fdri)) + + compute_mean_vector_field( + posterior_samples=posterior_samples, + adata=adata, + basis=vector_field_basis, + n_jobs=ncpus_use, + ) + + vector_field_posterior_mean = adata.obsm[ + f"velocity_pyro_{vector_field_basis}" + ] + + gene_ranking, genes = compute_volcano_data( + [posterior_samples], [adata], time_correlation_with="st" + ) + gene_ranking = ( + gene_ranking.sort_values("mean_mae", ascending=False) + .head(300) + .sort_values("time_correlation", ascending=False) + ) + posterior_samples["gene_ranking"] = gene_ranking + posterior_samples[ + "original_spaces_embeds_magnitude" + ] = original_spaces_embeds_magnitude + posterior_samples["genes"] = genes + posterior_samples[ + "vector_field_posterior_samples" + ] = vector_field_posterior_samples + posterior_samples[ + "vector_field_posterior_mean" + ] = vector_field_posterior_mean + posterior_samples["fdri"] = fdri + posterior_samples["embeds_magnitude"] = embeds_magnitude + posterior_samples["embeds_angle"] = embeds_radian + posterior_samples["ut_mean"] = posterior_samples["ut"].mean(0).squeeze() + posterior_samples["st_mean"] = posterior_samples["st"].mean(0).squeeze() + + ( + pca_vector_field_posterior_samples, + pca_embeds_radian, + pca_fdri, + ) = vector_field_uncertainty( + adata, + posterior_samples, + basis="pca", + n_jobs=ncpus_use, + ) + posterior_samples[ + "pca_vector_field_posterior_samples" + ] = pca_vector_field_posterior_samples + posterior_samples["pca_embeds_angle"] = pca_embeds_radian + posterior_samples["pca_fdri"] = pca_fdri + + del posterior_samples["u"] + del posterior_samples["s"] + del posterior_samples["ut"] + del posterior_samples["st"] + return posterior_samples + + @beartype + def save_pyrovelocity_data( + self, + posterior_samples: Dict[str, ndarray], + pyrovelocity_data_path: os.PathLike | str, + ): + """ + Save the PyroVelocity data to a pickle file. + + Args: + posterior_samples (Dict[str, ndarray]): Dictionary containing the posterior samples + pyrovelocity_data_path (os.PathLike | str): Path to save the PyroVelocity data. + """ + with open(pyrovelocity_data_path, "wb") as f: + pickle.dump(posterior_samples, f) + for k in posterior_samples: + logger.debug(k, "after", sys.getsizeof(posterior_samples[k])) + + def save_model( + self, + dir_path: str, + prefix: Optional[str] = None, + overwrite: bool = True, + save_anndata: bool = False, + **anndata_write_kwargs, + ) -> None: + """ + Save the Pyro-Velocity model to a directory. + + Dispatches to the `save` method of the inherited `BaseModelClass` which + calls `torch.save` on a model state dictionary, variable names, and user + attributes. + + Args: + dir_path (str): Path to the directory where the model will be saved. + prefix (Optional[str], optional): Prefix to add to the saved files. Defaults to None. + overwrite (bool, optional): Whether to overwrite existing files. Defaults to True. + save_anndata (bool, optional): Whether to save the AnnData object. Defaults to False. + """ + super().save( + dir_path, prefix, overwrite, save_anndata, **anndata_write_kwargs + ) + pyro.get_param_store().save( + os.path.join(dir_path, "param_store_test.pt") + ) + + @classmethod + def load_model( + cls, + dir_path: str, + adata: Optional[AnnData] = None, + use_gpu: str = "auto", + prefix: Optional[str] = None, + backup_url: Optional[str] = None, + ) -> BaseModelClass: + """ + Load the model from a directory with the same structure as that produced + by the save method. + + Args: + dir_path (str): Path to the directory where the model is saved. + adata (Optional[AnnData], optional): Anndata object to load into the model. Defaults to None. + use_gpu (str, optional): Whether and which GPU to use. Defaults to "auto". + prefix (Optional[str], optional): Prefix to add to the saved files. Defaults to None. + backup_url (Optional[str], optional): URL to download the model from. Defaults to None. + + Raises: + RuntimeError: If the model is not an instance of PyroBaseModuleClass. + + Returns: + PyroVelocity: The loaded PyroVelocity model. + """ + load_adata = adata is None + _accelerator, _devices, device = parse_device_args( + accelerator=use_gpu, return_device="torch" + ) + logger.info( + f"\nLoading model with:\n" + f"\taccelerator: {_accelerator}\n" + f"\tdevices: {_devices}\n" + f"\tdevice: {device}\n" + ) + + ( + attr_dict, + var_names, + model_state_dict, + new_adata, + ) = _load_saved_files( + dir_path, + load_adata, + map_location=device, + prefix=prefix, + backup_url=backup_url, + ) + + adata = new_adata if new_adata is not None else adata + + _validate_var_names(adata, var_names) + + registry = attr_dict.pop("registry_") + method_name = registry.get(_SETUP_METHOD_NAME, "setup_anndata") + getattr(cls, method_name)( + adata, source_registry=registry, **registry[_SETUP_ARGS_KEY] + ) + + model = _initialize_model(cls, adata, attr_dict) + + for attr, val in attr_dict.items(): + setattr(model, attr, val) + + pyro.clear_param_store() + old_history = model.history_ + try: + model.module.load_state_dict(model_state_dict) + except RuntimeError as err: + if not isinstance(model.module, PyroBaseModuleClass): + raise err + logger.info( + "Preparing underlying `PyroBaseModuleClass` module for load" + ) + try: + model.train(max_epochs=1, max_steps=1) + except Exception: + model.train( + max_epochs=1, + max_steps=1, + batch_size=adata.shape[0], + train_size=0.8, + valid_size=0.2, + ) + model.module.load_state_dict(model_state_dict) + + model.history_ = old_history + model.to_device(device) + model.module.eval() + model._validate_anndata(adata) + pyro.get_param_store().load( + os.path.join(dir_path, "param_store_test.pt"), + map_location=device, + ) + return model diff --git a/src/pyrovelocity/models/knn_model/_velocity_model.py b/src/pyrovelocity/models/knn_model/_velocity_model.py new file mode 100644 index 000000000..aa7786621 --- /dev/null +++ b/src/pyrovelocity/models/knn_model/_velocity_model.py @@ -0,0 +1,342 @@ +from typing import Optional, Tuple, Union + +import pyro +import torch +from beartype import beartype +from jaxtyping import Float, jaxtyped +from pyro import poutine +from pyro.distributions import Bernoulli, LogNormal, Normal, Poisson +from pyro.nn import PyroModule, PyroSample +from pyro.primitives import plate +from scvi.nn import Decoder +from torch.nn.functional import relu, softplus +from torch import Tensor + +from pyrovelocity.logging import configure_logging +from pyrovelocity.models._transcription_dynamics import mrna_dynamics, atac_mrna_dynamics, get_initial_states, get_cell_parameters + +logger = configure_logging(__name__) + +RNAInputType = Union[ + Float[torch.Tensor, ""], + Float[torch.Tensor, "num_genes"], + Float[torch.Tensor, "samples num_genes"], +] + +RNAOutputType = Union[ + Float[torch.Tensor, "num_cells num_genes"], + Float[torch.Tensor, "samples num_cells num_genes"], +] + +__all__ = [ + "LogNormalModel", + "VelocityModelAuto", + "MultiVelocityModelAuto", +] + +class VelocityModelAuto(LogNormalModel): + """Automatically configured velocity model. + + Args: + num_cells (int): _description_ + num_genes (int): _description_ + likelihood (str, optional): _description_. Defaults to "Poisson". + shared_time (bool, optional): _description_. Defaults to True. + t_scale_on (bool, optional): _description_. Defaults to False. + plate_size (int, optional): _description_. Defaults to 2. + latent_factor (str, optional): _description_. Defaults to "none". + latent_factor_size (int, optional): _description_. Defaults to 30. + latent_factor_operation (str, optional): _description_. Defaults to "selection". + include_prior (bool, optional): _description_. Defaults to False. + num_aux_cells (int, optional): _description_. Defaults to 100. + only_cell_times (bool, optional): _description_. Defaults to False. + decoder_on (bool, optional): _description_. Defaults to False. + add_offset (bool, optional): _description_. Defaults to False. + correct_library_size (Union[bool, str], optional): _description_. Defaults to True. + guide_type (str, optional): _description_. Defaults to "velocity". + cell_specific_kinetics (Optional[str], optional): _description_. Defaults to None. + kinetics_num (Optional[int], optional): _description_. Defaults to None. + + Examples: + >>> import torch + >>> from pyrovelocity.models._velocity_model import VelocityModelAuto + >>> model = VelocityModelAuto( + ... 3, + ... 4, + ... "Poisson", + ... True, + ... False, + ... 2, + ... "none", + ... latent_factor_operation="selection", + ... latent_factor_size=10, + ... include_prior=False, + ... num_aux_cells=0, + ... only_cell_times=True, + ... decoder_on=False, + ... add_offset=False, + ... correct_library_size=True, + ... guide_type="auto_t0_constraint", + ... cell_specific_kinetics=None, + ... **{} + ... ) + >>> logger.info(model) + """ + + @beartype + def __init__( + self, + num_cells: int, + num_genes: int, + likelihood: str = "Poisson", + shared_time: bool = True, + t_scale_on: bool = False, + plate_size: int = 2, + latent_factor: str = "none", + latent_factor_size: int = 30, + latent_factor_operation: str = "selection", + include_prior: bool = False, + num_aux_cells: int = 100, + only_cell_times: bool = False, + decoder_on: bool = False, + add_offset: bool = False, + correct_library_size: Union[bool, str] = True, + guide_type: str = "velocity", + cell_specific_kinetics: Optional[str] = None, + kinetics_num: Optional[int] = None, + **initial_values, + ) -> None: + assert num_cells > 0 and num_genes > 0 + super().__init__(num_cells, num_genes, likelihood, plate_size) + self.num_aux_cells = num_aux_cells + self.only_cell_times = only_cell_times + self.guide_type = guide_type + self.cell_specific_kinetics = cell_specific_kinetics + self.k = kinetics_num + + self.mask = initial_values.get( + "mask", torch.ones(self.num_cells, self.num_genes).bool() + ) + for key in initial_values: + self.register_buffer(f"{key}_init", initial_values[key]) + + self.shared_time = shared_time + self.t_scale_on = t_scale_on + self.add_offset = add_offset + self.plate_size = plate_size + + self.latent_factor = latent_factor + self.latent_factor_size = latent_factor_size + self.latent_factor_operation = latent_factor_operation + self.include_prior = include_prior + self.decoder_on = decoder_on + self.correct_library_size = correct_library_size + if self.decoder_on: + self.decoder = Decoder(1, self.num_genes, n_layers=2) + + self.enumeration = "parallel" + # self.set_enumeration_strategy() + + @beartype + def __repr__(self) -> str: + return ( + f"\nKnnModel(\n" + f"\tnum_cells={self.num_cells}, \n" + f"\tnum_genes={self.num_genes}, \n" + f")\n" + ) + + @beartype + def forward(self, + u_obs: torch.Tensor, + s_obs: torch.Tensor, + ind_x: torch.Tensor, + batch_index: torch.Tensor) + """ + Defines the forward model, which computes the unspliced (u) and spliced + (s) RNA expression levels given the observations and model parameters. + + Args: + u_obs (Optional[torch.Tensor], optional): Observed unspliced RNA expression. Default is None. + s_obs (Optional[torch.Tensor], optional): Observed spliced RNA expression. Default is None. + ind_x (Optional[torch.Tensor], optional): Indices for the cells. + batch_index (Optional[torch.Tensor], optional): Experimental batch index of cells. + + Returns: + + Examples: + >>> . + """ + + batch_size = len(ind_x) + obs2sample = one_hot(batch_index, self.n_batch) + obs_plate = self.create_plates(u_obs, s_obs, ind_x, batch_index) + + # ===================== Kinetic Rates ======================= # + # Splicing rate: + splicing_alpha = pyro.sample('splicing_alpha', + dist.Gamma(self.splicing_rate_alpha_hyp_prior_alpha, + self.splicing_rate_alpha_hyp_prior_alpha/self.splicing_rate_alpha_hyp_prior_mean)) + splicing_mean = pyro.sample('splicing_mean', + dist.Gamma(self.splicing_rate_mean_hyp_prior_alpha, + self.splicing_rate_mean_hyp_prior_alpha/self.splicing_rate_mean_hyp_prior_mean)) + beta_g = pyro.sample('beta_g', dist.Gamma(splicing_alpha, splicing_alpha/splicing_mean).expand([1,self.n_vars]).to_event(2)) + # Degredation rate: + degredation_alpha = pyro.sample('degredation_alpha', + dist.Gamma(self.degredation_rate_alpha_hyp_prior_alpha, + self.degredation_rate_alpha_hyp_prior_alpha/self.degredation_rate_alpha_hyp_prior_mean)) + degredation_alpha = degredation_alpha + 0.001 + degredation_mean = pyro.sample('degredation_mean', + dist.Gamma(self.degredation_rate_mean_hyp_prior_alpha, + self.degredation_rate_mean_hyp_prior_alpha/self.degredation_rate_mean_hyp_prior_mean)) + gamma_g = pyro.sample('gamma_g', dist.Gamma(degredation_alpha, degredation_alpha/degredation_mean).expand([1,self.n_vars]).to_event(2)) + # Transcription rate contribution of each module: + factor_level_g = pyro.sample( + "factor_level_g", + dist.Gamma(self.factor_prior_alpha, self.factor_prior_beta) + .expand([1, self.n_vars]) + .to_event(2) + ) + g_fg = pyro.sample( # (g_fg corresponds to module's spliced counts in steady state) + "g_fg", + dist.Gamma( + self.factor_states_per_gene / self.n_factors_torch, + self.ones / factor_level_g, + ) + .expand([self.n_modules, self.n_vars]) + .to_event(2) + ) + A_mgON = pyro.deterministic('A_mgON', g_fg*gamma_g) # (transform from spliced counts to transcription rate) + A_mgOFF = self.alpha_OFFg + # Activation and Deactivation rate: + lam_mu = pyro.sample('lam_mu', dist.Gamma(G_a(self.activation_rate_mean_hyp_prior_mean, self.activation_rate_mean_hyp_prior_sd), + G_b(self.activation_rate_mean_hyp_prior_mean, self.activation_rate_mean_hyp_prior_sd))) + lam_sd = pyro.sample('lam_sd', dist.Gamma(G_a(self.activation_rate_sd_hyp_prior_mean, self.activation_rate_sd_hyp_prior_sd), + G_b(self.activation_rate_sd_hyp_prior_mean, self.activation_rate_sd_hyp_prior_sd))) + lam_m_mu = pyro.sample('lam_m_mu', dist.Gamma(G_a(lam_mu, lam_sd), + G_b(lam_mu, lam_sd)).expand([self.n_modules, 1, 1]).to_event(3)) + lam_mi = pyro.sample('lam_mi', dist.Gamma(G_a(lam_m_mu, lam_m_mu*0.05), + G_b(lam_m_mu, lam_m_mu*0.05)).expand([self.n_modules, 1, 2]).to_event(3)) + + # =====================Time======================= # + # Global time for each cell: + T_max = pyro.sample('Tmax', dist.Gamma(G_a(self.Tmax_mean, self.Tmax_sd), G_b(self.Tmax_mean, self.Tmax_sd))) + t_c_loc = pyro.sample('t_c_loc', dist.Gamma(self.one, self.one/0.5)) + t_c_scale = pyro.sample('t_c_scale', dist.Gamma(self.one, self.one/0.25)) + with obs_plate: + t_c = pyro.sample('t_c', dist.Normal(t_c_loc, t_c_scale).expand([batch_size, 1, 1])) + T_c = pyro.deterministic('T_c', t_c*T_max) + # Global switch on time for each gene: +# t_mON = pyro.sample('t_mON', dist.Uniform(self.zero, self.one).expand([1, 1, self.n_modules]).to_event(2)) + t_delta = pyro.sample('t_delta', dist.Gamma(self.one*20, self.one * 20 *self.n_modules_torch). + expand([self.n_modules]).to_event(1)) + t_mON = torch.cumsum(torch.concat([self.zero.unsqueeze(0), t_delta[:-1]]), dim = 0).unsqueeze(0).unsqueeze(0) + T_mON = pyro.deterministic('T_mON', T_max*t_mON) + # Global switch off time for each gene: + t_mOFF = pyro.sample('t_mOFF', dist.Exponential(self.n_modules_torch).expand([1, 1, self.n_modules]).to_event(2)) + T_mOFF = pyro.deterministic('T_mOFF', T_mON + T_max*t_mOFF) + + # =========== Mean expression according to RNAvelocity model ======================= # + mu_total = torch.stack([self.zeros[idx,...], self.zeros[idx,...]], axis = -1) + for m in range(self.n_modules): + mu_total += mu_mRNA_continousAlpha_globalTime_twoStates( + A_mgON[m,:], A_mgOFF, beta_g, gamma_g, lam_mi[m,...], T_c[:,:,0], T_mON[:,:,m], T_mOFF[:,:,m], self.zeros[ind_x,...]) + with obs_plate: + mu_expression = pyro.deterministic('mu_expression', mu_total) + + # =============Detection efficiency of spliced and unspliced counts =============== # + # Cell specific relative detection efficiency with hierarchical prior across batches: + detection_mean_y_e = pyro.sample( + "detection_mean_y_e", + dist.Beta( + self.ones * self.detection_mean_hyp_prior_alpha, + self.ones * self.detection_mean_hyp_prior_beta, + ) + .expand([self.n_batch, 1]) + .to_event(2), + ) + detection_hyp_prior_alpha = pyro.deterministic( + "detection_hyp_prior_alpha", + self.detection_hyp_prior_alpha, + ) + + beta = detection_hyp_prior_alpha / (obs2sample @ detection_mean_y_e) + with obs_plate: + detection_y_c = pyro.sample( + "detection_y_c", + dist.Gamma(detection_hyp_prior_alpha.unsqueeze(dim=-1), beta.unsqueeze(dim=-1)), + ) # (self.n_obs, 1) + + # Global relative detection efficiency between spliced and unspliced counts + detection_y_i = pyro.sample( + "detection_y_i", + dist.Gamma( + self.ones * self.detection_i_prior_alpha, + self.ones * self.detection_i_prior_alpha, + ) + .expand([1, 1, 2]).to_event(3) + ) + + # Gene specific relative detection efficiency between spliced and unspliced counts + detection_y_gi = pyro.sample( + "detection_y_gi", + dist.Gamma( + self.ones * self.detection_gi_prior_alpha, + self.ones * self.detection_gi_prior_alpha, + ) + .expand([1, self.n_vars, 2]) + .to_event(3), + ) + + # =======Gene-specific additive component (Ambient RNA/ "Soup") for spliced and unspliced counts ====== # + # Independently sampled for spliced and unspliced counts: + s_g_gene_add_alpha_hyp = pyro.sample( + "s_g_gene_add_alpha_hyp", + dist.Gamma(self.gene_add_alpha_hyp_prior_alpha, self.gene_add_alpha_hyp_prior_beta).expand([2]).to_event(1), + ) + s_g_gene_add_mean = pyro.sample( + "s_g_gene_add_mean", + dist.Gamma( + self.gene_add_mean_hyp_prior_alpha, + self.gene_add_mean_hyp_prior_beta, + ) + .expand([self.n_batch, 1, 2]) + .to_event(3), + ) + s_g_gene_add_alpha_e_inv = pyro.sample( + "s_g_gene_add_alpha_e_inv", + dist.Exponential(s_g_gene_add_alpha_hyp).expand([self.n_batch, 1, 2]).to_event(3), + ) + s_g_gene_add_alpha_e = self.ones / s_g_gene_add_alpha_e_inv.pow(2) + s_g_gene_add = pyro.sample( + "s_g_gene_add", + dist.Gamma(s_g_gene_add_alpha_e, s_g_gene_add_alpha_e / s_g_gene_add_mean) + .expand([self.n_batch, self.n_vars, 2]) + .to_event(3), + ) + + # =========Gene-specific overdispersion of spliced and unspliced counts ============== # + # Overdispersion of unspliced counts: + stochastic_v_ag_hyp = pyro.sample( + "stochastic_v_ag_hyp", + dist.Gamma( + self.stochastic_v_ag_hyp_prior_alpha, + self.stochastic_v_ag_hyp_prior_beta, + ).expand([1, 2]).to_event(2)) + stochastic_v_ag_hyp = stochastic_v_ag_hyp + 0.001 + stochastic_v_ag_inv = pyro.sample( + "stochastic_v_ag_inv", + dist.Exponential(stochastic_v_ag_hyp) + .expand([1, self.n_vars, 2]).to_event(3), + ) + stochastic_v_ag = (self.ones / stochastic_v_ag_inv.pow(2)) + + # =====================Expected expression ======================= # + with obs_plate: + mu = pyro.deterministic('mu', (mu_expression + torch.einsum('cbi,bgi->cgi', obs2sample.unsqueeze(dim=-1), s_g_gene_add)) * \ + detection_y_c * detection_y_i * detection_y_gi) + + # =====================DATA likelihood ======================= # + with obs_plate: + pyro.sample("data_target", dist.GammaPoisson(concentration= stochastic_v_ag, + rate= stochastic_v_ag / mu), obs=torch.stack([u_obs, s_obs], axis = 2)) \ No newline at end of file diff --git a/src/pyrovelocity/models/knn_model/_velocity_module.py b/src/pyrovelocity/models/knn_model/_velocity_module.py new file mode 100644 index 000000000..cf8f1522a --- /dev/null +++ b/src/pyrovelocity/models/knn_model/_velocity_module.py @@ -0,0 +1,236 @@ +from typing import Any +from typing import Dict +from typing import Optional +from typing import Tuple +from typing import Union + +import torch +from pyro import poutine +from pyro.infer.autoguide import AutoLowRankMultivariateNormal +from pyro.infer.autoguide import AutoNormal +from pyro.infer.autoguide.guides import AutoGuideList +from scvi.module.base import PyroBaseModuleClass + +from pyrovelocity.logging import configure_logging +from pyrovelocity.models._velocity_model import VelocityModelAuto, MultiVelocityModelAuto + + +logger = configure_logging(__name__) + + +class VelocityModule(PyroBaseModuleClass): + """ + VelocityModule is an scvi-tools pyro module that combines the VelocityModelAuto and pyro AutoGuideList classes. + + Args: + num_cells (int): Number of cells. + num_genes (int): Number of genes. + model_type (str, optional): Model type. Default is "auto". + guide_type (str, optional): Guide type. Default is "velocity_auto". + likelihood (str, optional): Likelihood type. Default is "Poisson". + shared_time (bool, optional): If True, a shared time parameter will be used. Default is True. + t_scale_on (bool, optional): If True, scale time parameter. Default is False. + plate_size (int, optional): Size of the plate set. Default is 2. + latent_factor (str, optional): Latent factor. Default is "none". + latent_factor_operation (str, optional): Latent factor operation mode. Default is "selection". + latent_factor_size (int, optional): Size of the latent factor. Default is 10. + inducing_point_size (int, optional): Inducing point size. Default is 0. + include_prior (bool, optional): If True, include prior in the model. Default is False. + use_gpu (str, optional): Accelerator type. Default is "auto". + num_aux_cells (int, optional): Number of auxiliary cells. Default is 0. + only_cell_times (bool, optional): If True, only model cell times. Default is True. + decoder_on (bool, optional): If True, use the decoder. Default is False. + add_offset (bool, optional): If True, add offset to the model. Default is True. + correct_library_size (Union[bool, str], optional): Library size correction method. Default is True. + cell_specific_kinetics (Optional[str], optional): Cell-specific kinetics method. Default is None. + kinetics_num (Optional[int], optional): Number of kinetics. Default is None. + **initial_values: Initial values for the model parameters. + + Examples: + >>> from scvi.module.base import PyroBaseModuleClass + >>> from pyrovelocity.models._velocity_module import VelocityModule + >>> num_cells = 10 + >>> num_genes = 20 + >>> velocity_module1 = VelocityModule( + ... num_cells, num_genes, model_type="auto", + ... guide_type="auto_t0_constraint", add_offset=False + ... ) + >>> type(velocity_module1.model) + + >>> type(velocity_module1.guide) + + >>> velocity_module2 = VelocityModule( + ... num_cells, num_genes, model_type="auto", + ... guide_type="auto", add_offset=True + ... ) + >>> type(velocity_module2.model) + + >>> type(velocity_module2.guide) + + """ + + def __init__( + self, + num_cells: int, + num_genes: int, + model_type: str = "auto", + guide_type: str = "velocity_auto", + likelihood: str = "Poisson", + shared_time: bool = True, + t_scale_on: bool = False, + plate_size: int = 2, + latent_factor: str = "none", + latent_factor_operation: str = "selection", + latent_factor_size: int = 10, + inducing_point_size: int = 0, + include_prior: bool = False, + use_gpu: str = "auto", + num_aux_cells: int = 0, + only_cell_times: bool = True, + decoder_on: bool = False, + add_offset: bool = True, + correct_library_size: Union[bool, str] = True, + cell_specific_kinetics: Optional[str] = None, + kinetics_num: Optional[int] = None, + **initial_values, + ) -> None: + super().__init__() + self.num_cells = num_cells + self.num_genes = num_genes + self.model_type = model_type + self.guide_type = guide_type + self._model = None + self.plate_size = plate_size + self.num_aux_cells = num_aux_cells + self.only_cell_times = only_cell_times + logger.info( + f"Model type: {self.model_type}, Guide type: {self.guide_type}" + ) + + self.cell_specific_kinetics = cell_specific_kinetics + + self._model = VelocityModelAuto( + self.num_cells, + self.num_genes, + likelihood, + shared_time, + t_scale_on, + self.plate_size, + latent_factor, + latent_factor_operation=latent_factor_operation, + latent_factor_size=latent_factor_size, + include_prior=include_prior, + num_aux_cells=num_aux_cells, + only_cell_times=self.only_cell_times, + decoder_on=decoder_on, + add_offset=add_offset, + correct_library_size=correct_library_size, + guide_type=self.guide_type, + cell_specific_kinetics=self.cell_specific_kinetics, + **initial_values, + ) + + guide = AutoGuideList( + self._model, create_plates=self._model.create_plates + ) + guide.append( + AutoNormal( + poutine.block( + self._model, + expose=[ + "cell_time", + "u_read_depth", + "s_read_depth", + "kinetics_prob", + "kinetics_weights", + ], + ), + init_scale=0.1, + ) + ) + + if add_offset: + guide.append( + AutoLowRankMultivariateNormal( + poutine.block( + self._model, + expose=[ + "dt_switching", + ], + ), + rank=10, + init_scale=0.1, + ) + ) + else: + guide.append( + AutoLowRankMultivariateNormal( + poutine.block( + self._model, + expose=[ + "alpha", + "beta", + "gamma", + "dt_switching", + "t0", + "u_scale", + "s_scale", + ], + ), + rank=10, + init_scale=0.1, + ) + ) + self._guide = guide + + @property + def model(self) -> VelocityModelAuto: + return self._model + + @property + def guide(self) -> AutoGuideList: + return self._guide + + @staticmethod + def _get_fn_args_from_batch( + tensor_dict: Dict[str, torch.Tensor] + ) -> Tuple[ + Tuple[ + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + ], + Dict[Any, Any], + ]: + u_obs = tensor_dict["U"] + s_obs = tensor_dict["X"] + u_log_library = tensor_dict["u_lib_size"] + s_log_library = tensor_dict["s_lib_size"] + u_log_library_mean = tensor_dict["u_lib_size_mean"] + s_log_library_mean = tensor_dict["s_lib_size_mean"] + u_log_library_scale = tensor_dict["u_lib_size_scale"] + s_log_library_scale = tensor_dict["s_lib_size_scale"] + ind_x = tensor_dict["ind_x"].long().squeeze() + cell_state = tensor_dict.get("pyro_cell_state") + time_info = tensor_dict.get("time_info") + return ( + u_obs, + s_obs, + u_log_library, + s_log_library, + u_log_library_mean, + s_log_library_mean, + u_log_library_scale, + s_log_library_scale, + ind_x, + cell_state, + time_info, + ), {} \ No newline at end of file From 75f819602a7e8fc317377a9b0ce3531568fd37ad Mon Sep 17 00:00:00 2001 From: Alexander Aivazidis Date: Thu, 29 Aug 2024 20:31:42 +0000 Subject: [PATCH 30/47] feat(knn_model): Basic vector field for knn_model. Signed-off-by: Alexander Aivazidis --- .../models/knn_model/_vector_fields.py | 37 +++++++++++++++++++ 1 file changed, 37 insertions(+) create mode 100644 src/pyrovelocity/models/knn_model/_vector_fields.py diff --git a/src/pyrovelocity/models/knn_model/_vector_fields.py b/src/pyrovelocity/models/knn_model/_vector_fields.py new file mode 100644 index 000000000..03969bd2e --- /dev/null +++ b/src/pyrovelocity/models/knn_model/_vector_fields.py @@ -0,0 +1,37 @@ +from beartype import beartype +from torch import Tensor +from typing import Tuple +from typing import List + +@beartype +def vector_field_1( t: float, + y: Tuple, + args: List): + """ + Vector field of mRNA dynamics of unspliced and spliced counts, based on a regulatory + function that that takes u,s as input and returns transcription (alpha), splicing (beta) + and degradation rates (gamma). + + Args: + t (Float): Integration time. Only used when vector field used in Diffrax library + and otherwise can be an arbitrary value. + y (Tuple): State of the system. Tuple of unspliced (u) and spliced counts (s). + args (List): List containing a regulatory function that takes u,s as input + and returns transcription (alpha), splicing (beta) and degradation + rates (gamma). + + Returns: + Tuple: Rates of change in y (= unspliced and spliced counts) + + Examples: + >>> + """ + + u, s = y + regulatory_function = args[0] + alpha, beta, gamma = regulatory_function(u,s) + du = alpha - beta*u + ds = beta*u - gamma*s + dy = du, ds + + return dy \ No newline at end of file From 3c840cb885b543ce7ab7bd303c134a6612c36acc Mon Sep 17 00:00:00 2001 From: Alexander Aivazidis Date: Thu, 29 Aug 2024 20:32:19 +0000 Subject: [PATCH 31/47] feat(knn_model): Basic regulatory function, using 2 layer neural net, for knn model. Signed-off-by: Alexander Aivazidis --- .../knn_model/regulatory_functions_torch.py | 45 +++++++++++++++++++ 1 file changed, 45 insertions(+) create mode 100644 src/pyrovelocity/models/knn_model/regulatory_functions_torch.py diff --git a/src/pyrovelocity/models/knn_model/regulatory_functions_torch.py b/src/pyrovelocity/models/knn_model/regulatory_functions_torch.py new file mode 100644 index 000000000..e99dadd1c --- /dev/null +++ b/src/pyrovelocity/models/knn_model/regulatory_functions_torch.py @@ -0,0 +1,45 @@ +import torch +import torch.nn.functional as F +from beartype import beartype +from torch import Tensor + +@beartype +def regulatory_function_1(u: Tensor, + s: Tensor, + h1: int = 100, + h2: int = 100): + """ + Regulatory function that contains a neural net with two hidden layers that takes unspliced and spliced + counts as input and returns transcription (alpha), splicing (beta) and degradation rates (gamma). + + Args: + u (Tensor): Unspliced counts + s (Tensor): Spliced counts + h1 (int): Nodes in hidden layer 1 + h2 (int): Nodes in hidden layer 2 + + Returns: + Tuple: transcription (alpha), splicing (beta) and degradation rate (gamma). + + Examples: + >>> + """ + + input = torch.tensor(np.array([np.array(u), np.array(s)]).T) + + l1 = torch.nn.Linear(2, h1) + l2 = torch.nn.Linear(h1, h2) + l3 = torch.nn.Linear(h2, 3) + + x = l1(input) + x = F.leaky_relu(x) + x = l2(x) + x = F.leaky_relu(x) + x = l3(x) + + output = torch.sigmoid(x) + beta = output[:,0] + gamma = output[:,1] + alphas = output[:,2] + + return alphas, beta, gamma \ No newline at end of file From fbca112bfa41ed2dacf7bd4ce48ce343ab519019 Mon Sep 17 00:00:00 2001 From: Alexander Aivazidis Date: Thu, 29 Aug 2024 20:32:39 +0000 Subject: [PATCH 32/47] feat(knn_model): Initial commit for pyro model for knn model. Signed-off-by: Alexander Aivazidis --- src/pyrovelocity/models/knn_model/_velocity_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/pyrovelocity/models/knn_model/_velocity_model.py b/src/pyrovelocity/models/knn_model/_velocity_model.py index aa7786621..c46253983 100644 --- a/src/pyrovelocity/models/knn_model/_velocity_model.py +++ b/src/pyrovelocity/models/knn_model/_velocity_model.py @@ -135,7 +135,7 @@ def __init__( self.decoder = Decoder(1, self.num_genes, n_layers=2) self.enumeration = "parallel" - # self.set_enumeration_strategy() + # self.set_enumeration_strategy() @beartype def __repr__(self) -> str: From 33aae38d2d0a6bceaff026f269f5fbd29fe6c480 Mon Sep 17 00:00:00 2001 From: Alexander Aivazidis Date: Tue, 3 Sep 2024 17:44:11 +0000 Subject: [PATCH 33/47] feat(train.py): resolved merge conflict by keeping change proposed in beta branch. Signed-off-by: Alexander Aivazidis --- src/pyrovelocity/tasks/train.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/src/pyrovelocity/tasks/train.py b/src/pyrovelocity/tasks/train.py index 3a1a45e48..3d3b834b4 100644 --- a/src/pyrovelocity/tasks/train.py +++ b/src/pyrovelocity/tasks/train.py @@ -355,14 +355,8 @@ def train_model( >>> copy_raw_counts(adata) >>> _, model, posterior_samples = train_model(adata, use_gpu="auto", seed=99, max_epochs=200, loss_plot_path=loss_plot_path) """ -<<<<<<< HEAD + if isinstance(adata, str | Path): -||||||| parent of 9d2f67d (added atac_layer argument to train_model tasks and made tests for it) - if isinstance(adata, str): -======= - - if isinstance(adata, str): ->>>>>>> 9d2f67d (added atac_layer argument to train_model tasks and made tests for it) adata = load_anndata_from_path(adata) logger.info(f"AnnData object prior to model training") From ae19b056ef191eef7fa93259a3dd7756b42905a8 Mon Sep 17 00:00:00 2001 From: Alexander Aivazidis Date: Tue, 3 Sep 2024 23:03:19 +0000 Subject: [PATCH 34/47] fix(preprocess): Added spliced/unspliced count aggregation to metacells. Signed-off-by: Alexander Aivazidis --- src/pyrovelocity/tasks/preprocess.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/pyrovelocity/tasks/preprocess.py b/src/pyrovelocity/tasks/preprocess.py index efd41a2b9..6d3123dfc 100644 --- a/src/pyrovelocity/tasks/preprocess.py +++ b/src/pyrovelocity/tasks/preprocess.py @@ -721,6 +721,11 @@ def merge_RNA( adata_meta.obs[celltype_key] = [adata_rna[adata_rna.obs[cluster_key] == c,:].obs[celltype_key].mode()[0] for c in np.unique(adata_rna.obs[cluster_key])] adata_meta.obs['RNA counts'] = np.sum(adata_meta.X, axis = 1) + if 'unspliced' in adata_rna.layers: + adata_meta.layers['unspliced'] = np.concatenate([np.sum(adata_rna.layers['unspliced'][adata_rna.obs[cluster_key] == c,:], axis = 0) for c in np.unique(adata_rna.obs[cluster_key])], axis = 0) + if 'spliced' in adata_rna.layers: + adata_meta.layers['spliced'] = np.concatenate([np.sum(adata_rna.layers['spliced'][adata_rna.obs[cluster_key] == c,:], axis = 0) for c in np.unique(adata_rna.obs[cluster_key])], axis = 0) + if verbose: print('Mean RNA counts per cell before: ', np.mean(adata_rna.obs['RNA counts'])) print('Mean RNA counts per cell after: ', np.mean(adata_meta.obs['RNA counts'])) From 567f700ecf1eca5bdd5d1176b4c7b03829261bf8 Mon Sep 17 00:00:00 2001 From: Alexander Aivazidis Date: Tue, 3 Sep 2024 23:04:09 +0000 Subject: [PATCH 35/47] fix(regulatory_function_1): needs numpy import Signed-off-by: Alexander Aivazidis --- src/pyrovelocity/models/knn_model/regulatory_functions_torch.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/pyrovelocity/models/knn_model/regulatory_functions_torch.py b/src/pyrovelocity/models/knn_model/regulatory_functions_torch.py index e99dadd1c..ad41d705b 100644 --- a/src/pyrovelocity/models/knn_model/regulatory_functions_torch.py +++ b/src/pyrovelocity/models/knn_model/regulatory_functions_torch.py @@ -2,6 +2,7 @@ import torch.nn.functional as F from beartype import beartype from torch import Tensor +import numpy as np @beartype def regulatory_function_1(u: Tensor, From 8586e15b1417e8ce3f57a6882d1dbae0d08930b3 Mon Sep 17 00:00:00 2001 From: Alexander Aivazidis Date: Tue, 3 Sep 2024 23:16:40 +0000 Subject: [PATCH 36/47] fix(preprocess): need to return sparse matrix in layers after metacell construction Signed-off-by: Alexander Aivazidis --- src/pyrovelocity/tasks/preprocess.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/pyrovelocity/tasks/preprocess.py b/src/pyrovelocity/tasks/preprocess.py index 6d3123dfc..1720e9e72 100644 --- a/src/pyrovelocity/tasks/preprocess.py +++ b/src/pyrovelocity/tasks/preprocess.py @@ -723,8 +723,10 @@ def merge_RNA( if 'unspliced' in adata_rna.layers: adata_meta.layers['unspliced'] = np.concatenate([np.sum(adata_rna.layers['unspliced'][adata_rna.obs[cluster_key] == c,:], axis = 0) for c in np.unique(adata_rna.obs[cluster_key])], axis = 0) + adata_meta.layers['unspliced'] = csr_matrix(adata_meta.layers['unspliced'], dtype=np.uint16) if 'spliced' in adata_rna.layers: adata_meta.layers['spliced'] = np.concatenate([np.sum(adata_rna.layers['spliced'][adata_rna.obs[cluster_key] == c,:], axis = 0) for c in np.unique(adata_rna.obs[cluster_key])], axis = 0) + adata_meta.layers['spliced'] = csr_matrix(adata_meta.layers['spliced'], dtype=np.uint16) if verbose: print('Mean RNA counts per cell before: ', np.mean(adata_rna.obs['RNA counts'])) From 56284633f7a8a3b4b3c1c34198fa1eab4382604c Mon Sep 17 00:00:00 2001 From: Alexander Aivazidis Date: Thu, 5 Sep 2024 21:03:40 +0000 Subject: [PATCH 37/47] fix(knn_model._velocity): Fixed various bugs in imports. Signed-off-by: Alexander Aivazidis --- .../models/knn_model/_velocity.py | 55 +++++++++++++------ 1 file changed, 38 insertions(+), 17 deletions(-) diff --git a/src/pyrovelocity/models/knn_model/_velocity.py b/src/pyrovelocity/models/knn_model/_velocity.py index 918f4568a..dbb04e87d 100644 --- a/src/pyrovelocity/models/knn_model/_velocity.py +++ b/src/pyrovelocity/models/knn_model/_velocity.py @@ -12,9 +12,10 @@ from numpy import ndarray from scvi.data import AnnDataManager from scvi.data._constants import _SETUP_ARGS_KEY, _SETUP_METHOD_NAME -from scvi.data.fields import LayerField, NumericalObsField +from scvi.data.fields import LayerField, NumericalObsField, CategoricalObsField, ObsmField from scvi.model._utils import parse_device_args from scvi.model.base import BaseModelClass +from scvi import REGISTRY_KEYS from scvi.model.base._utils import ( _initialize_model, _load_saved_files, @@ -28,14 +29,14 @@ vector_field_uncertainty, ) from pyrovelocity.logging import configure_logging -from pyrovelocity.models._trainer import VelocityTrainingMixin +from scvi.model.base import PyroSviTrainMixin from pyrovelocity.models.knn_model._velocity_module import VelocityModule, MultiVelocityModule __all__ = ["PyroVelocity"] logger = configure_logging(__name__) -class PyroVelocity(VelocityTrainingMixin, BaseModelClass): +class PyroVelocity(PyroSviTrainMixin, BaseModelClass): """ PyroVelocity is a class for constructing and training a Pyro model for probabilistic RNA velocity estimation. This model leverages the @@ -300,20 +301,40 @@ def __init__( self.init_params_ = self._get_init_params(locals()) logger.info("Model initialized") - def train(self, **kwargs): + def train( + self, + max_epochs: int = 500, + batch_size: int = 1000, + train_size: float = 1, + lr: float = 0.01, + **kwargs, + ): """ - Trains the PyroVelocity model using the provided data and configuration. - - The method leverages the Pyro library to train the model using the underlying - data. It relies on the `VelocityTrainingMixin` to define the training logic. - - Args: - - **kwargs : dict, optional - Additional keyword arguments to be passed to the underlying train method - provided by the `VelocityTrainingMixin`. + Training function for the model. + + Parameters + ---------- + max_epochs + Number of passes through the dataset. If `None`, defaults to + ``np.min([round((20000 / n_cells) * 400), 400])`` + train_size + Size of training set in the range [0.0, 1.0]. + batch_size + Minibatch size to use during training. If `None`, no minibatching occurs and all + data is copied to device (e.g., GPU). + lr + Optimiser learning rate (default optimiser is :class:`~pyro.optim.ClippedAdam`). + Specifying optimiser via plan_kwargs overrides this choice of lr. + kwargs + Other arguments to :py:meth:`scvi.model.base.PyroSviTrainMixin().train` method """ - pyro.enable_validation(True) + + self.max_epochs = max_epochs + kwargs["max_epochs"] = max_epochs + kwargs["batch_size"] = batch_size + kwargs["train_size"] = train_size + kwargs["lr"] = lr + super().train(**kwargs) def enum_parallel_predict(self): @@ -336,7 +357,7 @@ def setup_anndata(cls, adata: AnnData, adata_atac = None, batch_key = None, *arg anndata_fields = [ LayerField("U", "raw_unspliced", is_count_data=True), LayerField("X", "raw_spliced", is_count_data=True), - CategoricalObsField(REGISTRY_KEYS.BATCH_KEY, batch_key) + CategoricalObsField(REGISTRY_KEYS.BATCH_KEY, batch_key), NumericalObsField("ind_x", "ind_x"), ] @@ -348,7 +369,7 @@ def setup_anndata(cls, adata: AnnData, adata_atac = None, batch_key = None, *arg adata.uns['atac'] = None if 'N_cn' in adata.obsm: - anndata_fields += [LayerField('N_cn', 'N_cn')] + anndata_fields += [ObsmField('N_cn', 'N_cn')] adata_manager = AnnDataManager( fields=anndata_fields, setup_method_args=setup_method_args From f34d7c561330c1731e7d27962e236a0f4ce90d29 Mon Sep 17 00:00:00 2001 From: Alexander Aivazidis Date: Thu, 5 Sep 2024 21:04:17 +0000 Subject: [PATCH 38/47] feat(knn_model._velocity_model): Completed first draft of the model. Signed-off-by: Alexander Aivazidis --- .../models/knn_model/_velocity_model.py | 236 +++++++++++------- 1 file changed, 145 insertions(+), 91 deletions(-) diff --git a/src/pyrovelocity/models/knn_model/_velocity_model.py b/src/pyrovelocity/models/knn_model/_velocity_model.py index c46253983..3834f6a26 100644 --- a/src/pyrovelocity/models/knn_model/_velocity_model.py +++ b/src/pyrovelocity/models/knn_model/_velocity_model.py @@ -1,6 +1,7 @@ from typing import Optional, Tuple, Union import pyro +from pyro.nn import PyroModule import torch from beartype import beartype from jaxtyping import Float, jaxtyped @@ -15,6 +16,9 @@ from pyrovelocity.logging import configure_logging from pyrovelocity.models._transcription_dynamics import mrna_dynamics, atac_mrna_dynamics, get_initial_states, get_cell_parameters +from pyrovelocity.models.knn_model.regulatory_functions_torch import regulatory_function_1 +from pyrovelocity.models.knn_model._vector_fields import vector_field_1 + logger = configure_logging(__name__) RNAInputType = Union[ @@ -29,12 +33,11 @@ ] __all__ = [ - "LogNormalModel", "VelocityModelAuto", "MultiVelocityModelAuto", ] -class VelocityModelAuto(LogNormalModel): +class VelocityModelAuto(PyroModule): """Automatically configured velocity model. Args: @@ -88,6 +91,7 @@ def __init__( self, num_cells: int, num_genes: int, + n_batch: int, likelihood: str = "Poisson", shared_time: bool = True, t_scale_on: bool = False, @@ -104,6 +108,15 @@ def __init__( guide_type: str = "velocity", cell_specific_kinetics: Optional[str] = None, kinetics_num: Optional[int] = None, + stochastic_v_ag_hyp_prior={"alpha": 6.0, "beta": 3.0}, + s_overdispersion_factor_hyp_prior={'alpha_mean': 100., 'beta_mean': 1., + 'alpha_sd': 1., 'beta_sd': 0.1}, + detection_hyp_prior={"alpha": 10.0, "mean_alpha": 1.0, "mean_beta": 1.0}, + detection_i_prior={"mean": 1, "alpha": 100}, + detection_gi_prior={"mean": 1, "alpha": 200}, + gene_add_alpha_hyp_prior={"alpha": 9.0, "beta": 3.0}, + gene_add_mean_hyp_prior={"alpha": 1.0, "beta": 100.0}, + Tmax_prior={"mean": 50., "sd": 50.}, **initial_values, ) -> None: assert num_cells > 0 and num_genes > 0 @@ -135,7 +148,106 @@ def __init__( self.decoder = Decoder(1, self.num_genes, n_layers=2) self.enumeration = "parallel" - # self.set_enumeration_strategy() + # self.set_enumeration_strategy() + + self.n_obs = num_cells + self.n_vars = num_genes + self.n_batch = n_batch + + self.stochastic_v_ag_hyp_prior = stochastic_v_ag_hyp_prior + self.gene_add_alpha_hyp_prior = gene_add_alpha_hyp_prior + self.gene_add_mean_hyp_prior = gene_add_mean_hyp_prior + self.detection_hyp_prior = detection_hyp_prior + self.s_overdispersion_factor_hyp_prior = s_overdispersion_factor_hyp_prior + self.detection_gi_prior = detection_gi_prior + self.detection_i_prior = detection_i_prior + + self.register_buffer( + "s_overdispersion_factor_alpha_mean", + torch.tensor(self.s_overdispersion_factor_hyp_prior["alpha_mean"]), + ) + self.register_buffer( + "s_overdispersion_factor_beta_mean", + torch.tensor(self.s_overdispersion_factor_hyp_prior["beta_mean"]), + ) + self.register_buffer( + "s_overdispersion_factor_alpha_sd", + torch.tensor(self.s_overdispersion_factor_hyp_prior["alpha_sd"]), + ) + self.register_buffer( + "s_overdispersion_factor_beta_sd", + torch.tensor(self.s_overdispersion_factor_hyp_prior["beta_sd"]), + ) + + self.register_buffer( + "detection_gi_prior_alpha", + torch.tensor(self.detection_gi_prior["alpha"]), + ) + self.register_buffer( + "detection_gi_prior_beta", + torch.tensor(self.detection_gi_prior["alpha"] / self.detection_gi_prior["mean"]), + ) + + self.register_buffer( + "detection_i_prior_alpha", + torch.tensor(self.detection_i_prior["alpha"]), + ) + self.register_buffer( + "detection_i_prior_beta", + torch.tensor(self.detection_i_prior["alpha"] / self.detection_i_prior["mean"]), + ) + + self.register_buffer( + "Tmax_mean", + torch.tensor(Tmax_prior["mean"]), + ) + + self.register_buffer( + "Tmax_sd", + torch.tensor(Tmax_prior["sd"]), + ) + + self.register_buffer( + "detection_mean_hyp_prior_alpha", + torch.tensor(self.detection_hyp_prior["mean_alpha"]), + ) + self.register_buffer( + "detection_mean_hyp_prior_beta", + torch.tensor(self.detection_hyp_prior["mean_beta"]), + ) + + self.register_buffer( + "stochastic_v_ag_hyp_prior_alpha", + torch.tensor(self.stochastic_v_ag_hyp_prior["alpha"]), + ) + self.register_buffer( + "stochastic_v_ag_hyp_prior_beta", + torch.tensor(self.stochastic_v_ag_hyp_prior["beta"]), + ) + self.register_buffer( + "gene_add_alpha_hyp_prior_alpha", + torch.tensor(self.gene_add_alpha_hyp_prior["alpha"]), + ) + self.register_buffer( + "gene_add_alpha_hyp_prior_beta", + torch.tensor(self.gene_add_alpha_hyp_prior["beta"]), + ) + self.register_buffer( + "gene_add_mean_hyp_prior_alpha", + torch.tensor(self.gene_add_mean_hyp_prior["alpha"]), + ) + self.register_buffer( + "gene_add_mean_hyp_prior_beta", + torch.tensor(self.gene_add_mean_hyp_prior["beta"]), + ) + + self.register_buffer( + "detection_hyp_prior_alpha", + torch.tensor(self.detection_hyp_prior["alpha"]), + ) + + self.register_buffer("one", torch.tensor(1.)) + @beartype def __repr__(self) -> str: @@ -150,8 +262,9 @@ def __repr__(self) -> str: def forward(self, u_obs: torch.Tensor, s_obs: torch.Tensor, + N_cn: torch.Tensor, ind_x: torch.Tensor, - batch_index: torch.Tensor) + batch_index: torch.Tensor): """ Defines the forward model, which computes the unspliced (u) and spliced (s) RNA expression levels given the observations and model parameters. @@ -171,80 +284,38 @@ def forward(self, batch_size = len(ind_x) obs2sample = one_hot(batch_index, self.n_batch) obs_plate = self.create_plates(u_obs, s_obs, ind_x, batch_index) + k = N_cn.shape[1] - # ===================== Kinetic Rates ======================= # - # Splicing rate: - splicing_alpha = pyro.sample('splicing_alpha', - dist.Gamma(self.splicing_rate_alpha_hyp_prior_alpha, - self.splicing_rate_alpha_hyp_prior_alpha/self.splicing_rate_alpha_hyp_prior_mean)) - splicing_mean = pyro.sample('splicing_mean', - dist.Gamma(self.splicing_rate_mean_hyp_prior_alpha, - self.splicing_rate_mean_hyp_prior_alpha/self.splicing_rate_mean_hyp_prior_mean)) - beta_g = pyro.sample('beta_g', dist.Gamma(splicing_alpha, splicing_alpha/splicing_mean).expand([1,self.n_vars]).to_event(2)) - # Degredation rate: - degredation_alpha = pyro.sample('degredation_alpha', - dist.Gamma(self.degredation_rate_alpha_hyp_prior_alpha, - self.degredation_rate_alpha_hyp_prior_alpha/self.degredation_rate_alpha_hyp_prior_mean)) - degredation_alpha = degredation_alpha + 0.001 - degredation_mean = pyro.sample('degredation_mean', - dist.Gamma(self.degredation_rate_mean_hyp_prior_alpha, - self.degredation_rate_mean_hyp_prior_alpha/self.degredation_rate_mean_hyp_prior_mean)) - gamma_g = pyro.sample('gamma_g', dist.Gamma(degredation_alpha, degredation_alpha/degredation_mean).expand([1,self.n_vars]).to_event(2)) - # Transcription rate contribution of each module: - factor_level_g = pyro.sample( - "factor_level_g", - dist.Gamma(self.factor_prior_alpha, self.factor_prior_beta) - .expand([1, self.n_vars]) - .to_event(2) - ) - g_fg = pyro.sample( # (g_fg corresponds to module's spliced counts in steady state) - "g_fg", - dist.Gamma( - self.factor_states_per_gene / self.n_factors_torch, - self.ones / factor_level_g, - ) - .expand([self.n_modules, self.n_vars]) - .to_event(2) - ) - A_mgON = pyro.deterministic('A_mgON', g_fg*gamma_g) # (transform from spliced counts to transcription rate) - A_mgOFF = self.alpha_OFFg - # Activation and Deactivation rate: - lam_mu = pyro.sample('lam_mu', dist.Gamma(G_a(self.activation_rate_mean_hyp_prior_mean, self.activation_rate_mean_hyp_prior_sd), - G_b(self.activation_rate_mean_hyp_prior_mean, self.activation_rate_mean_hyp_prior_sd))) - lam_sd = pyro.sample('lam_sd', dist.Gamma(G_a(self.activation_rate_sd_hyp_prior_mean, self.activation_rate_sd_hyp_prior_sd), - G_b(self.activation_rate_sd_hyp_prior_mean, self.activation_rate_sd_hyp_prior_sd))) - lam_m_mu = pyro.sample('lam_m_mu', dist.Gamma(G_a(lam_mu, lam_sd), - G_b(lam_mu, lam_sd)).expand([self.n_modules, 1, 1]).to_event(3)) - lam_mi = pyro.sample('lam_mi', dist.Gamma(G_a(lam_m_mu, lam_m_mu*0.05), - G_b(lam_m_mu, lam_m_mu*0.05)).expand([self.n_modules, 1, 2]).to_event(3)) - - # =====================Time======================= # - # Global time for each cell: + # ============= Expression Model =============== # T_max = pyro.sample('Tmax', dist.Gamma(G_a(self.Tmax_mean, self.Tmax_sd), G_b(self.Tmax_mean, self.Tmax_sd))) t_c_loc = pyro.sample('t_c_loc', dist.Gamma(self.one, self.one/0.5)) t_c_scale = pyro.sample('t_c_scale', dist.Gamma(self.one, self.one/0.25)) with obs_plate: t_c = pyro.sample('t_c', dist.Normal(t_c_loc, t_c_scale).expand([batch_size, 1, 1])) T_c = pyro.deterministic('T_c', t_c*T_max) - # Global switch on time for each gene: -# t_mON = pyro.sample('t_mON', dist.Uniform(self.zero, self.one).expand([1, 1, self.n_modules]).to_event(2)) - t_delta = pyro.sample('t_delta', dist.Gamma(self.one*20, self.one * 20 *self.n_modules_torch). - expand([self.n_modules]).to_event(1)) - t_mON = torch.cumsum(torch.concat([self.zero.unsqueeze(0), t_delta[:-1]]), dim = 0).unsqueeze(0).unsqueeze(0) - T_mON = pyro.deterministic('T_mON', T_max*t_mON) - # Global switch off time for each gene: - t_mOFF = pyro.sample('t_mOFF', dist.Exponential(self.n_modules_torch).expand([1, 1, self.n_modules]).to_event(2)) - T_mOFF = pyro.deterministic('T_mOFF', T_mON + T_max*t_mOFF) - # =========== Mean expression according to RNAvelocity model ======================= # - mu_total = torch.stack([self.zeros[idx,...], self.zeros[idx,...]], axis = -1) - for m in range(self.n_modules): - mu_total += mu_mRNA_continousAlpha_globalTime_twoStates( - A_mgON[m,:], A_mgOFF, beta_g, gamma_g, lam_mi[m,...], T_c[:,:,0], T_mON[:,:,m], T_mOFF[:,:,m], self.zeros[ind_x,...]) + # Time difference between neighbors: + delta_cn = T_c.unsqueeze(-1) - T_c[N_cn, :] + + # Counts in each cell: + mu0_cg = pyro.sample('mu0_cg', dist.Gamma(self.one*5.0, self.one*1.0).expand([batch_size, n_genes, 2])) + + # Weight of each nearest neighbor: + wdash0_nc = pyro.sample('wdash0_nc', dist.Gamma(self.one*0.1, self.one*10.0).expand([1,batch_size])) + wdash5_nc = pyro.sample('wdash_nc', dist.Gamma(self.one*10.0, self.one*10.0).expand([k-1,batch_size])) + wdash_nc = torch.concat([wdash0_nc, wdash5_nc], axis = 0) + w_nc = wdash_nc/torch.sum(wdash_nc, axis = 0) + + # Predicted counts from each neighbor: + y = (mu0_cg[...,0], mu0_cg[...,1]) + dy_cn = torch.stack(vector_field_1(0.0,y,[regulatory_function_1]), axis = -1)[N_cn,...] + muhat_cg = torch.stack(y, axis = -1) + torch.sum((w_nc.T.unsqueeze(-1).unsqueeze(-1) * delta_cn * dy_cn), axis = 1)/k + + # Initial conditions and predicted counts need to match up: with obs_plate: - mu_expression = pyro.deterministic('mu_expression', mu_total) + pyro.sample("data_target", dist.Gamma(muhat_cg, self.one, obs=mu0_cg)) - # =============Detection efficiency of spliced and unspliced counts =============== # + # ============= Measurement Model =============== # # Cell specific relative detection efficiency with hierarchical prior across batches: detection_mean_y_e = pyro.sample( "detection_mean_y_e", @@ -288,8 +359,7 @@ def forward(self, .to_event(3), ) - # =======Gene-specific additive component (Ambient RNA/ "Soup") for spliced and unspliced counts ====== # - # Independently sampled for spliced and unspliced counts: + # Gene-specific additive component (Ambient RNA/ "Soup") # s_g_gene_add_alpha_hyp = pyro.sample( "s_g_gene_add_alpha_hyp", dist.Gamma(self.gene_add_alpha_hyp_prior_alpha, self.gene_add_alpha_hyp_prior_beta).expand([2]).to_event(1), @@ -313,30 +383,14 @@ def forward(self, dist.Gamma(s_g_gene_add_alpha_e, s_g_gene_add_alpha_e / s_g_gene_add_mean) .expand([self.n_batch, self.n_vars, 2]) .to_event(3), - ) - - # =========Gene-specific overdispersion of spliced and unspliced counts ============== # - # Overdispersion of unspliced counts: - stochastic_v_ag_hyp = pyro.sample( - "stochastic_v_ag_hyp", - dist.Gamma( - self.stochastic_v_ag_hyp_prior_alpha, - self.stochastic_v_ag_hyp_prior_beta, - ).expand([1, 2]).to_event(2)) - stochastic_v_ag_hyp = stochastic_v_ag_hyp + 0.001 - stochastic_v_ag_inv = pyro.sample( - "stochastic_v_ag_inv", - dist.Exponential(stochastic_v_ag_hyp) - .expand([1, self.n_vars, 2]).to_event(3), - ) - stochastic_v_ag = (self.ones / stochastic_v_ag_inv.pow(2)) + ) - # =====================Expected expression ======================= # + # =====================Expected observed expression ======================= # with obs_plate: - mu = pyro.deterministic('mu', (mu_expression + torch.einsum('cbi,bgi->cgi', obs2sample.unsqueeze(dim=-1), s_g_gene_add)) * \ + mu = pyro.deterministic('mu', (mu0_cg + torch.einsum('cbi,bgi->cgi', obs2sample.unsqueeze(dim=-1), s_g_gene_add)) * \ detection_y_c * detection_y_i * detection_y_gi) # =====================DATA likelihood ======================= # with obs_plate: - pyro.sample("data_target", dist.GammaPoisson(concentration= stochastic_v_ag, - rate= stochastic_v_ag / mu), obs=torch.stack([u_obs, s_obs], axis = 2)) \ No newline at end of file + pyro.sample("data_target", dist.Poisson(rate = mu, + obs=torch.stack([u_obs, s_obs], axis = 2))) \ No newline at end of file From 3963a4dec5cee80b5fecd806e1c878c39d16f0b6 Mon Sep 17 00:00:00 2001 From: Alexander Aivazidis Date: Thu, 5 Sep 2024 21:04:45 +0000 Subject: [PATCH 39/47] fix(_velocity_module): Fixed various imports. Signed-off-by: Alexander Aivazidis --- .../models/knn_model/_velocity_module.py | 223 ++++++++++++++++-- 1 file changed, 205 insertions(+), 18 deletions(-) diff --git a/src/pyrovelocity/models/knn_model/_velocity_module.py b/src/pyrovelocity/models/knn_model/_velocity_module.py index cf8f1522a..f95694bc6 100644 --- a/src/pyrovelocity/models/knn_model/_velocity_module.py +++ b/src/pyrovelocity/models/knn_model/_velocity_module.py @@ -12,7 +12,7 @@ from scvi.module.base import PyroBaseModuleClass from pyrovelocity.logging import configure_logging -from pyrovelocity.models._velocity_model import VelocityModelAuto, MultiVelocityModelAuto +from pyrovelocity.models.knn_model._velocity_model import VelocityModelAuto logger = configure_logging(__name__) @@ -212,25 +212,212 @@ def _get_fn_args_from_batch( ]: u_obs = tensor_dict["U"] s_obs = tensor_dict["X"] - u_log_library = tensor_dict["u_lib_size"] - s_log_library = tensor_dict["s_lib_size"] - u_log_library_mean = tensor_dict["u_lib_size_mean"] - s_log_library_mean = tensor_dict["s_lib_size_mean"] - u_log_library_scale = tensor_dict["u_lib_size_scale"] - s_log_library_scale = tensor_dict["s_lib_size_scale"] + N_cn = tensor_dict["N_cn"] ind_x = tensor_dict["ind_x"].long().squeeze() - cell_state = tensor_dict.get("pyro_cell_state") - time_info = tensor_dict.get("time_info") return ( u_obs, s_obs, - u_log_library, - s_log_library, - u_log_library_mean, - s_log_library_mean, - u_log_library_scale, - s_log_library_scale, - ind_x, - cell_state, - time_info, + N_cn + ), {} + +class MultiVelocityModule(PyroBaseModuleClass): + """ + VelocityModule is an scvi-tools pyro module that combines the VelocityModelAuto and pyro AutoGuideList classes. + + Args: + num_cells (int): Number of cells. + num_genes (int): Number of genes. + model_type (str, optional): Model type. Default is "auto". + guide_type (str, optional): Guide type. Default is "velocity_auto". + likelihood (str, optional): Likelihood type. Default is "Poisson". + shared_time (bool, optional): If True, a shared time parameter will be used. Default is True. + t_scale_on (bool, optional): If True, scale time parameter. Default is False. + plate_size (int, optional): Size of the plate set. Default is 2. + latent_factor (str, optional): Latent factor. Default is "none". + latent_factor_operation (str, optional): Latent factor operation mode. Default is "selection". + latent_factor_size (int, optional): Size of the latent factor. Default is 10. + inducing_point_size (int, optional): Inducing point size. Default is 0. + include_prior (bool, optional): If True, include prior in the model. Default is False. + use_gpu (str, optional): Accelerator type. Default is "auto". + num_aux_cells (int, optional): Number of auxiliary cells. Default is 0. + only_cell_times (bool, optional): If True, only model cell times. Default is True. + decoder_on (bool, optional): If True, use the decoder. Default is False. + add_offset (bool, optional): If True, add offset to the model. Default is True. + correct_library_size (Union[bool, str], optional): Library size correction method. Default is True. + cell_specific_kinetics (Optional[str], optional): Cell-specific kinetics method. Default is None. + kinetics_num (Optional[int], optional): Number of kinetics. Default is None. + **initial_values: Initial values for the model parameters. + + Examples: + >>> from scvi.module.base import PyroBaseModuleClass + >>> from pyrovelocity.models._velocity_module import VelocityModule + >>> num_cells = 10 + >>> num_genes = 20 + >>> velocity_module1 = VelocityModule( + ... num_cells, num_genes, model_type="auto", + ... guide_type="auto_t0_constraint", add_offset=False + ... ) + >>> type(velocity_module1.model) + + >>> type(velocity_module1.guide) + + >>> velocity_module2 = VelocityModule( + ... num_cells, num_genes, model_type="auto", + ... guide_type="auto", add_offset=True + ... ) + >>> type(velocity_module2.model) + + >>> type(velocity_module2.guide) + + """ + + def __init__( + self, + num_cells: int, + num_genes: int, + model_type: str = "auto", + guide_type: str = "velocity_auto", + likelihood: str = "Poisson", + shared_time: bool = True, + t_scale_on: bool = False, + plate_size: int = 2, + latent_factor: str = "none", + latent_factor_operation: str = "selection", + latent_factor_size: int = 10, + inducing_point_size: int = 0, + include_prior: bool = False, + use_gpu: str = "auto", + num_aux_cells: int = 0, + only_cell_times: bool = True, + decoder_on: bool = False, + add_offset: bool = True, + correct_library_size: Union[bool, str] = True, + cell_specific_kinetics: Optional[str] = None, + kinetics_num: Optional[int] = None, + **initial_values, + ) -> None: + super().__init__() + self.num_cells = num_cells + self.num_genes = num_genes + self.model_type = model_type + self.guide_type = guide_type + self._model = None + self.plate_size = plate_size + self.num_aux_cells = num_aux_cells + self.only_cell_times = only_cell_times + logger.info( + f"Model type: {self.model_type}, Guide type: {self.guide_type}" + ) + + self.cell_specific_kinetics = cell_specific_kinetics + + self._model = VelocityModelAuto( + self.num_cells, + self.num_genes, + likelihood, + shared_time, + t_scale_on, + self.plate_size, + latent_factor, + latent_factor_operation=latent_factor_operation, + latent_factor_size=latent_factor_size, + include_prior=include_prior, + num_aux_cells=num_aux_cells, + only_cell_times=self.only_cell_times, + decoder_on=decoder_on, + add_offset=add_offset, + correct_library_size=correct_library_size, + guide_type=self.guide_type, + cell_specific_kinetics=self.cell_specific_kinetics, + **initial_values, + ) + + guide = AutoGuideList( + self._model, create_plates=self._model.create_plates + ) + guide.append( + AutoNormal( + poutine.block( + self._model, + expose=[ + "cell_time", + "u_read_depth", + "s_read_depth", + "kinetics_prob", + "kinetics_weights", + ], + ), + init_scale=0.1, + ) + ) + + if add_offset: + guide.append( + AutoLowRankMultivariateNormal( + poutine.block( + self._model, + expose=[ + "dt_switching", + ], + ), + rank=10, + init_scale=0.1, + ) + ) + else: + guide.append( + AutoLowRankMultivariateNormal( + poutine.block( + self._model, + expose=[ + "alpha", + "beta", + "gamma", + "dt_switching", + "t0", + "u_scale", + "s_scale", + ], + ), + rank=10, + init_scale=0.1, + ) + ) + self._guide = guide + + @property + def model(self) -> VelocityModelAuto: + return self._model + + @property + def guide(self) -> AutoGuideList: + return self._guide + + @staticmethod + def _get_fn_args_from_batch( + tensor_dict: Dict[str, torch.Tensor] + ) -> Tuple[ + Tuple[ + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + ], + Dict[Any, Any], + ]: + u_obs = tensor_dict["U"] + s_obs = tensor_dict["X"] + N_cn = tensor_dict["N_cn"] + ind_x = tensor_dict["ind_x"].long().squeeze() + return ( + u_obs, + s_obs, + N_cn ), {} \ No newline at end of file From a2e301ef55b36008fefad2a996bf5bdfb3e10d05 Mon Sep 17 00:00:00 2001 From: Alexander Aivazidis Date: Thu, 5 Sep 2024 21:05:13 +0000 Subject: [PATCH 40/47] feat(regulatory_function_1): Adapted function to deal with more than one cell at a time. Signed-off-by: Alexander Aivazidis --- .../models/knn_model/regulatory_functions_torch.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/pyrovelocity/models/knn_model/regulatory_functions_torch.py b/src/pyrovelocity/models/knn_model/regulatory_functions_torch.py index ad41d705b..1f2695b9e 100644 --- a/src/pyrovelocity/models/knn_model/regulatory_functions_torch.py +++ b/src/pyrovelocity/models/knn_model/regulatory_functions_torch.py @@ -39,8 +39,8 @@ def regulatory_function_1(u: Tensor, x = l3(x) output = torch.sigmoid(x) - beta = output[:,0] - gamma = output[:,1] - alphas = output[:,2] + beta = output[...,0].T + gamma = output[...,1].T + alphas = output[...,2].T return alphas, beta, gamma \ No newline at end of file From a69fda0e29f1bc792b6b64056e3b0b351b1362d4 Mon Sep 17 00:00:00 2001 From: Alexander Aivazidis Date: Thu, 5 Sep 2024 21:05:29 +0000 Subject: [PATCH 41/47] feat(knn_model): Added init file Signed-off-by: Alexander Aivazidis --- src/pyrovelocity/models/knn_model/__init__.py | 5 +++++ 1 file changed, 5 insertions(+) create mode 100644 src/pyrovelocity/models/knn_model/__init__.py diff --git a/src/pyrovelocity/models/knn_model/__init__.py b/src/pyrovelocity/models/knn_model/__init__.py new file mode 100644 index 000000000..a8d28fa24 --- /dev/null +++ b/src/pyrovelocity/models/knn_model/__init__.py @@ -0,0 +1,5 @@ +from pyrovelocity.models._velocity import PyroVelocity + +__all__ = [ + PyroVelocity +] From 39da15c379046d700e843b5827f2d998d78158bf Mon Sep 17 00:00:00 2001 From: Alexander Aivazidis Date: Fri, 6 Sep 2024 20:38:47 +0000 Subject: [PATCH 42/47] feat(knn_model): Completed first version of knn_model that trains without errors. Signed-off-by: Alexander Aivazidis --- .../models/knn_model/_velocity.py | 20 +--- .../models/knn_model/_velocity_model.py | 94 +++++++++++++++++-- .../models/knn_model/_velocity_module.py | 55 ++++------- 3 files changed, 102 insertions(+), 67 deletions(-) diff --git a/src/pyrovelocity/models/knn_model/_velocity.py b/src/pyrovelocity/models/knn_model/_velocity.py index dbb04e87d..4ff328b27 100644 --- a/src/pyrovelocity/models/knn_model/_velocity.py +++ b/src/pyrovelocity/models/knn_model/_velocity.py @@ -248,25 +248,7 @@ def __init__( self.module = VelocityModule( self.summary_stats["n_cells"], self.summary_stats["n_vars"], - model_type=model_type, - guide_type=guide_type, - likelihood=likelihood, - shared_time=shared_time, - t_scale_on=t_scale_on, - plate_size=plate_size, - latent_factor=latent_factor, - latent_factor_operation=latent_factor_operation, - latent_factor_size=latent_factor_size, - inducing_point_size=inducing_point_size, - include_prior=include_prior, - use_gpu=use_gpu, - num_aux_cells=num_aux_cells, - only_cell_times=only_cell_times, - decoder_on=decoder_on, - add_offset=add_offset, - correct_library_size=correct_library_size, - cell_specific_kinetics=cell_specific_kinetics, - kinetics_num=self.k, + self.summary_stats["n_batch"], **initial_values, ) else: diff --git a/src/pyrovelocity/models/knn_model/_velocity_model.py b/src/pyrovelocity/models/knn_model/_velocity_model.py index 3834f6a26..b94d73e9a 100644 --- a/src/pyrovelocity/models/knn_model/_velocity_model.py +++ b/src/pyrovelocity/models/knn_model/_velocity_model.py @@ -9,7 +9,9 @@ from pyro.distributions import Bernoulli, LogNormal, Normal, Poisson from pyro.nn import PyroModule, PyroSample from pyro.primitives import plate +import pyro.distributions as dist from scvi.nn import Decoder +from scvi.nn import one_hot from torch.nn.functional import relu, softplus from torch import Tensor @@ -37,6 +39,43 @@ "MultiVelocityModelAuto", ] +def G_a(mu, sd): + """ + Converts mean and standard deviation for a Gamma distribution into the shape parameter. + + Parameters + ---------- + mu + The mean of the Gamma distribution. + sd + The standard deviation of the Gamma distribution. + + Returns + ------- + Float + The shape parameter of the Gamma distribution. + """ + return mu**2/sd**2 + +def G_b(mu, sd): + """ + Converts mean and standard deviation for a Gamma distribution into the scale parameter. + + Parameters + ---------- + mu + The mean of the Gamma distribution. + sd + The standard deviation of the Gamma distribution. + + Returns + ------- + Float + The scale parameter of the Gamma distribution. + """ + + return mu/sd**2 + class VelocityModelAuto(PyroModule): """Automatically configured velocity model. @@ -119,13 +158,18 @@ def __init__( Tmax_prior={"mean": 50., "sd": 50.}, **initial_values, ) -> None: + + super().__init__() + assert num_cells > 0 and num_genes > 0 - super().__init__(num_cells, num_genes, likelihood, plate_size) self.num_aux_cells = num_aux_cells self.only_cell_times = only_cell_times self.guide_type = guide_type self.cell_specific_kinetics = cell_specific_kinetics self.k = kinetics_num + self.num_cells = num_cells + self.num_genes = num_genes + self.n_genes = num_genes self.mask = initial_values.get( "mask", torch.ones(self.num_cells, self.num_genes).bool() @@ -247,7 +291,36 @@ def __init__( ) self.register_buffer("one", torch.tensor(1.)) + self.register_buffer("ones", torch.ones((1, 1))) + @beartype + def create_plates(self, + u_obs: torch.Tensor, + s_obs: torch.Tensor, + N_cn: torch.Tensor, + ind_x: torch.Tensor, + batch_index: torch.Tensor): + """ + Creates a Pyro plate for observations. + + Parameters + ---------- + u_obs + Unspliced count data. + s_obs + Spliced count data. + ind_x + Index tensor to subsample. + batch_index + Index tensor indicating batch assignments. + + Returns + ------- + Pyro.plate + A Pyro plate representing the observations in the dataset. + """ + + return pyro.plate("obs_plate", size=self.n_obs, dim=-3, subsample=ind_x) @beartype def __repr__(self) -> str: @@ -283,8 +356,9 @@ def forward(self, batch_size = len(ind_x) obs2sample = one_hot(batch_index, self.n_batch) - obs_plate = self.create_plates(u_obs, s_obs, ind_x, batch_index) - k = N_cn.shape[1] + obs_plate = self.create_plates(u_obs, s_obs, N_cn, ind_x, batch_index) + k = N_cn.shape[1] + N_cn = N_cn.long() # ============= Expression Model =============== # T_max = pyro.sample('Tmax', dist.Gamma(G_a(self.Tmax_mean, self.Tmax_sd), G_b(self.Tmax_mean, self.Tmax_sd))) @@ -298,11 +372,12 @@ def forward(self, delta_cn = T_c.unsqueeze(-1) - T_c[N_cn, :] # Counts in each cell: - mu0_cg = pyro.sample('mu0_cg', dist.Gamma(self.one*5.0, self.one*1.0).expand([batch_size, n_genes, 2])) + with obs_plate: + mu0_cg = pyro.sample('mu0_cg', dist.Gamma(self.one*5.0, self.one*1.0).expand([batch_size, self.n_genes, 2])) # Weight of each nearest neighbor: - wdash0_nc = pyro.sample('wdash0_nc', dist.Gamma(self.one*0.1, self.one*10.0).expand([1,batch_size])) - wdash5_nc = pyro.sample('wdash_nc', dist.Gamma(self.one*10.0, self.one*10.0).expand([k-1,batch_size])) + wdash0_nc = pyro.sample('wdash0_nc', dist.Gamma(self.one*0.1, self.one*10.0).expand([1,batch_size]).to_event(2)) + wdash5_nc = pyro.sample('wdash_nc', dist.Gamma(self.one*10.0, self.one*10.0).expand([k-1,batch_size]).to_event(2)) wdash_nc = torch.concat([wdash0_nc, wdash5_nc], axis = 0) w_nc = wdash_nc/torch.sum(wdash_nc, axis = 0) @@ -313,7 +388,8 @@ def forward(self, # Initial conditions and predicted counts need to match up: with obs_plate: - pyro.sample("data_target", dist.Gamma(muhat_cg, self.one, obs=mu0_cg)) + pyro.sample("constrain", dist.Normal(muhat_cg, self.one*torch.abs(muhat_cg)*0.2), + obs=mu0_cg) # ============= Measurement Model =============== # # Cell specific relative detection efficiency with hierarchical prior across batches: @@ -392,5 +468,5 @@ def forward(self, # =====================DATA likelihood ======================= # with obs_plate: - pyro.sample("data_target", dist.Poisson(rate = mu, - obs=torch.stack([u_obs, s_obs], axis = 2))) \ No newline at end of file + pyro.sample("data_target", dist.Poisson(rate = mu), + obs=torch.stack([u_obs, s_obs], axis = 2)) \ No newline at end of file diff --git a/src/pyrovelocity/models/knn_model/_velocity_module.py b/src/pyrovelocity/models/knn_model/_velocity_module.py index f95694bc6..33e983f81 100644 --- a/src/pyrovelocity/models/knn_model/_velocity_module.py +++ b/src/pyrovelocity/models/knn_model/_velocity_module.py @@ -10,6 +10,7 @@ from pyro.infer.autoguide import AutoNormal from pyro.infer.autoguide.guides import AutoGuideList from scvi.module.base import PyroBaseModuleClass +from scvi import REGISTRY_KEYS from pyrovelocity.logging import configure_logging from pyrovelocity.models.knn_model._velocity_model import VelocityModelAuto @@ -73,6 +74,7 @@ def __init__( self, num_cells: int, num_genes: int, + n_batch: int, model_type: str = "auto", guide_type: str = "velocity_auto", likelihood: str = "Poisson", @@ -97,6 +99,7 @@ def __init__( super().__init__() self.num_cells = num_cells self.num_genes = num_genes + self.n_batch = n_batch self.model_type = model_type self.guide_type = guide_type self._model = None @@ -112,21 +115,7 @@ def __init__( self._model = VelocityModelAuto( self.num_cells, self.num_genes, - likelihood, - shared_time, - t_scale_on, - self.plate_size, - latent_factor, - latent_factor_operation=latent_factor_operation, - latent_factor_size=latent_factor_size, - include_prior=include_prior, - num_aux_cells=num_aux_cells, - only_cell_times=self.only_cell_times, - decoder_on=decoder_on, - add_offset=add_offset, - correct_library_size=correct_library_size, - guide_type=self.guide_type, - cell_specific_kinetics=self.cell_specific_kinetics, + self.n_batch, **initial_values, ) @@ -138,7 +127,7 @@ def __init__( poutine.block( self._model, expose=[ - "cell_time", + "Tmax", "u_read_depth", "s_read_depth", "kinetics_prob", @@ -154,8 +143,7 @@ def __init__( AutoLowRankMultivariateNormal( poutine.block( self._model, - expose=[ - "dt_switching", + expose=["detection_y_c", ], ), rank=10, @@ -168,7 +156,6 @@ def __init__( poutine.block( self._model, expose=[ - "alpha", "beta", "gamma", "dt_switching", @@ -201,12 +188,6 @@ def _get_fn_args_from_batch( torch.Tensor, torch.Tensor, torch.Tensor, - torch.Tensor, - torch.Tensor, - torch.Tensor, - torch.Tensor, - torch.Tensor, - torch.Tensor, ], Dict[Any, Any], ]: @@ -214,10 +195,13 @@ def _get_fn_args_from_batch( s_obs = tensor_dict["X"] N_cn = tensor_dict["N_cn"] ind_x = tensor_dict["ind_x"].long().squeeze() + batch_index = tensor_dict[REGISTRY_KEYS.BATCH_KEY] return ( u_obs, s_obs, - N_cn + N_cn, + ind_x, + batch_index ), {} class MultiVelocityModule(PyroBaseModuleClass): @@ -299,6 +283,7 @@ def __init__( super().__init__() self.num_cells = num_cells self.num_genes = num_genes + self.n_genes = num_genes self.model_type = model_type self.guide_type = guide_type self._model = None @@ -340,11 +325,8 @@ def __init__( poutine.block( self._model, expose=[ - "cell_time", - "u_read_depth", - "s_read_depth", - "kinetics_prob", - "kinetics_weights", + "Tmax", + 't_c_loc' ], ), init_scale=0.1, @@ -357,7 +339,8 @@ def __init__( poutine.block( self._model, expose=[ - "dt_switching", + "Tmax", + ], ), rank=10, @@ -370,13 +353,7 @@ def __init__( poutine.block( self._model, expose=[ - "alpha", - "beta", - "gamma", - "dt_switching", - "t0", - "u_scale", - "s_scale", + "Tmax", ], ), rank=10, From e5a715bcde5606c8c0d4f76b8854e2dae45b0ec7 Mon Sep 17 00:00:00 2001 From: Alexander Aivazidis Date: Fri, 13 Sep 2024 17:13:20 +0000 Subject: [PATCH 43/47] fix(_velocity): Provide number of cells in each metacell to model. Signed-off-by: Alexander Aivazidis --- .../models/knn_model/_velocity.py | 158 +++++++++++------- 1 file changed, 95 insertions(+), 63 deletions(-) diff --git a/src/pyrovelocity/models/knn_model/_velocity.py b/src/pyrovelocity/models/knn_model/_velocity.py index 4ff328b27..99ca3155f 100644 --- a/src/pyrovelocity/models/knn_model/_velocity.py +++ b/src/pyrovelocity/models/knn_model/_velocity.py @@ -14,8 +14,9 @@ from scvi.data._constants import _SETUP_ARGS_KEY, _SETUP_METHOD_NAME from scvi.data.fields import LayerField, NumericalObsField, CategoricalObsField, ObsmField from scvi.model._utils import parse_device_args -from scvi.model.base import BaseModelClass +from scvi.model.base import BaseModelClass, PyroSampleMixin from scvi import REGISTRY_KEYS +from datetime import date from scvi.model.base._utils import ( _initialize_model, _load_saved_files, @@ -36,7 +37,7 @@ logger = configure_logging(__name__) -class PyroVelocity(PyroSviTrainMixin, BaseModelClass): +class PyroVelocity(PyroSviTrainMixin, BaseModelClass, PyroSampleMixin): """ PyroVelocity is a class for constructing and training a Pyro model for probabilistic RNA velocity estimation. This model leverages the @@ -341,6 +342,7 @@ def setup_anndata(cls, adata: AnnData, adata_atac = None, batch_key = None, *arg LayerField("X", "raw_spliced", is_count_data=True), CategoricalObsField(REGISTRY_KEYS.BATCH_KEY, batch_key), NumericalObsField("ind_x", "ind_x"), + NumericalObsField("M_c", "n_cells") ] if adata_atac: @@ -360,74 +362,104 @@ def setup_anndata(cls, adata: AnnData, adata_atac = None, batch_key = None, *arg adata_manager.register_fields(adata, **kwargs) cls.register_manager(adata_manager) - def generate_posterior_samples( + def _export2adata(self, samples): + r""" + Export key model variables and samples + + Parameters + ---------- + samples + Dictionary with posterior mean, 5%/95% quantiles, SD, samples, generated by ``.sample_posterior()`` + + Returns + ------- + Dict + Updated dictionary with additional details is saved to ``adata.uns['mod']``. + """ + # add factor filter and samples of all parameters to unstructured data + results = { + "model_name": str(self.module.__class__.__name__), + "date": str(date.today()), + "var_names": self.adata.var_names.tolist(), + "obs_names": self.adata.obs_names.tolist(), + "post_sample_means": samples["post_sample_means"], + "post_sample_stds": samples["post_sample_stds"], + "post_sample_q05": samples["post_sample_q05"], + "post_sample_q95": samples["post_sample_q95"], + } + + return results + + def export_posterior( self, - adata: Optional[AnnData] = None, - indices: Optional[Sequence[int]] = None, - batch_size: Optional[int] = None, - num_samples: int = 100, - ) -> Dict[str, ndarray]: + adata, + sample_kwargs = {"num_samples": 30, "batch_size" : None, + 'return_samples': True}, + export_slot: str = "mod", + full_velocity_posterior = False, + normalize = True): """ - Generates posterior samples for the given data using the trained - PyroVelocity model. + Summarises posterior distribution and exports results to anndata object. Also computes RNAvelocity (based on posterior of rates) and normalized counts (based on posterior of technical variables). + + - **adata.obs:** Latent time, sequencing depth constant. + + - **adata.var:** transcription/splicing/degredation rates, switch on and off times. + + - **adata.uns:** Posterior of all parameters ('mean', 'sd', 'q05', 'q95' and optionally all samples), model name, date. + + - **adata.layers:** ``velocity`` (expected gradient of spliced counts), ``velocity_sd`` (uncertainty in this gradient), ``spliced_norm``, ``unspliced_norm`` (normalized counts). + + - **adata.uns:** If ``return_samples: True`` and ``full_velocity_posterior = True`` full posterior distribution for velocity is saved in ``adata.uns['velocity_posterior']``. + + Parameters + ---------- + adata + AnnData object where results should be saved. + sample_kwargs + Optionally a dictionary of arguments for self.sample_posterior, namely: + + - **num_sample:s** Number of samples to use (Default = 1000). + - **batch_size:** Data batch size (keep low enough to fit on GPU, default 2048). + - **use_gpu:** Use gpu for generating samples. + - **return_samples:** Export all posterior samples (Otherwise just summary statistics). + export_slot + adata.uns slot where to export results. + full_velocity_posterior + Whether to save full posterior of velocity (only possible if "return_samples: True"). + normalize + Whether to compute normalized spliced and unspliced counts based on posterior of technical variables. + Returns + ------- + AnnData + AnnData object with posterior added in adata.obs, adata.var and adata.uns. + + """ + + if sample_kwargs['batch_size'] == None: + sample_kwargs['batch_size'] = adata.n_obs - The method generates posterior samples by running the trained model on the - provided data and returns a dictionary containing samples for each parameter. + # generate samples from posterior distributions for all parameters + # and compute mean, 5%/95% quantiles and standard deviation + self.samples = self.sample_posterior(**sample_kwargs) - Args: - adata (AnnData, optional): Anndata object containing the data for which posterior samples - are to be computed. If not provided, the anndata used to initialize the model will be used. - indices (Sequence[int], optional): Indices of cells in `adata` for which the posterior - samples are to be computed. - batch_size (int, optional): The size of the mini-batches used during computation. - If not provided, the entire dataset will be used. - num_samples (int, default: 100): The number of posterior samples to compute for each parameter. + # export posterior distribution summary for all parameters and + # annotation (model, date, var, obs and cell type names) to anndata object + adata.uns[export_slot] = self._export2adata(self.samples) - Returns: - Dict[str, ndarray]: A dictionary containing the posterior samples for each parameter. - """ - self.module.eval() - predictive = self.module.create_predictive( - model=pyro.poutine.uncondition(self.module.model), - num_samples=num_samples, - ) + if sample_kwargs['return_samples']: + print('Warning: Saving ALL posterior samples. Specify "return_samples: False" to save just summary statistics.') + adata.uns[export_slot]['post_samples'] = self.samples['posterior_samples'] - scdl = self._make_data_loader( - adata=adata, indices=indices, batch_size=batch_size - ) + adata.obs['Time (hours)'] = self.samples['post_sample_means']['T_c'].flatten() - np.min(self.samples['post_sample_means']['T_c'].flatten()) + adata.obs['Time Uncertainty (sd)'] = self.samples['post_sample_stds']['T_c'].flatten() + +# adata.layers['spliced mean'] = self.samples['post_sample_means']['mu_expression'][...,1] +# adata.layers['velocity'] = torch.tensor(self.samples['post_sample_means']['beta_g']) * \ +# self.samples['post_sample_means']['mu_expression'][...,0] - \ +# torch.tensor(self.samples['post_sample_means']['gamma_g']) * \ +# self.samples['post_sample_means']['mu_expression'][...,1] - with torch.no_grad(), pyro.poutine.mask(mask=False): - posterior_samples = [] - for tensor in scdl: - args, kwargs = self.module._get_fn_args_from_batch(tensor) - posterior_sample = { - k: v.cpu().numpy() - for k, v in predictive(*args, **kwargs).items() - } - posterior_samples.append(posterior_sample) - samples = {} - for k in posterior_samples[0].keys(): - if k in [ - "ut_norm", - "st_norm", - "time_constraint", - ]: - continue - - if posterior_samples[0][k].shape[-2] == 1: - samples[k] = posterior_samples[0][k] - else: - samples[k] = np.concatenate( - [ - posterior_samples[j][k] - for j in range(len(posterior_samples)) - ], - axis=-2, - ) - - logger.debug(k, "before", sys.getsizeof(samples[k])) - self.num_samples = num_samples - return samples + return adata def get_mlflow_logs(self): return From f51a3dd313058e4a52c7060fe017ca74bd6219f4 Mon Sep 17 00:00:00 2001 From: Alexander Aivazidis Date: Fri, 13 Sep 2024 17:14:04 +0000 Subject: [PATCH 44/47] fix(_velocity_model): Various fixes to model architecture. Signed-off-by: Alexander Aivazidis --- .../models/knn_model/_velocity_model.py | 154 +++++++++--------- 1 file changed, 80 insertions(+), 74 deletions(-) diff --git a/src/pyrovelocity/models/knn_model/_velocity_model.py b/src/pyrovelocity/models/knn_model/_velocity_model.py index b94d73e9a..ecc26c550 100644 --- a/src/pyrovelocity/models/knn_model/_velocity_model.py +++ b/src/pyrovelocity/models/knn_model/_velocity_model.py @@ -14,6 +14,7 @@ from scvi.nn import one_hot from torch.nn.functional import relu, softplus from torch import Tensor +import torch.nn.functional as F from pyrovelocity.logging import configure_logging from pyrovelocity.models._transcription_dynamics import mrna_dynamics, atac_mrna_dynamics, get_initial_states, get_cell_parameters @@ -155,7 +156,7 @@ def __init__( detection_gi_prior={"mean": 1, "alpha": 200}, gene_add_alpha_hyp_prior={"alpha": 9.0, "beta": 3.0}, gene_add_mean_hyp_prior={"alpha": 1.0, "beta": 100.0}, - Tmax_prior={"mean": 50., "sd": 50.}, + Tmax_prior={"mean": 50., "sd": 20.}, **initial_values, ) -> None: @@ -206,6 +207,13 @@ def __init__( self.detection_gi_prior = detection_gi_prior self.detection_i_prior = detection_i_prior + self.l1 = PyroModule[torch.nn.Linear](self.n_vars, 10) + self.l2 = PyroModule[torch.nn.Linear](10, self.n_vars) + self.l3 = PyroModule[torch.nn.Linear](10, self.n_vars) + self.l4 = PyroModule[torch.nn.Linear](10, self.n_vars) + self.dropout = torch.nn.Dropout(p=0.1) + + self.register_buffer( "s_overdispersion_factor_alpha_mean", torch.tensor(self.s_overdispersion_factor_hyp_prior["alpha_mean"]), @@ -298,6 +306,7 @@ def create_plates(self, u_obs: torch.Tensor, s_obs: torch.Tensor, N_cn: torch.Tensor, + M_c: torch.Tensor, ind_x: torch.Tensor, batch_index: torch.Tensor): """ @@ -336,6 +345,7 @@ def forward(self, u_obs: torch.Tensor, s_obs: torch.Tensor, N_cn: torch.Tensor, + M_c: torch.Tensor, ind_x: torch.Tensor, batch_index: torch.Tensor): """ @@ -356,9 +366,10 @@ def forward(self, batch_size = len(ind_x) obs2sample = one_hot(batch_index, self.n_batch) - obs_plate = self.create_plates(u_obs, s_obs, N_cn, ind_x, batch_index) k = N_cn.shape[1] - N_cn = N_cn.long() + N_cn = N_cn.long() + M_c = M_c.long().unsqueeze(-1) + obs_plate = self.create_plates(u_obs, s_obs, N_cn, M_c, ind_x, batch_index) # ============= Expression Model =============== # T_max = pyro.sample('Tmax', dist.Gamma(G_a(self.Tmax_mean, self.Tmax_sd), G_b(self.Tmax_mean, self.Tmax_sd))) @@ -368,30 +379,17 @@ def forward(self, t_c = pyro.sample('t_c', dist.Normal(t_c_loc, t_c_scale).expand([batch_size, 1, 1])) T_c = pyro.deterministic('T_c', t_c*T_max) - # Time difference between neighbors: - delta_cn = T_c.unsqueeze(-1) - T_c[N_cn, :] - - # Counts in each cell: + # Time difference between neighbors (previously: T_c.unsqueeze(-1) - T_c[N_cn, :]): with obs_plate: - mu0_cg = pyro.sample('mu0_cg', dist.Gamma(self.one*5.0, self.one*1.0).expand([batch_size, self.n_genes, 2])) - - # Weight of each nearest neighbor: - wdash0_nc = pyro.sample('wdash0_nc', dist.Gamma(self.one*0.1, self.one*10.0).expand([1,batch_size]).to_event(2)) - wdash5_nc = pyro.sample('wdash_nc', dist.Gamma(self.one*10.0, self.one*10.0).expand([k-1,batch_size]).to_event(2)) - wdash_nc = torch.concat([wdash0_nc, wdash5_nc], axis = 0) - w_nc = wdash_nc/torch.sum(wdash_nc, axis = 0) - - # Predicted counts from each neighbor: - y = (mu0_cg[...,0], mu0_cg[...,1]) - dy_cn = torch.stack(vector_field_1(0.0,y,[regulatory_function_1]), axis = -1)[N_cn,...] - muhat_cg = torch.stack(y, axis = -1) + torch.sum((w_nc.T.unsqueeze(-1).unsqueeze(-1) * delta_cn * dy_cn), axis = 1)/k + delta_cn = pyro.sample('delta_cn', dist.Gamma(self.one, self.one).expand([batch_size, k, 1])) - # Initial conditions and predicted counts need to match up: - with obs_plate: - pyro.sample("constrain", dist.Normal(muhat_cg, self.one*torch.abs(muhat_cg)*0.2), - obs=mu0_cg) + # Counts in each cell: + # with obs_plate: + # mu0_cg = pyro.sample('mu0_cg', dist.Gamma(self.one*5.0, self.one*1.0).expand([batch_size, self.n_genes, 2])) + + mu0_cg = pyro.deterministic('mu0_cg', torch.stack([u_obs, s_obs], axis = 2)/M_c) - # ============= Measurement Model =============== # + # ============= Measurement Model =============== # # Cell specific relative detection efficiency with hierarchical prior across batches: detection_mean_y_e = pyro.sample( "detection_mean_y_e", @@ -413,60 +411,68 @@ def forward(self, "detection_y_c", dist.Gamma(detection_hyp_prior_alpha.unsqueeze(dim=-1), beta.unsqueeze(dim=-1)), ) # (self.n_obs, 1) + + # =====================Expected observed expression ======================= # + with obs_plate: + mu = pyro.deterministic('mu', (self.one*10**(-5) + mu0_cg * detection_y_c)) - # Global relative detection efficiency between spliced and unspliced counts - detection_y_i = pyro.sample( - "detection_y_i", - dist.Gamma( - self.ones * self.detection_i_prior_alpha, - self.ones * self.detection_i_prior_alpha, - ) - .expand([1, 1, 2]).to_event(3) - ) + # Weight of each nearest neighbor: + wdash0_nc = pyro.sample('wdash0_nc', dist.Gamma(self.one*0.000001, self.one*1000000.0).expand([1,batch_size]).to_event(2)) + wdash5_nc = pyro.sample('wdash5_nc', dist.Gamma(self.one*0.1, self.one*0.1).expand([k-1,batch_size]).to_event(2)) + wdash_nc = pyro.deterministic('wdash_nc',torch.concat([wdash0_nc, wdash5_nc], axis = 0)) + w_nc = pyro.deterministic('w_nc', wdash_nc/torch.sum(wdash_nc, axis = 0)) - # Gene specific relative detection efficiency between spliced and unspliced counts - detection_y_gi = pyro.sample( - "detection_y_gi", - dist.Gamma( - self.ones * self.detection_gi_prior_alpha, - self.ones * self.detection_gi_prior_alpha, - ) - .expand([1, self.n_vars, 2]) - .to_event(3), - ) + # Vector field: + + # x_u = self.l1(torch.log(mu[...,0])) + # x_u = F.leaky_relu(x_u) - # Gene-specific additive component (Ambient RNA/ "Soup") # - s_g_gene_add_alpha_hyp = pyro.sample( - "s_g_gene_add_alpha_hyp", - dist.Gamma(self.gene_add_alpha_hyp_prior_alpha, self.gene_add_alpha_hyp_prior_beta).expand([2]).to_event(1), - ) - s_g_gene_add_mean = pyro.sample( - "s_g_gene_add_mean", - dist.Gamma( - self.gene_add_mean_hyp_prior_alpha, - self.gene_add_mean_hyp_prior_beta, - ) - .expand([self.n_batch, 1, 2]) - .to_event(3), - ) - s_g_gene_add_alpha_e_inv = pyro.sample( - "s_g_gene_add_alpha_e_inv", - dist.Exponential(s_g_gene_add_alpha_hyp).expand([self.n_batch, 1, 2]).to_event(3), - ) - s_g_gene_add_alpha_e = self.ones / s_g_gene_add_alpha_e_inv.pow(2) - s_g_gene_add = pyro.sample( - "s_g_gene_add", - dist.Gamma(s_g_gene_add_alpha_e, s_g_gene_add_alpha_e / s_g_gene_add_mean) - .expand([self.n_batch, self.n_vars, 2]) - .to_event(3), - ) + alpha0_g = pyro.sample('alpha0_g', dist.Gamma(self.one, self.one).expand([1,self.n_vars]).to_event(2)) + beta0_g = pyro.sample('beta0_g', dist.Gamma(self.one, self.one/2.0).expand([1,self.n_vars]).to_event(2)) + gamma0_g = pyro.sample('gamma0_g', dist.Gamma(self.one, self.one).expand([1,self.n_vars]).to_event(2)) + + x = self.l1(torch.log(mu[...,1])) + x = self.dropout(x) + x = F.leaky_relu(x) + + # x = torch.concat([x_u, x_s], axis = -1) + + x_alpha = self.l2(x) + x_alpha = F.leaky_relu(x_alpha) + alpha = torch.sigmoid(x_alpha)*alpha0_g + + x_beta = self.l3(x) + x_beta = F.leaky_relu(x_beta) + beta = torch.sigmoid(x_beta)*beta0_g + + x_gamma = self.l4(x) + x_gamma = F.leaky_relu(x_gamma) + gamma = torch.sigmoid(x_gamma)*gamma0_g + + pyro.deterministic('alpha', alpha) + pyro.deterministic('beta', beta) + pyro.deterministic('gamma', gamma) + + # print('alpha', alpha.shape) + # print('beta', beta.shape) + # print('gamma', gamma.shape) + + du = alpha - beta*mu[...,0] + ds = beta*mu[...,0] - gamma*mu[...,1] + dy = du, ds - # =====================Expected observed expression ======================= # - with obs_plate: - mu = pyro.deterministic('mu', (mu0_cg + torch.einsum('cbi,bgi->cgi', obs2sample.unsqueeze(dim=-1), s_g_gene_add)) * \ - detection_y_c * detection_y_i * detection_y_gi) + # Predicted counts from each neighbor: + y = (mu[...,0], mu[...,1]) + # dy = vector_field_1(0.0,y,[regulatory_function_1]) + velocity = pyro.deterministic('velocity', dy[1]) + dy_cn = pyro.deterministic('dy_cn', torch.stack(dy, axis = -1)[N_cn,...]) + muhat_cg = pyro.deterministic('muhat_cg', (torch.stack(y, axis = -1) + torch.sum((w_nc.T.unsqueeze(-1).unsqueeze(-1) * delta_cn.unsqueeze(-1) * dy_cn), axis = 1)/k)) # =====================DATA likelihood ======================= # with obs_plate: - pyro.sample("data_target", dist.Poisson(rate = mu), - obs=torch.stack([u_obs, s_obs], axis = 2)) \ No newline at end of file + # pyro.sample("data_target", dist.Poisson(rate = mu), + # obs=torch.stack([u_obs, s_obs], axis = 2)) + pyro.sample("constrain", dist.Normal(muhat_cg, 0.01), obs=mu) + + # print('MAE', torch.sum((torch.abs(muhat_cg - mu)))) + # print('1', torch.sum((mu - torch.stack([u_obs, s_obs], axis = 2))**2)) \ No newline at end of file From 46cb17de6f69cecb3f6aaae994c18f60cd6db720 Mon Sep 17 00:00:00 2001 From: Alexander Aivazidis Date: Fri, 13 Sep 2024 17:14:27 +0000 Subject: [PATCH 45/47] feat(_velocity_module): Provide number of cells in each metacell to model. Signed-off-by: Alexander Aivazidis --- src/pyrovelocity/models/knn_model/_velocity_module.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/pyrovelocity/models/knn_model/_velocity_module.py b/src/pyrovelocity/models/knn_model/_velocity_module.py index 33e983f81..bd53796bf 100644 --- a/src/pyrovelocity/models/knn_model/_velocity_module.py +++ b/src/pyrovelocity/models/knn_model/_velocity_module.py @@ -194,12 +194,14 @@ def _get_fn_args_from_batch( u_obs = tensor_dict["U"] s_obs = tensor_dict["X"] N_cn = tensor_dict["N_cn"] + M_c = tensor_dict["M_c"] ind_x = tensor_dict["ind_x"].long().squeeze() batch_index = tensor_dict[REGISTRY_KEYS.BATCH_KEY] return ( u_obs, s_obs, N_cn, + M_c, ind_x, batch_index ), {} From 9fec6836a7e398cf7961822b61fe4870ae3ebd68 Mon Sep 17 00:00:00 2001 From: Alexander Aivazidis Date: Fri, 13 Sep 2024 19:01:57 +0000 Subject: [PATCH 46/47] fix(compute_metacell): copying over var_names Signed-off-by: Alexander Aivazidis --- src/pyrovelocity/tasks/preprocess.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/pyrovelocity/tasks/preprocess.py b/src/pyrovelocity/tasks/preprocess.py index 1720e9e72..d80b9c4c5 100644 --- a/src/pyrovelocity/tasks/preprocess.py +++ b/src/pyrovelocity/tasks/preprocess.py @@ -716,6 +716,7 @@ def merge_RNA( X = np.concatenate([np.sum(adata_rna.X[adata_rna.obs[cluster_key] == c,:], axis = 0) for c in np.unique(adata_rna.obs[cluster_key])], axis = 0) print(X.shape) adata_meta = sc.AnnData(X = np.array(X)) + adata_meta.var = adata_rna.var adata_meta.obs['n_cells'] = [np.sum(adata_rna.obs[cluster_key] == c) for c in np.unique(adata_rna.obs[cluster_key])] if celltype_key: adata_meta.obs[celltype_key] = [adata_rna[adata_rna.obs[cluster_key] == c,:].obs[celltype_key].mode()[0] for c in np.unique(adata_rna.obs[cluster_key])] @@ -769,6 +770,7 @@ def merge_ATAC( X = np.concatenate([np.sum(adata_atac.X[adata_rna.obs[cluster_key] == c,:], axis = 0) for c in np.unique(adata_atac.obs[cluster_key])], axis = 0) adata_atac_meta = sc.AnnData(X = np.array(X)) + adata_atac_meta.var = adata_atac.var adata_atac_meta.obs['n_cells'] = [np.sum(adata_atac.obs[cluster_key] == c) for c in np.unique(adata_atac.obs[cluster_key])] if celltype_key: adata_atac_meta.obs[celltype_key] = [adata_atac[adata_atac.obs[cluster_key] == c,:].obs[celltype_key].mode()[0] for c in np.unique(adata_atac.obs[cluster_key])] From bba3a2fd52eafc1e2aa9bb3a240b22a6a5e276c8 Mon Sep 17 00:00:00 2001 From: Alexander Aivazidis Date: Fri, 13 Sep 2024 19:02:33 +0000 Subject: [PATCH 47/47] feat(velocity_model): changed neural networks to be similar to celldancer Signed-off-by: Alexander Aivazidis --- .../models/knn_model/_velocity_model.py | 49 +++++++++++-------- 1 file changed, 28 insertions(+), 21 deletions(-) diff --git a/src/pyrovelocity/models/knn_model/_velocity_model.py b/src/pyrovelocity/models/knn_model/_velocity_model.py index ecc26c550..df95ce651 100644 --- a/src/pyrovelocity/models/knn_model/_velocity_model.py +++ b/src/pyrovelocity/models/knn_model/_velocity_model.py @@ -207,10 +207,16 @@ def __init__( self.detection_gi_prior = detection_gi_prior self.detection_i_prior = detection_i_prior - self.l1 = PyroModule[torch.nn.Linear](self.n_vars, 10) - self.l2 = PyroModule[torch.nn.Linear](10, self.n_vars) - self.l3 = PyroModule[torch.nn.Linear](10, self.n_vars) - self.l4 = PyroModule[torch.nn.Linear](10, self.n_vars) + h1 = 10 + h2 = 10 + self.l1 = [] + self.l2 = [] + self.l3 = [] + for i in range(self.n_vars): + self.l1 += [PyroModule[torch.nn.Linear](2, h1)] + self.l2 += [PyroModule[torch.nn.Linear](h1, h2)] + self.l3 += [PyroModule[torch.nn.Linear](h2, 3)] + self.dropout = torch.nn.Dropout(p=0.1) @@ -431,23 +437,24 @@ def forward(self, beta0_g = pyro.sample('beta0_g', dist.Gamma(self.one, self.one/2.0).expand([1,self.n_vars]).to_event(2)) gamma0_g = pyro.sample('gamma0_g', dist.Gamma(self.one, self.one).expand([1,self.n_vars]).to_event(2)) - x = self.l1(torch.log(mu[...,1])) - x = self.dropout(x) - x = F.leaky_relu(x) - - # x = torch.concat([x_u, x_s], axis = -1) - - x_alpha = self.l2(x) - x_alpha = F.leaky_relu(x_alpha) - alpha = torch.sigmoid(x_alpha)*alpha0_g - - x_beta = self.l3(x) - x_beta = F.leaky_relu(x_beta) - beta = torch.sigmoid(x_beta)*beta0_g - - x_gamma = self.l4(x) - x_gamma = F.leaky_relu(x_gamma) - gamma = torch.sigmoid(x_gamma)*gamma0_g + print('mu', mu.shape) + betas = [] + alphas = [] + gammas = [] + for i in range(self.n_vars): + x = self.l1[i](mu[:,i,...]) + x = F.leaky_relu(x) + x = self.l2[i](x) + x = F.leaky_relu(x) + x = self.l3[i](x) + output = torch.sigmoid(x) + betas += [output[:,0]] + gammas += [output[:,1]] + alphas += [output[:,2]] + + alphas = torch.concat(alphas, axis = -1) * alpha0_g + beta = torch.concat(betas, axis = -1) * beta0_g + gamma = torch.concat(gammas, axis = -1) * gamma0_g pyro.deterministic('alpha', alpha) pyro.deterministic('beta', beta)