Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding Harmony and multiread #108

Merged
merged 2 commits into from
Nov 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,4 @@ figs
.coverage
docs/_build
test-results
docs
13 changes: 0 additions & 13 deletions codecov.yml

This file was deleted.

4 changes: 3 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,6 @@ diptest
phenotype_cover
leidenalg
pyensembl
seaborn
seaborn
harmonypy
scanpy
8 changes: 6 additions & 2 deletions src/grinch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,13 @@
from . import processors as pr
from . import shortcuts
from .aliases import ADK, OBS, OBSM, OBSP, UNS, VAR, VARM, VARP, AnnDataKeys
from .base import StorageMixin
from .cond_filter import Filter, StackedFilter
from .conf import BaseConfigurable
from .filters import FilterCells, FilterGenes, VarianceFilter
from .main import instantiate_config
from .normalizers import Combat, Log1P, NormalizeTotal, Scale
from .pipeline import GRPipeline
from .normalizers import Combat, Harmony, Log1P, NormalizeTotal, Scale
from .pipeline import GRPipeline, MultiRead
from .processors import * # noqa
from .reporter import Report, Reporter
from .shortcuts import * # noqa
Expand All @@ -33,9 +34,11 @@
'OBSP',
'VARP',
'UNS',
'StorageMixin',
'AnnDataKeys',
'BaseConfigurable',
'GRPipeline',
'MultiRead',
'FilterCells',
'FilterGenes',
'VarianceFilter',
Expand All @@ -44,6 +47,7 @@
'Filter',
'StackedFilter',
'Combat',
'Harmony',
'Log1P',
'Scale',
'NormalizeTotal',
Expand Down
1 change: 1 addition & 0 deletions src/grinch/aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ class OBSM:
X_TRUNCATED_SVD = auto()
X_UMAP = auto()
GAUSSIAN_MIXTURE_PROBA = auto()
X_HARMONY = auto()

class VARM:
LOG_REG_COEF = auto()
Expand Down
20 changes: 19 additions & 1 deletion src/grinch/cond_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,10 @@ class Filter(BaseModel, Generic[T]):
>>> r([3, 4, 5, 6, 7], as_mask=True)
array([False, False, False, False, False])
"""
__conditions__ = ['ge', 'le', 'gt', 'lt', 'top_k', 'bot_k', 'top_ratio', 'bot_ratio']
__conditions__ = ['ge', 'le', 'gt', 'lt',
'equal', 'not_equal',
'top_k', 'bot_k',
'top_ratio', 'bot_ratio']

model_config = {
'validate_assignment': True,
Expand All @@ -80,6 +83,9 @@ class Filter(BaseModel, Generic[T]):
gt: T | None = None # greater than
lt: T | None = None # less than

equal: T | None = None # exactly equal to
not_equal: T | None = None # not equal to

top_k: NonNegativeInt | None = None # top k items after sorting
bot_k: NonNegativeInt | None = None # bottom k items after sorting
# These will be rounded up to the nearest item
Expand Down Expand Up @@ -167,6 +173,15 @@ def _take_ratio(self, arr, as_mask: bool = True):
k = int(np.ceil(ratio * len(arr))) # round up
return self._take_k_functional(arr, k, as_mask, self.is_top)

def _take_equal(self, arr, as_mask: bool = True):
"""Take elements exactly equal to `self.cfg.equal`.
"""
if self.equal is not None:
mask = arr == self.equal
elif self.not_equal is not None:
mask = arr != self.not_equal
return mask if as_mask else arr[mask]

def _take_cutoff(self, arr, as_mask: bool = True):
"""Takes the elements which are greater than or less than cutoff.
"""
Expand Down Expand Up @@ -233,6 +248,8 @@ def __call__(self, obj, as_mask=True):

if any_not_None(self.ge, self.gt, self.le, self.lt):
return self._take_cutoff(arr, as_mask)
if any_not_None(self.equal, self.not_equal):
return self._take_equal(arr, as_mask)
if any_not_None(self.top_k, self.bot_k):
return self._take_k(arr, as_mask)
if any_not_None(self.top_ratio, self.bot_ratio):
Expand All @@ -254,6 +271,7 @@ class StackedFilter(UserList):
*filters: iterable
An iterable of Filter's or StackedFilter's.
"""

def __init__(self, *filters: Filter | StackedFilter):
__filters__: List[Filter] = []

