Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding xgboost #112

Merged
merged 12 commits into from
Dec 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,5 @@ leidenalg
pyensembl
seaborn
harmonypy
scanpy
scanpy
xgboost
3 changes: 3 additions & 0 deletions src/grinch/aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ class OBS:
GAUSSIAN_MIXTURE_SCORE = auto()
LEIDEN = auto()
LOG_REG = auto()
XGB_CLASSIFIER = auto()

class VAR:
N_COUNTS = auto()
Expand All @@ -45,6 +46,7 @@ class OBSM:
X_UMAP = auto()
GAUSSIAN_MIXTURE_PROBA = auto()
X_HARMONY = auto()
XGB_CLASSIFIER_PROBA = auto()

class VARM:
LOG_REG_COEF = auto()
Expand All @@ -59,6 +61,7 @@ class UNS:
N_GENE_ID_TO_NAME_FAILED = auto()
ALL_CUSTOM_LEAD_GENES = auto()
PIPELINE = auto()
XGB_CLASSIFIER_SCORE = auto()

class OBSP:
KNN_CONNECTIVITY = auto()
Expand Down
2 changes: 2 additions & 0 deletions src/grinch/processors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
KMeans,
Leiden,
LogisticRegression,
XGBClassifier,
)
from .repeat import RepeatProcessor
from .splitter import DataSplitter, Splitter
Expand Down Expand Up @@ -50,6 +51,7 @@
'FuzzySimplicialSetGraph',
'Leiden',
'LogisticRegression',
'XGBClassifier',
'DataSplitter',
'RepeatProcessor',
'Splitter',
Expand Down
4 changes: 2 additions & 2 deletions src/grinch/processors/base_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
)

from anndata import AnnData
from pydantic import field_validator, validate_call
from pydantic import Field, field_validator, validate_call

from ..base import StorageMixin
from ..conf import BaseConfigurable
Expand Down Expand Up @@ -95,7 +95,7 @@ class Config(BaseConfigurable.Config):
create: Callable[..., 'BaseProcessor']

attrs_key: WriteKey | None = None
kwargs: Dict[str, ProcessorParam] = {} # Processor kwargs
kwargs: Dict[str, ProcessorParam] = Field(default_factory=dict) # Processor kwargs

# Kwargs used by the processor, but are not ProcessorParam's
__extra_processor_params__: List[str] = []
Expand Down
54 changes: 51 additions & 3 deletions src/grinch/processors/predictors.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,21 @@
import numpy as np
import pandas as pd
from anndata import AnnData
from pydantic import Field, PositiveFloat, PositiveInt, validate_call
from pydantic import (
Field,
NonNegativeInt,
PositiveFloat,
PositiveInt,
validate_call,
)
from sklearn.cluster import KMeans as _KMeans
from sklearn.linear_model import LogisticRegression as _LogisticRegression
from sklearn.mixture import BayesianGaussianMixture as _BayesianGaussianMixture
from sklearn.mixture import GaussianMixture as _GaussianMixture
from sklearn.utils import indexable
from xgboost import XGBClassifier as _XGBClassifier

from ..aliases import OBS, OBSM, OBSP
from ..aliases import OBS, OBSM, OBSP, UNS
from ..base import StorageMixin
from ..custom_types import NP1D_Any, NP1D_float
from ..utils.ops import group_indices
Expand Down Expand Up @@ -208,7 +215,7 @@


class BaseSupervisedPredictor(BasePredictor, abc.ABC):
"""A base class for unsupervised predictors, e.g., clustering."""
"""A base class for supervised predictors, e.g., logistic regression."""
__processor_reqs__ = ['fit']

class Config(BasePredictor.Config):
Expand Down Expand Up @@ -263,3 +270,44 @@
random_state=self.cfg.seed,
**self.cfg.kwargs,
)


class XGBClassifier(BaseSupervisedPredictor):
"""XGBoostClassifier"""
__processor_attrs__ = ['feature_importances_', 'n_features_in_']

class Config(BaseSupervisedPredictor.Config):

if TYPE_CHECKING:
create: Callable[..., 'XGBClassifier']

Check warning on line 282 in src/grinch/processors/predictors.py

View check run for this annotation

Codecov / codecov/patch

src/grinch/processors/predictors.py#L282

Added line #L282 was not covered by tests

