Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

adding LRGW #611

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ dependencies = [
"docrep>=0.3.2",
"ott-jax>=0.4.3",
"cloudpickle>=2.2.0",
"rich>=13.5",
]

[project.optional-dependencies]
Expand Down
18 changes: 14 additions & 4 deletions src/moscot/backends/ott/output.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import jax.numpy as jnp
import numpy as np
from ott.solvers.linear import sinkhorn, sinkhorn_lr
from ott.solvers.quadratic import gromov_wasserstein
from ott.solvers.quadratic import gromov_wasserstein, gromov_wasserstein_lr

import matplotlib as mpl
import matplotlib.pyplot as plt
Expand All @@ -29,7 +29,13 @@ class OTTOutput(BaseSolverOutput):
_NOT_COMPUTED = -1.0 # sentinel value used in `ott`

def __init__(
self, output: Union[sinkhorn.SinkhornOutput, sinkhorn_lr.LRSinkhornOutput, gromov_wasserstein.GWOutput]
self,
output: Union[
sinkhorn.SinkhornOutput,
sinkhorn_lr.LRSinkhornOutput,
gromov_wasserstein.GWOutput,
gromov_wasserstein_lr.LRGWOutput,
],
):
super().__init__()
self._output = output
Expand Down Expand Up @@ -218,8 +224,12 @@ def potentials(self) -> Optional[Tuple[ArrayLike, ArrayLike]]: # noqa: D102

@property
def rank(self) -> int: # noqa: D102
lin_output = self._output if self.is_linear else self._output.linear_state
return len(lin_output.g) if isinstance(lin_output, sinkhorn_lr.LRSinkhornOutput) else -1
output = self._output.linear_state if isinstance(self._output, gromov_wasserstein.GWOutput) else self._output
return (
len(output.g)
if isinstance(output, (sinkhorn_lr.LRSinkhornOutput, gromov_wasserstein_lr.LRGWOutput))
else -1
)

def _ones(self, n: int) -> ArrayLike: # noqa: D102
return jnp.ones((n,))
35 changes: 22 additions & 13 deletions src/moscot/backends/ott/solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from ott.problems.linear import linear_problem
from ott.problems.quadratic import quadratic_problem
from ott.solvers.linear import sinkhorn, sinkhorn_lr
from ott.solvers.quadratic import gromov_wasserstein
from ott.solvers.quadratic import gromov_wasserstein, gromov_wasserstein_lr

from moscot._types import ProblemKind_t, QuadInitializer_t, SinkhornInitializer_t
from moscot.backends.ott._utils import alpha_to_fused_penalty, check_shapes, ensure_2d
Expand All @@ -19,7 +19,12 @@

__all__ = ["SinkhornSolver", "GWSolver"]

OTTSolver_t = Union[sinkhorn.Sinkhorn, sinkhorn_lr.LRSinkhorn, gromov_wasserstein.GromovWasserstein]
OTTSolver_t = Union[
sinkhorn.Sinkhorn,
sinkhorn_lr.LRSinkhorn,
gromov_wasserstein.GromovWasserstein,
gromov_wasserstein_lr.LRGromovWasserstein,
]
OTTProblem_t = Union[linear_problem.LinearProblem, quadratic_problem.QuadraticProblem]
Scale_t = Union[float, Literal["mean", "median", "max_cost", "max_norm", "max_bound"]]