Expand Down
12 changes: 2 additions & 10 deletions src/grinch/main.py
100644 → 100755
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#!/Users/ehasanaj/mambaforge/envs/m10/bin/python
import argparse
import logging
import os
Expand All @@ -18,18 +19,9 @@
logging.captureWarnings(True)


try:
import grinch
src_dir = os.path.dirname(grinch.__file__)
except ImportError:
src_dir = os.path.dirname(os.path.realpath(__file__))

root_dir = os.path.abspath(os.path.join(src_dir, os.pardir, os.pardir))


def instantiate_config(config_name):
head, tail = os.path.split(config_name)
config_dir = os.path.join(root_dir, head)
config_dir = os.path.join(os.getcwd(), 'conf')
# context initialization
with hydra.initialize_config_dir(version_base=None, config_dir=config_dir):
cfg = hydra.compose(config_name=tail)
Expand Down
40 changes: 35 additions & 5 deletions src/grinch/normalizers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import abc
from typing import TYPE_CHECKING, Callable, Optional
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional

import harmonypy
import numpy as np
import pandas as pd
import scipy.sparse as sp
Expand All @@ -9,9 +10,11 @@
from sklearn.preprocessing import normalize
from sklearn.utils.validation import check_array, check_non_negative

from .aliases import OBSM
from .base import StorageMixin
from .conf import BaseConfigurable
from .external.combat import combat # type: ignore
from .processors import BaseProcessor
from .processors import ReadKey, WriteKey
from .utils.stats import mean_var


Expand Down Expand Up @@ -78,7 +81,7 @@ def _normalize(self, adata: AnnData) -> None:
raise NotImplementedError


class Combat(BaseNormalizer):
class Combat(BaseNormalizer, StorageMixin):
"""Performs batch correction using Combat
Source:
https://academic.oup.com/biostatistics/article/8/1/118/252073?login=false
Expand All @@ -90,12 +93,12 @@ class Config(BaseNormalizer.Config):
if TYPE_CHECKING:
create: Callable[..., 'Combat']

batch_key: str
batch_key: ReadKey

cfg: Config

def _normalize(self, adata: AnnData) -> None:
batch: pd.Series = BaseProcessor.read(adata, self.cfg.batch_key)
batch: pd.Series = self.read(adata, self.cfg.batch_key)
if not isinstance(batch, pd.Series):
raise ValueError("Batch should be a pandas series")

Expand All @@ -110,6 +113,33 @@ def _normalize(self, adata: AnnData) -> None:
adata.X = corrected_data.to_numpy()


class Harmony(BaseNormalizer, StorageMixin):
"""Performs batch correction based on Harmony.
https://www.nature.com/articles/s41592-019-0619-0

Uses scanpy's port.
"""
class Config(BaseNormalizer.Config):

if TYPE_CHECKING:
create: Callable[..., 'Harmony']

batch: str
x_key: ReadKey = f"obsm.{OBSM.X_PCA}"
write_key: WriteKey = f"obsm.{OBSM.X_HARMONY}"
kwargs: Dict[str, Any] = {}

cfg: Config

@StorageMixin.lazy_writer
def _normalize(self, adata: AnnData) -> None:
X = self.read(adata, self.cfg.x_key)
hm_out = harmonypy.run_harmony(
X, adata.obs, self.cfg.batch, **self.cfg.kwargs
)
self.store_item(self.cfg.write_key, hm_out.Z_corr.T)


class NormalizeTotal(BaseNormalizer):
"""Normalizes each cell so that total counts are equal."""

Expand Down
91 changes: 81 additions & 10 deletions src/grinch/pipeline.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,102 @@
import gc
import logging
import traceback
from os.path import expanduser
from pathlib import Path
from typing import TYPE_CHECKING, Callable, List, Optional
from typing import TYPE_CHECKING, Any, Callable, Dict, List

import anndata
import scanpy as sc
from anndata import AnnData
from pydantic import Field, FilePath, field_validator, validate_call
from tqdm.auto import tqdm

from .base import StorageMixin
from .conf import BaseConfigurable
from .processors import (
BasePredictor,
BaseTransformer,
DataSplitter,
GroupProcess,
Splitter,
WriteKey,
)

logger = logging.getLogger(__name__)


class GRPipeline(BaseConfigurable):
class ReadMixin:
"""Mixin class for reading data files."""

@staticmethod
def read(filepath: FilePath) -> AnnData:
"""Reads AnnData from filepath"""
if filepath.suffix == '.h5':
return sc.read_10x_h5(filepath)
return anndata.read(filepath)


class MultiRead(BaseConfigurable, ReadMixin):
"""Reads multiple adatas and concatenates them."""

class Config(BaseConfigurable.Config):
"""MultiRead.Config

