Skip to content

Commit

Permalink
Add CBC score (#1168)
Browse files Browse the repository at this point in the history
* Add `KernelExpression::get_boundary`

Adds function to identify observations at the boundary of two clusters.

* Add `_get_empirical_velocity_field`

Add `KernelExpression` class method to estimate the empirical velocity
field of boundary cells in the source region towards a target cluster.

* Add `_get_velocity_field_estimate`

Adds `KernelExpression` class method to estimate velocity field based on
a single step under the transition matrix.

* Add `KernelExpression::cbc`

Add class method for computing cross-boundary correctness score.

* Add `TestKernel::test_cbc`
  • Loading branch information
WeilerP authored Mar 4, 2024
1 parent f0e09d2 commit 696a23b
Show file tree
Hide file tree
Showing 3 changed files with 147 additions and 0 deletions.
125 changes: 125 additions & 0 deletions src/cellrank/kernels/_base_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,6 +463,131 @@ def _reuse_cache(self, expected_params: Dict[str, Any], *, time: Optional[Any] =
self._params = expected_params
# fmt: on

def _get_boundary(self, source: str, target: str, cluster_key: str, graph_key: str = "distances") -> List[int]:
"""Identify source observations at boundary to target cluster.
Parameters
----------
source
Name of source cluster.
target
Name of target cluster.
cluster_key
Key in :attr:`~anndata.AnnData.obs` to obtain cluster annotations.
graph_key
Name of graph representation to use from :attr:`~anndata.AnnData.obsp`.
Returns
-------
List of observation IDs at boundary to target cluster.
"""
source_obs_mask = self.adata.obs[cluster_key].isin([source] if isinstance(source, str) else source)
target_obs_mask = self.adata.obs[cluster_key].isin([target] if isinstance(target, str) else target)

source_ids = np.where(source_obs_mask)[0]
boundary_ids = []

graph = self.adata.obsp[graph_key]
for source_id in source_ids:
obs_mask = graph[source_id, :].toarray().squeeze().astype(bool)

if (obs_mask & target_obs_mask).any():
boundary_ids.append(source_id)

return boundary_ids

def _get_empirical_velocity_field(
self, boundary_ids: List[int], target_obs_mask, rep: str, graph_key: str = "distances"
) -> np.ndarray:
"""Compute an emprical estimate of velocity field between two clusters.
Parameters
----------
boundary_ids
List of observation IDs at boundary to target cluster.
target_obs_mask
Boolean indicator identifying relevant observations from target.
graph_key
Name of graph representation to use from :attr:`~anndata.AnnData.obsp`.
Returns
-------
Empirical velocity estimate.
"""
obs_ids = np.arange(0, self.adata.n_obs)
graph = self.adata.obsp[graph_key]
features = self.adata.obsm[rep]
empirical_velo = np.empty(shape=(len(boundary_ids), features.shape[1]))

for idx, boundary_id in enumerate(boundary_ids):
row = graph[boundary_id, :].toarray().squeeze()
obs_mask = row.astype(bool) & target_obs_mask
neighbors = obs_ids[obs_mask]
weights = row[obs_mask]

empirical_velo[idx, :] = np.sum(
weights.reshape(-1, 1) * (features[neighbors, :] - features[boundary_id, :]), axis=0
)

empirical_velo = np.array(empirical_velo)
obs_mask = np.isnan(empirical_velo).any(axis=1)
empirical_velo = empirical_velo[~obs_mask, :]

return empirical_velo

def _get_vector_field_estimate(self, rep: str) -> np.ndarray:
"""Compute estimate of vector field under one step of the transition matrix.
Parameters
----------
rep
Key in :attr:`~anndata.AnnData.obsm` to use as data representation.
Returns
-------
Vector field estimate based on kernel dynamics.
"""
extrapolated_gex = self.transition_matrix @ self.adata.obsm[rep]
return extrapolated_gex - self.adata.obsm[rep]

# TODO: Add definition/reference to paper
def cbc(self, source: str, target: str, cluster_key: str, rep: str, graph_key: str = "distances") -> np.ndarray:
"""Compute cross-boundary correctness score between source and target cluster.
Parameters
----------
source
Name of the source cluster.
target
Name of the target cluster.
cluster_key
Key in :attr:`~anndata.AnnData.obs` to obtain cluster annotations.
rep
Key in :attr:`~anndata.AnnData.obsm` to use as data representation.
graph_key
Name of graph representation to use from :attr:`~anndata.AnnData.obsp`.
Returns
-------
Cross-boundary correctness score for each observation.
"""

def _pearsonr(x: np.ndarray, y: np.ndarray) -> np.ndarray:
x_centered = x - np.mean(x, axis=1, keepdims=True)
y_centered = y - np.mean(y, axis=1, keepdims=True)
denom = np.linalg.norm(x_centered, axis=1) * np.linalg.norm(y_centered, axis=1)

return np.sum(x_centered * y_centered, axis=1) / denom

target_obs_mask = self.adata.obs[cluster_key].isin([target])
boundary_ids = self._get_boundary(source=source, target=target, cluster_key=cluster_key, graph_key=graph_key)
empirical_velo = self._get_empirical_velocity_field(
boundary_ids=boundary_ids, target_obs_mask=target_obs_mask, rep=rep, graph_key=graph_key
)
estimated_velo = self._get_vector_field_estimate(rep=rep)[boundary_ids, :]

return _pearsonr(x=estimated_velo, y=empirical_velo)


@d.dedent
class Kernel(KernelExpression, abc.ABC):
Expand Down
Binary file modified tests/_ground_truth_adatas/adata_50.h5ad
Binary file not shown.
22 changes: 22 additions & 0 deletions tests/test_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -583,6 +583,28 @@ def test_connectivities_key_kernel(self, adata: AnnData):
assert T_cr is not adata.obsp[key]
np.testing.assert_array_equal(T_cr.A, adata.obsp[key])

@pytest.mark.parametrize("cluster_pair", [("Granule immature", "Granule mature"), ("nIPC", "Neuroblast")])
@pytest.mark.parametrize("graph_key", ["distances", "connectivities"])
def test_cbc(self, adata: AnnData, cluster_pair: Tuple[str, str], graph_key: str):
cluster_key = "clusters"
rep = "X_pca"
source, target = cluster_pair

vk = cr.kernels.VelocityKernel(adata)
vk.compute_transition_matrix()

ck = cr.kernels.ConnectivityKernel(adata)
ck.compute_transition_matrix()
combined_kernel = 0.8 * vk + 0.2 * ck

cbc_vk = vk.cbc(source=source, target=target, cluster_key=cluster_key, rep=rep)
np.testing.assert_almost_equal(cbc_vk, adata.uns["cbc"][f"{source}-{target}-{graph_key}-vk"])

cbc_combined_kernel = combined_kernel.cbc(source=source, target=target, cluster_key=cluster_key, rep=rep)
np.testing.assert_almost_equal(
cbc_combined_kernel, adata.uns["cbc"][f"{source}-{target}-{graph_key}-0.8vk+0.2ck"]
)


class TestVelocityKernelReadData:
@pytest.mark.parametrize("attr", ["layers", "obsm"])
Expand Down

0 comments on commit 696a23b

Please sign in to comment.