From 6db9806303cadf70b3ffb1edd7d1fc989b178b08 Mon Sep 17 00:00:00 2001 From: Arina Danilina Date: Fri, 6 Oct 2023 08:47:43 +0200 Subject: [PATCH] fix imports --- src/moscot/backends/ott/output.py | 4 ++-- src/moscot/backends/ott/solver.py | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/moscot/backends/ott/output.py b/src/moscot/backends/ott/output.py index d8eb9e5c2..0b7c81a98 100644 --- a/src/moscot/backends/ott/output.py +++ b/src/moscot/backends/ott/output.py @@ -6,7 +6,7 @@ import jax.numpy as jnp import numpy as np from ott.solvers.linear import sinkhorn, sinkhorn_lr -from ott.solvers.quadratic import gromov_wasserstein +from ott.solvers.quadratic import gromov_wasserstein, gromov_wasserstein_lr import matplotlib as mpl import matplotlib.pyplot as plt @@ -34,7 +34,7 @@ def __init__( sinkhorn.SinkhornOutput, sinkhorn_lr.LRSinkhornOutput, gromov_wasserstein.GWOutput, - gromov_wasserstein.LRGWOutput, + gromov_wasserstein_lr.LRGWOutput, ], ): super().__init__() diff --git a/src/moscot/backends/ott/solver.py b/src/moscot/backends/ott/solver.py index 622d5870c..d8bc86b0c 100644 --- a/src/moscot/backends/ott/solver.py +++ b/src/moscot/backends/ott/solver.py @@ -8,7 +8,7 @@ from ott.problems.linear import linear_problem from ott.problems.quadratic import quadratic_problem from ott.solvers.linear import sinkhorn, sinkhorn_lr -from ott.solvers.quadratic import gromov_wasserstein +from ott.solvers.quadratic import gromov_wasserstein, gromov_wasserstein_lr from moscot._types import ProblemKind_t, QuadInitializer_t, SinkhornInitializer_t from moscot.backends.ott._utils import alpha_to_fused_penalty, check_shapes, ensure_2d @@ -23,7 +23,7 @@ sinkhorn.Sinkhorn, sinkhorn_lr.LRSinkhorn, gromov_wasserstein.GromovWasserstein, - gromov_wasserstein.LRGromovWasserstein, + gromov_wasserstein_lr.LRGromovWasserstein, ] OTTProblem_t = Union[linear_problem.LinearProblem, quadratic_problem.QuadraticProblem] Scale_t = Union[float, Literal["mean", "median", "max_cost", "max_norm", "max_bound"]] @@ -252,7 +252,7 @@ def __init__( linear_solver_kwargs.setdefault("gamma", 10) linear_solver_kwargs.setdefault("gamma_rescale", True) initializer = "rank2" if initializer is None else initializer - self._solver = gromov_wasserstein.LRGromovWasserstein( + self._solver = gromov_wasserstein_lr.LRGromovWasserstein( rank=rank, quad_initializer=initializer, kwargs_init=initializer_kwargs,