diff --git a/src/scvi/external/mrvi/_model.py b/src/scvi/external/mrvi/_model.py index baad88751d..3569448bae 100644 --- a/src/scvi/external/mrvi/_model.py +++ b/src/scvi/external/mrvi/_model.py @@ -783,7 +783,11 @@ def get_aggregated_posterior( ------- A mixture distribution of the aggregated posterior. """ - from numpyro.distributions import Categorical, MixtureSameFamily, Normal + from numpyro.distributions import ( + Categorical, + MixtureSameFamily, + Normal, + ) self._check_if_trained(warn=False) adata = self._validate_anndata(adata) @@ -806,11 +810,11 @@ def get_aggregated_posterior( qu_locs.append(outputs["qu"].loc) qu_scales.append(outputs["qu"].scale) - qu_loc = jnp.concatenate(qu_locs, axis=0).T - qu_scale = jnp.concatenate(qu_scales, axis=0).T + qu_loc = jnp.concatenate(qu_locs, axis=0) # n_cells x n_latent_u + qu_scale = jnp.concatenate(qu_scales, axis=0) # n_cells x n_latent_u return MixtureSameFamily( - Categorical(probs=jnp.ones(qu_loc.shape[1]) / qu_loc.shape[1]), - Normal(qu_loc, qu_scale), + Categorical(probs=jnp.ones(qu_loc.shape[0]) / qu_loc.shape[0]), + (Normal(qu_loc, qu_scale).to_event(1)), ) def differential_abundance( @@ -819,6 +823,7 @@ def differential_abundance( sample_cov_keys: list[str] | None = None, sample_subset: list[str] | None = None, compute_log_enrichment: bool = False, + omit_original_sample: bool = True, batch_size: int = 128, ) -> xr.Dataset: """Compute the differential abundance between samples. @@ -839,6 +844,9 @@ def differential_abundance( Only compute differential abundance for these sample labels. compute_log_enrichment Whether to compute the log enrichment scores for each covariate value. + omit_original_sample + If true, each cell's sample-of-origin is discarded to compute aggregate posteriors. + Only relevant if sample_cov_keys is not None. batch_size Minibatch size for computing the differential abundance. @@ -883,7 +891,7 @@ def differential_abundance( n_splits = max(adata.n_obs // batch_size, 1) log_probs_ = [] for u_rep in np.array_split(us, n_splits): - log_probs_.append(jax.device_get(ap.log_prob(u_rep).sum(-1, keepdims=True))) + log_probs_.append(jax.device_get(ap.log_prob(u_rep))[..., np.newaxis]) log_probs.append(np.concatenate(log_probs_, axis=0)) # (n_cells, 1) @@ -901,6 +909,23 @@ def differential_abundance( if sample_cov_keys is None or len(sample_cov_keys) == 0: return log_probs_arr + def aggregate_log_probs(log_probs, samples, omit_original_sample=False): + sample_log_probs = log_probs.loc[ + {"sample": samples} + ].values # (n_cells, n_samples_in_group) + if omit_original_sample: + sample_one_hot = np.zeros((adata.n_obs, len(samples))) + for i, sample in enumerate(samples): + sample_one_hot[adata.obs[self.sample_key] == sample, i] = 1 + log_probs_no_original = np.where( + sample_one_hot, -np.inf, sample_log_probs + ) # virtually discards samples-of-origin from aggregate posteriors + return logsumexp(log_probs_no_original, axis=1) - np.log( + (1 - sample_one_hot).sum(axis=1) + ) + else: + return logsumexp(sample_log_probs, axis=1) - np.log(sample_log_probs.shape[1]) + sample_cov_log_probs_map = {} sample_cov_log_enrichs_map = {} for sample_cov_key in sample_cov_keys: @@ -916,8 +941,11 @@ def differential_abundance( if len(cov_samples) == 0: continue - sel_log_probs = log_probs_arr.log_probs.loc[{"sample": cov_samples}] - val_log_probs = logsumexp(sel_log_probs, axis=1) - np.log(sel_log_probs.shape[1]) + val_log_probs = aggregate_log_probs( + log_probs_arr.log_probs, + cov_samples, + omit_original_sample=omit_original_sample, + ) per_val_log_probs[sample_cov_value] = val_log_probs if compute_log_enrichment: @@ -930,9 +958,10 @@ def differential_abundance( stacklevel=2, ) continue - rest_log_probs = log_probs_arr.log_probs.loc[{"sample": rest_samples}] - rest_val_log_probs = logsumexp(rest_log_probs, axis=1) - np.log( - rest_log_probs.shape[1] + rest_val_log_probs = aggregate_log_probs( + log_probs_arr.log_probs, + rest_samples, + omit_original_sample=omit_original_sample, ) enrichment_scores = val_log_probs - rest_val_log_probs per_val_log_enrichs[sample_cov_value] = enrichment_scores @@ -1018,8 +1047,8 @@ def get_outlier_cell_sample_pairs( ap = self.get_aggregated_posterior(adata=adata, indices=sample_idxs) in_max_comp_log_probs = ap.component_distribution.log_prob( - np.expand_dims(adata_s.obsm["U"], ap.mixture_dim) - ).sum(axis=1) + np.expand_dims(adata_s.obsm["U"], ap.mixture_dim) # (n_cells_ap, 1, n_latent_dim) + ) # (n_cells_ap, n_cells_ap) log_probs_s = rowwise_max_excluding_diagonal(in_max_comp_log_probs) log_probs_ = [] @@ -1028,10 +1057,12 @@ def get_outlier_cell_sample_pairs( log_probs_.append( jax.device_get( ap.component_distribution.log_prob( - np.expand_dims(u_rep, ap.mixture_dim) - ) # (n_cells_batch, n_cells_ap, n_latent_dim) - .sum(axis=1) # (n_cells_batch, n_latent_dim) - .max(axis=1, keepdims=True) # (n_cells_batch, 1) + np.expand_dims( + u_rep, ap.mixture_dim + ) # (n_cells_batch, 1, n_latent_dim) + ).max( # (n_cells_batch, n_cells_ap) + axis=1, keepdims=True + ) # (n_cells_batch, 1) ) ) diff --git a/src/scvi/external/mrvi/_module.py b/src/scvi/external/mrvi/_module.py index 303f4589d5..519c9b96b9 100644 --- a/src/scvi/external/mrvi/_module.py +++ b/src/scvi/external/mrvi/_module.py @@ -268,7 +268,10 @@ def __call__( sample_covariate: jax.typing.ArrayLike, training: bool | None = None, ) -> dist.Normal: - from scvi.external.mrvi._components import ConditionalNormalization, NormalDistOutputNN + from scvi.external.mrvi._components import ( + ConditionalNormalization, + NormalDistOutputNN, + ) training = nn.merge_param("training", self.training, training) x_feat = jnp.log1p(x) @@ -424,10 +427,10 @@ def setup(self): "u_prior_logits", nn.initializers.zeros, (u_prior_mixture_k,) ) self.u_prior_means = self.param( - "u_prior_means", jax.random.normal, (u_dim, u_prior_mixture_k) + "u_prior_means", jax.random.normal, (u_prior_mixture_k, u_dim) ) self.u_prior_scales = self.param( - "u_prior_scales", nn.initializers.zeros, (u_dim, u_prior_mixture_k) + "u_prior_scales", nn.initializers.zeros, (u_prior_mixture_k, u_dim) ) @property @@ -517,7 +520,9 @@ def generative( 10.0 * jax.nn.one_hot(label_index, self.n_labels) if self.n_labels >= 2 else 0.0 ) cats = dist.Categorical(logits=self.u_prior_logits + offset) - normal_dists = dist.Normal(self.u_prior_means, jnp.exp(self.u_prior_scales)) + normal_dists = dist.Normal(self.u_prior_means, jnp.exp(self.u_prior_scales)).to_event( + 1 + ) pu = dist.MixtureSameFamily(cats, normal_dists) else: pu = dist.Normal(0, jnp.exp(self.u_prior_scale)) @@ -536,13 +541,11 @@ def loss( ) if self.u_prior_mixture: - kl_u = inference_outputs["qu"].log_prob(inference_outputs["u"]) - generative_outputs[ - "pu" - ].log_prob(inference_outputs["u"]) - kl_u = kl_u.sum(-1) + kl_u = inference_outputs["qu"].log_prob(inference_outputs["u"]).sum( + -1 + ) - generative_outputs["pu"].log_prob(inference_outputs["u"]) else: kl_u = dist.kl_divergence(inference_outputs["qu"], generative_outputs["pu"]).sum(-1) - inference_outputs["qeps"] kl_z = 0.0 eps = inference_outputs["z"] - inference_outputs["z_base"] diff --git a/tests/external/mrvi/test_model.py b/tests/external/mrvi/test_model.py index 05edb27496..75561589f2 100644 --- a/tests/external/mrvi/test_model.py +++ b/tests/external/mrvi/test_model.py @@ -117,6 +117,7 @@ def test_mrvi_de(model: MRVI, setup_kwargs: dict[str, Any], de_kwargs: dict[str, [ {"sample_cov_keys": ["meta1_cat"]}, {"sample_cov_keys": ["meta1_cat", "batch"]}, + {"sample_cov_keys": ["meta1_cat"], "omit_original_sample": False}, {"sample_cov_keys": ["meta1_cat"], "compute_log_enrichment": True}, {"sample_cov_keys": ["meta1_cat", "batch"], "compute_log_enrichment": True}, ],