Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add TSI score code #1166

Merged
merged 20 commits into from
Mar 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
165 changes: 164 additions & 1 deletion src/cellrank/estimators/terminal_states/_gpcca.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
import collections
import datetime
import enum
import pathlib
import types
from typing import Any, Dict, Literal, Mapping, Optional, Sequence, Tuple, Union
from pathlib import Path
from typing import Any, Dict, List, Literal, Mapping, Optional, Sequence, Tuple, Union

import numpy as np
import pandas as pd
import scipy.sparse as sp
from pandas.api.types import infer_dtype

import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.axes import Axes
from matplotlib.colorbar import ColorbarBase
from matplotlib.colors import ListedColormap, Normalize
Expand Down Expand Up @@ -86,6 +89,7 @@ def __init__(
self._coarse_init_dist: Optional[pd.Series] = None
self._coarse_stat_dist: Optional[pd.Series] = None
self._coarse_tmat: Optional[pd.DataFrame] = None
self._tsi: Optional[AnnData] = None

@property
@d.get_summary(base="gpcca_macro")
Expand Down Expand Up @@ -532,6 +536,165 @@ def set_initial_states(
)
return self

# TODO: Add definition/link to paper.
def tsi(
WeilerP marked this conversation as resolved.
Show resolved Hide resolved
self,
n_macrostates: int,
terminal_states: Optional[List[str]] = None,
cluster_key: Optional[str] = None,
**kwargs: Any,
) -> float:
"""Compute terminal state identificiation (TSI) score.

WeilerP marked this conversation as resolved.
Show resolved Hide resolved
Parameters
----------
n_macrostates
Maximum number of macrostates to consider.
terminal_states
List of terminal states.
cluster_key
Key in :attr:`~anndata.AnnData.obs` defining cluster labels including terminal states.
kwargs
Keyword arguments passed to :meth:`compute_macrostates` function.

Returns
-------
Returns TSI score.
"""
tsi_precomputed = (self._tsi is not None) and (self._tsi[:, "number_of_macrostates"].X.max() >= n_macrostates)
if terminal_states is not None:
tsi_precomputed = tsi_precomputed and (set(self._tsi.uns["terminal_states"]) == set(terminal_states))
if cluster_key is not None:
tsi_precomputed = tsi_precomputed and (self._tsi.uns["cluster_key"] == cluster_key)

if not tsi_precomputed:
if terminal_states is None:
raise RuntimeError("`terminal_states` needs to be specified to compute TSI.")
if cluster_key is None:
raise RuntimeError("`cluster_key` needs to be specified to compute TSI.")

macrostates = {}
for n_states in range(n_macrostates, 0, -1):
self.compute_macrostates(n_states=n_states, cluster_key=cluster_key, **kwargs)
macrostates[n_states] = self.macrostates.cat.categories

max_terminal_states = len(terminal_states)

tsi_df = collections.defaultdict(list)
for n_states, states in macrostates.items():
n_terminal_states = (
states.str.replace(r"(_).*", "", regex=True).drop_duplicates().isin(terminal_states).sum()
)
tsi_df["number_of_macrostates"].append(n_states)
tsi_df["identified_terminal_states"].append(n_terminal_states)

tsi_df["optimal_identification"].append(min(n_states, max_terminal_states))

tsi_df = AnnData(pd.DataFrame(tsi_df), uns={"terminal_states": terminal_states, "cluster_key": cluster_key})
self._tsi = tsi_df
WeilerP marked this conversation as resolved.
Show resolved Hide resolved

tsi_df = self._tsi.to_df()
row_mask = tsi_df["number_of_macrostates"] <= n_macrostates
optimal_score = tsi_df.loc[row_mask, "optimal_identification"].sum()

return tsi_df.loc[row_mask, "identified_terminal_states"].sum() / optimal_score

@d.dedent
def plot_tsi(
self,
n_macrostates: Optional[int] = None,
x_offset: Tuple[float, float] = (0.2, 0.2),
y_offset: Tuple[float, float] = (0.1, 0.1),
figsize: Tuple[float, float] = (6, 4),
dpi: Optional[int] = None,
save: Optional[Union[str, Path]] = None,
**kwargs: Any,
) -> Axes:
"""Plot terminal state identificiation (TSI).

WeilerP marked this conversation as resolved.
Show resolved Hide resolved
Requires computing TSI with :meth:`tsi` first.

Parameters
----------
n_macrostates
Maximum number of macrostates to consider. Defaults to using all.
x_offset
Offset of x-axis.
y_offset
Offset of y-axis.
%(plotting)s
kwargs
Keyword arguments for :func:`~seaborn.lineplot`.

Returns
-------
Plot TSI of the kernel and an optimal identification strategy.
"""
if self._tsi is None:
raise RuntimeError("Compute TSI with `tsi` first as `.tsi()`.")

tsi_df = self._tsi.to_df()
if n_macrostates is not None:
tsi_df = tsi_df.loc[tsi_df["number_of_macrostates"] <= n_macrostates, :]

optimal_identification = tsi_df[["number_of_macrostates", "optimal_identification"]]
optimal_identification = optimal_identification.rename(
columns={"optimal_identification": "identified_terminal_states"}
)
optimal_identification["method"] = "Optimal identification"
optimal_identification["line_style"] = "--"
WeilerP marked this conversation as resolved.
Show resolved Hide resolved

df = tsi_df[["number_of_macrostates", "identified_terminal_states"]]
df["method"] = self.kernel.__class__.__name__
df["line_style"] = "-"

df = pd.concat([df, optimal_identification])

fig, ax = plt.subplots(figsize=figsize, dpi=dpi, tight_layout=True)
sns.lineplot(
data=df,
x="number_of_macrostates",
y="identified_terminal_states",
hue="method",
style="line_style",
drawstyle="steps-post",
ax=ax,
**kwargs,
)

ax.set_xticks(df["number_of_macrostates"].unique().astype(int))
# Plot is generated from large to small values on the x-axis
for label_id, label in enumerate(ax.xaxis.get_ticklabels()[::-1]):
if ((label_id + 1) % 5 != 0) and label_id != 0:
label.set_visible(False)
ax.set_yticks(df["identified_terminal_states"].unique())

x_min = df["number_of_macrostates"].min() - x_offset[0]
x_max = df["number_of_macrostates"].max() + x_offset[1]
y_min = df["identified_terminal_states"].min() - y_offset[0]
y_max = df["identified_terminal_states"].max() + y_offset[1]
ax.set(
xlim=[x_min, x_max],
ylim=[y_min, y_max],
xlabel="Number of macrostates",
ylabel="Identified terminal states",
)

ax.get_legend().remove()

n_methods = len(df["method"].unique())
handles, labels = ax.get_legend_handles_labels()
handles[n_methods].set_linestyle("--")
handles = handles[: (n_methods + 1)]
labels = labels[: (n_methods + 1)]
labels[0] = "Method"
fig.legend(handles=handles, labels=labels, loc="lower center", ncol=(n_methods + 1), bbox_to_anchor=(0.5, -0.1))

if save is not None:
save_fig(fig=fig, path=save)

return ax

@d.dedent
def fit(
self,
Expand Down
Binary file modified tests/_ground_truth_adatas/adata_200.h5ad
Binary file not shown.
Binary file added tests/_ground_truth_figures/plot_tsi.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
18 changes: 18 additions & 0 deletions tests/test_gpcca.py
Original file line number Diff line number Diff line change
Expand Up @@ -828,6 +828,24 @@ def test_plot_lineage_drivers_normal_run(self, adata_large: AnnData):

mc.plot_lineage_drivers("0", use_raw=False)

def test_tsi(self, adata_large: AnnData):
groundtruth_adata = adata_large.uns["tsi"].copy()

vk = VelocityKernel(adata_large).compute_transition_matrix()
estimator = cr.estimators.GPCCA(vk)
estimator.compute_schur(n_components=5)

terminal_states = ["Neuroblast", "Astrocyte", "Granule mature"]
cluster_key = "clusters"
tsi_score = estimator.tsi(n_macrostates=3, terminal_states=terminal_states, cluster_key=cluster_key, n_cells=10)

np.testing.assert_almost_equal(tsi_score, groundtruth_adata.uns["score"])
assert isinstance(estimator._tsi.uns["terminal_states"], list)
assert len(estimator._tsi.uns["terminal_states"]) == len(groundtruth_adata.uns["terminal_states"])
assert (estimator._tsi.uns["terminal_states"] == groundtruth_adata.uns["terminal_states"]).all()
assert estimator._tsi.uns["cluster_key"] == groundtruth_adata.uns["cluster_key"]
pd.testing.assert_frame_equal(estimator._tsi.to_df(), groundtruth_adata.to_df())

def test_compute_priming_clusters(self, adata_large: AnnData):
vk = VelocityKernel(adata_large).compute_transition_matrix(softmax_scale=4)
ck = ConnectivityKernel(adata_large).compute_transition_matrix()
Expand Down
8 changes: 8 additions & 0 deletions tests/test_plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -2257,6 +2257,14 @@ def test_scvelo_transition_matrix_projection(self, mc: GPCCA, fpath: str):
save=fpath,
)

@compare(kind="gpcca")
def test_plot_tsi(self, mc: GPCCA, fpath: str):
mc = mc.copy(deep=True)
terminal_states = ["Neuroblast", "Astrocyte", "Granule mature"]
cluster_key = "clusters"
_ = mc.tsi(n_macrostates=3, terminal_states=terminal_states, cluster_key=cluster_key, n_cells=10)
mc.plot_tsi(dpi=DPI, save=fpath)


class TestLineage:
@compare(kind="lineage")
Expand Down
Loading