Skip to content

Commit

Permalink
Feature/generic solvers (#208)
Browse files Browse the repository at this point in the history
* added generic problems
  • Loading branch information
MUCDK authored May 16, 2022
1 parent 1444f5c commit dbf0ddf
Show file tree
Hide file tree
Showing 10 changed files with 250 additions and 14 deletions.
11 changes: 10 additions & 1 deletion docs/source/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,16 @@ Biological Problems
Generic Problems
~~~~~~~~~~~~~~~~

TODO
.. module:: moscot.problems.generic
.. currentmodule:: moscot.problems.generic

.. autosummary::
:toctree: api

moscot.problems.generic.SinkhornProblem
moscot.problems.generic.GWProblem
moscot.problems.generic.FGWProblem


Solvers
~~~~~~~
Expand Down
16 changes: 14 additions & 2 deletions moscot/_docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,16 @@
_policy = """\
policy
Defines the rule according to which pairs of distributions are selected to compute the transport map between."""
_key = """\
key
Key in :attr:`anndata.AnnData.obs` allocating the cell to a certain cell distribution."""
_joint_attr = """\
joint_attr
Parameter defining how to allocate the data needed to compute the transport maps. If None, the data is read
from :attr:`anndata.AnnData.X` and for each time point the corresponding PCA space is computed. If
`joint_attr` is a string the data is assumed to be found in :attr:`anndata.AnnData.obsm`.
If `joint_attr` is a dictionary the dictionary is supposed to contain the attribute of
:attr:`anndata.AnnData` as a key and the corresponding attribute as a value."""


def inject_docs(**kwargs: Any):
Expand Down Expand Up @@ -152,11 +162,13 @@ def decorator2(obj):
subset=_subset,
marginal_kwargs=_marginal_kwargs,
shape=_shape,
transport_matrix = _transport_matrix,
converged = _converged,
transport_matrix=_transport_matrix,
converged=_converged,
a=_a,
b=_b,
time_key=_time_key,
spatial_key=_spatial_key,
policy=_policy,
key=_key,
joint_attr=_joint_attr,
)
1 change: 0 additions & 1 deletion moscot/analysis_mixins/_base_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
# TODO(michalk8): need to think about this a bit more
# TODO(MUCDK): remove ABC?
class AnalysisMixin(ABC):

def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)

Expand Down
2 changes: 1 addition & 1 deletion moscot/analysis_mixins/_spatial_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def spatial_key(self) -> Optional[str]:
def spatial_key(self, value: Optional[str] = None) -> None:
if value not in self.adata.obs.columns:
raise KeyError(f"TODO: {value} not found in `adata.obs.columns`")
#TODO(@MUCDK) check data type -> which ones do we allow
# TODO(@MUCDK) check data type -> which ones do we allow
self._spatial_key = value