Parameters
----------
data_readpath: Dict
Maps the ID of a dataset to the path of the AnnData.
id_key: str
The ID will be stored as a key under `id_key` if not None.
[obs|var]_names_make_unique: bool
If True, will make the corresponding axis labels unique.
kwargs: Dict
Arguments to pass to `concat`.
"""

if TYPE_CHECKING:
create: Callable[..., 'MultiRead']

paths: Dict[str, FilePath] = {}
id_key: WriteKey | None = 'obs.batch_ID'
obs_names_make_unique: bool = True
var_names_make_unique: bool = True
kwargs: Dict[str, Any] = {}

@field_validator('paths', mode='before')
def expand_paths(cls, val):
return {k: expanduser(v) for k, v in val.items()}

cfg: Config

def __call__(self) -> AnnData:
adatas = []
for idx, readpath in self.cfg.paths.items():
logger.info(f"Reading AnnData from '{readpath}'...")
adata = self.read(readpath)
if self.cfg.obs_names_make_unique:
adata.obs_names_make_unique()
if self.cfg.var_names_make_unique:
adata.var_names_make_unique()
if self.cfg.id_key is not None:
StorageMixin.write(adata, self.cfg.id_key, idx)
adatas.append(adata)
adata = anndata.concat(adatas, **self.cfg.kwargs)
del adatas
gc.collect()
adata.obs_names_make_unique()
return adata


class GRPipeline(BaseConfigurable, ReadMixin):

class Config(BaseConfigurable.Config):

if TYPE_CHECKING:
create: Callable[..., 'GRPipeline']

data_readpath: FilePath | None = None # FilePath ensures file exists
# FilePath ensures file exists
data_readpath: FilePath | MultiRead.Config | None = None
data_writepath: Path | None = None
processors: List[BaseConfigurable.Config]
verbose: bool = Field(True, exclude=True)
Expand All @@ -41,7 +109,9 @@ class Config(BaseConfigurable.Config):

@field_validator('data_readpath', 'data_writepath', mode='before')
def expand_paths(cls, val):
return expanduser(val) if val is not None else None
if not isinstance(val, MultiRead.Config):
return expanduser(val) if val is not None else None
return val

cfg: Config

Expand All @@ -52,13 +122,10 @@ def __init__(self, cfg: Config, /) -> None:
for c in self.cfg.processors:
if self.cfg.seed is not None:
c.seed = self.cfg.seed
path = self.cfg.data_writepath or self.cfg.data_readpath
if path is not None:
c.logs_path = c.logs_path / path.stem
self.processors.append(c.create())

@validate_call(config=dict(arbitrary_types_allowed=True))
def __call__(self, adata: Optional[AnnData] = None, **kwargs) -> DataSplitter:
def __call__(self, adata: AnnData | None = None, **kwargs) -> DataSplitter:
"""Applies processor to the different data splits in DataSplitter.
It differentiates between predictors (calls processor.predict),
transformers (calls processor.transform) and it defaults to
Expand All @@ -67,8 +134,12 @@ def __call__(self, adata: Optional[AnnData] = None, **kwargs) -> DataSplitter:
if adata is None:
if self.cfg.data_readpath is None:
raise ValueError("A path to adata or an adata object is required.")
logger.info(f"Reading AnnData from '{self.cfg.data_readpath}'...")
adata = anndata.read_h5ad(self.cfg.data_readpath)
if isinstance(self.cfg.data_readpath, MultiRead.Config):
multi_read = self.cfg.data_readpath.create()
adata = multi_read()
else:
logger.info(f"Reading AnnData from '{self.cfg.data_readpath}'...")
adata = self.read(self.cfg.data_readpath)
logger.info(adata)
ds = DataSplitter(adata) if not isinstance(adata, DataSplitter) else adata

Expand Down
4 changes: 3 additions & 1 deletion src/grinch/processors/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .base_processor import BaseProcessor
from .base_processor import BaseProcessor, ReadKey, WriteKey
from .de import KSTest, TTest, UnimodalityTest
from .feature_selection import PhenotypeCover
from .graphs import BaseGraphConstructor, FuzzySimplicialSetGraph, KNNGraph
Expand Down Expand Up @@ -26,6 +26,8 @@

__all__ = [
'BaseProcessor',
'ReadKey',
'WriteKey',
'TTest',
'KSTest',
'UnimodalityTest',
Expand Down
Loading