From 09947b65bf440e0d3c9553a055970a26312a55bb Mon Sep 17 00:00:00 2001 From: Justin Hong Date: Thu, 13 Feb 2025 13:59:58 -0500 Subject: [PATCH 01/14] fix aggregated posterior computation as described in #3188 --- src/scvi/external/mrvi/_model.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/scvi/external/mrvi/_model.py b/src/scvi/external/mrvi/_model.py index e7e7b7c2fa..23b1316cb7 100644 --- a/src/scvi/external/mrvi/_model.py +++ b/src/scvi/external/mrvi/_model.py @@ -789,11 +789,11 @@ def get_aggregated_posterior( qu_locs.append(outputs["qu"].loc) qu_scales.append(outputs["qu"].scale) - qu_loc = jnp.concatenate(qu_locs, axis=0).T - qu_scale = jnp.concatenate(qu_scales, axis=0).T + qu_loc = jnp.concatenate(qu_locs, axis=0) + qu_scale = jnp.concatenate(qu_scales, axis=0) return MixtureSameFamily( - Categorical(probs=jnp.ones(qu_loc.shape[1]) / qu_loc.shape[1]), - Normal(qu_loc, qu_scale), + Categorical(probs=jnp.ones(qu_loc.shape[0]) / qu_loc.shape[0]), + Normal(qu_loc, qu_scale).to_event(1), ) def differential_abundance( From d87bb96a112115869fe7c0aca18eeaa4093fc335 Mon Sep 17 00:00:00 2001 From: Justin Hong Date: Thu, 13 Feb 2025 16:02:56 -0500 Subject: [PATCH 02/14] fix dimension issue --- src/scvi/external/mrvi/_model.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/scvi/external/mrvi/_model.py b/src/scvi/external/mrvi/_model.py index 23b1316cb7..db7ba848b3 100644 --- a/src/scvi/external/mrvi/_model.py +++ b/src/scvi/external/mrvi/_model.py @@ -866,11 +866,11 @@ def differential_abundance( n_splits = max(adata.n_obs // batch_size, 1) log_probs_ = [] for u_rep in np.array_split(us, n_splits): - log_probs_.append(jax.device_get(ap.log_prob(u_rep).sum(-1, keepdims=True))) + log_probs_.append(jax.device_get(ap.log_prob(u_rep))) - log_probs.append(np.concatenate(log_probs_, axis=0)) # (n_cells, 1) + log_probs.append(np.concatenate(log_probs_)) # (n_cells,) - log_probs = np.concatenate(log_probs, 1) + log_probs = np.concatenate(log_probs) coords = { "cell_name": adata.obs_names.to_numpy(), @@ -1002,7 +1002,7 @@ def get_outlier_cell_sample_pairs( ap = self.get_aggregated_posterior(adata=adata, indices=sample_idxs) in_max_comp_log_probs = ap.component_distribution.log_prob( np.expand_dims(adata_s.obsm["U"], ap.mixture_dim) - ).sum(axis=1) + ) log_probs_s = rowwise_max_excluding_diagonal(in_max_comp_log_probs) log_probs_ = [] From a6cee88faef2da49c142e660a6a08d61b6e38913 Mon Sep 17 00:00:00 2001 From: Justin Hong Date: Thu, 13 Feb 2025 16:04:34 -0500 Subject: [PATCH 03/14] fix mixture distribution issue for MoG prior --- src/scvi/external/mrvi/_module.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/src/scvi/external/mrvi/_module.py b/src/scvi/external/mrvi/_module.py index 303f4589d5..132f929a0e 100644 --- a/src/scvi/external/mrvi/_module.py +++ b/src/scvi/external/mrvi/_module.py @@ -268,7 +268,10 @@ def __call__( sample_covariate: jax.typing.ArrayLike, training: bool | None = None, ) -> dist.Normal: - from scvi.external.mrvi._components import ConditionalNormalization, NormalDistOutputNN + from scvi.external.mrvi._components import ( + ConditionalNormalization, + NormalDistOutputNN, + ) training = nn.merge_param("training", self.training, training) x_feat = jnp.log1p(x) @@ -424,10 +427,10 @@ def setup(self): "u_prior_logits", nn.initializers.zeros, (u_prior_mixture_k,) ) self.u_prior_means = self.param( - "u_prior_means", jax.random.normal, (u_dim, u_prior_mixture_k) + "u_prior_means", jax.random.normal, (u_prior_mixture_k, u_dim) ) self.u_prior_scales = self.param( - "u_prior_scales", nn.initializers.zeros, (u_dim, u_prior_mixture_k) + "u_prior_scales", nn.initializers.zeros, (u_prior_mixture_k, u_dim) ) @property @@ -517,7 +520,9 @@ def generative( 10.0 * jax.nn.one_hot(label_index, self.n_labels) if self.n_labels >= 2 else 0.0 ) cats = dist.Categorical(logits=self.u_prior_logits + offset) - normal_dists = dist.Normal(self.u_prior_means, jnp.exp(self.u_prior_scales)) + normal_dists = dist.Normal(self.u_prior_means, jnp.exp(self.u_prior_scales)).to_event( + 1 + ) pu = dist.MixtureSameFamily(cats, normal_dists) else: pu = dist.Normal(0, jnp.exp(self.u_prior_scale)) @@ -539,9 +544,8 @@ def loss( kl_u = inference_outputs["qu"].log_prob(inference_outputs["u"]) - generative_outputs[ "pu" ].log_prob(inference_outputs["u"]) - kl_u = kl_u.sum(-1) else: - kl_u = dist.kl_divergence(inference_outputs["qu"], generative_outputs["pu"]).sum(-1) + kl_u = dist.kl_divergence(inference_outputs["qu"], generative_outputs["pu"]) inference_outputs["qeps"] kl_z = 0.0 From 973952c8a061fb571717c232a7a21acc52610401 Mon Sep 17 00:00:00 2001 From: Justin Hong Date: Thu, 13 Feb 2025 16:26:30 -0500 Subject: [PATCH 04/14] fix loss term dimension --- src/scvi/external/mrvi/_module.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/src/scvi/external/mrvi/_module.py b/src/scvi/external/mrvi/_module.py index 132f929a0e..519c9b96b9 100644 --- a/src/scvi/external/mrvi/_module.py +++ b/src/scvi/external/mrvi/_module.py @@ -541,12 +541,11 @@ def loss( ) if self.u_prior_mixture: - kl_u = inference_outputs["qu"].log_prob(inference_outputs["u"]) - generative_outputs[ - "pu" - ].log_prob(inference_outputs["u"]) + kl_u = inference_outputs["qu"].log_prob(inference_outputs["u"]).sum( + -1 + ) - generative_outputs["pu"].log_prob(inference_outputs["u"]) else: - kl_u = dist.kl_divergence(inference_outputs["qu"], generative_outputs["pu"]) - inference_outputs["qeps"] + kl_u = dist.kl_divergence(inference_outputs["qu"], generative_outputs["pu"]).sum(-1) kl_z = 0.0 eps = inference_outputs["z"] - inference_outputs["z_base"] From 1660bbb92273500b7361d375dd12c4192709bb8e Mon Sep 17 00:00:00 2001 From: Justin Hong Date: Thu, 13 Feb 2025 17:03:28 -0500 Subject: [PATCH 05/14] fix dim issues in get outlier cell sample pairs --- src/scvi/external/mrvi/_model.py | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/src/scvi/external/mrvi/_model.py b/src/scvi/external/mrvi/_model.py index db7ba848b3..a43c822d5d 100644 --- a/src/scvi/external/mrvi/_model.py +++ b/src/scvi/external/mrvi/_model.py @@ -866,11 +866,11 @@ def differential_abundance( n_splits = max(adata.n_obs // batch_size, 1) log_probs_ = [] for u_rep in np.array_split(us, n_splits): - log_probs_.append(jax.device_get(ap.log_prob(u_rep))) + log_probs_.append(jax.device_get(ap.log_prob(u_rep))[..., np.newaxis]) - log_probs.append(np.concatenate(log_probs_)) # (n_cells,) + log_probs.append(np.concatenate(log_probs_, axis=0)) # (n_cells, 1) - log_probs = np.concatenate(log_probs) + log_probs = np.concatenate(log_probs, 1) coords = { "cell_name": adata.obs_names.to_numpy(), @@ -1001,8 +1001,8 @@ def get_outlier_cell_sample_pairs( ap = self.get_aggregated_posterior(adata=adata, indices=sample_idxs) in_max_comp_log_probs = ap.component_distribution.log_prob( - np.expand_dims(adata_s.obsm["U"], ap.mixture_dim) - ) + np.expand_dims(adata_s.obsm["U"], ap.mixture_dim) # (n_cells_ap, 1, n_latent_dim) + ) # (n_cells_ap, n_cells_ap) log_probs_s = rowwise_max_excluding_diagonal(in_max_comp_log_probs) log_probs_ = [] @@ -1011,10 +1011,12 @@ def get_outlier_cell_sample_pairs( log_probs_.append( jax.device_get( ap.component_distribution.log_prob( - np.expand_dims(u_rep, ap.mixture_dim) - ) # (n_cells_batch, n_cells_ap, n_latent_dim) - .sum(axis=1) # (n_cells_batch, n_latent_dim) - .max(axis=1, keepdims=True) # (n_cells_batch, 1) + np.expand_dims( + u_rep, ap.mixture_dim + ) # (n_cells_batch, 1, n_latent_dim) + ).max( # (n_cells_batch, n_cells_ap) + axis=1, keepdims=True + ) # (n_cells_batch, 1) ) ) From 0c40f542515b4f329671790590dbc1c42dc487b3 Mon Sep 17 00:00:00 2001 From: Justin Hong Date: Thu, 20 Feb 2025 13:15:26 -0500 Subject: [PATCH 06/14] omit sample option for diff abundance --- src/scvi/external/mrvi/_model.py | 163 +++++++++++++++++++++++------- tests/external/mrvi/test_model.py | 9 +- 2 files changed, 136 insertions(+), 36 deletions(-) diff --git a/src/scvi/external/mrvi/_model.py b/src/scvi/external/mrvi/_model.py index a43c822d5d..654b9adb9f 100644 --- a/src/scvi/external/mrvi/_model.py +++ b/src/scvi/external/mrvi/_model.py @@ -133,7 +133,9 @@ def to_device(self, device): # TODO(jhong): remove this once we have a better way to handle device. pass - def _generate_stacked_rngs(self, n_sets: int | tuple) -> dict[str, jax.random.KeyArray]: + def _generate_stacked_rngs( + self, n_sets: int | tuple + ) -> dict[str, jax.random.KeyArray]: return_1d = isinstance(n_sets, int) if return_1d: n_sets_1d = n_sets @@ -191,7 +193,9 @@ def setup_anndata( fields.NumericalObsField(REGISTRY_KEYS.INDICES_KEY, "_indices"), ] - adata_manager = AnnDataManager(fields=anndata_fields, setup_method_args=setup_method_args) + adata_manager = AnnDataManager( + fields=anndata_fields, setup_method_args=setup_method_args + ) adata_manager.register_fields(adata, **kwargs) cls.register_manager(adata_manager) @@ -405,12 +409,16 @@ def per_sample_inference_fn(pair): for ur in reqs.ungrouped_reductions: ungrouped_data_arrs[ur.name] = [] for gr in reqs.grouped_reductions: - grouped_data_arrs[gr.name] = {} # Will map group category to running group sum. + grouped_data_arrs[gr.name] = ( + {} + ) # Will map group category to running group sum. for array_dict in tqdm(scdl): indices = array_dict[REGISTRY_KEYS.INDICES_KEY].astype(int).flatten() n_cells = array_dict[REGISTRY_KEYS.X_KEY].shape[0] - cf_sample = np.broadcast_to(np.arange(n_sample)[:, None, None], (n_sample, n_cells, 1)) + cf_sample = np.broadcast_to( + np.arange(n_sample)[:, None, None], (n_sample, n_cells, 1) + ) inf_inputs = self.module._get_inference_input( array_dict, ) @@ -478,8 +486,11 @@ def per_sample_inference_fn(pair): normalization_means = normalization_means.reshape(-1, 1, 1, 1) normalization_vars = normalization_vars.reshape(-1, 1, 1, 1) normalized_dists = ( - (sampled_dists - normalization_means) / (normalization_vars**0.5) - ).mean(dim="mc_sample") # (n_cells, n_samples, n_samples) + (sampled_dists - normalization_means) + / (normalization_vars**0.5) + ).mean( + dim="mc_sample" + ) # (n_cells, n_samples, n_samples) # Compute each reduction for r in reductions: @@ -503,7 +514,9 @@ def per_sample_inference_fn(pair): group_by_cats = group_by.unique() for cat in group_by_cats: cat_summed_outputs = outputs.sel( - cell_name=self.adata.obs_names[indices][group_by == cat].values + cell_name=self.adata.obs_names[indices][ + group_by == cat + ].values ).sum(dim="cell_name") cat_summed_outputs = cat_summed_outputs.assign_coords( {f"{r.group_by}_name": cat} @@ -525,8 +538,12 @@ def per_sample_inference_fn(pair): group_by_counts = group_by.value_counts() averaged_grouped_data_arrs = [] for cat, count in group_by_counts.items(): - averaged_grouped_data_arrs.append(grouped_data_arrs[gr.name][cat] / count) - final_data_arr = xr.concat(averaged_grouped_data_arrs, dim=f"{gr.group_by}_name") + averaged_grouped_data_arrs.append( + grouped_data_arrs[gr.name][cat] / count + ) + final_data_arr = xr.concat( + averaged_grouped_data_arrs, dim=f"{gr.group_by}_name" + ) final_data_arrs[gr.name] = final_data_arr return xr.Dataset(data_vars=final_data_arrs) @@ -712,7 +729,9 @@ def get_local_sample_distances( reductions = [] if not keep_cell and not groupby: - raise ValueError("Undefined computation because not keep_cell and no groupby.") + raise ValueError( + "Undefined computation because not keep_cell and no groupby." + ) if keep_cell: reductions.append( MRVIReduction( @@ -743,6 +762,7 @@ def get_aggregated_posterior( self, adata: AnnData | None = None, sample: str | int | None = None, + use_student_t: bool = False, indices: npt.ArrayLike | None = None, batch_size: int = 256, ) -> Distribution: @@ -757,6 +777,8 @@ def get_aggregated_posterior( AnnData object to use. Defaults to the AnnData object used to initialize the model. sample Name or index of the sample to filter on. If ``None``, uses all cells. + use_student_t + Whether to use a student-t distribution for the aggregated posterior. indices Indices of cells to use. batch_size @@ -766,7 +788,12 @@ def get_aggregated_posterior( ------- A mixture distribution of the aggregated posterior. """ - from numpyro.distributions import Categorical, MixtureSameFamily, Normal + from numpyro.distributions import ( + Categorical, + MixtureSameFamily, + Normal, + StudentT, + ) self._check_if_trained(warn=False) adata = self._validate_anndata(adata) @@ -782,18 +809,24 @@ def get_aggregated_posterior( qu_locs = [] qu_scales = [] - jit_inference_fn = self.module.get_jit_inference_fn(inference_kwargs={"use_mean": True}) + jit_inference_fn = self.module.get_jit_inference_fn( + inference_kwargs={"use_mean": True} + ) for array_dict in scdl: outputs = jit_inference_fn(self.module.rngs, array_dict) qu_locs.append(outputs["qu"].loc) qu_scales.append(outputs["qu"].scale) - qu_loc = jnp.concatenate(qu_locs, axis=0) - qu_scale = jnp.concatenate(qu_scales, axis=0) + qu_loc = jnp.concatenate(qu_locs, axis=0) # n_cells x n_latent_u + qu_scale = jnp.concatenate(qu_scales, axis=0) # n_cells x n_latent_u return MixtureSameFamily( Categorical(probs=jnp.ones(qu_loc.shape[0]) / qu_loc.shape[0]), - Normal(qu_loc, qu_scale).to_event(1), + ( + Normal(qu_loc, qu_scale).to_event(1) + if use_student_t + else StudentT(qu_loc, qu_scale).to_event(1) + ), ) def differential_abundance( @@ -802,6 +835,7 @@ def differential_abundance( sample_cov_keys: list[str] | None = None, sample_subset: list[str] | None = None, compute_log_enrichment: bool = False, + omit_original_sample: bool = False, batch_size: int = 128, ) -> xr.Dataset: """Compute the differential abundance between samples. @@ -822,6 +856,9 @@ def differential_abundance( Only compute differential abundance for these sample labels. compute_log_enrichment Whether to compute the log enrichment scores for each covariate value. + omit_original_sample + Whether to omit the original sample from the differential abundance computation. + Only relevant if sample_cov_keys is not None. batch_size Minibatch size for computing the differential abundance. @@ -892,15 +929,32 @@ def differential_abundance( per_val_log_enrichs = {} for sample_cov_value in sample_cov_unique_values: cov_samples = ( - self.sample_info[self.sample_info[sample_cov_key] == sample_cov_value] + self.sample_info[ + self.sample_info[sample_cov_key] == sample_cov_value + ] )[self.sample_key].to_numpy() if sample_subset is not None: cov_samples = np.intersect1d(cov_samples, np.array(sample_subset)) if len(cov_samples) == 0: continue - sel_log_probs = log_probs_arr.log_probs.loc[{"sample": cov_samples}] - val_log_probs = logsumexp(sel_log_probs, axis=1) - np.log(sel_log_probs.shape[1]) + sel_log_probs = log_probs_arr.log_probs.loc[ + {"sample": cov_samples} + ].values + if omit_original_sample: + sample_cov_one_hot = np.zeros((adata.n_obs, len(cov_samples))) + for i, sample in enumerate(cov_samples): + sample_cov_one_hot[adata.obs[self.sample_key] == sample, i] = 1 + sel_log_probs_no_original = np.where( + sample_cov_one_hot, sel_log_probs, -np.inf + ) + val_log_probs = logsumexp( + sel_log_probs_no_original, axis=1 + ) - np.log((1 - sample_cov_one_hot).sum(axis=1)) + else: + val_log_probs = logsumexp(sel_log_probs, axis=1) - np.log( + sel_log_probs.shape[1] + ) per_val_log_probs[sample_cov_value] = val_log_probs if compute_log_enrichment: @@ -913,13 +967,32 @@ def differential_abundance( stacklevel=2, ) continue - rest_log_probs = log_probs_arr.log_probs.loc[{"sample": rest_samples}] - rest_val_log_probs = logsumexp(rest_log_probs, axis=1) - np.log( - rest_log_probs.shape[1] - ) + rest_log_probs = log_probs_arr.log_probs.loc[ + {"sample": rest_samples} + ].values + if omit_original_sample: + rest_samples_one_hot = np.zeros( + (adata.n_obs, len(rest_samples)) + ) + for i, sample in enumerate(rest_samples): + rest_samples_one_hot[ + adata.obs[self.sample_key] == sample, i + ] = 1 + rest_log_probs_no_original = np.where( + rest_samples_one_hot, rest_log_probs, -np.inf + ) + rest_val_log_probs = logsumexp( + rest_log_probs_no_original, axis=1 + ) - np.log((1 - rest_samples_one_hot).sum(axis=1)) + else: + rest_val_log_probs = logsumexp(rest_log_probs, axis=1) - np.log( + rest_log_probs.shape[1] + ) enrichment_scores = val_log_probs - rest_val_log_probs per_val_log_enrichs[sample_cov_value] = enrichment_scores - sample_cov_log_probs_map[sample_cov_key] = DataFrame.from_dict(per_val_log_probs) + sample_cov_log_probs_map[sample_cov_key] = DataFrame.from_dict( + per_val_log_probs + ) if compute_log_enrichment and len(per_val_log_enrichs) > 0: sample_cov_log_enrichs_map[sample_cov_key] = DataFrame.from_dict( per_val_log_enrichs @@ -996,12 +1069,16 @@ def get_outlier_cell_sample_pairs( for sample_name in tqdm(unique_samples): sample_idxs = np.where(adata.obs[self.sample_key] == sample_name)[0] if subsample_size is not None and sample_idxs.shape[0] > subsample_size: - sample_idxs = np.random.choice(sample_idxs, size=subsample_size, replace=False) + sample_idxs = np.random.choice( + sample_idxs, size=subsample_size, replace=False + ) adata_s = adata[sample_idxs] ap = self.get_aggregated_posterior(adata=adata, indices=sample_idxs) in_max_comp_log_probs = ap.component_distribution.log_prob( - np.expand_dims(adata_s.obsm["U"], ap.mixture_dim) # (n_cells_ap, 1, n_latent_dim) + np.expand_dims( + adata_s.obsm["U"], ap.mixture_dim + ) # (n_cells_ap, 1, n_latent_dim) ) # (n_cells_ap, n_cells_ap) log_probs_s = rowwise_max_excluding_diagonal(in_max_comp_log_probs) @@ -1309,7 +1386,9 @@ def h_inference_fn(extra_eps, batch_index_cf, batch_offset_eps): ) mc_samples, _, n_cells_, n_latent = betas_covariates.shape betas_offset_ = ( - jnp.zeros((mc_samples, self.summary_stats.n_batch, n_cells_, n_latent)) + jnp.zeros( + (mc_samples, self.summary_stats.n_batch, n_cells_, n_latent) + ) + eps_mean_ ) # batch_offset shape (mc_samples, n_batch, n_cells, n_latent) @@ -1317,7 +1396,9 @@ def h_inference_fn(extra_eps, batch_index_cf, batch_offset_eps): f_ = jax.vmap( h_inference_fn, in_axes=(0, None, 0), out_axes=0 ) # fn over MC samples - f_ = jax.vmap(f_, in_axes=(1, None, None), out_axes=1) # fn over covariates + f_ = jax.vmap( + f_, in_axes=(1, None, None), out_axes=1 + ) # fn over covariates f_ = jax.vmap(f_, in_axes=(None, 0, 1), out_axes=0) # fn over batches h_fn = jax.jit(f_) @@ -1327,7 +1408,9 @@ def h_inference_fn(extra_eps, batch_index_cf, batch_offset_eps): lfcs = jnp.log2(x_1 + eps_lfc) - jnp.log2(x_0 + eps_lfc) lfc_mean = jnp.average(lfcs.mean(1), weights=batch_weights, axis=0) if delta is not None: - lfc_std = jnp.sqrt(jnp.average(lfcs.var(1), weights=batch_weights, axis=0)) + lfc_std = jnp.sqrt( + jnp.average(lfcs.var(1), weights=batch_weights, axis=0) + ) pde = (jnp.abs(lfcs) >= delta).mean(1).mean(0) if store_baseline: @@ -1363,7 +1446,9 @@ def h_inference_fn(extra_eps, batch_index_cf, batch_offset_eps): stacked_rngs = self._generate_stacked_rngs(cf_sample.shape[0]) rngs_de = self.module.rngs if store_lfc else None - admissible_samples_mat = jnp.array(admissible_samples[indices]) # (n_cells, n_samples) + admissible_samples_mat = jnp.array( + admissible_samples[indices] + ) # (n_cells, n_samples) n_samples_per_cell = admissible_samples_mat.sum(axis=1) admissible_samples_dmat = jax.vmap(jnp.diag)(admissible_samples_mat).astype( float @@ -1400,7 +1485,9 @@ def h_inference_fn(extra_eps, batch_index_cf, batch_offset_eps): effect_size = np.concatenate(effect_size, axis=0) pvalue = np.concatenate(pvalue, axis=0) pvalue_shape = pvalue.shape - padj = false_discovery_control(pvalue.flatten(), method="bh").reshape(pvalue_shape) + padj = false_discovery_control(pvalue.flatten(), method="bh").reshape( + pvalue_shape + ) coords = { "cell_name": (("cell_name"), adata.obs_names), @@ -1519,19 +1606,27 @@ def _construct_design_matrix( Xmat_dim_to_key = np.concatenate(Xmat_dim_to_key) if normalize_design_matrix: - Xmat = (Xmat - Xmat.min(axis=0)) / (1e-6 + Xmat.max(axis=0) - Xmat.min(axis=0)) + Xmat = (Xmat - Xmat.min(axis=0)) / ( + 1e-6 + Xmat.max(axis=0) - Xmat.min(axis=0) + ) if add_batch_specific_offsets: cov = sample_info["_scvi_batch"] if cov.nunique() == self.summary_stats.n_batch: - cov = np.eye(self.summary_stats.n_batch)[sample_info["_scvi_batch"].values] - cov_names = ["offset_batch_" + str(i) for i in range(self.summary_stats.n_batch)] + cov = np.eye(self.summary_stats.n_batch)[ + sample_info["_scvi_batch"].values + ] + cov_names = [ + "offset_batch_" + str(i) for i in range(self.summary_stats.n_batch) + ] Xmat = np.concatenate([cov, Xmat], axis=1) Xmat_names = np.concatenate([np.array(cov_names), Xmat_names]) Xmat_dim_to_key = np.concatenate([np.array(cov_names), Xmat_dim_to_key]) # Retrieve indices of offset covariates in the right order offset_indices = ( - Series(np.arange(len(Xmat_names)), index=Xmat_names).loc[cov_names].values + Series(np.arange(len(Xmat_names)), index=Xmat_names) + .loc[cov_names] + .values ) offset_indices = jnp.array(offset_indices) else: diff --git a/tests/external/mrvi/test_model.py b/tests/external/mrvi/test_model.py index 05edb27496..64f3b93415 100644 --- a/tests/external/mrvi/test_model.py +++ b/tests/external/mrvi/test_model.py @@ -117,6 +117,7 @@ def test_mrvi_de(model: MRVI, setup_kwargs: dict[str, Any], de_kwargs: dict[str, [ {"sample_cov_keys": ["meta1_cat"]}, {"sample_cov_keys": ["meta1_cat", "batch"]}, + {"sample_cov_keys": ["meta1_cat"], "omit_original_sample": True}, {"sample_cov_keys": ["meta1_cat"], "compute_log_enrichment": True}, {"sample_cov_keys": ["meta1_cat", "batch"], "compute_log_enrichment": True}, ], @@ -155,7 +156,9 @@ def test_mrvi_da(model, sample_key, da_kwargs): }, ], ) -def test_mrvi_model_kwargs(adata: AnnData, model_kwargs: dict[str, Any], save_path: str): +def test_mrvi_model_kwargs( + adata: AnnData, model_kwargs: dict[str, Any], save_path: str +): MRVI.setup_anndata( adata, sample_key="sample_str", @@ -173,7 +176,9 @@ def test_mrvi_model_kwargs(adata: AnnData, model_kwargs: dict[str, Any], save_pa def test_mrvi_sample_subset(model: MRVI, adata: AnnData, save_path: str): sample_cov_keys = ["meta1_cat", "meta2", "cont_cov"] sample_subset = [chr(i + ord("a")) for i in range(8)] - model.differential_expression(sample_cov_keys=sample_cov_keys, sample_subset=sample_subset) + model.differential_expression( + sample_cov_keys=sample_cov_keys, sample_subset=sample_subset + ) model_path = os.path.join(save_path, "mrvi_model") model.save(model_path, save_anndata=False, overwrite=True) From 6d568996dc7c7557500c36741bf1c7b460d43bd0 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 20 Feb 2025 18:16:02 +0000 Subject: [PATCH 07/14] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/scvi/external/mrvi/_model.py | 117 ++++++++---------------------- tests/external/mrvi/test_model.py | 8 +- 2 files changed, 33 insertions(+), 92 deletions(-) diff --git a/src/scvi/external/mrvi/_model.py b/src/scvi/external/mrvi/_model.py index 654b9adb9f..ba4140bfe2 100644 --- a/src/scvi/external/mrvi/_model.py +++ b/src/scvi/external/mrvi/_model.py @@ -133,9 +133,7 @@ def to_device(self, device): # TODO(jhong): remove this once we have a better way to handle device. pass - def _generate_stacked_rngs( - self, n_sets: int | tuple - ) -> dict[str, jax.random.KeyArray]: + def _generate_stacked_rngs(self, n_sets: int | tuple) -> dict[str, jax.random.KeyArray]: return_1d = isinstance(n_sets, int) if return_1d: n_sets_1d = n_sets @@ -193,9 +191,7 @@ def setup_anndata( fields.NumericalObsField(REGISTRY_KEYS.INDICES_KEY, "_indices"), ] - adata_manager = AnnDataManager( - fields=anndata_fields, setup_method_args=setup_method_args - ) + adata_manager = AnnDataManager(fields=anndata_fields, setup_method_args=setup_method_args) adata_manager.register_fields(adata, **kwargs) cls.register_manager(adata_manager) @@ -409,16 +405,12 @@ def per_sample_inference_fn(pair): for ur in reqs.ungrouped_reductions: ungrouped_data_arrs[ur.name] = [] for gr in reqs.grouped_reductions: - grouped_data_arrs[gr.name] = ( - {} - ) # Will map group category to running group sum. + grouped_data_arrs[gr.name] = {} # Will map group category to running group sum. for array_dict in tqdm(scdl): indices = array_dict[REGISTRY_KEYS.INDICES_KEY].astype(int).flatten() n_cells = array_dict[REGISTRY_KEYS.X_KEY].shape[0] - cf_sample = np.broadcast_to( - np.arange(n_sample)[:, None, None], (n_sample, n_cells, 1) - ) + cf_sample = np.broadcast_to(np.arange(n_sample)[:, None, None], (n_sample, n_cells, 1)) inf_inputs = self.module._get_inference_input( array_dict, ) @@ -486,11 +478,8 @@ def per_sample_inference_fn(pair): normalization_means = normalization_means.reshape(-1, 1, 1, 1) normalization_vars = normalization_vars.reshape(-1, 1, 1, 1) normalized_dists = ( - (sampled_dists - normalization_means) - / (normalization_vars**0.5) - ).mean( - dim="mc_sample" - ) # (n_cells, n_samples, n_samples) + (sampled_dists - normalization_means) / (normalization_vars**0.5) + ).mean(dim="mc_sample") # (n_cells, n_samples, n_samples) # Compute each reduction for r in reductions: @@ -514,9 +503,7 @@ def per_sample_inference_fn(pair): group_by_cats = group_by.unique() for cat in group_by_cats: cat_summed_outputs = outputs.sel( - cell_name=self.adata.obs_names[indices][ - group_by == cat - ].values + cell_name=self.adata.obs_names[indices][group_by == cat].values ).sum(dim="cell_name") cat_summed_outputs = cat_summed_outputs.assign_coords( {f"{r.group_by}_name": cat} @@ -538,12 +525,8 @@ def per_sample_inference_fn(pair): group_by_counts = group_by.value_counts() averaged_grouped_data_arrs = [] for cat, count in group_by_counts.items(): - averaged_grouped_data_arrs.append( - grouped_data_arrs[gr.name][cat] / count - ) - final_data_arr = xr.concat( - averaged_grouped_data_arrs, dim=f"{gr.group_by}_name" - ) + averaged_grouped_data_arrs.append(grouped_data_arrs[gr.name][cat] / count) + final_data_arr = xr.concat(averaged_grouped_data_arrs, dim=f"{gr.group_by}_name") final_data_arrs[gr.name] = final_data_arr return xr.Dataset(data_vars=final_data_arrs) @@ -729,9 +712,7 @@ def get_local_sample_distances( reductions = [] if not keep_cell and not groupby: - raise ValueError( - "Undefined computation because not keep_cell and no groupby." - ) + raise ValueError("Undefined computation because not keep_cell and no groupby.") if keep_cell: reductions.append( MRVIReduction( @@ -809,9 +790,7 @@ def get_aggregated_posterior( qu_locs = [] qu_scales = [] - jit_inference_fn = self.module.get_jit_inference_fn( - inference_kwargs={"use_mean": True} - ) + jit_inference_fn = self.module.get_jit_inference_fn(inference_kwargs={"use_mean": True}) for array_dict in scdl: outputs = jit_inference_fn(self.module.rngs, array_dict) @@ -929,18 +908,14 @@ def differential_abundance( per_val_log_enrichs = {} for sample_cov_value in sample_cov_unique_values: cov_samples = ( - self.sample_info[ - self.sample_info[sample_cov_key] == sample_cov_value - ] + self.sample_info[self.sample_info[sample_cov_key] == sample_cov_value] )[self.sample_key].to_numpy() if sample_subset is not None: cov_samples = np.intersect1d(cov_samples, np.array(sample_subset)) if len(cov_samples) == 0: continue - sel_log_probs = log_probs_arr.log_probs.loc[ - {"sample": cov_samples} - ].values + sel_log_probs = log_probs_arr.log_probs.loc[{"sample": cov_samples}].values if omit_original_sample: sample_cov_one_hot = np.zeros((adata.n_obs, len(cov_samples))) for i, sample in enumerate(cov_samples): @@ -948,9 +923,9 @@ def differential_abundance( sel_log_probs_no_original = np.where( sample_cov_one_hot, sel_log_probs, -np.inf ) - val_log_probs = logsumexp( - sel_log_probs_no_original, axis=1 - ) - np.log((1 - sample_cov_one_hot).sum(axis=1)) + val_log_probs = logsumexp(sel_log_probs_no_original, axis=1) - np.log( + (1 - sample_cov_one_hot).sum(axis=1) + ) else: val_log_probs = logsumexp(sel_log_probs, axis=1) - np.log( sel_log_probs.shape[1] @@ -967,17 +942,11 @@ def differential_abundance( stacklevel=2, ) continue - rest_log_probs = log_probs_arr.log_probs.loc[ - {"sample": rest_samples} - ].values + rest_log_probs = log_probs_arr.log_probs.loc[{"sample": rest_samples}].values if omit_original_sample: - rest_samples_one_hot = np.zeros( - (adata.n_obs, len(rest_samples)) - ) + rest_samples_one_hot = np.zeros((adata.n_obs, len(rest_samples))) for i, sample in enumerate(rest_samples): - rest_samples_one_hot[ - adata.obs[self.sample_key] == sample, i - ] = 1 + rest_samples_one_hot[adata.obs[self.sample_key] == sample, i] = 1 rest_log_probs_no_original = np.where( rest_samples_one_hot, rest_log_probs, -np.inf ) @@ -990,9 +959,7 @@ def differential_abundance( ) enrichment_scores = val_log_probs - rest_val_log_probs per_val_log_enrichs[sample_cov_value] = enrichment_scores - sample_cov_log_probs_map[sample_cov_key] = DataFrame.from_dict( - per_val_log_probs - ) + sample_cov_log_probs_map[sample_cov_key] = DataFrame.from_dict(per_val_log_probs) if compute_log_enrichment and len(per_val_log_enrichs) > 0: sample_cov_log_enrichs_map[sample_cov_key] = DataFrame.from_dict( per_val_log_enrichs @@ -1069,16 +1036,12 @@ def get_outlier_cell_sample_pairs( for sample_name in tqdm(unique_samples): sample_idxs = np.where(adata.obs[self.sample_key] == sample_name)[0] if subsample_size is not None and sample_idxs.shape[0] > subsample_size: - sample_idxs = np.random.choice( - sample_idxs, size=subsample_size, replace=False - ) + sample_idxs = np.random.choice(sample_idxs, size=subsample_size, replace=False) adata_s = adata[sample_idxs] ap = self.get_aggregated_posterior(adata=adata, indices=sample_idxs) in_max_comp_log_probs = ap.component_distribution.log_prob( - np.expand_dims( - adata_s.obsm["U"], ap.mixture_dim - ) # (n_cells_ap, 1, n_latent_dim) + np.expand_dims(adata_s.obsm["U"], ap.mixture_dim) # (n_cells_ap, 1, n_latent_dim) ) # (n_cells_ap, n_cells_ap) log_probs_s = rowwise_max_excluding_diagonal(in_max_comp_log_probs) @@ -1386,9 +1349,7 @@ def h_inference_fn(extra_eps, batch_index_cf, batch_offset_eps): ) mc_samples, _, n_cells_, n_latent = betas_covariates.shape betas_offset_ = ( - jnp.zeros( - (mc_samples, self.summary_stats.n_batch, n_cells_, n_latent) - ) + jnp.zeros((mc_samples, self.summary_stats.n_batch, n_cells_, n_latent)) + eps_mean_ ) # batch_offset shape (mc_samples, n_batch, n_cells, n_latent) @@ -1396,9 +1357,7 @@ def h_inference_fn(extra_eps, batch_index_cf, batch_offset_eps): f_ = jax.vmap( h_inference_fn, in_axes=(0, None, 0), out_axes=0 ) # fn over MC samples - f_ = jax.vmap( - f_, in_axes=(1, None, None), out_axes=1 - ) # fn over covariates + f_ = jax.vmap(f_, in_axes=(1, None, None), out_axes=1) # fn over covariates f_ = jax.vmap(f_, in_axes=(None, 0, 1), out_axes=0) # fn over batches h_fn = jax.jit(f_) @@ -1408,9 +1367,7 @@ def h_inference_fn(extra_eps, batch_index_cf, batch_offset_eps): lfcs = jnp.log2(x_1 + eps_lfc) - jnp.log2(x_0 + eps_lfc) lfc_mean = jnp.average(lfcs.mean(1), weights=batch_weights, axis=0) if delta is not None: - lfc_std = jnp.sqrt( - jnp.average(lfcs.var(1), weights=batch_weights, axis=0) - ) + lfc_std = jnp.sqrt(jnp.average(lfcs.var(1), weights=batch_weights, axis=0)) pde = (jnp.abs(lfcs) >= delta).mean(1).mean(0) if store_baseline: @@ -1446,9 +1403,7 @@ def h_inference_fn(extra_eps, batch_index_cf, batch_offset_eps): stacked_rngs = self._generate_stacked_rngs(cf_sample.shape[0]) rngs_de = self.module.rngs if store_lfc else None - admissible_samples_mat = jnp.array( - admissible_samples[indices] - ) # (n_cells, n_samples) + admissible_samples_mat = jnp.array(admissible_samples[indices]) # (n_cells, n_samples) n_samples_per_cell = admissible_samples_mat.sum(axis=1) admissible_samples_dmat = jax.vmap(jnp.diag)(admissible_samples_mat).astype( float @@ -1485,9 +1440,7 @@ def h_inference_fn(extra_eps, batch_index_cf, batch_offset_eps): effect_size = np.concatenate(effect_size, axis=0) pvalue = np.concatenate(pvalue, axis=0) pvalue_shape = pvalue.shape - padj = false_discovery_control(pvalue.flatten(), method="bh").reshape( - pvalue_shape - ) + padj = false_discovery_control(pvalue.flatten(), method="bh").reshape(pvalue_shape) coords = { "cell_name": (("cell_name"), adata.obs_names), @@ -1606,27 +1559,19 @@ def _construct_design_matrix( Xmat_dim_to_key = np.concatenate(Xmat_dim_to_key) if normalize_design_matrix: - Xmat = (Xmat - Xmat.min(axis=0)) / ( - 1e-6 + Xmat.max(axis=0) - Xmat.min(axis=0) - ) + Xmat = (Xmat - Xmat.min(axis=0)) / (1e-6 + Xmat.max(axis=0) - Xmat.min(axis=0)) if add_batch_specific_offsets: cov = sample_info["_scvi_batch"] if cov.nunique() == self.summary_stats.n_batch: - cov = np.eye(self.summary_stats.n_batch)[ - sample_info["_scvi_batch"].values - ] - cov_names = [ - "offset_batch_" + str(i) for i in range(self.summary_stats.n_batch) - ] + cov = np.eye(self.summary_stats.n_batch)[sample_info["_scvi_batch"].values] + cov_names = ["offset_batch_" + str(i) for i in range(self.summary_stats.n_batch)] Xmat = np.concatenate([cov, Xmat], axis=1) Xmat_names = np.concatenate([np.array(cov_names), Xmat_names]) Xmat_dim_to_key = np.concatenate([np.array(cov_names), Xmat_dim_to_key]) # Retrieve indices of offset covariates in the right order offset_indices = ( - Series(np.arange(len(Xmat_names)), index=Xmat_names) - .loc[cov_names] - .values + Series(np.arange(len(Xmat_names)), index=Xmat_names).loc[cov_names].values ) offset_indices = jnp.array(offset_indices) else: diff --git a/tests/external/mrvi/test_model.py b/tests/external/mrvi/test_model.py index 64f3b93415..58ae229e7a 100644 --- a/tests/external/mrvi/test_model.py +++ b/tests/external/mrvi/test_model.py @@ -156,9 +156,7 @@ def test_mrvi_da(model, sample_key, da_kwargs): }, ], ) -def test_mrvi_model_kwargs( - adata: AnnData, model_kwargs: dict[str, Any], save_path: str -): +def test_mrvi_model_kwargs(adata: AnnData, model_kwargs: dict[str, Any], save_path: str): MRVI.setup_anndata( adata, sample_key="sample_str", @@ -176,9 +174,7 @@ def test_mrvi_model_kwargs( def test_mrvi_sample_subset(model: MRVI, adata: AnnData, save_path: str): sample_cov_keys = ["meta1_cat", "meta2", "cont_cov"] sample_subset = [chr(i + ord("a")) for i in range(8)] - model.differential_expression( - sample_cov_keys=sample_cov_keys, sample_subset=sample_subset - ) + model.differential_expression(sample_cov_keys=sample_cov_keys, sample_subset=sample_subset) model_path = os.path.join(save_path, "mrvi_model") model.save(model_path, save_anndata=False, overwrite=True) From aa53513dc017d7fedfd0cd9c1500184dd68287bf Mon Sep 17 00:00:00 2001 From: Justin Hong Date: Thu, 20 Feb 2025 14:19:20 -0500 Subject: [PATCH 08/14] fix np where bug, remove student t option --- src/scvi/external/mrvi/_model.py | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/src/scvi/external/mrvi/_model.py b/src/scvi/external/mrvi/_model.py index 654b9adb9f..13a184218e 100644 --- a/src/scvi/external/mrvi/_model.py +++ b/src/scvi/external/mrvi/_model.py @@ -762,7 +762,6 @@ def get_aggregated_posterior( self, adata: AnnData | None = None, sample: str | int | None = None, - use_student_t: bool = False, indices: npt.ArrayLike | None = None, batch_size: int = 256, ) -> Distribution: @@ -777,8 +776,6 @@ def get_aggregated_posterior( AnnData object to use. Defaults to the AnnData object used to initialize the model. sample Name or index of the sample to filter on. If ``None``, uses all cells. - use_student_t - Whether to use a student-t distribution for the aggregated posterior. indices Indices of cells to use. batch_size @@ -824,8 +821,6 @@ def get_aggregated_posterior( Categorical(probs=jnp.ones(qu_loc.shape[0]) / qu_loc.shape[0]), ( Normal(qu_loc, qu_scale).to_event(1) - if use_student_t - else StudentT(qu_loc, qu_scale).to_event(1) ), ) @@ -835,7 +830,7 @@ def differential_abundance( sample_cov_keys: list[str] | None = None, sample_subset: list[str] | None = None, compute_log_enrichment: bool = False, - omit_original_sample: bool = False, + omit_original_sample: bool = True, batch_size: int = 128, ) -> xr.Dataset: """Compute the differential abundance between samples. @@ -946,7 +941,7 @@ def differential_abundance( for i, sample in enumerate(cov_samples): sample_cov_one_hot[adata.obs[self.sample_key] == sample, i] = 1 sel_log_probs_no_original = np.where( - sample_cov_one_hot, sel_log_probs, -np.inf + sample_cov_one_hot, -np.inf, sel_log_probs ) val_log_probs = logsumexp( sel_log_probs_no_original, axis=1 @@ -979,7 +974,7 @@ def differential_abundance( adata.obs[self.sample_key] == sample, i ] = 1 rest_log_probs_no_original = np.where( - rest_samples_one_hot, rest_log_probs, -np.inf + rest_samples_one_hot, -np.inf, rest_log_probs ) rest_val_log_probs = logsumexp( rest_log_probs_no_original, axis=1 From 53284956f14ed7c3a8d592821cef9dbbb21c983c Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 20 Feb 2025 19:20:22 +0000 Subject: [PATCH 09/14] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/scvi/external/mrvi/_model.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/scvi/external/mrvi/_model.py b/src/scvi/external/mrvi/_model.py index c482787fe3..2842ce2ff4 100644 --- a/src/scvi/external/mrvi/_model.py +++ b/src/scvi/external/mrvi/_model.py @@ -770,7 +770,6 @@ def get_aggregated_posterior( Categorical, MixtureSameFamily, Normal, - StudentT, ) self._check_if_trained(warn=False) @@ -798,9 +797,7 @@ def get_aggregated_posterior( qu_scale = jnp.concatenate(qu_scales, axis=0) # n_cells x n_latent_u return MixtureSameFamily( Categorical(probs=jnp.ones(qu_loc.shape[0]) / qu_loc.shape[0]), - ( - Normal(qu_loc, qu_scale).to_event(1) - ), + (Normal(qu_loc, qu_scale).to_event(1)), ) def differential_abundance( From 6be768e9c3a83b6a05f7a0bde3c1fcd8187f225f Mon Sep 17 00:00:00 2001 From: Justin Hong Date: Thu, 20 Feb 2025 14:22:46 -0500 Subject: [PATCH 10/14] Fix test parameters --- tests/external/mrvi/test_model.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/tests/external/mrvi/test_model.py b/tests/external/mrvi/test_model.py index 58ae229e7a..736b4e9b01 100644 --- a/tests/external/mrvi/test_model.py +++ b/tests/external/mrvi/test_model.py @@ -117,7 +117,7 @@ def test_mrvi_de(model: MRVI, setup_kwargs: dict[str, Any], de_kwargs: dict[str, [ {"sample_cov_keys": ["meta1_cat"]}, {"sample_cov_keys": ["meta1_cat", "batch"]}, - {"sample_cov_keys": ["meta1_cat"], "omit_original_sample": True}, + {"sample_cov_keys": ["meta1_cat"], "omit_original_sample": False}, {"sample_cov_keys": ["meta1_cat"], "compute_log_enrichment": True}, {"sample_cov_keys": ["meta1_cat", "batch"], "compute_log_enrichment": True}, ], @@ -156,7 +156,9 @@ def test_mrvi_da(model, sample_key, da_kwargs): }, ], ) -def test_mrvi_model_kwargs(adata: AnnData, model_kwargs: dict[str, Any], save_path: str): +def test_mrvi_model_kwargs( + adata: AnnData, model_kwargs: dict[str, Any], save_path: str +): MRVI.setup_anndata( adata, sample_key="sample_str", @@ -174,7 +176,9 @@ def test_mrvi_model_kwargs(adata: AnnData, model_kwargs: dict[str, Any], save_pa def test_mrvi_sample_subset(model: MRVI, adata: AnnData, save_path: str): sample_cov_keys = ["meta1_cat", "meta2", "cont_cov"] sample_subset = [chr(i + ord("a")) for i in range(8)] - model.differential_expression(sample_cov_keys=sample_cov_keys, sample_subset=sample_subset) + model.differential_expression( + sample_cov_keys=sample_cov_keys, sample_subset=sample_subset + ) model_path = os.path.join(save_path, "mrvi_model") model.save(model_path, save_anndata=False, overwrite=True) From 057d07364274297ab20235585cf5d9ae7d696cf3 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 20 Feb 2025 19:23:22 +0000 Subject: [PATCH 11/14] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/external/mrvi/test_model.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/tests/external/mrvi/test_model.py b/tests/external/mrvi/test_model.py index 736b4e9b01..75561589f2 100644 --- a/tests/external/mrvi/test_model.py +++ b/tests/external/mrvi/test_model.py @@ -156,9 +156,7 @@ def test_mrvi_da(model, sample_key, da_kwargs): }, ], ) -def test_mrvi_model_kwargs( - adata: AnnData, model_kwargs: dict[str, Any], save_path: str -): +def test_mrvi_model_kwargs(adata: AnnData, model_kwargs: dict[str, Any], save_path: str): MRVI.setup_anndata( adata, sample_key="sample_str", @@ -176,9 +174,7 @@ def test_mrvi_model_kwargs( def test_mrvi_sample_subset(model: MRVI, adata: AnnData, save_path: str): sample_cov_keys = ["meta1_cat", "meta2", "cont_cov"] sample_subset = [chr(i + ord("a")) for i in range(8)] - model.differential_expression( - sample_cov_keys=sample_cov_keys, sample_subset=sample_subset - ) + model.differential_expression(sample_cov_keys=sample_cov_keys, sample_subset=sample_subset) model_path = os.path.join(save_path, "mrvi_model") model.save(model_path, save_anndata=False, overwrite=True) From 6471a4382ca1e6783593060d480cd113d5f52544 Mon Sep 17 00:00:00 2001 From: Justin Hong Date: Thu, 20 Feb 2025 22:25:44 -0500 Subject: [PATCH 12/14] address pr comments --- src/scvi/external/mrvi/_model.py | 160 ++++++++++++++++++++----------- 1 file changed, 104 insertions(+), 56 deletions(-) diff --git a/src/scvi/external/mrvi/_model.py b/src/scvi/external/mrvi/_model.py index 7cbbd232cb..c949ab1be4 100644 --- a/src/scvi/external/mrvi/_model.py +++ b/src/scvi/external/mrvi/_model.py @@ -131,7 +131,9 @@ def to_device(self, device): # TODO(jhong): remove this once we have a better way to handle device. pass - def _generate_stacked_rngs(self, n_sets: int | tuple) -> dict[str, jax.random.KeyArray]: + def _generate_stacked_rngs( + self, n_sets: int | tuple + ) -> dict[str, jax.random.KeyArray]: return_1d = isinstance(n_sets, int) if return_1d: n_sets_1d = n_sets @@ -189,7 +191,9 @@ def setup_anndata( fields.NumericalObsField(REGISTRY_KEYS.INDICES_KEY, "_indices"), ] - adata_manager = AnnDataManager(fields=anndata_fields, setup_method_args=setup_method_args) + adata_manager = AnnDataManager( + fields=anndata_fields, setup_method_args=setup_method_args + ) adata_manager.register_fields(adata, **kwargs) cls.register_manager(adata_manager) @@ -406,12 +410,16 @@ def per_sample_inference_fn(pair): for ur in reqs.ungrouped_reductions: ungrouped_data_arrs[ur.name] = [] for gr in reqs.grouped_reductions: - grouped_data_arrs[gr.name] = {} # Will map group category to running group sum. + grouped_data_arrs[gr.name] = ( + {} + ) # Will map group category to running group sum. for array_dict in tqdm(scdl): indices = array_dict[REGISTRY_KEYS.INDICES_KEY].astype(int).flatten() n_cells = array_dict[REGISTRY_KEYS.X_KEY].shape[0] - cf_sample = np.broadcast_to(np.arange(n_sample)[:, None, None], (n_sample, n_cells, 1)) + cf_sample = np.broadcast_to( + np.arange(n_sample)[:, None, None], (n_sample, n_cells, 1) + ) inf_inputs = self.module._get_inference_input( array_dict, ) @@ -488,8 +496,11 @@ def per_sample_inference_fn(pair): normalization_means = normalization_means.reshape(-1, 1, 1, 1) normalization_vars = normalization_vars.reshape(-1, 1, 1, 1) normalized_dists = ( - (sampled_dists - normalization_means) / (normalization_vars**0.5) - ).mean(dim="mc_sample") # (n_cells, n_samples, n_samples) + (sampled_dists - normalization_means) + / (normalization_vars**0.5) + ).mean( + dim="mc_sample" + ) # (n_cells, n_samples, n_samples) # Compute each reduction for r in reductions: @@ -513,7 +524,9 @@ def per_sample_inference_fn(pair): group_by_cats = group_by.unique() for cat in group_by_cats: cat_summed_outputs = outputs.sel( - cell_name=self.adata.obs_names[indices][group_by == cat].values + cell_name=self.adata.obs_names[indices][ + group_by == cat + ].values ).sum(dim="cell_name") cat_summed_outputs = cat_summed_outputs.assign_coords( {f"{r.group_by}_name": cat} @@ -535,8 +548,12 @@ def per_sample_inference_fn(pair): group_by_counts = group_by.value_counts() averaged_grouped_data_arrs = [] for cat, count in group_by_counts.items(): - averaged_grouped_data_arrs.append(grouped_data_arrs[gr.name][cat] / count) - final_data_arr = xr.concat(averaged_grouped_data_arrs, dim=f"{gr.group_by}_name") + averaged_grouped_data_arrs.append( + grouped_data_arrs[gr.name][cat] / count + ) + final_data_arr = xr.concat( + averaged_grouped_data_arrs, dim=f"{gr.group_by}_name" + ) final_data_arrs[gr.name] = final_data_arr return xr.Dataset(data_vars=final_data_arrs) @@ -729,7 +746,9 @@ def get_local_sample_distances( reductions = [] if not keep_cell and not groupby: - raise ValueError("Undefined computation because not keep_cell and no groupby.") + raise ValueError( + "Undefined computation because not keep_cell and no groupby." + ) if keep_cell: reductions.append( MRVIReduction( @@ -803,7 +822,9 @@ def get_aggregated_posterior( qu_locs = [] qu_scales = [] - jit_inference_fn = self.module.get_jit_inference_fn(inference_kwargs={"use_mean": True}) + jit_inference_fn = self.module.get_jit_inference_fn( + inference_kwargs={"use_mean": True} + ) for array_dict in scdl: outputs = jit_inference_fn(self.module.rngs, array_dict) @@ -845,7 +866,7 @@ def differential_abundance( compute_log_enrichment Whether to compute the log enrichment scores for each covariate value. omit_original_sample - Whether to omit the original sample from the differential abundance computation. + If true, each cell's sample-of-origin is discarded to compute aggregate posteriors. Only relevant if sample_cov_keys is not None. batch_size Minibatch size for computing the differential abundance. @@ -909,6 +930,25 @@ def differential_abundance( if sample_cov_keys is None or len(sample_cov_keys) == 0: return log_probs_arr + def aggregate_log_probs(log_probs, samples, omit_original_sample=False): + sample_log_probs = log_probs.loc[ + {"sample": rest_samples} + ].values # (n_cells, n_samples_in_group) + if omit_original_sample: + sample_one_hot = np.zeros((adata.n_obs, len(samples))) + for i, sample in enumerate(samples): + sample_one_hot[adata.obs[self.sample_key] == sample, i] = 1 + log_probs_no_original = np.where( + sample_one_hot, -np.inf, sample_log_probs + ) # virtually discards samples-of-origin from aggregate posteriors + return logsumexp(log_probs_no_original, axis=1) - np.log( + (1 - sample_one_hot).sum(axis=1) + ) + else: + return logsumexp(sample_log_probs, axis=1) - np.log( + sample_log_probs.shape[1] + ) + sample_cov_log_probs_map = {} sample_cov_log_enrichs_map = {} for sample_cov_key in sample_cov_keys: @@ -917,28 +957,20 @@ def differential_abundance( per_val_log_enrichs = {} for sample_cov_value in sample_cov_unique_values: cov_samples = ( - self.sample_info[self.sample_info[sample_cov_key] == sample_cov_value] + self.sample_info[ + self.sample_info[sample_cov_key] == sample_cov_value + ] )[self.sample_key].to_numpy() if sample_subset is not None: cov_samples = np.intersect1d(cov_samples, np.array(sample_subset)) if len(cov_samples) == 0: continue - sel_log_probs = log_probs_arr.log_probs.loc[{"sample": cov_samples}].values - if omit_original_sample: - sample_cov_one_hot = np.zeros((adata.n_obs, len(cov_samples))) - for i, sample in enumerate(cov_samples): - sample_cov_one_hot[adata.obs[self.sample_key] == sample, i] = 1 - sel_log_probs_no_original = np.where( - sample_cov_one_hot, -np.inf, sel_log_probs - ) - val_log_probs = logsumexp(sel_log_probs_no_original, axis=1) - np.log( - (1 - sample_cov_one_hot).sum(axis=1) - ) - else: - val_log_probs = logsumexp(sel_log_probs, axis=1) - np.log( - sel_log_probs.shape[1] - ) + val_log_probs = aggregate_log_probs( + log_probs_arr.log_probs, + cov_samples, + omit_original_sample=omit_original_sample, + ) per_val_log_probs[sample_cov_value] = val_log_probs if compute_log_enrichment: @@ -951,24 +983,16 @@ def differential_abundance( stacklevel=2, ) continue - rest_log_probs = log_probs_arr.log_probs.loc[{"sample": rest_samples}].values - if omit_original_sample: - rest_samples_one_hot = np.zeros((adata.n_obs, len(rest_samples))) - for i, sample in enumerate(rest_samples): - rest_samples_one_hot[adata.obs[self.sample_key] == sample, i] = 1 - rest_log_probs_no_original = np.where( - rest_samples_one_hot, -np.inf, rest_log_probs - ) - rest_val_log_probs = logsumexp( - rest_log_probs_no_original, axis=1 - ) - np.log((1 - rest_samples_one_hot).sum(axis=1)) - else: - rest_val_log_probs = logsumexp(rest_log_probs, axis=1) - np.log( - rest_log_probs.shape[1] - ) + rest_val_log_probs = aggregate_log_probs( + log_probs_arr.log_probs, + rest_samples, + omit_original_sample=omit_original_sample, + ) enrichment_scores = val_log_probs - rest_val_log_probs per_val_log_enrichs[sample_cov_value] = enrichment_scores - sample_cov_log_probs_map[sample_cov_key] = DataFrame.from_dict(per_val_log_probs) + sample_cov_log_probs_map[sample_cov_key] = DataFrame.from_dict( + per_val_log_probs + ) if compute_log_enrichment and len(per_val_log_enrichs) > 0: sample_cov_log_enrichs_map[sample_cov_key] = DataFrame.from_dict( per_val_log_enrichs @@ -1045,12 +1069,16 @@ def get_outlier_cell_sample_pairs( for sample_name in tqdm(unique_samples): sample_idxs = np.where(adata.obs[self.sample_key] == sample_name)[0] if subsample_size is not None and sample_idxs.shape[0] > subsample_size: - sample_idxs = np.random.choice(sample_idxs, size=subsample_size, replace=False) + sample_idxs = np.random.choice( + sample_idxs, size=subsample_size, replace=False + ) adata_s = adata[sample_idxs] ap = self.get_aggregated_posterior(adata=adata, indices=sample_idxs) in_max_comp_log_probs = ap.component_distribution.log_prob( - np.expand_dims(adata_s.obsm["U"], ap.mixture_dim) # (n_cells_ap, 1, n_latent_dim) + np.expand_dims( + adata_s.obsm["U"], ap.mixture_dim + ) # (n_cells_ap, 1, n_latent_dim) ) # (n_cells_ap, n_cells_ap) log_probs_s = rowwise_max_excluding_diagonal(in_max_comp_log_probs) @@ -1360,7 +1388,9 @@ def h_inference_fn(extra_eps, batch_index_cf, batch_offset_eps): ) mc_samples, _, n_cells_, n_latent = betas_covariates.shape betas_offset_ = ( - jnp.zeros((mc_samples, self.summary_stats.n_batch, n_cells_, n_latent)) + jnp.zeros( + (mc_samples, self.summary_stats.n_batch, n_cells_, n_latent) + ) + eps_mean_ ) # batch_offset shape (mc_samples, n_batch, n_cells, n_latent) @@ -1368,7 +1398,9 @@ def h_inference_fn(extra_eps, batch_index_cf, batch_offset_eps): f_ = jax.vmap( h_inference_fn, in_axes=(0, None, 0), out_axes=0 ) # fn over MC samples - f_ = jax.vmap(f_, in_axes=(1, None, None), out_axes=1) # fn over covariates + f_ = jax.vmap( + f_, in_axes=(1, None, None), out_axes=1 + ) # fn over covariates f_ = jax.vmap(f_, in_axes=(None, 0, 1), out_axes=0) # fn over batches h_fn = jax.jit(f_) @@ -1378,7 +1410,9 @@ def h_inference_fn(extra_eps, batch_index_cf, batch_offset_eps): lfcs = jnp.log2(x_1 + eps_lfc) - jnp.log2(x_0 + eps_lfc) lfc_mean = jnp.average(lfcs.mean(1), weights=batch_weights, axis=0) if delta is not None: - lfc_std = jnp.sqrt(jnp.average(lfcs.var(1), weights=batch_weights, axis=0)) + lfc_std = jnp.sqrt( + jnp.average(lfcs.var(1), weights=batch_weights, axis=0) + ) pde = (jnp.abs(lfcs) >= delta).mean(1).mean(0) if store_baseline: @@ -1414,7 +1448,9 @@ def h_inference_fn(extra_eps, batch_index_cf, batch_offset_eps): stacked_rngs = self._generate_stacked_rngs(cf_sample.shape[0]) rngs_de = self.module.rngs if store_lfc else None - admissible_samples_mat = jnp.array(admissible_samples[indices]) # (n_cells, n_samples) + admissible_samples_mat = jnp.array( + admissible_samples[indices] + ) # (n_cells, n_samples) n_samples_per_cell = admissible_samples_mat.sum(axis=1) admissible_samples_dmat = jax.vmap(jnp.diag)(admissible_samples_mat).astype( float @@ -1440,7 +1476,9 @@ def h_inference_fn(extra_eps, batch_index_cf, batch_offset_eps): ) except jax.errors.JaxRuntimeError as e: if use_vmap: - raise RuntimeError("JAX ran out of memory. Try setting use_vmap=False.") from e + raise RuntimeError( + "JAX ran out of memory. Try setting use_vmap=False." + ) from e else: raise e @@ -1458,7 +1496,9 @@ def h_inference_fn(extra_eps, batch_index_cf, batch_offset_eps): effect_size = np.concatenate(effect_size, axis=0) pvalue = np.concatenate(pvalue, axis=0) pvalue_shape = pvalue.shape - padj = false_discovery_control(pvalue.flatten(), method="bh").reshape(pvalue_shape) + padj = false_discovery_control(pvalue.flatten(), method="bh").reshape( + pvalue_shape + ) coords = { "cell_name": (("cell_name"), adata.obs_names), @@ -1577,19 +1617,27 @@ def _construct_design_matrix( Xmat_dim_to_key = np.concatenate(Xmat_dim_to_key) if normalize_design_matrix: - Xmat = (Xmat - Xmat.min(axis=0)) / (1e-6 + Xmat.max(axis=0) - Xmat.min(axis=0)) + Xmat = (Xmat - Xmat.min(axis=0)) / ( + 1e-6 + Xmat.max(axis=0) - Xmat.min(axis=0) + ) if add_batch_specific_offsets: cov = sample_info["_scvi_batch"] if cov.nunique() == self.summary_stats.n_batch: - cov = np.eye(self.summary_stats.n_batch)[sample_info["_scvi_batch"].values] - cov_names = ["offset_batch_" + str(i) for i in range(self.summary_stats.n_batch)] + cov = np.eye(self.summary_stats.n_batch)[ + sample_info["_scvi_batch"].values + ] + cov_names = [ + "offset_batch_" + str(i) for i in range(self.summary_stats.n_batch) + ] Xmat = np.concatenate([cov, Xmat], axis=1) Xmat_names = np.concatenate([np.array(cov_names), Xmat_names]) Xmat_dim_to_key = np.concatenate([np.array(cov_names), Xmat_dim_to_key]) # Retrieve indices of offset covariates in the right order offset_indices = ( - Series(np.arange(len(Xmat_names)), index=Xmat_names).loc[cov_names].values + Series(np.arange(len(Xmat_names)), index=Xmat_names) + .loc[cov_names] + .values ) offset_indices = jnp.array(offset_indices) else: From ff68af1a38d5f1f0fa00569210fce30d6ffbf979 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 21 Feb 2025 03:25:59 +0000 Subject: [PATCH 13/14] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/scvi/external/mrvi/_model.py | 103 ++++++++----------------------- 1 file changed, 26 insertions(+), 77 deletions(-) diff --git a/src/scvi/external/mrvi/_model.py b/src/scvi/external/mrvi/_model.py index c949ab1be4..45ef9677de 100644 --- a/src/scvi/external/mrvi/_model.py +++ b/src/scvi/external/mrvi/_model.py @@ -131,9 +131,7 @@ def to_device(self, device): # TODO(jhong): remove this once we have a better way to handle device. pass - def _generate_stacked_rngs( - self, n_sets: int | tuple - ) -> dict[str, jax.random.KeyArray]: + def _generate_stacked_rngs(self, n_sets: int | tuple) -> dict[str, jax.random.KeyArray]: return_1d = isinstance(n_sets, int) if return_1d: n_sets_1d = n_sets @@ -191,9 +189,7 @@ def setup_anndata( fields.NumericalObsField(REGISTRY_KEYS.INDICES_KEY, "_indices"), ] - adata_manager = AnnDataManager( - fields=anndata_fields, setup_method_args=setup_method_args - ) + adata_manager = AnnDataManager(fields=anndata_fields, setup_method_args=setup_method_args) adata_manager.register_fields(adata, **kwargs) cls.register_manager(adata_manager) @@ -410,16 +406,12 @@ def per_sample_inference_fn(pair): for ur in reqs.ungrouped_reductions: ungrouped_data_arrs[ur.name] = [] for gr in reqs.grouped_reductions: - grouped_data_arrs[gr.name] = ( - {} - ) # Will map group category to running group sum. + grouped_data_arrs[gr.name] = {} # Will map group category to running group sum. for array_dict in tqdm(scdl): indices = array_dict[REGISTRY_KEYS.INDICES_KEY].astype(int).flatten() n_cells = array_dict[REGISTRY_KEYS.X_KEY].shape[0] - cf_sample = np.broadcast_to( - np.arange(n_sample)[:, None, None], (n_sample, n_cells, 1) - ) + cf_sample = np.broadcast_to(np.arange(n_sample)[:, None, None], (n_sample, n_cells, 1)) inf_inputs = self.module._get_inference_input( array_dict, ) @@ -496,11 +488,8 @@ def per_sample_inference_fn(pair): normalization_means = normalization_means.reshape(-1, 1, 1, 1) normalization_vars = normalization_vars.reshape(-1, 1, 1, 1) normalized_dists = ( - (sampled_dists - normalization_means) - / (normalization_vars**0.5) - ).mean( - dim="mc_sample" - ) # (n_cells, n_samples, n_samples) + (sampled_dists - normalization_means) / (normalization_vars**0.5) + ).mean(dim="mc_sample") # (n_cells, n_samples, n_samples) # Compute each reduction for r in reductions: @@ -524,9 +513,7 @@ def per_sample_inference_fn(pair): group_by_cats = group_by.unique() for cat in group_by_cats: cat_summed_outputs = outputs.sel( - cell_name=self.adata.obs_names[indices][ - group_by == cat - ].values + cell_name=self.adata.obs_names[indices][group_by == cat].values ).sum(dim="cell_name") cat_summed_outputs = cat_summed_outputs.assign_coords( {f"{r.group_by}_name": cat} @@ -548,12 +535,8 @@ def per_sample_inference_fn(pair): group_by_counts = group_by.value_counts() averaged_grouped_data_arrs = [] for cat, count in group_by_counts.items(): - averaged_grouped_data_arrs.append( - grouped_data_arrs[gr.name][cat] / count - ) - final_data_arr = xr.concat( - averaged_grouped_data_arrs, dim=f"{gr.group_by}_name" - ) + averaged_grouped_data_arrs.append(grouped_data_arrs[gr.name][cat] / count) + final_data_arr = xr.concat(averaged_grouped_data_arrs, dim=f"{gr.group_by}_name") final_data_arrs[gr.name] = final_data_arr return xr.Dataset(data_vars=final_data_arrs) @@ -746,9 +729,7 @@ def get_local_sample_distances( reductions = [] if not keep_cell and not groupby: - raise ValueError( - "Undefined computation because not keep_cell and no groupby." - ) + raise ValueError("Undefined computation because not keep_cell and no groupby.") if keep_cell: reductions.append( MRVIReduction( @@ -822,9 +803,7 @@ def get_aggregated_posterior( qu_locs = [] qu_scales = [] - jit_inference_fn = self.module.get_jit_inference_fn( - inference_kwargs={"use_mean": True} - ) + jit_inference_fn = self.module.get_jit_inference_fn(inference_kwargs={"use_mean": True}) for array_dict in scdl: outputs = jit_inference_fn(self.module.rngs, array_dict) @@ -945,9 +924,7 @@ def aggregate_log_probs(log_probs, samples, omit_original_sample=False): (1 - sample_one_hot).sum(axis=1) ) else: - return logsumexp(sample_log_probs, axis=1) - np.log( - sample_log_probs.shape[1] - ) + return logsumexp(sample_log_probs, axis=1) - np.log(sample_log_probs.shape[1]) sample_cov_log_probs_map = {} sample_cov_log_enrichs_map = {} @@ -957,9 +934,7 @@ def aggregate_log_probs(log_probs, samples, omit_original_sample=False): per_val_log_enrichs = {} for sample_cov_value in sample_cov_unique_values: cov_samples = ( - self.sample_info[ - self.sample_info[sample_cov_key] == sample_cov_value - ] + self.sample_info[self.sample_info[sample_cov_key] == sample_cov_value] )[self.sample_key].to_numpy() if sample_subset is not None: cov_samples = np.intersect1d(cov_samples, np.array(sample_subset)) @@ -990,9 +965,7 @@ def aggregate_log_probs(log_probs, samples, omit_original_sample=False): ) enrichment_scores = val_log_probs - rest_val_log_probs per_val_log_enrichs[sample_cov_value] = enrichment_scores - sample_cov_log_probs_map[sample_cov_key] = DataFrame.from_dict( - per_val_log_probs - ) + sample_cov_log_probs_map[sample_cov_key] = DataFrame.from_dict(per_val_log_probs) if compute_log_enrichment and len(per_val_log_enrichs) > 0: sample_cov_log_enrichs_map[sample_cov_key] = DataFrame.from_dict( per_val_log_enrichs @@ -1069,16 +1042,12 @@ def get_outlier_cell_sample_pairs( for sample_name in tqdm(unique_samples): sample_idxs = np.where(adata.obs[self.sample_key] == sample_name)[0] if subsample_size is not None and sample_idxs.shape[0] > subsample_size: - sample_idxs = np.random.choice( - sample_idxs, size=subsample_size, replace=False - ) + sample_idxs = np.random.choice(sample_idxs, size=subsample_size, replace=False) adata_s = adata[sample_idxs] ap = self.get_aggregated_posterior(adata=adata, indices=sample_idxs) in_max_comp_log_probs = ap.component_distribution.log_prob( - np.expand_dims( - adata_s.obsm["U"], ap.mixture_dim - ) # (n_cells_ap, 1, n_latent_dim) + np.expand_dims(adata_s.obsm["U"], ap.mixture_dim) # (n_cells_ap, 1, n_latent_dim) ) # (n_cells_ap, n_cells_ap) log_probs_s = rowwise_max_excluding_diagonal(in_max_comp_log_probs) @@ -1388,9 +1357,7 @@ def h_inference_fn(extra_eps, batch_index_cf, batch_offset_eps): ) mc_samples, _, n_cells_, n_latent = betas_covariates.shape betas_offset_ = ( - jnp.zeros( - (mc_samples, self.summary_stats.n_batch, n_cells_, n_latent) - ) + jnp.zeros((mc_samples, self.summary_stats.n_batch, n_cells_, n_latent)) + eps_mean_ ) # batch_offset shape (mc_samples, n_batch, n_cells, n_latent) @@ -1398,9 +1365,7 @@ def h_inference_fn(extra_eps, batch_index_cf, batch_offset_eps): f_ = jax.vmap( h_inference_fn, in_axes=(0, None, 0), out_axes=0 ) # fn over MC samples - f_ = jax.vmap( - f_, in_axes=(1, None, None), out_axes=1 - ) # fn over covariates + f_ = jax.vmap(f_, in_axes=(1, None, None), out_axes=1) # fn over covariates f_ = jax.vmap(f_, in_axes=(None, 0, 1), out_axes=0) # fn over batches h_fn = jax.jit(f_) @@ -1410,9 +1375,7 @@ def h_inference_fn(extra_eps, batch_index_cf, batch_offset_eps): lfcs = jnp.log2(x_1 + eps_lfc) - jnp.log2(x_0 + eps_lfc) lfc_mean = jnp.average(lfcs.mean(1), weights=batch_weights, axis=0) if delta is not None: - lfc_std = jnp.sqrt( - jnp.average(lfcs.var(1), weights=batch_weights, axis=0) - ) + lfc_std = jnp.sqrt(jnp.average(lfcs.var(1), weights=batch_weights, axis=0)) pde = (jnp.abs(lfcs) >= delta).mean(1).mean(0) if store_baseline: @@ -1448,9 +1411,7 @@ def h_inference_fn(extra_eps, batch_index_cf, batch_offset_eps): stacked_rngs = self._generate_stacked_rngs(cf_sample.shape[0]) rngs_de = self.module.rngs if store_lfc else None - admissible_samples_mat = jnp.array( - admissible_samples[indices] - ) # (n_cells, n_samples) + admissible_samples_mat = jnp.array(admissible_samples[indices]) # (n_cells, n_samples) n_samples_per_cell = admissible_samples_mat.sum(axis=1) admissible_samples_dmat = jax.vmap(jnp.diag)(admissible_samples_mat).astype( float @@ -1476,9 +1437,7 @@ def h_inference_fn(extra_eps, batch_index_cf, batch_offset_eps): ) except jax.errors.JaxRuntimeError as e: if use_vmap: - raise RuntimeError( - "JAX ran out of memory. Try setting use_vmap=False." - ) from e + raise RuntimeError("JAX ran out of memory. Try setting use_vmap=False.") from e else: raise e @@ -1496,9 +1455,7 @@ def h_inference_fn(extra_eps, batch_index_cf, batch_offset_eps): effect_size = np.concatenate(effect_size, axis=0) pvalue = np.concatenate(pvalue, axis=0) pvalue_shape = pvalue.shape - padj = false_discovery_control(pvalue.flatten(), method="bh").reshape( - pvalue_shape - ) + padj = false_discovery_control(pvalue.flatten(), method="bh").reshape(pvalue_shape) coords = { "cell_name": (("cell_name"), adata.obs_names), @@ -1617,27 +1574,19 @@ def _construct_design_matrix( Xmat_dim_to_key = np.concatenate(Xmat_dim_to_key) if normalize_design_matrix: - Xmat = (Xmat - Xmat.min(axis=0)) / ( - 1e-6 + Xmat.max(axis=0) - Xmat.min(axis=0) - ) + Xmat = (Xmat - Xmat.min(axis=0)) / (1e-6 + Xmat.max(axis=0) - Xmat.min(axis=0)) if add_batch_specific_offsets: cov = sample_info["_scvi_batch"] if cov.nunique() == self.summary_stats.n_batch: - cov = np.eye(self.summary_stats.n_batch)[ - sample_info["_scvi_batch"].values - ] - cov_names = [ - "offset_batch_" + str(i) for i in range(self.summary_stats.n_batch) - ] + cov = np.eye(self.summary_stats.n_batch)[sample_info["_scvi_batch"].values] + cov_names = ["offset_batch_" + str(i) for i in range(self.summary_stats.n_batch)] Xmat = np.concatenate([cov, Xmat], axis=1) Xmat_names = np.concatenate([np.array(cov_names), Xmat_names]) Xmat_dim_to_key = np.concatenate([np.array(cov_names), Xmat_dim_to_key]) # Retrieve indices of offset covariates in the right order offset_indices = ( - Series(np.arange(len(Xmat_names)), index=Xmat_names) - .loc[cov_names] - .values + Series(np.arange(len(Xmat_names)), index=Xmat_names).loc[cov_names].values ) offset_indices = jnp.array(offset_indices) else: From 798eee62afa07a3cdb0480326308277fc80d5680 Mon Sep 17 00:00:00 2001 From: Justin Hong Date: Fri, 21 Feb 2025 12:13:18 -0500 Subject: [PATCH 14/14] fix nameerror bug --- src/scvi/external/mrvi/_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/scvi/external/mrvi/_model.py b/src/scvi/external/mrvi/_model.py index 45ef9677de..3569448bae 100644 --- a/src/scvi/external/mrvi/_model.py +++ b/src/scvi/external/mrvi/_model.py @@ -911,7 +911,7 @@ def differential_abundance( def aggregate_log_probs(log_probs, samples, omit_original_sample=False): sample_log_probs = log_probs.loc[ - {"sample": rest_samples} + {"sample": samples} ].values # (n_cells, n_samples_in_group) if omit_original_sample: sample_one_hot = np.zeros((adata.n_obs, len(samples)))