From a0c1c1c2928cf147967c256dd066b3030d4737df Mon Sep 17 00:00:00 2001 From: Michal Klein <46717574+michalk8@users.noreply.github.com> Date: Mon, 4 Mar 2024 18:35:56 +0100 Subject: [PATCH] Copy the estimator instead --- src/cellrank/estimators/terminal_states/_gpcca.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/src/cellrank/estimators/terminal_states/_gpcca.py b/src/cellrank/estimators/terminal_states/_gpcca.py index 93d4eff9f..152bd76de 100644 --- a/src/cellrank/estimators/terminal_states/_gpcca.py +++ b/src/cellrank/estimators/terminal_states/_gpcca.py @@ -573,13 +573,11 @@ def tsi( if cluster_key is None: raise RuntimeError("`cluster_key` needs to be specified to compute TSI.") - # copy the `adata` object to avoid overwrites, as the estimator overrides some keys - adata = self.adata.copy() + # create a new GPCCA object to avoid unsetting attributes + # that depend on the macrostates, e.g. the terminal states + g = self.copy(deep=True) macrostates = {} - for n_states in range(1, n_macrostates): - # create a new GPCCA object to avoid unsetting attributes that depend on the macrostates, - # e.g. the terminal states - g = GPCCA(self.transition_matrix, adata=adata) + for n_states in range(n_macrostates, 0, -1): g = g.compute_macrostates(n_states=n_states, cluster_key=cluster_key, **kwargs) macrostates[n_states] = g.macrostates.cat.categories