Skip to content

Commit

Permalink
fix imports
Browse files Browse the repository at this point in the history
  • Loading branch information
Arina Danilina committed Oct 6, 2023
1 parent b7e5f46 commit 6db9806
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 5 deletions.
4 changes: 2 additions & 2 deletions src/moscot/backends/ott/output.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import jax.numpy as jnp
import numpy as np
from ott.solvers.linear import sinkhorn, sinkhorn_lr
from ott.solvers.quadratic import gromov_wasserstein
from ott.solvers.quadratic import gromov_wasserstein, gromov_wasserstein_lr

import matplotlib as mpl
import matplotlib.pyplot as plt
Expand Down Expand Up @@ -34,7 +34,7 @@ def __init__(
sinkhorn.SinkhornOutput,
sinkhorn_lr.LRSinkhornOutput,
gromov_wasserstein.GWOutput,
gromov_wasserstein.LRGWOutput,
gromov_wasserstein_lr.LRGWOutput,
],
):
super().__init__()
Expand Down
6 changes: 3 additions & 3 deletions src/moscot/backends/ott/solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from ott.problems.linear import linear_problem
from ott.problems.quadratic import quadratic_problem
from ott.solvers.linear import sinkhorn, sinkhorn_lr
from ott.solvers.quadratic import gromov_wasserstein
from ott.solvers.quadratic import gromov_wasserstein, gromov_wasserstein_lr

from moscot._types import ProblemKind_t, QuadInitializer_t, SinkhornInitializer_t
from moscot.backends.ott._utils import alpha_to_fused_penalty, check_shapes, ensure_2d
Expand All @@ -23,7 +23,7 @@
sinkhorn.Sinkhorn,
sinkhorn_lr.LRSinkhorn,
gromov_wasserstein.GromovWasserstein,
gromov_wasserstein.LRGromovWasserstein,
gromov_wasserstein_lr.LRGromovWasserstein,
]
OTTProblem_t = Union[linear_problem.LinearProblem, quadratic_problem.QuadraticProblem]
Scale_t = Union[float, Literal["mean", "median", "max_cost", "max_norm", "max_bound"]]
Expand Down Expand Up @@ -252,7 +252,7 @@ def __init__(
linear_solver_kwargs.setdefault("gamma", 10)
linear_solver_kwargs.setdefault("gamma_rescale", True)
initializer = "rank2" if initializer is None else initializer
self._solver = gromov_wasserstein.LRGromovWasserstein(
self._solver = gromov_wasserstein_lr.LRGromovWasserstein(
rank=rank,
quad_initializer=initializer,
kwargs_init=initializer_kwargs,
Expand Down

0 comments on commit 6db9806

Please sign in to comment.