labels_key: WriteKey = f"obs.{OBS.XGB_CLASSIFIER}"
proba_key: WriteKey = f"obsm.{OBSM.XGB_CLASSIFIER_PROBA}"
score_key: WriteKey = f"uns.{UNS.XGB_CLASSIFIER_SCORE}"
# XGBoost kwargs
n_estimators: ProcessorParam[PositiveInt | None] = 2
max_depth: ProcessorParam[PositiveInt | None] = 1
max_leaves: ProcessorParam[NonNegativeInt] = 0 # 0 == no limit
learning_rate: ProcessorParam[PositiveFloat | None] = 1.0

cfg: Config

def __init__(self, cfg: Config, /):
super().__init__(cfg)

self.processor: _XGBClassifier = _XGBClassifier(
n_estimators=self.cfg.n_estimators,
max_depth=self.cfg.max_depth,
max_leaves=self.cfg.max_leaves,
learning_rate=self.cfg.learning_rate,
random_state=self.cfg.seed,
**self.cfg.kwargs,
)

def _post_process(self, adata: AnnData) -> None:
x = self.read(adata, self.cfg.x_key)
y = self.read(adata, self.cfg.y_key)
proba = self.processor.predict_proba(x)
score = self.processor.score(x, y)
self.store_item(self.cfg.proba_key, proba)
self.store_item(self.cfg.score_key, score)
89 changes: 63 additions & 26 deletions tests/test_predictors.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,18 @@
X = np.array([
[6, 8, 0, 0, 0],
[5, 7, 0, 0, 0],
[0, 1, 5, 6, 5],
[6, 8, 1, 0, 0],
[4, 7, 0, 0, 0],
[0, 1, 5, 6, 8],
[2, 1, 7, 9, 8],
[0, 1, 5, 6, 7],
[0, 1, 8, 6, 7],
[0, 1, 8, 6, 5],
[2, 1, 7, 8, 8],
[0, 1, 9, 6, 7],
], dtype=np.float32)

K_plus = 4

