Skip to content

Commit

Permalink
adding LRGW
Browse files Browse the repository at this point in the history
  • Loading branch information
Arina Danilina committed Oct 8, 2023
1 parent 62cdf59 commit c5476ef
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 13 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ dependencies = [
"docrep>=0.3.2",
"ott-jax>=0.4.3",
"cloudpickle>=2.2.0",
"rich>=13.5",
]

[project.optional-dependencies]
Expand Down
18 changes: 14 additions & 4 deletions src/moscot/backends/ott/output.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import jax.numpy as jnp
import numpy as np
from ott.solvers.linear import sinkhorn, sinkhorn_lr
from ott.solvers.quadratic import gromov_wasserstein
from ott.solvers.quadratic import gromov_wasserstein, gromov_wasserstein_lr

import matplotlib as mpl
import matplotlib.pyplot as plt
Expand All @@ -29,7 +29,13 @@ class OTTOutput(BaseSolverOutput):
_NOT_COMPUTED = -1.0 # sentinel value used in `ott`

def __init__(
self, output: Union[sinkhorn.SinkhornOutput, sinkhorn_lr.LRSinkhornOutput, gromov_wasserstein.GWOutput]
self,
output: Union[
sinkhorn.SinkhornOutput,
sinkhorn_lr.LRSinkhornOutput,
gromov_wasserstein.GWOutput,
gromov_wasserstein_lr.LRGWOutput,
],
):
super().__init__()
self._output = output
Expand Down Expand Up @@ -218,8 +224,12 @@ def potentials(self) -> Optional[Tuple[ArrayLike, ArrayLike]]: # noqa: D102

@property
def rank(self) -> int: # noqa: D102
lin_output = self._output if self.is_linear else self._output.linear_state
return len(lin_output.g) if isinstance(lin_output, sinkhorn_lr.LRSinkhornOutput) else -1
output = self._output.linear_state if isinstance(self._output, gromov_wasserstein.GWOutput) else self._output
return (
len(output.g)
if isinstance(output, (sinkhorn_lr.LRSinkhornOutput, gromov_wasserstein_lr.LRGWOutput))
else -1
)

def _ones(self, n: int) -> ArrayLike: # noqa: D102
return jnp.ones((n,))
25 changes: 16 additions & 9 deletions src/moscot/backends/ott/solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from ott.problems.linear import linear_problem
from ott.problems.quadratic import quadratic_problem
from ott.solvers.linear import sinkhorn, sinkhorn_lr
from ott.solvers.quadratic import gromov_wasserstein
from ott.solvers.quadratic import gromov_wasserstein, gromov_wasserstein_lr

from moscot._types import ProblemKind_t, QuadInitializer_t, SinkhornInitializer_t
from moscot.backends.ott._utils import alpha_to_fused_penalty, check_shapes, ensure_2d
Expand All @@ -19,7 +19,7 @@

__all__ = ["SinkhornSolver", "GWSolver"]

OTTSolver_t = Union[sinkhorn.Sinkhorn, sinkhorn_lr.LRSinkhorn, gromov_wasserstein.GromovWasserstein]
OTTSolver_t = Union[sinkhorn.Sinkhorn, sinkhorn_lr.LRSinkhorn, gromov_wasserstein.GromovWasserstein, gromov_wasserstein_lr.LRGromovWasserstein]
OTTProblem_t = Union[linear_problem.LinearProblem, quadratic_problem.QuadraticProblem]
Scale_t = Union[float, Literal["mean", "median", "max_cost", "max_norm", "max_bound"]]

Expand Down Expand Up @@ -248,16 +248,23 @@ def __init__(
linear_solver_kwargs.setdefault("gamma_rescale", True)
linear_ot_solver = sinkhorn_lr.LRSinkhorn(rank=rank, **linear_solver_kwargs)
initializer = "rank2" if initializer is None else initializer
self._solver = gromov_wasserstein.LRGromovWasserstein(
rank=rank,
linear_ot_solver=linear_ot_solver,
initializer=initializer,
kwargs_init=initializer_kwargs,
**kwargs,
)
else:
linear_ot_solver = sinkhorn.Sinkhorn(**linear_solver_kwargs)
initializer = None
self._solver = gromov_wasserstein.GromovWasserstein(
rank=rank,
linear_ot_solver=linear_ot_solver,
quad_initializer=initializer,
kwargs_init=initializer_kwargs,
**kwargs,
)
self._solver = gromov_wasserstein.GromovWasserstein(
rank=rank,
linear_ot_solver=linear_ot_solver,
quad_initializer=initializer,
kwargs_init=initializer_kwargs,
**kwargs,
)

def _prepare(
self,
Expand Down

0 comments on commit c5476ef

Please sign in to comment.