diff --git a/src/cellrank/estimators/terminal_states/_gpcca.py b/src/cellrank/estimators/terminal_states/_gpcca.py index 1dbc077a0..2b525a6d2 100644 --- a/src/cellrank/estimators/terminal_states/_gpcca.py +++ b/src/cellrank/estimators/terminal_states/_gpcca.py @@ -1,8 +1,10 @@ +import collections import datetime import enum import pathlib import types -from typing import Any, Dict, Literal, Mapping, Optional, Sequence, Tuple, Union +from pathlib import Path +from typing import Any, Dict, List, Literal, Mapping, Optional, Sequence, Tuple, Union import numpy as np import pandas as pd @@ -10,6 +12,7 @@ from pandas.api.types import infer_dtype import matplotlib.pyplot as plt +import seaborn as sns from matplotlib.axes import Axes from matplotlib.colorbar import ColorbarBase from matplotlib.colors import ListedColormap, Normalize @@ -86,6 +89,7 @@ def __init__( self._coarse_init_dist: Optional[pd.Series] = None self._coarse_stat_dist: Optional[pd.Series] = None self._coarse_tmat: Optional[pd.DataFrame] = None + self._tsi: Optional[AnnData] = None @property @d.get_summary(base="gpcca_macro") @@ -532,6 +536,165 @@ def set_initial_states( ) return self + # TODO: Add definition/link to paper. + def tsi( + self, + n_macrostates: int, + terminal_states: Optional[List[str]] = None, + cluster_key: Optional[str] = None, + **kwargs: Any, + ) -> float: + """Compute terminal state identificiation (TSI) score. + + Parameters + ---------- + n_macrostates + Maximum number of macrostates to consider. + terminal_states + List of terminal states. + cluster_key + Key in :attr:`~anndata.AnnData.obs` defining cluster labels including terminal states. + kwargs + Keyword arguments passed to :meth:`compute_macrostates` function. + + Returns + ------- + Returns TSI score. + """ + tsi_precomputed = (self._tsi is not None) and (self._tsi[:, "number_of_macrostates"].X.max() >= n_macrostates) + if terminal_states is not None: + tsi_precomputed = tsi_precomputed and (set(self._tsi.uns["terminal_states"]) == set(terminal_states)) + if cluster_key is not None: + tsi_precomputed = tsi_precomputed and (self._tsi.uns["cluster_key"] == cluster_key) + + if not tsi_precomputed: + if terminal_states is None: + raise RuntimeError("`terminal_states` needs to be specified to compute TSI.") + if cluster_key is None: + raise RuntimeError("`cluster_key` needs to be specified to compute TSI.") + + macrostates = {} + for n_states in range(n_macrostates, 0, -1): + self.compute_macrostates(n_states=n_states, cluster_key=cluster_key, **kwargs) + macrostates[n_states] = self.macrostates.cat.categories + + max_terminal_states = len(terminal_states) + + tsi_df = collections.defaultdict(list) + for n_states, states in macrostates.items(): + n_terminal_states = ( + states.str.replace(r"(_).*", "", regex=True).drop_duplicates().isin(terminal_states).sum() + ) + tsi_df["number_of_macrostates"].append(n_states) + tsi_df["identified_terminal_states"].append(n_terminal_states) + + tsi_df["optimal_identification"].append(min(n_states, max_terminal_states)) + + tsi_df = AnnData(pd.DataFrame(tsi_df), uns={"terminal_states": terminal_states, "cluster_key": cluster_key}) + self._tsi = tsi_df + + tsi_df = self._tsi.to_df() + row_mask = tsi_df["number_of_macrostates"] <= n_macrostates + optimal_score = tsi_df.loc[row_mask, "optimal_identification"].sum() + + return tsi_df.loc[row_mask, "identified_terminal_states"].sum() / optimal_score + + @d.dedent + def plot_tsi( + self, + n_macrostates: Optional[int] = None, + x_offset: Tuple[float, float] = (0.2, 0.2), + y_offset: Tuple[float, float] = (0.1, 0.1), + figsize: Tuple[float, float] = (6, 4), + dpi: Optional[int] = None, + save: Optional[Union[str, Path]] = None, + **kwargs: Any, + ) -> Axes: + """Plot terminal state identificiation (TSI). + + Requires computing TSI with :meth:`tsi` first. + + Parameters + ---------- + n_macrostates + Maximum number of macrostates to consider. Defaults to using all. + x_offset + Offset of x-axis. + y_offset + Offset of y-axis. + %(plotting)s + kwargs + Keyword arguments for :func:`~seaborn.lineplot`. + + Returns + ------- + Plot TSI of the kernel and an optimal identification strategy. + """ + if self._tsi is None: + raise RuntimeError("Compute TSI with `tsi` first as `.tsi()`.") + + tsi_df = self._tsi.to_df() + if n_macrostates is not None: + tsi_df = tsi_df.loc[tsi_df["number_of_macrostates"] <= n_macrostates, :] + + optimal_identification = tsi_df[["number_of_macrostates", "optimal_identification"]] + optimal_identification = optimal_identification.rename( + columns={"optimal_identification": "identified_terminal_states"} + ) + optimal_identification["method"] = "Optimal identification" + optimal_identification["line_style"] = "--" + + df = tsi_df[["number_of_macrostates", "identified_terminal_states"]] + df["method"] = self.kernel.__class__.__name__ + df["line_style"] = "-" + + df = pd.concat([df, optimal_identification]) + + fig, ax = plt.subplots(figsize=figsize, dpi=dpi, tight_layout=True) + sns.lineplot( + data=df, + x="number_of_macrostates", + y="identified_terminal_states", + hue="method", + style="line_style", + drawstyle="steps-post", + ax=ax, + **kwargs, + ) + + ax.set_xticks(df["number_of_macrostates"].unique().astype(int)) + # Plot is generated from large to small values on the x-axis + for label_id, label in enumerate(ax.xaxis.get_ticklabels()[::-1]): + if ((label_id + 1) % 5 != 0) and label_id != 0: + label.set_visible(False) + ax.set_yticks(df["identified_terminal_states"].unique()) + + x_min = df["number_of_macrostates"].min() - x_offset[0] + x_max = df["number_of_macrostates"].max() + x_offset[1] + y_min = df["identified_terminal_states"].min() - y_offset[0] + y_max = df["identified_terminal_states"].max() + y_offset[1] + ax.set( + xlim=[x_min, x_max], + ylim=[y_min, y_max], + xlabel="Number of macrostates", + ylabel="Identified terminal states", + ) + + ax.get_legend().remove() + + n_methods = len(df["method"].unique()) + handles, labels = ax.get_legend_handles_labels() + handles[n_methods].set_linestyle("--") + handles = handles[: (n_methods + 1)] + labels = labels[: (n_methods + 1)] + labels[0] = "Method" + fig.legend(handles=handles, labels=labels, loc="lower center", ncol=(n_methods + 1), bbox_to_anchor=(0.5, -0.1)) + + if save is not None: + save_fig(fig=fig, path=save) + + return ax + @d.dedent def fit( self, diff --git a/tests/_ground_truth_adatas/adata_200.h5ad b/tests/_ground_truth_adatas/adata_200.h5ad index d1d52a309..31842a7fb 100644 Binary files a/tests/_ground_truth_adatas/adata_200.h5ad and b/tests/_ground_truth_adatas/adata_200.h5ad differ diff --git a/tests/_ground_truth_figures/plot_tsi.png b/tests/_ground_truth_figures/plot_tsi.png new file mode 100644 index 000000000..d84e18a5a Binary files /dev/null and b/tests/_ground_truth_figures/plot_tsi.png differ diff --git a/tests/test_gpcca.py b/tests/test_gpcca.py index 7de29217c..8af963dee 100644 --- a/tests/test_gpcca.py +++ b/tests/test_gpcca.py @@ -828,6 +828,24 @@ def test_plot_lineage_drivers_normal_run(self, adata_large: AnnData): mc.plot_lineage_drivers("0", use_raw=False) + def test_tsi(self, adata_large: AnnData): + groundtruth_adata = adata_large.uns["tsi"].copy() + + vk = VelocityKernel(adata_large).compute_transition_matrix() + estimator = cr.estimators.GPCCA(vk) + estimator.compute_schur(n_components=5) + + terminal_states = ["Neuroblast", "Astrocyte", "Granule mature"] + cluster_key = "clusters" + tsi_score = estimator.tsi(n_macrostates=3, terminal_states=terminal_states, cluster_key=cluster_key, n_cells=10) + + np.testing.assert_almost_equal(tsi_score, groundtruth_adata.uns["score"]) + assert isinstance(estimator._tsi.uns["terminal_states"], list) + assert len(estimator._tsi.uns["terminal_states"]) == len(groundtruth_adata.uns["terminal_states"]) + assert (estimator._tsi.uns["terminal_states"] == groundtruth_adata.uns["terminal_states"]).all() + assert estimator._tsi.uns["cluster_key"] == groundtruth_adata.uns["cluster_key"] + pd.testing.assert_frame_equal(estimator._tsi.to_df(), groundtruth_adata.to_df()) + def test_compute_priming_clusters(self, adata_large: AnnData): vk = VelocityKernel(adata_large).compute_transition_matrix(softmax_scale=4) ck = ConnectivityKernel(adata_large).compute_transition_matrix() diff --git a/tests/test_plotting.py b/tests/test_plotting.py index ec724eac5..aef689cce 100644 --- a/tests/test_plotting.py +++ b/tests/test_plotting.py @@ -2257,6 +2257,14 @@ def test_scvelo_transition_matrix_projection(self, mc: GPCCA, fpath: str): save=fpath, ) + @compare(kind="gpcca") + def test_plot_tsi(self, mc: GPCCA, fpath: str): + mc = mc.copy(deep=True) + terminal_states = ["Neuroblast", "Astrocyte", "Granule mature"] + cluster_key = "clusters" + _ = mc.tsi(n_macrostates=3, terminal_states=terminal_states, cluster_key=cluster_key, n_cells=10) + mc.plot_tsi(dpi=DPI, save=fpath) + class TestLineage: @compare(kind="lineage")