Expand Down Expand Up @@ -243,21 +248,25 @@ def __init__(
):
super().__init__(jit=jit)
if rank > -1:
linear_solver_kwargs = dict(linear_solver_kwargs)
linear_solver_kwargs.setdefault("gamma", 10)
linear_solver_kwargs.setdefault("gamma_rescale", True)
linear_ot_solver = sinkhorn_lr.LRSinkhorn(rank=rank, **linear_solver_kwargs)
kwargs.setdefault("gamma", 10)
kwargs.setdefault("gamma_rescale", True)
initializer = "rank2" if initializer is None else initializer
self._solver = gromov_wasserstein_lr.LRGromovWasserstein(
rank=rank,
initializer=initializer,
kwargs_init=initializer_kwargs,
**kwargs,
)
else:
linear_ot_solver = sinkhorn.Sinkhorn(**linear_solver_kwargs)
initializer = None
self._solver = gromov_wasserstein.GromovWasserstein(
rank=rank,
linear_ot_solver=linear_ot_solver,
quad_initializer=initializer,
kwargs_init=initializer_kwargs,
**kwargs,
)
self._solver = gromov_wasserstein.GromovWasserstein(
rank=rank,
linear_ot_solver=linear_ot_solver,
quad_initializer=initializer,
kwargs_init=initializer_kwargs,
**kwargs,
)

def _prepare(
self,
Expand Down
6 changes: 4 additions & 2 deletions src/moscot/base/problems/compound_problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,7 @@ def _(
):
problem = self.problems[src, tgt]
fun = problem.push if forward else problem.pull
res[src] = fun(data=data, scale_by_marginals=scale_by_marginals)
res[src] = fun(data=data, scale_by_marginals=scale_by_marginals, **kwargs)
return res if return_all else res[src]

@_apply.register(ExplicitPolicy)
Expand Down Expand Up @@ -382,7 +382,9 @@ def _(
for _src, _tgt in [(src, tgt)] + rest:
problem = self.problems[_src, _tgt]
fun = problem.push if forward else problem.pull
res[_tgt if forward else _src] = current_mass = fun(current_mass, scale_by_marginals=scale_by_marginals)
res[_tgt if forward else _src] = current_mass = fun(
current_mass, scale_by_marginals=scale_by_marginals, **kwargs
)

return res if return_all else current_mass

Expand Down
13 changes: 10 additions & 3 deletions tests/backends/ott/test_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from ott.solvers.linear.sinkhorn_lr import LRSinkhorn
from ott.solvers.quadratic.gromov_wasserstein import GromovWasserstein
from ott.solvers.quadratic.gromov_wasserstein import solve as gromov_wasserstein
from ott.solvers.quadratic.gromov_wasserstein_lr import LRGromovWasserstein

from moscot._types import ArrayLike, Device_t
from moscot.backends.ott import GWSolver, SinkhornSolver
Expand Down Expand Up @@ -131,9 +132,15 @@ def test_epsilon(self, x_cost: jnp.ndarray, y_cost: jnp.ndarray, eps: Optional[f
@pytest.mark.parametrize("rank", [-1, 7])
def test_solver_rank(self, x: Geom_t, y: Geom_t, rank: int) -> None:
thresh, eps = 1e-2, 1e-2
gt = GromovWasserstein(epsilon=eps, rank=rank, threshold=thresh)(
QuadraticProblem(PointCloud(x, epsilon=eps), PointCloud(y, epsilon=eps))
)
if rank > -1:
gt = LRGromovWasserstein(epsilon=eps, rank=rank, threshold=thresh)(
QuadraticProblem(PointCloud(x, epsilon=eps), PointCloud(y, epsilon=eps))
)

else:
gt = GromovWasserstein(epsilon=eps, rank=rank, threshold=thresh)(
QuadraticProblem(PointCloud(x, epsilon=eps), PointCloud(y, epsilon=eps))
)

solver = GWSolver(rank=rank, epsilon=eps, threshold=thresh)
pred = solver(x=x, y=y, tags={"x": "point_cloud", "y": "point_cloud"})
Expand Down
1 change: 0 additions & 1 deletion tests/problems/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,6 @@ def adata_time_with_tmap(adata_time: AnnData) -> AnnData:
"gw_unbalanced_correction": False,
"ranks": 3,
"tolerances": 3e-2,
"warm_start": True,
"linear_solver_kwargs": linear_solver_kwargs2,
}

Expand Down
Loading