Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor/handle solve args #748

Merged
merged 70 commits into from
Dec 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
70 commits
Select commit Hold shift + click to select a range
310b273
move the alpha check to GWSolver level instead of parent classes
selmanozleyen Sep 22, 2024
6c8d42d
add tests to check if alpha fails or not
selmanozleyen Sep 22, 2024
0b06cc4
remove kwargs on a more public problem class
selmanozleyen Sep 22, 2024
feca7ec
pre-commit
selmanozleyen Sep 22, 2024
c7cceb1
add test that asserts type error when unrecognized args are given
selmanozleyen Sep 22, 2024
9edaafb
set default according to the data provided
selmanozleyen Sep 22, 2024
376eccf
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 22, 2024
01c89a2
Revert "remove kwargs on a more public problem class"
selmanozleyen Sep 22, 2024
49c8bbc
improve the tests to also use other rank solvers
selmanozleyen Sep 23, 2024
f303a14
remove skipped tests and add link to other skipped test
selmanozleyen Sep 23, 2024
9457676
Merge branch 'main' into refactor/handle-solve-args
selmanozleyen Oct 21, 2024
1582baa
adapt to was solvers new api
selmanozleyen Oct 23, 2024
a21c45f
update the tests
selmanozleyen Oct 23, 2024
a1dfd64
update tests for solvers new api
selmanozleyen Oct 23, 2024
e4dea94
adapt tests for solvers new api
selmanozleyen Oct 23, 2024
608ca73
check if it's callable for solvers initializer instance instead of st…
selmanozleyen Oct 23, 2024
ae2c681
again simply linear_ot_solver -> linear_solver
selmanozleyen Oct 23, 2024
895085d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 23, 2024
7e472dc
fix test for also fgw tests
selmanozleyen Oct 23, 2024
19d7f35
lint
selmanozleyen Oct 23, 2024
9e069bf
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 23, 2024
23c7caa
fix linting errors
selmanozleyen Oct 23, 2024
aeb1621
fix test_backend tests. There were some ignored args
selmanozleyen Oct 23, 2024
6543d57
format
selmanozleyen Oct 23, 2024
0b517a6
update test_pass_arguments
selmanozleyen Oct 23, 2024
9cb9c36
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 23, 2024
aa593c6
Merge branch 'main' into refactor/handle-solve-args
selmanozleyen Nov 27, 2024
b47be25
let util function cast everything to float64
selmanozleyen Nov 27, 2024
1a595bf
set the version to new release
selmanozleyen Dec 9, 2024
4cc42ca
fix test initializers
selmanozleyen Dec 10, 2024
fb90d56
fix types
selmanozleyen Dec 10, 2024
7de8496
fix tests
selmanozleyen Dec 10, 2024
124c0df
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 10, 2024
14e96e6
fix some more tests
selmanozleyen Dec 10, 2024
b014233
adapt to new epsilon scheduler class
selmanozleyen Dec 10, 2024
1b56904
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 10, 2024
f0d4809
remove unused lse_mode
selmanozleyen Dec 10, 2024
f7c5d02
remove unused arg again
selmanozleyen Dec 10, 2024
fec5b0f
ignore return of _get_data
selmanozleyen Dec 10, 2024
b1ef471
remove lse_mode
selmanozleyen Dec 10, 2024
9ccdfea
remove lse mode
selmanozleyen Dec 10, 2024
6998fe4
fix kwargs init
selmanozleyen Dec 10, 2024
39e6c03
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 10, 2024
3fd0fac
fix initializer_kwargs test
selmanozleyen Dec 10, 2024
e22318e
update solution files
selmanozleyen Dec 10, 2024
d5b31b7
fix test
selmanozleyen Dec 10, 2024
11aa435
fix docs
selmanozleyen Dec 10, 2024
64ab534
make alpha mandatory
selmanozleyen Dec 10, 2024
f6f7cfb
fix test
selmanozleyen Dec 10, 2024
97846e2
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 10, 2024
f15c4e3
fix test
selmanozleyen Dec 10, 2024
e13665e
fix alpha tests
selmanozleyen Dec 10, 2024
ddfb89c
remove the skip
selmanozleyen Dec 10, 2024
ff13556
upgrade version
selmanozleyen Dec 10, 2024
0f28ec0
remove skips since we removed python 3.9 support
selmanozleyen Dec 10, 2024
ef50b55
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 10, 2024
45509e0
refactor the gwproblem and fgwproblem inheritance
selmanozleyen Dec 10, 2024
69703af
Revert "[pre-commit.ci] auto fixes from pre-commit.com hooks"
selmanozleyen Dec 10, 2024
3766a04
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 10, 2024
f06d4c6
Revert "refactor the gwproblem and fgwproblem inheritance"
selmanozleyen Dec 10, 2024
3ec3328
update tests and implement str as initializer input
selmanozleyen Dec 12, 2024
2474cf7
fix tests
selmanozleyen Dec 12, 2024
101b248
fix backend test
selmanozleyen Dec 12, 2024
09a3b49
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 12, 2024
90b93bb
fix other tests
selmanozleyen Dec 12, 2024
c334e47
fix other tests
selmanozleyen Dec 12, 2024
c9295ea
update tau arguments
selmanozleyen Dec 12, 2024
7f6d1cf
fix errors
selmanozleyen Dec 12, 2024
aed1866
rename the initializer factor method class
selmanozleyen Dec 12, 2024
792d913
update tests
selmanozleyen Dec 12, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,9 @@
("py:class", "None. Update D from dict/iterable E and F."),
("py:class", "an object providing a view on D's values"),
("py:class", "a shallow copy of D"),
# ignore these classes until ott-jax adds them to their docs
("py:class", "ott.initializers.quadratic.initializers.BaseQuadraticInitializer"),
("py:class", "ott.initializers.linear.initializers.SinkhornInitializer"),
]
# TODO(michalk8): remove once typing has been cleaned-up
nitpick_ignore_regex = [
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ dependencies = [
"scanpy>=1.9.3",
"wrapt>=1.13.2",
"docrep>=0.3.2",
"ott-jax[neural]>=0.4.6,<=0.4.8",
"ott-jax[neural]>=0.5.0",
"cloudpickle>=2.2.0",
"rich>=13.5",
"docstring_inheritance>=2.0.0",
Expand Down
14 changes: 9 additions & 5 deletions src/moscot/_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
from typing import Any, Literal, Mapping, Optional, Sequence, Union

import numpy as np
from ott.initializers.linear.initializers import SinkhornInitializer
from ott.initializers.linear.initializers_lr import LRInitializer
from ott.initializers.quadratic.initializers import BaseQuadraticInitializer

# TODO(michalk8): polish

Expand All @@ -17,13 +20,14 @@
Numeric_t = Union[int, float] # type of `time_key` arguments
Filter_t = Optional[Union[str, Mapping[str, Sequence[Any]]]] # type how to filter adata
Str_Dict_t = Optional[Union[str, Mapping[str, Sequence[Any]]]] # type for `cell_transition`
SinkFullRankInit = Literal["default", "gaussian", "sorting"]
LRInitializer_t = Literal["random", "rank2", "k-means", "generalized-k-means"]
SinkhornInitializerTag_t = Literal["default", "gaussian", "sorting"]
LRInitializerTag_t = Literal["random", "rank2", "k-means", "generalized-k-means"]

SinkhornInitializer_t = Optional[Union[SinkFullRankInit, LRInitializer_t]]
QuadInitializer_t = Optional[LRInitializer_t]
LRInitializer_t = Optional[Union[LRInitializer, LRInitializerTag_t]]
SinkhornInitializer_t = Optional[Union[SinkhornInitializer, SinkhornInitializerTag_t]]
QuadInitializer_t = Optional[Union[BaseQuadraticInitializer]]

Initializer_t = Union[SinkhornInitializer_t, LRInitializer_t]
Initializer_t = Union[SinkhornInitializer_t, QuadInitializer_t, LRInitializer_t]
ProblemStage_t = Literal["prepared", "solved"]
Device_t = Union[Literal["cpu", "gpu", "tpu"], str]

Expand Down
95 changes: 92 additions & 3 deletions src/moscot/backends/ott/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
import numpy as np
import scipy.sparse as sp
from ott.geometry import epsilon_scheduler, geodesic, geometry, pointcloud
from ott.initializers.linear import initializers as init_lib
from ott.initializers.linear import initializers_lr as lr_init_lib
from ott.neural import datasets
from ott.solvers import utils as solver_utils
from ott.tools.sinkhorn_divergence import sinkhorn_divergence as sinkhorn_div
Expand All @@ -21,6 +23,90 @@
__all__ = ["sinkhorn_divergence"]


class InitializerResolver:
"""Class for creating various OT solver initializers.

This class provides static methods to create and manage different types of
initializers used in optimal transport solvers, including low-rank, k-means,
and standard Sinkhorn initializers.
"""

@staticmethod
def lr_from_str(
initializer: str,
rank: int,
**kwargs: Any,
) -> lr_init_lib.LRInitializer:
"""Create a low-rank initializer from a string specification.

Parameters
----------
initializer : str
Either existing initializer instance or string specifier.
rank : int
Rank for the initialization.
**kwargs : Any
Additional keyword arguments for initializer creation.

Returns
-------
LRInitializer
Configured low-rank initializer.

Raises
------
NotImplementedError
If requested initializer type is not implemented.
"""
if isinstance(initializer, lr_init_lib.LRInitializer):
return initializer
if initializer == "k-means":
return lr_init_lib.KMeansInitializer(rank=rank, **kwargs)
if initializer == "generalized-k-means":
return lr_init_lib.GeneralizedKMeansInitializer(rank=rank, **kwargs)
if initializer == "random":
return lr_init_lib.RandomInitializer(rank=rank, **kwargs)
if initializer == "rank2":
return lr_init_lib.Rank2Initializer(rank=rank, **kwargs)
raise NotImplementedError(f"Initializer `{initializer}` is not implemented.")

@staticmethod
def from_str(
initializer: str,
**kwargs: Any,
) -> init_lib.SinkhornInitializer:
"""Create a Sinkhorn initializer from a string specification.

Parameters
----------
initializer : str
String specifier for initializer type.
**kwargs : Any
Additional keyword arguments for initializer creation.

Returns
-------
SinkhornInitializer
Configured Sinkhorn initializer.

Raises
------
NotImplementedError
If requested initializer type is not implemented.
"""
if isinstance(initializer, init_lib.SinkhornInitializer):
return initializer
if initializer == "default":
return init_lib.DefaultInitializer(**kwargs)
if initializer == "gaussian":
return init_lib.GaussianInitializer(**kwargs)
if initializer == "sorting":
return init_lib.SortingInitializer(**kwargs)
if initializer == "subsample":
return init_lib.SubsampleInitializer(**kwargs)
raise NotImplementedError(f"Initializer `{initializer}` is not yet implemented.")


def sinkhorn_divergence(
point_cloud_1: ArrayLike,
point_cloud_2: ArrayLike,
Expand All @@ -45,11 +131,14 @@ def sinkhorn_divergence(
batch_size=batch_size,
a=a,
b=b,
sinkhorn_kwargs={"tau_a": tau_a, "tau_b": tau_b},
scale_cost=scale_cost,
epsilon=epsilon,
solve_kwargs={
"tau_a": tau_a,
"tau_b": tau_b,
},
**kwargs,
)
)[1]
xy_conv, xx_conv, *yy_conv = output.converged

if not xy_conv:
Expand Down Expand Up @@ -132,7 +221,7 @@ def ensure_2d(arr: ArrayLike, *, reshape: bool = False) -> jax.Array:
return jnp.reshape(arr, (-1, 1))
if arr.ndim != 2:
raise ValueError(f"Expected array to have 2 dimensions, found `{arr.ndim}`.")
return arr
return arr.astype(jnp.float64)


def convert_scipy_sparse(arr: Union[sp.spmatrix, jesp.BCOO]) -> jesp.BCOO:
Expand Down
50 changes: 30 additions & 20 deletions src/moscot/backends/ott/solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,13 @@
from moscot._logging import logger
from moscot._types import (
ArrayLike,
LRInitializer_t,
ProblemKind_t,
QuadInitializer_t,
SinkhornInitializer_t,
)
from moscot.backends.ott._utils import (
InitializerResolver,
Loader,
MultiLoader,
_instantiate_geodesic_cost,
Expand Down Expand Up @@ -88,16 +90,20 @@ class OTTJaxSolver(OTSolver[OTTOutput], abc.ABC):
----------
jit
Whether to :func:`~jax.jit` the :attr:`solver`.
initializer_kwargs
Keyword arguments for the initializer.
"""

def __init__(self, jit: bool = True):
def __init__(self, jit: bool = True, initializer_kwargs: Mapping[str, Any] = types.MappingProxyType({})):
super().__init__()
self._solver: Optional[OTTSolver_t] = None
self._problem: Optional[OTTProblem_t] = None
self._jit = jit
self._a: Optional[jnp.ndarray] = None
self._b: Optional[jnp.ndarray] = None

self.initializer_kwargs = initializer_kwargs

def _create_geometry(
self,
x: TaggedArray,
Expand Down Expand Up @@ -170,7 +176,7 @@ def _solve( # type: ignore[override]
**kwargs: Any,
) -> Union[OTTOutput, GraphOTTOutput]:
solver = jax.jit(self.solver) if self._jit else self.solver
out = solver(prob, **kwargs)
out = solver(prob, **self.initializer_kwargs, **kwargs)
if isinstance(prob, linear_problem.LinearProblem) and isinstance(prob.geom, geodesic.Geodesic):
return GraphOTTOutput(out, shape=(len(self._a), len(self._b))) # type: ignore[arg-type]
return OTTOutput(out)
Expand Down Expand Up @@ -275,20 +281,20 @@ def __init__(
initializer_kwargs: Mapping[str, Any] = types.MappingProxyType({}),
**kwargs: Any,
):
super().__init__(jit=jit)
super().__init__(jit=jit, initializer_kwargs=initializer_kwargs)
if rank > -1:
kwargs.setdefault("gamma", 500)
kwargs.setdefault("gamma_rescale", True)
eps = kwargs.get("epsilon")
if eps is not None and eps > 0.0:
logger.info(f"Found `epsilon`={eps}>0. We recommend setting `epsilon`=0 for the low-rank solver.")
initializer = "rank2" if initializer is None else initializer
self._solver = sinkhorn_lr.LRSinkhorn(
rank=rank, epsilon=epsilon, initializer=initializer, kwargs_init=initializer_kwargs, **kwargs
)
if isinstance(initializer, str):
initializer = InitializerResolver.lr_from_str(initializer, rank=rank)
self._solver = sinkhorn_lr.LRSinkhorn(rank=rank, epsilon=epsilon, initializer=initializer, **kwargs)
else:
initializer = "default" if initializer is None else initializer
self._solver = sinkhorn.Sinkhorn(initializer=initializer, kwargs_init=initializer_kwargs, **kwargs)
if isinstance(initializer, str):
initializer = InitializerResolver.from_str(initializer)
self._solver = sinkhorn.Sinkhorn(initializer=initializer, **kwargs)

def _prepare(
self,
Expand Down Expand Up @@ -389,40 +395,40 @@ def __init__(
self,
jit: bool = True,
rank: int = -1,
initializer: QuadInitializer_t = None,
initializer: QuadInitializer_t | LRInitializer_t = None,
initializer_kwargs: Mapping[str, Any] = types.MappingProxyType({}),
linear_solver_kwargs: Mapping[str, Any] = types.MappingProxyType({}),
**kwargs: Any,
):
super().__init__(jit=jit)
super().__init__(jit=jit, initializer_kwargs=initializer_kwargs)
if rank > -1:
kwargs.setdefault("gamma", 10)
kwargs.setdefault("gamma_rescale", True)
eps = kwargs.get("epsilon")
if eps is not None and eps > 0.0:
logger.info(f"Found `epsilon`={eps}>0. We recommend setting `epsilon`=0 for the low-rank solver.")
initializer = "rank2" if initializer is None else initializer
if isinstance(initializer, str):
initializer = InitializerResolver.lr_from_str(initializer, rank=rank)
self._solver = gromov_wasserstein_lr.LRGromovWasserstein(
rank=rank,
initializer=initializer,
kwargs_init=initializer_kwargs,
**kwargs,
)
else:
linear_ot_solver = sinkhorn.Sinkhorn(**linear_solver_kwargs)
initializer = None
linear_solver = sinkhorn.Sinkhorn(**linear_solver_kwargs)
if isinstance(initializer, str):
raise ValueError("Expected `initializer` to be `None` or `ott.initializers.quadratic.initializers`.")
self._solver = gromov_wasserstein.GromovWasserstein(
rank=rank,
linear_ot_solver=linear_ot_solver,
quad_initializer=initializer,
kwargs_init=initializer_kwargs,
linear_solver=linear_solver,
initializer=initializer,
**kwargs,
)

def _prepare(
self,
a: jnp.ndarray,
b: jnp.ndarray,
alpha: float,
xy: Optional[TaggedArray] = None,
x: Optional[TaggedArray] = None,
y: Optional[TaggedArray] = None,
Expand All @@ -435,7 +441,6 @@ def _prepare(
cost_matrix_rank: Optional[int] = None,
time_scales_heat_kernel: Optional[TimeScalesHeatKernel] = None,
# problem
alpha: float = 0.5,
**kwargs: Any,
) -> quadratic_problem.QuadraticProblem:
self._a = a
Expand All @@ -456,6 +461,11 @@ def _prepare(
geom_kwargs["cost_matrix_rank"] = cost_matrix_rank
geom_xx = self._create_geometry(x, t=time_scales_heat_kernel.x, is_linear_term=False, **geom_kwargs)
geom_yy = self._create_geometry(y, t=time_scales_heat_kernel.y, is_linear_term=False, **geom_kwargs)
if alpha <= 0.0:
selmanozleyen marked this conversation as resolved.
Show resolved Hide resolved
raise ValueError(f"Expected `alpha` to be in interval `(0, 1]`, found `{alpha}`.")
if (alpha == 1.0 and xy is not None) or (alpha != 1.0 and xy is None):
raise ValueError(f"Expected `xy` to be `None` if `alpha` is not 1.0, found xy={xy}, alpha={alpha}.")

if alpha == 1.0 or xy is None: # GW
# arbitrary fused penalty; must be positive
geom_xy, fused_penalty = None, 1.0
Expand Down
18 changes: 1 addition & 17 deletions src/moscot/base/problems/problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -518,24 +518,8 @@ def solve(
solver_class = backends.get_solver(
self.problem_kind, solver_name=solver_name, backend=backend, return_class=True
)
init_kwargs, call_kwargs = solver_class._partition_kwargs(**kwargs)
# if linear problem, then alpha is 0.0 by default
# if quadratic problem, then alpha is 1.0 by default
alpha = call_kwargs.get("alpha", 0.0 if self.problem_kind == "linear" else 1.0)
if alpha < 0.0 or alpha > 1.0:
raise ValueError("Expected `alpha` to be in the range `[0, 1]`, found `{alpha}`.")
if self.problem_kind == "linear" and (alpha != 0.0 or not (self.x is None or self.y is None)):
raise ValueError("Unable to solve a linear problem with `alpha != 0` or `x` and `y` supplied.")
if self.problem_kind == "quadratic":
if self.x is None or self.y is None:
raise ValueError("Unable to solve a quadratic problem without `x` and `y` supplied.")
if alpha != 1.0 and self.xy is None: # means FGW case
raise ValueError(
"`alpha` must be 1.0 for quadratic problems without `xy` supplied. See `FGWProblem` class."
)
if alpha == 1.0 and self.xy is not None:
raise ValueError("Unable to solve a quadratic problem with `alpha = 1` and `xy` supplied.")

init_kwargs, call_kwargs = solver_class._partition_kwargs(**kwargs)
self._solver = solver_class(**init_kwargs)

# note that the solver call consists of solver._prepare and solver._solve
Expand Down
9 changes: 3 additions & 6 deletions src/moscot/problems/time/_lineage.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,6 @@ def solve(
jit: bool = True,
threshold: float = 1e-3,
lse_mode: bool = True,
inner_iterations: int = 10,
min_iterations: Optional[int] = None,
max_iterations: Optional[int] = None,
device: Optional[Literal["cpu", "gpu", "tpu"]] = None,
Expand Down Expand Up @@ -233,9 +232,7 @@ def solve(
lse_mode
Whether to use `log-sum-exp (LSE)
<https://en.wikipedia.org/wiki/LogSumExp#log-sum-exp_trick_for_log-domain_calculations>`_
computations for numerical stability.
inner_iterations
Compute the convergence criterion every ``inner_iterations``.
computations for numerical stability. Valid only for the :term:`linear problem`.
min_iterations
Minimum number of :term:`Sinkhorn` iterations.
max_iterations
Expand All @@ -253,6 +250,8 @@ def solve(
- :attr:`solutions` - the :term:`OT` solutions for each subproblem.
- :attr:`stage` - set to ``'solved'``.
"""
if self.problem_kind == "linear":
kwargs["lse_mode"] = lse_mode
return super().solve( # type:ignore[return-value]
epsilon=epsilon,
tau_a=tau_a,
Expand All @@ -265,8 +264,6 @@ def solve(
initializer_kwargs=initializer_kwargs,
jit=jit,
threshold=threshold,
lse_mode=lse_mode,
inner_iterations=inner_iterations,
min_iterations=min_iterations,
max_iterations=max_iterations,
device=device,
Expand Down
Loading
Loading