diff --git a/requirements.txt b/requirements.txt index f1279cd..d059cf7 100644 --- a/requirements.txt +++ b/requirements.txt @@ -14,4 +14,5 @@ leidenalg pyensembl seaborn harmonypy -scanpy \ No newline at end of file +scanpy +xgboost \ No newline at end of file diff --git a/src/grinch/aliases.py b/src/grinch/aliases.py index dcfccad..f7ba2da 100644 --- a/src/grinch/aliases.py +++ b/src/grinch/aliases.py @@ -23,6 +23,7 @@ class OBS: GAUSSIAN_MIXTURE_SCORE = auto() LEIDEN = auto() LOG_REG = auto() + XGB_CLASSIFIER = auto() class VAR: N_COUNTS = auto() @@ -45,6 +46,7 @@ class OBSM: X_UMAP = auto() GAUSSIAN_MIXTURE_PROBA = auto() X_HARMONY = auto() + XGB_CLASSIFIER_PROBA = auto() class VARM: LOG_REG_COEF = auto() @@ -59,6 +61,7 @@ class UNS: N_GENE_ID_TO_NAME_FAILED = auto() ALL_CUSTOM_LEAD_GENES = auto() PIPELINE = auto() + XGB_CLASSIFIER_SCORE = auto() class OBSP: KNN_CONNECTIVITY = auto() diff --git a/src/grinch/processors/__init__.py b/src/grinch/processors/__init__.py index 311c01a..0d0dc41 100644 --- a/src/grinch/processors/__init__.py +++ b/src/grinch/processors/__init__.py @@ -18,6 +18,7 @@ KMeans, Leiden, LogisticRegression, + XGBClassifier, ) from .repeat import RepeatProcessor from .splitter import DataSplitter, Splitter @@ -50,6 +51,7 @@ 'FuzzySimplicialSetGraph', 'Leiden', 'LogisticRegression', + 'XGBClassifier', 'DataSplitter', 'RepeatProcessor', 'Splitter', diff --git a/src/grinch/processors/base_processor.py b/src/grinch/processors/base_processor.py index 00554b4..77f4c9a 100644 --- a/src/grinch/processors/base_processor.py +++ b/src/grinch/processors/base_processor.py @@ -16,7 +16,7 @@ ) from anndata import AnnData -from pydantic import field_validator, validate_call +from pydantic import Field, field_validator, validate_call from ..base import StorageMixin from ..conf import BaseConfigurable @@ -95,7 +95,7 @@ class Config(BaseConfigurable.Config): create: Callable[..., 'BaseProcessor'] attrs_key: WriteKey | None = None - kwargs: Dict[str, ProcessorParam] = {} # Processor kwargs + kwargs: Dict[str, ProcessorParam] = Field(default_factory=dict) # Processor kwargs # Kwargs used by the processor, but are not ProcessorParam's __extra_processor_params__: List[str] = [] diff --git a/src/grinch/processors/predictors.py b/src/grinch/processors/predictors.py index c9e4a61..9dfc594 100644 --- a/src/grinch/processors/predictors.py +++ b/src/grinch/processors/predictors.py @@ -5,14 +5,21 @@ import numpy as np import pandas as pd from anndata import AnnData -from pydantic import Field, PositiveFloat, PositiveInt, validate_call +from pydantic import ( + Field, + NonNegativeInt, + PositiveFloat, + PositiveInt, + validate_call, +) from sklearn.cluster import KMeans as _KMeans from sklearn.linear_model import LogisticRegression as _LogisticRegression from sklearn.mixture import BayesianGaussianMixture as _BayesianGaussianMixture from sklearn.mixture import GaussianMixture as _GaussianMixture from sklearn.utils import indexable +from xgboost import XGBClassifier as _XGBClassifier -from ..aliases import OBS, OBSM, OBSP +from ..aliases import OBS, OBSM, OBSP, UNS from ..base import StorageMixin from ..custom_types import NP1D_Any, NP1D_float from ..utils.ops import group_indices @@ -208,7 +215,7 @@ def _post_process(self, adata: AnnData) -> None: class BaseSupervisedPredictor(BasePredictor, abc.ABC): - """A base class for unsupervised predictors, e.g., clustering.""" + """A base class for supervised predictors, e.g., logistic regression.""" __processor_reqs__ = ['fit'] class Config(BasePredictor.Config): @@ -263,3 +270,44 @@ def __init__(self, cfg: Config, /): random_state=self.cfg.seed, **self.cfg.kwargs, ) + + +class XGBClassifier(BaseSupervisedPredictor): + """XGBoostClassifier""" + __processor_attrs__ = ['feature_importances_', 'n_features_in_'] + + class Config(BaseSupervisedPredictor.Config): + + if TYPE_CHECKING: + create: Callable[..., 'XGBClassifier'] + + labels_key: WriteKey = f"obs.{OBS.XGB_CLASSIFIER}" + proba_key: WriteKey = f"obsm.{OBSM.XGB_CLASSIFIER_PROBA}" + score_key: WriteKey = f"uns.{UNS.XGB_CLASSIFIER_SCORE}" + # XGBoost kwargs + n_estimators: ProcessorParam[PositiveInt | None] = 2 + max_depth: ProcessorParam[PositiveInt | None] = 1 + max_leaves: ProcessorParam[NonNegativeInt] = 0 # 0 == no limit + learning_rate: ProcessorParam[PositiveFloat | None] = 1.0 + + cfg: Config + + def __init__(self, cfg: Config, /): + super().__init__(cfg) + + self.processor: _XGBClassifier = _XGBClassifier( + n_estimators=self.cfg.n_estimators, + max_depth=self.cfg.max_depth, + max_leaves=self.cfg.max_leaves, + learning_rate=self.cfg.learning_rate, + random_state=self.cfg.seed, + **self.cfg.kwargs, + ) + + def _post_process(self, adata: AnnData) -> None: + x = self.read(adata, self.cfg.x_key) + y = self.read(adata, self.cfg.y_key) + proba = self.processor.predict_proba(x) + score = self.processor.score(x, y) + self.store_item(self.cfg.proba_key, proba) + self.store_item(self.cfg.score_key, score) diff --git a/tests/test_predictors.py b/tests/test_predictors.py index 0cefa4a..6d530e8 100644 --- a/tests/test_predictors.py +++ b/tests/test_predictors.py @@ -12,11 +12,18 @@ X = np.array([ [6, 8, 0, 0, 0], [5, 7, 0, 0, 0], - [0, 1, 5, 6, 5], + [6, 8, 1, 0, 0], + [4, 7, 0, 0, 0], + [0, 1, 5, 6, 8], [2, 1, 7, 9, 8], - [0, 1, 5, 6, 7], + [0, 1, 8, 6, 7], + [0, 1, 8, 6, 5], + [2, 1, 7, 8, 8], + [0, 1, 9, 6, 7], ], dtype=np.float32) +K_plus = 4 + X_test = np.array([ [0, -1, 5, 6, 5], [5, 6, 0, 1, 0], @@ -42,9 +49,9 @@ def test_kmeans_x(X): kmeans = cfg.create() adata = AnnData(X) kmeans(adata) - outp = adata.obs[OBS.KMEANS] - assert np.unique(outp[:2]).size == 1 - assert np.unique(outp[2:]).size == 1 + outp = adata.obs[OBS.KMEANS].to_numpy() + assert np.unique(outp[:K_plus]).size == 1 + assert np.unique(outp[K_plus:]).size == 1 assert outp[0] != outp[-1] @@ -74,15 +81,15 @@ def test_kmeans_x_pca(X): cfg = instantiate(cfg) kmeans = cfg.create() kmeans(adata) - outp = adata.obs[OBS.KMEANS] - assert np.unique(outp[:2]).size == 1 - assert np.unique(outp[2:]).size == 1 + outp = adata.obs[OBS.KMEANS].to_numpy() + assert np.unique(outp[:K_plus]).size == 1 + assert np.unique(outp[K_plus:]).size == 1 assert outp[0] != outp[-1] adata_test = AnnData(X_test) pca.transform(adata_test) kmeans.predict(adata_test) - outp = adata_test.obs[OBS.KMEANS] + outp = adata_test.obs[OBS.KMEANS].to_numpy() assert outp[0] == outp[2] assert outp[0] != outp[1] @@ -101,17 +108,20 @@ def test_gmix_x(X): kmeans = cfg.create() adata = AnnData(X) kmeans(adata) - outp = adata.obs[OBS.GAUSSIAN_MIXTURE] - assert np.unique(outp[:2]).size == 1 - assert np.unique(outp[2:]).size == 1 + outp = adata.obs[OBS.GAUSSIAN_MIXTURE].to_numpy() + assert np.unique(outp[:K_plus]).size == 1 + assert np.unique(outp[K_plus:]).size == 1 assert outp[0] != outp[-1] proba = adata.obsm[OBSM.GAUSSIAN_MIXTURE_PROBA] - assert (proba[:2, 0] > proba[:2, 1]).all() - assert (proba[2:, 0] < proba[2:, 1]).all() + assert (proba[:K_plus, 0] > proba[:K_plus, 1]).all() + assert (proba[K_plus:, 0] < proba[K_plus:, 1]).all() @pytest.mark.parametrize("X", X_mods_no_sparse) -def test_log_reg_x(X): +@pytest.mark.parametrize( + "classifier, key", [("LogisticRegression", OBS.LOG_REG)] +) +def test_classifiers_x(X, classifier, key): adata = AnnData(X) cfg_pca = OmegaConf.create( { @@ -139,30 +149,55 @@ def test_log_reg_x(X): cfg = OmegaConf.create( { - "_target_": "src.grinch.LogisticRegression.Config", + "_target_": f"src.grinch.{classifier}.Config", "x_key": f"obsm.{OBSM.X_PCA}", "y_key": f"obs.{OBS.KMEANS}", "seed": 42, - "labels_key": f"obs.{OBS.LOG_REG}", + "labels_key": f"obs.{key}", } ) # Need to start using convert all for lists and dicts cfg = instantiate(cfg, _convert_='all') lr = cfg.create() lr(adata) - outp = adata.obs[OBS.LOG_REG] - assert np.unique(outp[:2]).size == 1 - assert np.unique(outp[2:]).size == 1 + outp = adata.obs[key].to_numpy() + assert np.unique(outp[:K_plus]).size == 1 + assert np.unique(outp[K_plus:]).size == 1 assert outp[0] != outp[-1] adata_test = AnnData(X_test) pca.transform(adata_test) lr.predict(adata_test) - outp = adata_test.obs[OBS.LOG_REG] + outp = adata_test.obs[key].to_numpy() assert outp[0] == outp[2] assert outp[0] != outp[1] +def test_xgboost(): + from sklearn.datasets import make_classification + X, y = make_classification( + n_samples=100, n_features=2, n_informative=2, n_redundant=0, + random_state=42, n_clusters_per_class=1, flip_y=False, class_sep=2.0) + adata = AnnData(X) + adata.obs['y'] = y + + cfg = OmegaConf.create( + { + "_target_": "src.grinch.XGBClassifier.Config", + "seed": 42, + "x_key": "X", + "y_key": "obs.y", + } + ) + + cfg = instantiate(cfg) + obj = cfg.create() + obj(adata) + outp = adata.obs[OBS.XGB_CLASSIFIER].to_numpy() + # < 5% error + assert (y != outp).mean() < 0.05 or (y != 1 - outp).mean() < 0.05 + + @pytest.mark.parametrize("X", X_mods) def test_leiden(X): adata = AnnData(X) @@ -170,7 +205,7 @@ def test_leiden(X): { "_target_": "src.grinch.KNNGraph.Config", "x_key": "X", - "n_neighbors": 1, + "n_neighbors": 3, } ) cfg_knn = instantiate(cfg_knn) @@ -182,20 +217,22 @@ def test_leiden(X): "_target_": "src.grinch.Leiden.Config", "x_key": f"obsp.{OBSP.KNN_DISTANCE}", "seed": 42, + "resolution": 0.5, } ) cfg = instantiate(cfg) leiden = cfg.create() leiden(adata) - pred = adata.obs[OBS.LEIDEN] - true = np.array([0, 0, 1, 1, 1]) + pred = adata.obs[OBS.LEIDEN].to_numpy() + true = np.ones(X.shape[0]) + true[:K_plus] = 0 if pred[0] == 1: true = 1 - true assert_allclose(pred, true) centroids = { - pred[0]: np.ravel(X[:2].mean(axis=0)), - 1 - pred[0]: np.ravel(X[2:].mean(axis=0)), + pred[0]: np.ravel(X[:K_plus].mean(axis=0)), + 1 - pred[0]: np.ravel(X[K_plus:].mean(axis=0)), } pred_centroid = adata.uns['leiden_']["cluster_centers_"] assert_allclose(centroids[0], pred_centroid['0']) diff --git a/tests/test_transformers.py b/tests/test_transformers.py index 185f5e1..7321c8b 100644 --- a/tests/test_transformers.py +++ b/tests/test_transformers.py @@ -87,6 +87,7 @@ def test_umap(X): # things happening with spectral initialization and reproducibility 'kwargs': { 'init': 'random', + 'n_jobs': 1, # since using random seed } } ) @@ -98,6 +99,7 @@ def test_umap(X): random_state=SEED, transform_seed=SEED, init='random', + n_jobs=1, ) adata = AnnData(X) up = cfg.create()