X_test = np.array([
[0, -1, 5, 6, 5],
[5, 6, 0, 1, 0],
Expand All @@ -42,9 +49,9 @@ def test_kmeans_x(X):
kmeans = cfg.create()
adata = AnnData(X)
kmeans(adata)
outp = adata.obs[OBS.KMEANS]
assert np.unique(outp[:2]).size == 1
assert np.unique(outp[2:]).size == 1
outp = adata.obs[OBS.KMEANS].to_numpy()
assert np.unique(outp[:K_plus]).size == 1
assert np.unique(outp[K_plus:]).size == 1
assert outp[0] != outp[-1]


Expand Down Expand Up @@ -74,15 +81,15 @@ def test_kmeans_x_pca(X):
cfg = instantiate(cfg)
kmeans = cfg.create()
kmeans(adata)
outp = adata.obs[OBS.KMEANS]
assert np.unique(outp[:2]).size == 1
assert np.unique(outp[2:]).size == 1
outp = adata.obs[OBS.KMEANS].to_numpy()
assert np.unique(outp[:K_plus]).size == 1
assert np.unique(outp[K_plus:]).size == 1
assert outp[0] != outp[-1]

adata_test = AnnData(X_test)
pca.transform(adata_test)
kmeans.predict(adata_test)
outp = adata_test.obs[OBS.KMEANS]
outp = adata_test.obs[OBS.KMEANS].to_numpy()
assert outp[0] == outp[2]
assert outp[0] != outp[1]

Expand All @@ -101,17 +108,20 @@ def test_gmix_x(X):
kmeans = cfg.create()
adata = AnnData(X)
kmeans(adata)
outp = adata.obs[OBS.GAUSSIAN_MIXTURE]
assert np.unique(outp[:2]).size == 1
assert np.unique(outp[2:]).size == 1
outp = adata.obs[OBS.GAUSSIAN_MIXTURE].to_numpy()
assert np.unique(outp[:K_plus]).size == 1
assert np.unique(outp[K_plus:]).size == 1
assert outp[0] != outp[-1]
proba = adata.obsm[OBSM.GAUSSIAN_MIXTURE_PROBA]
assert (proba[:2, 0] > proba[:2, 1]).all()
assert (proba[2:, 0] < proba[2:, 1]).all()
assert (proba[:K_plus, 0] > proba[:K_plus, 1]).all()
assert (proba[K_plus:, 0] < proba[K_plus:, 1]).all()


@pytest.mark.parametrize("X", X_mods_no_sparse)
def test_log_reg_x(X):
@pytest.mark.parametrize(
"classifier, key", [("LogisticRegression", OBS.LOG_REG)]
)
def test_classifiers_x(X, classifier, key):
adata = AnnData(X)
cfg_pca = OmegaConf.create(
{
Expand Down Expand Up @@ -139,38 +149,63 @@ def test_log_reg_x(X):

cfg = OmegaConf.create(
{
"_target_": "src.grinch.LogisticRegression.Config",
"_target_": f"src.grinch.{classifier}.Config",
"x_key": f"obsm.{OBSM.X_PCA}",
"y_key": f"obs.{OBS.KMEANS}",
"seed": 42,
"labels_key": f"obs.{OBS.LOG_REG}",
"labels_key": f"obs.{key}",
}
)
# Need to start using convert all for lists and dicts
cfg = instantiate(cfg, _convert_='all')
lr = cfg.create()
lr(adata)
outp = adata.obs[OBS.LOG_REG]
assert np.unique(outp[:2]).size == 1
assert np.unique(outp[2:]).size == 1
outp = adata.obs[key].to_numpy()
assert np.unique(outp[:K_plus]).size == 1
assert np.unique(outp[K_plus:]).size == 1
assert outp[0] != outp[-1]

adata_test = AnnData(X_test)
pca.transform(adata_test)
lr.predict(adata_test)
outp = adata_test.obs[OBS.LOG_REG]
outp = adata_test.obs[key].to_numpy()
assert outp[0] == outp[2]
assert outp[0] != outp[1]


def test_xgboost():
from sklearn.datasets import make_classification
X, y = make_classification(
n_samples=100, n_features=2, n_informative=2, n_redundant=0,
random_state=42, n_clusters_per_class=1, flip_y=False, class_sep=2.0)
adata = AnnData(X)
adata.obs['y'] = y

cfg = OmegaConf.create(
{
"_target_": "src.grinch.XGBClassifier.Config",
"seed": 42,
"x_key": "X",
"y_key": "obs.y",
}
)

cfg = instantiate(cfg)
obj = cfg.create()
obj(adata)
outp = adata.obs[OBS.XGB_CLASSIFIER].to_numpy()
# < 5% error
assert (y != outp).mean() < 0.05 or (y != 1 - outp).mean() < 0.05


@pytest.mark.parametrize("X", X_mods)
def test_leiden(X):
adata = AnnData(X)
cfg_knn = OmegaConf.create(
{
"_target_": "src.grinch.KNNGraph.Config",
"x_key": "X",
"n_neighbors": 1,
"n_neighbors": 3,
}
)
cfg_knn = instantiate(cfg_knn)
Expand All @@ -182,20 +217,22 @@ def test_leiden(X):
"_target_": "src.grinch.Leiden.Config",
"x_key": f"obsp.{OBSP.KNN_DISTANCE}",
"seed": 42,
"resolution": 0.5,
}
)
cfg = instantiate(cfg)
leiden = cfg.create()
leiden(adata)
pred = adata.obs[OBS.LEIDEN]
true = np.array([0, 0, 1, 1, 1])
pred = adata.obs[OBS.LEIDEN].to_numpy()
true = np.ones(X.shape[0])
true[:K_plus] = 0
if pred[0] == 1:
true = 1 - true
assert_allclose(pred, true)

centroids = {
pred[0]: np.ravel(X[:2].mean(axis=0)),
1 - pred[0]: np.ravel(X[2:].mean(axis=0)),
pred[0]: np.ravel(X[:K_plus].mean(axis=0)),
1 - pred[0]: np.ravel(X[K_plus:].mean(axis=0)),
}
pred_centroid = adata.uns['leiden_']["cluster_centers_"]
assert_allclose(centroids[0], pred_centroid['0'])
Expand Down
2 changes: 2 additions & 0 deletions tests/test_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ def test_umap(X):
# things happening with spectral initialization and reproducibility
'kwargs': {
'init': 'random',
'n_jobs': 1, # since using random seed
}
}
)
Expand All @@ -98,6 +99,7 @@ def test_umap(X):
random_state=SEED,
transform_seed=SEED,
init='random',
n_jobs=1,
)
adata = AnnData(X)
up = cfg.create()
Expand Down