Skip to content

Commit

Permalink
datastore.py: updated metric_lisi, metric_silhouette, and metric_inte…
Browse files Browse the repository at this point in the history
…gration to use latest KNN location and option to provide KNN location as input; assay.py and metrics.py: ruff formatting
  • Loading branch information
Gautam8387 committed Jan 13, 2025
1 parent ab08c6e commit 60ae0b1
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 119 deletions.
18 changes: 8 additions & 10 deletions scarf/assay.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,6 @@
from .metadata import MetaData
from .utils import controlled_compute, logger, show_dask_progress

zarrGroup = z_hierarchy.Group

__all__ = ["Assay", "RNAassay", "ATACassay", "ADTassay"]


Expand Down Expand Up @@ -102,7 +100,7 @@ class Assay:
for later KNN graph construction.
Args:
z (zarrGroup): Zarr hierarchy where raw data is located
z (z_hierarchy.Group): Zarr hierarchy where raw data is located
name (str): A label/name for assay.
cell_data: Metadata class object for the cell attributes.
nthreads: number for threads to use for dask parallel computations
Expand All @@ -122,7 +120,7 @@ class Assay:

def __init__(
self,
z: zarrGroup,
z: z_hierarchy.Group,
workspace: Union[str, None],
name: str, # FIXME change to assay_name
cell_data: MetaData,
Expand Down Expand Up @@ -757,7 +755,7 @@ class RNAassay(Assay):
normalization of scRNA-Seq data.
Args:
z (zarrGroup): Zarr hierarchy where raw data is located
z (z_hierarchy.Group): Zarr hierarchy where raw data is located
name (str): A label/name for assay.
cell_data: Metadata class object for the cell attributes.
**kwargs: kwargs to be passed to the Assay class
Expand All @@ -769,7 +767,7 @@ class RNAassay(Assay):
It is set to None until normed method is called.
"""

def __init__(self, z: zarrGroup, name: str, cell_data: MetaData, **kwargs):
def __init__(self, z: z_hierarchy.Group, name: str, cell_data: MetaData, **kwargs):
super().__init__(z=z, name=name, cell_data=cell_data, **kwargs)
self.normMethod = norm_lib_size
if "size_factor" in self.attrs:
Expand Down Expand Up @@ -1076,12 +1074,12 @@ class ATACassay(Assay):
"""This subclass of Assay is designed for feature selection and
normalization of scATAC-Seq data."""

def __init__(self, z: zarrGroup, name: str, cell_data: MetaData, **kwargs):
def __init__(self, z: z_hierarchy.Group, name: str, cell_data: MetaData, **kwargs):
"""This Assay subclass is designed for feature selection and
normalization of scATAC-Seq data.
Args:
z (zarrGroup): Zarr hierarchy where raw data is located
z (z_hierarchy.Group): Zarr hierarchy where raw data is located
name (str): A label/name for assay.
cell_data: Metadata class object for the cell attributes.
**kwargs:
Expand Down Expand Up @@ -1208,7 +1206,7 @@ class ADTassay(Assay):
(feature-barcodes library) data from CITE-Seq experiments.
Args:
z (zarrGroup): Zarr hierarchy where raw data is located
z (z_hierarchy.Group): Zarr hierarchy where raw data is located
name (str): A label/name for assay.
cell_data: Metadata class object for the cell attributes.
**kwargs:
Expand All @@ -1217,7 +1215,7 @@ class ADTassay(Assay):
normMethod: Pointer to the function to be used for normalization of the raw data
"""

def __init__(self, z: zarrGroup, name: str, cell_data: MetaData, **kwargs):
def __init__(self, z: z_hierarchy.Group, name: str, cell_data: MetaData, **kwargs):
"""This subclass of Assay is designed for normalization of ADT/HTO
(feature-barcodes library) data from CITE-Seq experiments."""
super().__init__(z=z, name=name, cell_data=cell_data, **kwargs)
Expand Down
132 changes: 30 additions & 102 deletions scarf/datastore/datastore.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,7 @@ def mark_hvgs(
if cell_key is None:
cell_key = "I"
assay = self._get_assay(from_assay)
if type(assay) != RNAassay:
if type(assay) != RNAassay: # noqa: E721
raise TypeError(
f"ERROR: This method of feature selection can only be applied to RNAassay type of assay. "
f"The provided assay is {type(assay)} type"
Expand Down Expand Up @@ -334,7 +334,7 @@ def mark_prevalent_peaks(
if cell_key is None:
cell_key = "I"
assay = self._get_assay(from_assay)
if type(assay) != ATACassay:
if type(assay) != ATACassay: # noqa: E721
raise TypeError(
f"ERROR: This method of feature selection can only be applied to ATACassay type of assay. "
f"The provided assay is {type(assay)} type"
Expand Down Expand Up @@ -1286,7 +1286,7 @@ def plot_cells_dists(
pass

if cols is not None:
if type(cols) != list:
if type(cols) != list: # noqa: E721
raise ValueError("ERROR: 'cols' argument must be of type list")
plot_cols = []
for i in cols:
Expand Down Expand Up @@ -2052,17 +2052,7 @@ def metric_lisi(
label_colnames: Iterable[str],
use_latest_knn: bool = True,
from_assay: Optional[str] = None,
cell_key: Optional[str] = None,
feat_key: Optional[str] = None,
dims: Optional[str] = None,
reduction_method: Optional[str] = None,
pca_cell_key: Optional[str] = None,
ann_metric: Optional[str] = None,
ann_efc: Optional[int] = None,
ann_ef: Optional[int] = None,
ann_m: Optional[int] = None,
rand_state: Optional[int] = 4466,
k: Optional[int] = None,
knn_loc: Optional[str] = None,
save_result: bool = False,
return_lisi: bool = True,
) -> Optional[List[Tuple[str, np.ndarray]]]:
Expand All @@ -2075,17 +2065,7 @@ def metric_lisi(
label_colnames: Column names from cell metadata containing population labels
use_latest_knn: Whether to use the most recent KNN graph (default: True)
from_assay: Name of assay to use if not using latest KNN
cell_key: Cell filtering key for normalization
feat_key: Feature selection key for normalization
dims: Number of dimensions used for reduction
reduction_method: Name of dimensionality reduction method
pca_cell_key: Cell key used for PCA
ann_metric: Metric used for approximate nearest neighbors
ann_efc: Construction time/accuracy trade-off for ANN index
ann_ef: Query time/accuracy trade-off for ANN index
ann_m: Max number of connections in ANN graph
rand_state: Random seed for reproducibility (default: 4466)
k: Number of nearest neighbors
knn_loc: Location of KNN graph if not using latest (default: None)
save_result: Whether to save LISI scores to cell metadata (default: True)
return_lisi: Whether to return LISI scores (default: False)
Expand All @@ -2105,43 +2085,24 @@ def metric_lisi(
Higher scores indicate more mixing between different labels.
"""

if use_latest_knn:
if use_latest_knn and knn_loc is None:
knn_loc = self._get_latest_knn_loc(from_assay)
cell_key = self.zw[self._load_default_assay()].attrs["latest_cell_key"]
logger.info(f"Using the latest knn graph at location: {knn_loc}")
else:
if None in [
from_assay,
cell_key,
feat_key,
dims,
k,
reduction_method,
pca_cell_key,
ann_metric,
ann_efc,
ann_ef,
ann_m,
rand_state,
]:
raise ValueError(
"Please provide values for all the parameters: from_assay, cell_key, feat_key, dims, k, reduction_method, pca_cell_key, ann_metric, ann_efc, ann_ef, ann_m, rand_state"
)
normed_loc = f"{from_assay}/normed__{cell_key}__{feat_key}"
reduction_loc = (
f"{normed_loc}/reduction__{reduction_method}__{dims}__{pca_cell_key}"
)
ann_loc = f"{reduction_loc}/ann__{ann_metric}__{ann_efc}__{ann_ef}__{ann_m}__{rand_state}"
knn_loc = f"{ann_loc}/knn__{k}"

else:
if knn_loc is None:
raise ValueError("Please provide values for the KNN graph location.")
if knn_loc not in self.zw:
raise ValueError(f"Could not find the knn graph at location: {knn_loc}")

logger.info(f"Using the knn graph at location: {knn_loc}")

knn = self.zw[knn_loc]

distances = knn["distances"]
indices = knn["indices"]

try:
metadata = self.cells.to_pandas_dataframe(
columns=label_colnames + [cell_key]
Expand Down Expand Up @@ -2171,17 +2132,7 @@ def metric_silhouette(
use_latest_knn: bool = True,
res_label: str = "leiden_cluster",
from_assay: Optional[str] = None,
cell_key: Optional[str] = None,
feat_key: Optional[str] = None,
dims: Optional[str] = None,
reduction_method: Optional[str] = None,
pca_cell_key: Optional[str] = None,
ann_metric: Optional[str] = None,
ann_efc: Optional[int] = None,
ann_ef: Optional[int] = None,
ann_m: Optional[int] = None,
rand_state: Optional[int] = 4466,
k: Optional[int] = None,
knn_loc: Optional[str] = None,
) -> Optional[np.ndarray]:
"""Calculate modified silhouette scores for evaluating cluster separation.
Expand All @@ -2191,18 +2142,8 @@ def metric_silhouette(
Args:
use_latest_knn: Whether to use most recent KNN graph (default: True)
res_label: Column name containing cluster labels (default: "leiden_cluster")
from_assay: Name of assay to use if not using latest KNN
cell_key: Cell filtering key for normalization
feat_key: Feature selection key for normalization
dims: Number of dimensions used for reduction
reduction_method: Name of dimensionality reduction method
pca_cell_key: Cell key used for PCA
ann_metric: Metric used for approximate nearest neighbors
ann_efc: Construction time/accuracy trade-off for ANN index
ann_ef: Query time/accuracy trade-off for ANN index
ann_m: Max number of connections in ANN graph
rand_state: Random seed for reproducibility (default: 4466)
k: Number of nearest neighbors
from_assay: Name of assay to use if not using latest KNN (default: None)
knn_loc: Location of KNN graph if not using latest (default: None)
Returns:
numpy array of silhouette scores for each cluster, or None if computation fails
Expand All @@ -2220,46 +2161,34 @@ def metric_silhouette(
NaN values indicate clusters that couldn't be scored due to size constraints.
"""

if use_latest_knn:
knn_loc = self._get_latest_knn_loc(from_assay)
from_assay = self._load_default_assay()
logger.info(f"Using the latest knn graph at location: {knn_loc}")
def compute_graph_feats(knn_loc: str):
k = knn_loc.rsplit("/", 1)[-1].split("__")[-1]
dims = knn_loc.rsplit("/", 2)[0].split("__")[-2]
feat_key = knn_loc.split("/")[1].split("__")[-1]
return k, dims, feat_key

else:
if None in [
from_assay,
cell_key,
feat_key,
dims,
k,
reduction_method,
pca_cell_key,
ann_metric,
ann_efc,
ann_ef,
ann_m,
rand_state,
]:
raise ValueError(
"Please provide values for all the parameters: from_assay, cell_key, feat_key, dims, k, reduction_method, pca_cell_key, ann_metric, ann_efc, ann_ef, ann_m, rand_state"
)
normed_loc = f"{from_assay}/normed__{cell_key}__{feat_key}"
reduction_loc = (
f"{normed_loc}/reduction__{reduction_method}__{dims}__{pca_cell_key}"
if from_assay is None:
from_assay = self._load_default_assay()

if use_latest_knn and knn_loc is None:
knn_loc = self._get_latest_knn_loc(from_assay)
k, dims, feat_key = compute_graph_feats(knn_loc)
logger.info(
f"Using the latest knn graph at location: {knn_loc} for assay: {from_assay}"
)
ann_loc = f"{reduction_loc}/ann__{ann_metric}__{ann_efc}__{ann_ef}__{ann_m}__{rand_state}"
knn_loc = f"{ann_loc}/knn__{k}"

else:
if knn_loc is None:
raise ValueError("Please provide values for the KNN graph location.")
if knn_loc not in self.zw:
raise ValueError(f"Could not find the knn graph at location: {knn_loc}")
k, dims, feat_key, from_assay = compute_graph_feats(knn_loc)
logger.info(f"Using the knn graph at location: {knn_loc}")

from ..metrics import knn_to_csr_matrix, silhouette_scoring

isHarmonized = self.zw[knn_loc.rsplit("/", 1)[0]].attrs["isHarmonized"]

batches = None
if isHarmonized:
batches = self.zw[knn_loc.rsplit("/", 2)[0] + "/harmonizedData"].attrs[
Expand All @@ -2274,10 +2203,9 @@ def metric_silhouette(
harmonize=isHarmonized,
batch_columns=batches,
)
graph = knn_to_csr_matrix(self.z[knn_loc].indices, self.z[knn_loc].distances)

graph = knn_to_csr_matrix(self.z[knn_loc].indices, self.z[knn_loc].distances)
hvg_data = self.z[knn_loc.rsplit("/", 3)[0] + "/data"]

scores = silhouette_scoring(
self, ann_obj, graph, hvg_data, from_assay, res_label
)
Expand Down
14 changes: 7 additions & 7 deletions scarf/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,9 +195,9 @@ def calculate_weighted_cluster_similarity(
"""
unique_cluster_ids = np.unique(cluster_labels)
expected_cluster_ids = np.arange(0, len(unique_cluster_ids))
assert np.array_equal(
unique_cluster_ids, expected_cluster_ids
), "Cluster labels must be contiguous integers starting at 1"
assert np.array_equal(unique_cluster_ids, expected_cluster_ids), (
"Cluster labels must be contiguous integers starting at 1"
)

num_clusters = len(unique_cluster_ids)
inter_cluster_weights = np.zeros((num_clusters, num_clusters))
Expand All @@ -212,7 +212,7 @@ def calculate_weighted_cluster_similarity(
):
inter_cluster_weights[cluster_id, neighbor_cluster] += edge_weight

assert inter_cluster_weights.sum() == knn_graph.data.sum()
# assert inter_cluster_weights.sum() == knn_graph.data.sum()

# Ensure symmetry
inter_cluster_weights = (inter_cluster_weights + inter_cluster_weights.T) / 2
Expand Down Expand Up @@ -263,9 +263,9 @@ def calculate_top_k_neighbor_distances(
AssertionError: If matrices don't have the same number of features
"""
# Check if the matrices have the same number of features (d)
assert (
matrix_a.shape[1] == matrix_b.shape[1]
), "Matrices must have the same number of features"
assert matrix_a.shape[1] == matrix_b.shape[1], (
"Matrices must have the same number of features"
)

# Ensure k is not larger than the number of points in matrix_b
k = min(k, matrix_b.shape[0])
Expand Down

0 comments on commit 60ae0b1

Please sign in to comment.