Skip to content

Commit

Permalink
update spatial domain modules to inherit from BaseClusteringMethod (
Browse files Browse the repository at this point in the history
#170)

* update SpaGCN to inherit BaseClusteringMethod

* use fit_predict

* use dance seeding

* use logger

* move calculate_p and search_l to spagcn since it is specific to spagcn

* update Louvain to inherit from BaseClusteringMethod

* use logger

* update stlearn to inherit from BaseClusteringMethod

* add fit_kwargs to fit_score

* adapt Stagate to BaseClusteringMethod class

* minor format fixes

* use absolute imports, sort imports
  • Loading branch information
RemyLau authored Feb 20, 2023
1 parent a72b021 commit f5f34b7
Show file tree
Hide file tree
Showing 12 changed files with 182 additions and 252 deletions.
26 changes: 21 additions & 5 deletions dance/modules/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,18 +42,34 @@ def predict_proba(self, x):
def predict(self, x):
...

def fit_predict(self, x, y, **fit_kwargs):
self.fit(x, y, **fit_kwargs)
pred = self.predict(x)
return pred
@property
def default_score_func(self) -> Mapping[Any, float]:
return resolve_score_func(self._DEFAULT_METRIC)

def score(self, x, y, score_func: Optional[Union[str, Mapping[Any, float]]] = None,
def score(self, x, y, *, score_func: Optional[Union[str, Mapping[Any, float]]] = None,
return_pred: bool = False) -> Union[float, Tuple[float, Any]]:
y_pred = self.predict(x)
func = resolve_score_func(score_func or self._DEFAULT_METRIC)
score = func(y, y_pred)
return (score, y_pred) if return_pred else score

def fit_predict(self, x, y=None, **fit_kwargs):
self.fit(x, y, **fit_kwargs)
pred = self.predict(x)
return pred

def fit_score(self, x, y, *, score_func: Optional[Union[str, Mapping[Any, float]]] = None,
return_pred: bool = False, **fit_kwargs) -> Union[float, Tuple[float, Any]]:
"""Shortcut for fitting data using the input feature and return eval.
Note
----
Only work for models where the fitting does not require labeled data, i.e. unsupervised methods.
"""
self.fit(x, **fit_kwargs)
return self.score(x, y, score_func=score_func, return_pred=return_pred)


class BasePretrain(ABC):

Expand Down
10 changes: 5 additions & 5 deletions dance/modules/spatial/spatial_domain/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from .louvain import Louvain
from .spagcn import SpaGCN
from .stagate import Stagate
from .stlearn import StLouvain
from dance.modules.spatial.spatial_domain.louvain import Louvain
from dance.modules.spatial.spatial_domain.spagcn import SpaGCN
from dance.modules.spatial.spatial_domain.stagate import Stagate
from dance.modules.spatial.spatial_domain.stlearn import StLouvain

__all__ = [
"Louvain",
"SpaGCN",
"Stagate",
"StLouvain",
"Stagate",
]
49 changes: 15 additions & 34 deletions dance/modules/spatial/spatial_domain/louvain.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
Blondel, V. D., et al. "Fast Unfolding of Community Hierarchies in Large Networks, 1–6 (2008)." arXiv:0803.0476.
"""

import array
import numbers
import warnings
Expand All @@ -16,6 +15,8 @@
import numpy as np
import scanpy as sc

from dance import logger
from dance.modules.base import BaseClusteringMethod
from dance.transforms import AnnDataTransform, CellPCA, Compose, FilterGenesMatch, SetConfig
from dance.transforms.graph import NeighborGraph
from dance.typing import LogLevel
Expand Down Expand Up @@ -325,8 +326,8 @@ def best_partition(graph, partition=None, weight="weight", resolution=1., random
return partition_at_level(dendo, len(dendo) - 1)


class Louvain:
"""Louvain class.
class Louvain(BaseClusteringMethod):
"""Louvain classBaseClassificationMethod.
Parameters
----------
Expand Down Expand Up @@ -378,47 +379,27 @@ def fit(self, adj, partition=None, weight="weight", randomize=None, random_state
"""
# convert adata,adj into networkx
print("adj to networkx graph .... ")
logger.info("Converting adjacency matrix to networkx graph...")
if (adj - adj.T).sum() != 0:
ValueError("louvain use no direction graph, but the input is not")
g = nx.from_numpy_array(adj)
print("convert over")
print("start fit ... ")
logger.info("Conversion done. Start fitting...")
self.dendo = generate_dendrogram(g, partition, weight, self.resolution, randomize, random_state)
logger.info("Fitting done.")

print("fit over ")

def predict(self):
"""Prediction function."""
self.predict_result = partition_at_level(self.dendo, len(self.dendo) - 1)
self.y_pred = self.predict_result
return self.predict_result

def score(self, y_true):
"""Score function to evaluate the prediction performance.
def predict(self, x=None):
"""Prediction function.
Parameters
----------
y_true
Ground truth label.
Returns
-------
float
Evaluation score.
x
Not used. For compatibility with :func:`dance.modules.base.BaseMethod.fit_score`, which calls :meth:`fit`
with ``x``.
"""
pred_val = []
for key in self.y_pred:
pred_val.append(self.y_pred[key])

from sklearn.metrics.cluster import adjusted_rand_score
score = adjusted_rand_score(y_true, np.array(pred_val))
print("ARI {}".format(adjusted_rand_score(y_true, np.array(pred_val))))
return score


# add by us
pred_dict = partition_at_level(self.dendo, len(self.dendo) - 1)
pred = np.array(list(map(pred_dict.get, sorted(pred_dict))))
return pred


def generate_dendrogram(graph, part_init=None, weight="weight", resolution=1., randomize=None, random_state=None):
Expand Down
Loading

0 comments on commit f5f34b7

Please sign in to comment.