Expand Down
4 changes: 3 additions & 1 deletion moscot/problems/_anndata.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,4 +106,6 @@ def ensure_2D(arr: npt.ArrayLike, *, allow_reshape: bool = True) -> np.ndarray:
if self.key not in container:
raise KeyError(f"TODO: unable to find `adata.{self.attr}['{self.key}']`.")
container = container[self.key]
return TaggedArray(container.A if scipy.sparse.issparse(container) else container, tag=self.tag, loss=backend_losses[self.loss])
return TaggedArray(
container.A if scipy.sparse.issparse(container) else container, tag=self.tag, loss=backend_losses[self.loss]
)
3 changes: 1 addition & 2 deletions moscot/problems/_compound_problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,8 +150,7 @@ def prepare(
Parameters
----------
key
Key in :attr:`anndata.AnnData.obs` allocating the cell to a certain cell distribution.
%(key)s
policy
Defines which transport maps to compute given different cell distributions.
subset
Expand Down
1 change: 1 addition & 0 deletions moscot/problems/generic/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from moscot.problems.generic._generic import GWProblem, FGWProblem, SinkhornProblem
217 changes: 217 additions & 0 deletions moscot/problems/generic/_generic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,217 @@
from types import MappingProxyType
from typing import Any, Literal, Mapping, Optional, Tuple, Type, Union

from anndata import AnnData

from moscot._docs import d
from moscot.problems import OTProblem, SingleCompoundProblem
from moscot.analysis_mixins import AnalysisMixin
from moscot.problems._compound_problem import B


@d.dedent
class SinkhornProblem(SingleCompoundProblem, AnalysisMixin):
"""
Class for solving linear OT problems.
Parameters
----------
%(adata)s
Examples
--------
See notebook TODO(@MUCDK) LINK NOTEBOOK for how to use it
"""

def __init__(self, adata: AnnData, **kwargs: Any):
super().__init__(adata, **kwargs)

@d.dedent
def prepare(
self,
key: str,
joint_attr: Optional[Union[str, Mapping[str, Any]]] = None,
policy: Literal["sequential", "pairwise", "explicit"] = "sequential",
**kwargs: Any,
) -> "SinkhornProblem":
"""
Prepare the :class:`moscot.problems.generic.SinkhornProblem`.
This method executes multiple steps to prepare the optimal transport problems.
Parameters
----------
%(key)s
%(joint_attr)s
%(policy)s
%(marginal_kwargs)s
%(a)s
%(b)s
%(subset)s
%(reference)s
%(axis)s
%(callback)s
%(callback_kwargs)s
kwargs
Keyword arguments for :meth:`moscot.problems.CompoundBaseProblem._create_problems`.
Returns
-------
:class:`moscot.problems.generic.SinkhornProblem`
Notes
-----
If `a` and `b` are provided `marginal_kwargs` are ignored.
"""
if joint_attr is None:
kwargs["callback"] = "local-pca"
kwargs["callback_kwargs"] = {**kwargs.get("callback_kwargs", {}), **{"return_linear": True}}
elif isinstance(joint_attr, str):
kwargs["xy"] = {
"x_attr": "obsm",
"x_key": joint_attr,
"y_attr": "obsm",
"y_key": joint_attr,
}
elif isinstance(joint_attr, Mapping):
kwargs["xy"] = joint_attr
else:
raise TypeError("TODO")

return super().prepare(
key=key,
policy=policy,
**kwargs,
)

@property
def _base_problem_type(self) -> Type[B]:
return OTProblem

@property
def _valid_policies(self) -> Tuple[str, ...]:
return "sequential", "pairwise", "explicit"


@d.get_sections(base="GWProblem", sections=["Parameters"])
@d.dedent
class GWProblem(SingleCompoundProblem, AnalysisMixin):
"""
Class for solving Gromov-Wasserstein problems.
Parameters
----------
%(adata)s
Examples
--------
See notebook TODO(@MUCDK) LINK NOTEBOOK for how to use it
"""

def __init__(self, adata: AnnData, **kwargs: Any):
super().__init__(adata, **kwargs)

@d.dedent
def prepare(
self,
key: str,
GW_attr: Mapping[str, Any] = MappingProxyType({}),
policy: Literal["sequential", "pairwise", "explicit"] = "sequential",
**kwargs: Any,
) -> "GWProblem":
"""
Prepare the :class:`moscot.problems.generic.GWProblem`.
This method executes multiple steps to prepare the problem for the Optimal Transport solver to be ready
to solve it
Parameters
----------
%(key)s
GW_attr
Specifies the way the GW information is processed. TODO: Specify.
%(joint_attr)
%(policy)s
%(marginal_kwargs)s
%(a)s
%(b)s
%(subset)s
%(reference)s
%(callback)s
%(callback_kwargs)s
kwargs
Keyword arguments for :meth:`moscot.problems.CompoundBaseProblem._create_problems`
Returns
-------
:class:`moscot.problems.generic.GWProblem`
Notes
-----
If `a` and `b` are provided `marginal_kwargs` are ignored.
"""
# TODO(michalk8): use and
if not len(GW_attr):
if "cost_matrices" not in self.adata.obsp:
raise ValueError(
"TODO: default location for quadratic loss is `adata.obsp[`cost_matrices`]` \
but adata has no key `cost_matrices` in `obsp`."
)
# TODO(michalk8): refactor me
GW_attr = dict(GW_attr)
GW_attr.setdefault("attr", "obsp")
GW_attr.setdefault("key", "cost_matrices")
GW_attr.setdefault("loss", None)
GW_attr.setdefault("tag", "cost")
GW_attr.setdefault("loss_kwargs", {})
x = y = GW_attr

return super().prepare(
key,
x=x,
y=y,
policy=policy,
**kwargs,
)

@property
def _base_problem_type(self) -> Type[B]:
return OTProblem

@property
def _valid_policies(self) -> Tuple[str, ...]:
return "sequential", "pairwise", "explicit"


@d.dedent
class FGWProblem(GWProblem):
"""
Class for solving Fused Gromov-Wasserstein problems.
Parameters
----------
%(adata)s
Examples
--------
See notebook TODO(@MUCDK) LINK NOTEBOOK for how to use it
"""

@d.dedent
def prepare(
self,
*args,
joint_attr: Mapping[str, Any] = MappingProxyType({}),
**kwargs: Any,
) -> "FGWProblem":
"""
Prepare the :class:`moscot.problems.generic.GWProblem`.
Parameters
----------
%(GWProblem.parameters)s
%(joint_attr)s
"""
kwargs["joint_attr"] = joint_attr
return super().prepare(*args, joint_attr=joint_attr, **kwargs)
1 change: 1 addition & 0 deletions moscot/problems/mixins/_temporal_mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
class MultiMarginalMixin(ABC):
"""Mixin class for biological problems based on :class:`moscot.problems.MultiMarginalProblem`."""


@d.dedent
class BirthDeathMixin(MultiMarginalMixin):
"""Mixin class for biological problems based on :class:`moscot.problems.mixins.BirthDeathBaseProblem`."""
Expand Down
8 changes: 2 additions & 6 deletions moscot/problems/time/_lineage.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,12 +54,7 @@ def prepare(
Parameters
----------
%(time_key)s
joint_attr
Parameter defining how to allocate the data needed to compute the transport maps. If None, the data is read
from :attr:`anndata.AnnData.X` and for each time point the corresponding PCA space is computed. If
`joint_attr` is a string the data is assumed to be found in :attr:`anndata.AnnData.obsm`.
If `joint_attr` is a dictionary the dictionary is supposed to contain the attribute of
:attr:`anndata.AnnData` as a key and the corresponding attribute as a value.
%(joint_attr)s
policy
Defines which transport maps to compute given different cell distributions.
%(marginal_kwargs)s
Expand Down Expand Up @@ -236,6 +231,7 @@ class LineageProblem(TemporalProblem):
--------
See notebook TODO(@MUCDK) LINK NOTEBOOK for how to use it
"""

@d.dedent
def prepare(
self,
Expand Down

0 comments on commit dbf0ddf

Please sign in to comment.