diff --git a/docs/conf.py b/docs/conf.py index 6ac6bff5e..7806d7038 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -72,6 +72,9 @@ ("py:class", "None. Update D from dict/iterable E and F."), ("py:class", "an object providing a view on D's values"), ("py:class", "a shallow copy of D"), + # ignore these classes until ott-jax adds them to their docs + ("py:class", "ott.initializers.quadratic.initializers.BaseQuadraticInitializer"), + ("py:class", "ott.initializers.linear.initializers.SinkhornInitializer"), ] # TODO(michalk8): remove once typing has been cleaned-up nitpick_ignore_regex = [ diff --git a/pyproject.toml b/pyproject.toml index 9fdbc3fa3..af308d131 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -54,7 +54,7 @@ dependencies = [ "scanpy>=1.9.3", "wrapt>=1.13.2", "docrep>=0.3.2", - "ott-jax[neural]>=0.4.6,<=0.4.8", + "ott-jax[neural]>=0.5.0", "cloudpickle>=2.2.0", "rich>=13.5", "docstring_inheritance>=2.0.0", diff --git a/src/moscot/_types.py b/src/moscot/_types.py index 1c60884a2..1f72953d8 100644 --- a/src/moscot/_types.py +++ b/src/moscot/_types.py @@ -2,6 +2,9 @@ from typing import Any, Literal, Mapping, Optional, Sequence, Union import numpy as np +from ott.initializers.linear.initializers import SinkhornInitializer +from ott.initializers.linear.initializers_lr import LRInitializer +from ott.initializers.quadratic.initializers import BaseQuadraticInitializer # TODO(michalk8): polish @@ -17,13 +20,14 @@ Numeric_t = Union[int, float] # type of `time_key` arguments Filter_t = Optional[Union[str, Mapping[str, Sequence[Any]]]] # type how to filter adata Str_Dict_t = Optional[Union[str, Mapping[str, Sequence[Any]]]] # type for `cell_transition` -SinkFullRankInit = Literal["default", "gaussian", "sorting"] -LRInitializer_t = Literal["random", "rank2", "k-means", "generalized-k-means"] +SinkhornInitializerTag_t = Literal["default", "gaussian", "sorting"] +LRInitializerTag_t = Literal["random", "rank2", "k-means", "generalized-k-means"] -SinkhornInitializer_t = Optional[Union[SinkFullRankInit, LRInitializer_t]] -QuadInitializer_t = Optional[LRInitializer_t] +LRInitializer_t = Optional[Union[LRInitializer, LRInitializerTag_t]] +SinkhornInitializer_t = Optional[Union[SinkhornInitializer, SinkhornInitializerTag_t]] +QuadInitializer_t = Optional[Union[BaseQuadraticInitializer]] -Initializer_t = Union[SinkhornInitializer_t, LRInitializer_t] +Initializer_t = Union[SinkhornInitializer_t, QuadInitializer_t, LRInitializer_t] ProblemStage_t = Literal["prepared", "solved"] Device_t = Union[Literal["cpu", "gpu", "tpu"], str] diff --git a/src/moscot/backends/ott/_utils.py b/src/moscot/backends/ott/_utils.py index 2cac53b30..9f71e2d5a 100644 --- a/src/moscot/backends/ott/_utils.py +++ b/src/moscot/backends/ott/_utils.py @@ -8,6 +8,8 @@ import numpy as np import scipy.sparse as sp from ott.geometry import epsilon_scheduler, geodesic, geometry, pointcloud +from ott.initializers.linear import initializers as init_lib +from ott.initializers.linear import initializers_lr as lr_init_lib from ott.neural import datasets from ott.solvers import utils as solver_utils from ott.tools.sinkhorn_divergence import sinkhorn_divergence as sinkhorn_div @@ -21,6 +23,90 @@ __all__ = ["sinkhorn_divergence"] +class InitializerResolver: + """Class for creating various OT solver initializers. + + This class provides static methods to create and manage different types of + initializers used in optimal transport solvers, including low-rank, k-means, + and standard Sinkhorn initializers. + """ + + @staticmethod + def lr_from_str( + initializer: str, + rank: int, + **kwargs: Any, + ) -> lr_init_lib.LRInitializer: + """Create a low-rank initializer from a string specification. + + Parameters + ---------- + initializer : str + Either existing initializer instance or string specifier. + rank : int + Rank for the initialization. + **kwargs : Any + Additional keyword arguments for initializer creation. + + Returns + ------- + LRInitializer + Configured low-rank initializer. + + Raises + ------ + NotImplementedError + If requested initializer type is not implemented. + """ + if isinstance(initializer, lr_init_lib.LRInitializer): + return initializer + if initializer == "k-means": + return lr_init_lib.KMeansInitializer(rank=rank, **kwargs) + if initializer == "generalized-k-means": + return lr_init_lib.GeneralizedKMeansInitializer(rank=rank, **kwargs) + if initializer == "random": + return lr_init_lib.RandomInitializer(rank=rank, **kwargs) + if initializer == "rank2": + return lr_init_lib.Rank2Initializer(rank=rank, **kwargs) + raise NotImplementedError(f"Initializer `{initializer}` is not implemented.") + + @staticmethod + def from_str( + initializer: str, + **kwargs: Any, + ) -> init_lib.SinkhornInitializer: + """Create a Sinkhorn initializer from a string specification. + + Parameters + ---------- + initializer : str + String specifier for initializer type. + **kwargs : Any + Additional keyword arguments for initializer creation. + + Returns + ------- + SinkhornInitializer + Configured Sinkhorn initializer. + + Raises + ------ + NotImplementedError + If requested initializer type is not implemented. + """ + if isinstance(initializer, init_lib.SinkhornInitializer): + return initializer + if initializer == "default": + return init_lib.DefaultInitializer(**kwargs) + if initializer == "gaussian": + return init_lib.GaussianInitializer(**kwargs) + if initializer == "sorting": + return init_lib.SortingInitializer(**kwargs) + if initializer == "subsample": + return init_lib.SubsampleInitializer(**kwargs) + raise NotImplementedError(f"Initializer `{initializer}` is not yet implemented.") + + def sinkhorn_divergence( point_cloud_1: ArrayLike, point_cloud_2: ArrayLike, @@ -45,11 +131,14 @@ def sinkhorn_divergence( batch_size=batch_size, a=a, b=b, - sinkhorn_kwargs={"tau_a": tau_a, "tau_b": tau_b}, scale_cost=scale_cost, epsilon=epsilon, + solve_kwargs={ + "tau_a": tau_a, + "tau_b": tau_b, + }, **kwargs, - ) + )[1] xy_conv, xx_conv, *yy_conv = output.converged if not xy_conv: @@ -132,7 +221,7 @@ def ensure_2d(arr: ArrayLike, *, reshape: bool = False) -> jax.Array: return jnp.reshape(arr, (-1, 1)) if arr.ndim != 2: raise ValueError(f"Expected array to have 2 dimensions, found `{arr.ndim}`.") - return arr + return arr.astype(jnp.float64) def convert_scipy_sparse(arr: Union[sp.spmatrix, jesp.BCOO]) -> jesp.BCOO: diff --git a/src/moscot/backends/ott/solver.py b/src/moscot/backends/ott/solver.py index dba12b5ac..9a21acf86 100644 --- a/src/moscot/backends/ott/solver.py +++ b/src/moscot/backends/ott/solver.py @@ -36,11 +36,13 @@ from moscot._logging import logger from moscot._types import ( ArrayLike, + LRInitializer_t, ProblemKind_t, QuadInitializer_t, SinkhornInitializer_t, ) from moscot.backends.ott._utils import ( + InitializerResolver, Loader, MultiLoader, _instantiate_geodesic_cost, @@ -88,9 +90,11 @@ class OTTJaxSolver(OTSolver[OTTOutput], abc.ABC): ---------- jit Whether to :func:`~jax.jit` the :attr:`solver`. + initializer_kwargs + Keyword arguments for the initializer. """ - def __init__(self, jit: bool = True): + def __init__(self, jit: bool = True, initializer_kwargs: Mapping[str, Any] = types.MappingProxyType({})): super().__init__() self._solver: Optional[OTTSolver_t] = None self._problem: Optional[OTTProblem_t] = None @@ -98,6 +102,8 @@ def __init__(self, jit: bool = True): self._a: Optional[jnp.ndarray] = None self._b: Optional[jnp.ndarray] = None + self.initializer_kwargs = initializer_kwargs + def _create_geometry( self, x: TaggedArray, @@ -170,7 +176,7 @@ def _solve( # type: ignore[override] **kwargs: Any, ) -> Union[OTTOutput, GraphOTTOutput]: solver = jax.jit(self.solver) if self._jit else self.solver - out = solver(prob, **kwargs) + out = solver(prob, **self.initializer_kwargs, **kwargs) if isinstance(prob, linear_problem.LinearProblem) and isinstance(prob.geom, geodesic.Geodesic): return GraphOTTOutput(out, shape=(len(self._a), len(self._b))) # type: ignore[arg-type] return OTTOutput(out) @@ -275,20 +281,20 @@ def __init__( initializer_kwargs: Mapping[str, Any] = types.MappingProxyType({}), **kwargs: Any, ): - super().__init__(jit=jit) + super().__init__(jit=jit, initializer_kwargs=initializer_kwargs) if rank > -1: kwargs.setdefault("gamma", 500) kwargs.setdefault("gamma_rescale", True) eps = kwargs.get("epsilon") if eps is not None and eps > 0.0: logger.info(f"Found `epsilon`={eps}>0. We recommend setting `epsilon`=0 for the low-rank solver.") - initializer = "rank2" if initializer is None else initializer - self._solver = sinkhorn_lr.LRSinkhorn( - rank=rank, epsilon=epsilon, initializer=initializer, kwargs_init=initializer_kwargs, **kwargs - ) + if isinstance(initializer, str): + initializer = InitializerResolver.lr_from_str(initializer, rank=rank) + self._solver = sinkhorn_lr.LRSinkhorn(rank=rank, epsilon=epsilon, initializer=initializer, **kwargs) else: - initializer = "default" if initializer is None else initializer - self._solver = sinkhorn.Sinkhorn(initializer=initializer, kwargs_init=initializer_kwargs, **kwargs) + if isinstance(initializer, str): + initializer = InitializerResolver.from_str(initializer) + self._solver = sinkhorn.Sinkhorn(initializer=initializer, **kwargs) def _prepare( self, @@ -389,33 +395,32 @@ def __init__( self, jit: bool = True, rank: int = -1, - initializer: QuadInitializer_t = None, + initializer: QuadInitializer_t | LRInitializer_t = None, initializer_kwargs: Mapping[str, Any] = types.MappingProxyType({}), linear_solver_kwargs: Mapping[str, Any] = types.MappingProxyType({}), **kwargs: Any, ): - super().__init__(jit=jit) + super().__init__(jit=jit, initializer_kwargs=initializer_kwargs) if rank > -1: kwargs.setdefault("gamma", 10) kwargs.setdefault("gamma_rescale", True) eps = kwargs.get("epsilon") if eps is not None and eps > 0.0: logger.info(f"Found `epsilon`={eps}>0. We recommend setting `epsilon`=0 for the low-rank solver.") - initializer = "rank2" if initializer is None else initializer + if isinstance(initializer, str): + initializer = InitializerResolver.lr_from_str(initializer, rank=rank) self._solver = gromov_wasserstein_lr.LRGromovWasserstein( rank=rank, initializer=initializer, - kwargs_init=initializer_kwargs, **kwargs, ) else: - linear_ot_solver = sinkhorn.Sinkhorn(**linear_solver_kwargs) - initializer = None + linear_solver = sinkhorn.Sinkhorn(**linear_solver_kwargs) + if isinstance(initializer, str): + raise ValueError("Expected `initializer` to be `None` or `ott.initializers.quadratic.initializers`.") self._solver = gromov_wasserstein.GromovWasserstein( - rank=rank, - linear_ot_solver=linear_ot_solver, - quad_initializer=initializer, - kwargs_init=initializer_kwargs, + linear_solver=linear_solver, + initializer=initializer, **kwargs, ) @@ -423,6 +428,7 @@ def _prepare( self, a: jnp.ndarray, b: jnp.ndarray, + alpha: float, xy: Optional[TaggedArray] = None, x: Optional[TaggedArray] = None, y: Optional[TaggedArray] = None, @@ -435,7 +441,6 @@ def _prepare( cost_matrix_rank: Optional[int] = None, time_scales_heat_kernel: Optional[TimeScalesHeatKernel] = None, # problem - alpha: float = 0.5, **kwargs: Any, ) -> quadratic_problem.QuadraticProblem: self._a = a @@ -456,6 +461,11 @@ def _prepare( geom_kwargs["cost_matrix_rank"] = cost_matrix_rank geom_xx = self._create_geometry(x, t=time_scales_heat_kernel.x, is_linear_term=False, **geom_kwargs) geom_yy = self._create_geometry(y, t=time_scales_heat_kernel.y, is_linear_term=False, **geom_kwargs) + if alpha <= 0.0: + raise ValueError(f"Expected `alpha` to be in interval `(0, 1]`, found `{alpha}`.") + if (alpha == 1.0 and xy is not None) or (alpha != 1.0 and xy is None): + raise ValueError(f"Expected `xy` to be `None` if `alpha` is not 1.0, found xy={xy}, alpha={alpha}.") + if alpha == 1.0 or xy is None: # GW # arbitrary fused penalty; must be positive geom_xy, fused_penalty = None, 1.0 diff --git a/src/moscot/base/problems/problem.py b/src/moscot/base/problems/problem.py index 1c71c0beb..8cebbb639 100644 --- a/src/moscot/base/problems/problem.py +++ b/src/moscot/base/problems/problem.py @@ -518,24 +518,8 @@ def solve( solver_class = backends.get_solver( self.problem_kind, solver_name=solver_name, backend=backend, return_class=True ) - init_kwargs, call_kwargs = solver_class._partition_kwargs(**kwargs) - # if linear problem, then alpha is 0.0 by default - # if quadratic problem, then alpha is 1.0 by default - alpha = call_kwargs.get("alpha", 0.0 if self.problem_kind == "linear" else 1.0) - if alpha < 0.0 or alpha > 1.0: - raise ValueError("Expected `alpha` to be in the range `[0, 1]`, found `{alpha}`.") - if self.problem_kind == "linear" and (alpha != 0.0 or not (self.x is None or self.y is None)): - raise ValueError("Unable to solve a linear problem with `alpha != 0` or `x` and `y` supplied.") - if self.problem_kind == "quadratic": - if self.x is None or self.y is None: - raise ValueError("Unable to solve a quadratic problem without `x` and `y` supplied.") - if alpha != 1.0 and self.xy is None: # means FGW case - raise ValueError( - "`alpha` must be 1.0 for quadratic problems without `xy` supplied. See `FGWProblem` class." - ) - if alpha == 1.0 and self.xy is not None: - raise ValueError("Unable to solve a quadratic problem with `alpha = 1` and `xy` supplied.") + init_kwargs, call_kwargs = solver_class._partition_kwargs(**kwargs) self._solver = solver_class(**init_kwargs) # note that the solver call consists of solver._prepare and solver._solve diff --git a/src/moscot/problems/time/_lineage.py b/src/moscot/problems/time/_lineage.py index 809b7f8b9..3fcfd53c9 100644 --- a/src/moscot/problems/time/_lineage.py +++ b/src/moscot/problems/time/_lineage.py @@ -185,7 +185,6 @@ def solve( jit: bool = True, threshold: float = 1e-3, lse_mode: bool = True, - inner_iterations: int = 10, min_iterations: Optional[int] = None, max_iterations: Optional[int] = None, device: Optional[Literal["cpu", "gpu", "tpu"]] = None, @@ -233,9 +232,7 @@ def solve( lse_mode Whether to use `log-sum-exp (LSE) `_ - computations for numerical stability. - inner_iterations - Compute the convergence criterion every ``inner_iterations``. + computations for numerical stability. Valid only for the :term:`linear problem`. min_iterations Minimum number of :term:`Sinkhorn` iterations. max_iterations @@ -253,6 +250,8 @@ def solve( - :attr:`solutions` - the :term:`OT` solutions for each subproblem. - :attr:`stage` - set to ``'solved'``. """ + if self.problem_kind == "linear": + kwargs["lse_mode"] = lse_mode return super().solve( # type:ignore[return-value] epsilon=epsilon, tau_a=tau_a, @@ -265,8 +264,6 @@ def solve( initializer_kwargs=initializer_kwargs, jit=jit, threshold=threshold, - lse_mode=lse_mode, - inner_iterations=inner_iterations, min_iterations=min_iterations, max_iterations=max_iterations, device=device, diff --git a/src/moscot/problems/time/_mixins.py b/src/moscot/problems/time/_mixins.py index c2537f6b4..dee020f35 100644 --- a/src/moscot/problems/time/_mixins.py +++ b/src/moscot/problems/time/_mixins.py @@ -518,7 +518,7 @@ def _get_data( if src == source: source_data = self.problems[src, tgt].xy.data_src if only_start: - return source_data, self.problems[src, tgt].adata_src + return source_data.astype(np.float64), self.problems[src, tgt].adata_src # TODO(michalk8): posterior marginals attr = "posterior_growth_rates" if posterior_marginals else "prior_growth_rates" growth_rates_source = getattr(self.problems[src, tgt], attr) @@ -540,12 +540,12 @@ def _get_data( raise ValueError(f"No data found for `{target}` time point.") return ( - source_data, - growth_rates_source, - intermediate_data, + source_data.astype(np.float64) if source_data is not None else None, + growth_rates_source.astype(np.float64) if growth_rates_source is not None else None, + intermediate_data.astype(np.float64) if intermediate_data is not None else None, intermediate_adata, - target_data, - ) + target_data.astype(np.float64) if target_data is not None else None, + ) # type: ignore[return-value] def compute_interpolated_distance( self, diff --git a/tests/backends/ott/test_backend.py b/tests/backends/ott/test_backend.py index 3962c53d1..3a60308e4 100644 --- a/tests/backends/ott/test_backend.py +++ b/tests/backends/ott/test_backend.py @@ -20,7 +20,7 @@ from moscot._types import ArrayLike, Device_t from moscot.backends.ott import GWSolver, SinkhornSolver -from moscot.backends.ott._utils import alpha_to_fused_penalty +from moscot.backends.ott._utils import InitializerResolver, alpha_to_fused_penalty from moscot.base.output import BaseDiscreteSolverOutput from moscot.base.solver import O, OTSolver from moscot.utils.tagged_array import Tag, TaggedArray @@ -52,6 +52,7 @@ def test_matches_ott(self, x: Geom_t, eps: Optional[float], jit: bool): def test_solver_rank(self, y: Geom_t, rank: Optional[int], initializer: str): eps = 1e-2 default_gamma_lr_sinhorn = 500 + initializer = InitializerResolver.lr_from_str(initializer, rank=rank) lr_sinkhorn = LRSinkhorn(rank=rank, initializer=initializer, gamma=default_gamma_lr_sinhorn) problem = LinearProblem(PointCloud(y, epsilon=eps)) gt = lr_sinkhorn(problem) @@ -99,7 +100,7 @@ def test_matches_ott(self, x: Geom_t, y: Geom_t, eps: Optional[float], jit: bool thresh = 1e-2 pc_x, pc_y = PointCloud(x, epsilon=eps), PointCloud(y, epsilon=eps) prob = quadratic_problem.QuadraticProblem(pc_x, pc_y) - sol = GromovWasserstein(epsilon=eps, threshold=thresh) + sol = GromovWasserstein(epsilon=eps, threshold=thresh, linear_solver=Sinkhorn()) solver = jax.jit(sol, static_argnames=["threshold", "epsilon"]) if jit else sol gt = solver(prob) @@ -114,6 +115,7 @@ def test_matches_ott(self, x: Geom_t, y: Geom_t, eps: Optional[float], jit: bool x=x, y=y, tags={"x": "point_cloud", "y": "point_cloud"}, + alpha=1.0, ) assert solver.is_fused is False @@ -130,7 +132,7 @@ def test_epsilon(self, x_cost: jnp.ndarray, y_cost: jnp.ndarray, eps: Optional[f problem = QuadraticProblem( geom_xx=Geometry(cost_matrix=x_cost, epsilon=eps), geom_yy=Geometry(cost_matrix=y_cost, epsilon=eps) ) - gt = GromovWasserstein(epsilon=eps, threshold=thresh)(problem) + gt = GromovWasserstein(epsilon=eps, threshold=thresh, linear_solver=Sinkhorn())(problem) solver = GWSolver(epsilon=eps, threshold=thresh) pred = solver( @@ -139,6 +141,7 @@ def test_epsilon(self, x_cost: jnp.ndarray, y_cost: jnp.ndarray, eps: Optional[f x=x_cost, y=y_cost, tags={"x": Tag.COST_MATRIX, "y": Tag.COST_MATRIX}, + alpha=1.0, ) assert solver.is_fused is False @@ -152,12 +155,13 @@ def test_epsilon(self, x_cost: jnp.ndarray, y_cost: jnp.ndarray, eps: Optional[f def test_solver_rank(self, x: Geom_t, y: Geom_t, rank: int) -> None: thresh, eps = 1e-2, 1e-2 if rank > -1: - gt = LRGromovWasserstein(epsilon=eps, rank=rank, threshold=thresh, initializer="rank2")( + initializer = InitializerResolver.lr_from_str("random", rank=rank) + gt = LRGromovWasserstein(epsilon=eps, rank=rank, threshold=thresh, initializer=initializer)( QuadraticProblem(PointCloud(x, epsilon=eps), PointCloud(y, epsilon=eps)) ) else: - gt = GromovWasserstein(epsilon=eps, rank=rank, threshold=thresh)( + gt = GromovWasserstein(epsilon=eps, threshold=thresh, linear_solver=Sinkhorn(threshold=thresh))( QuadraticProblem(PointCloud(x, epsilon=eps), PointCloud(y, epsilon=eps)) ) @@ -168,6 +172,7 @@ def test_solver_rank(self, x: Geom_t, y: Geom_t, rank: int) -> None: x=x, y=y, tags={"x": "point_cloud", "y": "point_cloud"}, + alpha=1.0, ) assert solver.is_fused is False @@ -183,7 +188,7 @@ def test_matches_ott(self, x: Geom_t, y: Geom_t, xy: Geom_t, eps: Optional[float thresh = 1e-2 xx, yy = xy - ott_solver = GromovWasserstein(epsilon=eps, threshold=thresh) + ott_solver = GromovWasserstein(epsilon=eps, threshold=thresh, linear_solver=Sinkhorn()) problem = quadratic_problem.QuadraticProblem( geom_xx=PointCloud(x, epsilon=eps), geom_yy=PointCloud(y, epsilon=eps), @@ -218,7 +223,7 @@ def test_alpha(self, x: Geom_t, y: Geom_t, xy: Geom_t, alpha: float) -> None: thresh, eps = 5e-2, 1e-1 xx, yy = xy - ott_solver = GromovWasserstein(epsilon=eps, threshold=thresh) + ott_solver = GromovWasserstein(epsilon=eps, threshold=thresh, linear_solver=Sinkhorn()) problem = quadratic_problem.QuadraticProblem( geom_xx=PointCloud(x, epsilon=eps), geom_yy=PointCloud(y, epsilon=eps), @@ -256,7 +261,7 @@ def test_epsilon( geom_xy=Geometry(cost_matrix=xy_cost, epsilon=eps), fused_penalty=alpha_to_fused_penalty(alpha), ) - gt = GromovWasserstein(epsilon=eps, threshold=thresh)(problem) + gt = GromovWasserstein(epsilon=eps, threshold=thresh, linear_solver=Sinkhorn())(problem) solver = GWSolver(epsilon=eps, threshold=thresh) pred = solver( @@ -344,8 +349,7 @@ def test_pull( b, ndim = (b, b.shape[1]) if batched else (b[:, 0], None) xx, yy = xy solver = solver_t() - - out = solver(a=jnp.ones(len(x)) / len(x), b=jnp.ones(len(y)) / len(y), x=x, y=y, xy=(xx, yy)) + out = solver(a=jnp.ones(len(x)) / len(x), b=jnp.ones(len(y)) / len(y), x=x, y=y, xy=(xx, yy), alpha=0.5) p = out.pull(b, scale_by_marginals=False) assert isinstance(out, BaseDiscreteSolverOutput) @@ -386,11 +390,11 @@ def test_to_device(self, x: Geom_t, device: Optional[Device_t]) -> None: class TestOutputPlotting(PlotTester, metaclass=PlotTesterMeta): def test_plot_costs(self, x: Geom_t, y: Geom_t): - out = GWSolver()(a=jnp.ones(len(x)) / len(x), b=jnp.ones(len(y)) / len(y), x=x, y=y) + out = GWSolver()(a=jnp.ones(len(x)) / len(x), b=jnp.ones(len(y)) / len(y), x=x, y=y, alpha=1.0) out.plot_costs() def test_plot_costs_last(self, x: Geom_t, y: Geom_t): - out = GWSolver(rank=2)(a=jnp.ones(len(x)) / len(x), b=jnp.ones(len(y)) / len(y), x=x, y=y) + out = GWSolver(rank=2)(a=jnp.ones(len(x)) / len(x), b=jnp.ones(len(y)) / len(y), x=x, y=y, alpha=1.0) out.plot_costs(last=3) def test_plot_errors_sink(self, x: Geom_t, y: Geom_t): @@ -398,7 +402,7 @@ def test_plot_errors_sink(self, x: Geom_t, y: Geom_t): out.plot_errors() def test_plot_errors_gw(self, x: Geom_t, y: Geom_t): - out = GWSolver(a=jnp.ones(len(x)) / len(x), b=jnp.ones(len(y)) / len(y), store_inner_errors=True)( - a=jnp.ones(len(x)) / len(x), b=jnp.ones(len(y)) / len(y), x=x, y=y + out = GWSolver(store_inner_errors=True)( + a=jnp.ones(len(x)) / len(x), b=jnp.ones(len(y)) / len(y), x=x, y=y, alpha=1.0 ) out.plot_errors() diff --git a/tests/data/alignment_solutions.pkl b/tests/data/alignment_solutions.pkl index 6e8f2be8b..145468edd 100644 Binary files a/tests/data/alignment_solutions.pkl and b/tests/data/alignment_solutions.pkl differ diff --git a/tests/data/mapping_solutions.pkl b/tests/data/mapping_solutions.pkl index b9a2a05f1..63a4d0f29 100644 Binary files a/tests/data/mapping_solutions.pkl and b/tests/data/mapping_solutions.pkl differ diff --git a/tests/problems/base/test_general_problem.py b/tests/problems/base/test_general_problem.py index 5e17b6dec..7cd907ba1 100644 --- a/tests/problems/base/test_general_problem.py +++ b/tests/problems/base/test_general_problem.py @@ -1,3 +1,4 @@ +import re from typing import Literal, Optional, Tuple import pytest @@ -29,6 +30,29 @@ def test_simple_run(self, adata_x: AnnData, adata_y: AnnData): assert isinstance(prob.solution, BaseDiscreteSolverOutput) + @pytest.mark.parametrize( + ("kind", "rank"), + [ + ("linear", -1), + ("linear", 5), + ("quadratic", -1), + ("quadratic", 5), + ], + ) + def test_unrecognized_args( + self, adata_x: AnnData, adata_y: AnnData, kind: Literal["linear", "quadratic"], rank: int + ): + prob = OTProblem(adata_x, adata_y) + data = { + "xy": {"x_attr": "obsm", "x_key": "X_pca", "y_attr": "obsm", "y_key": "X_pca"}, + } + if "quadratic" in kind: + data["x"] = {"attr": "X"} + data["y"] = {"attr": "X"} + + with pytest.raises(TypeError): + prob.prepare(**data).solve(epsilon=5e-1, rank=rank, dummy=42) + @pytest.mark.fast def test_output(self, adata_x: AnnData, x: Geom_t): problem = OTProblem(adata_x) @@ -264,7 +288,7 @@ def test_set_graph_x_y(self, adata_x: AnnData, adata_y: AnnData, ts: Tuple[Optio assert ta2.tag == Tag.GRAPH assert ta2.cost == "geodesic" - prob1 = prob1.solve(lse_mode=False, epsilon=10.0) + prob1 = prob1.solve(epsilon=10.0, alpha=1.0) prob2 = OTProblem(adata_x, adata_y) prob2 = prob2.prepare( @@ -289,7 +313,7 @@ def test_set_graph_x_y(self, adata_x: AnnData, adata_y: AnnData, ts: Tuple[Optio assert ta2.tag == Tag.GRAPH assert ta2.cost == "geodesic" - prob2 = prob2.solve(lse_mode=False, epsilon=10.0) + prob2 = prob2.solve(epsilon=10.0, alpha=1.0) assert not np.allclose(prob1.solution._output.geom.cost_matrix, prob2.solution._output.geom.cost_matrix) @@ -346,3 +370,35 @@ def test_set_graph_xy_test_t(self, adata_x: AnnData, adata_y: AnnData, t: float) assert pushed_0.shape == pushed_1.shape assert np.all(np.abs(pushed_0 - pushed_1).sum() > np.abs(pushed_2 - pushed_1).sum()) assert np.all(np.abs(pushed_0 - pushed_2).sum() > np.abs(pushed_1 - pushed_2).sum()) + + @pytest.mark.parametrize( + ("attrs", "alpha", "raise_msg"), + [ + ({"xy"}, 0.5, "type-error"), + ({"xy", "x", "y"}, 0, re.escape("Expected `alpha` to be in interval `(0, 1]`, found")), + ({"xy", "x", "y"}, 1.1, re.escape("Expected `alpha` to be in interval `(0, 1]`, found")), + ({"xy", "x", "y"}, 0.5, None), + ({"x", "y"}, 1.0, None), + ({"x", "y"}, 0.5, re.escape("Expected `xy` to be `None` if `alpha` is not 1.0, found")), + ], + ) + def test_xy_alpha_raises(self, adata_x: AnnData, adata_y: AnnData, attrs, alpha, raise_msg): + prob = OTProblem(adata_x, adata_y) + data = { + "xy": {"x_attr": "obsm", "x_key": "X_pca", "y_attr": "obsm", "y_key": "X_pca"} if "xy" in attrs else {}, + "x": {"attr": "X"} if "x" in attrs else {}, + "y": {"attr": "X"} if "y" in attrs else {}, + } + prob = prob.prepare( + **data, + ) + if raise_msg is not None: + if raise_msg == "type-error": + with pytest.raises(TypeError): + prob.solve(epsilon=5e-1, alpha=alpha) + else: + with pytest.raises(ValueError, match=raise_msg): + prob.solve(epsilon=5e-1, alpha=alpha) + else: + prob.solve(epsilon=5e-1, alpha=alpha) + assert isinstance(prob.solution, BaseDiscreteSolverOutput) diff --git a/tests/problems/conftest.py b/tests/problems/conftest.py index 9ee32e693..b4f883bba 100644 --- a/tests/problems/conftest.py +++ b/tests/problems/conftest.py @@ -183,9 +183,8 @@ def marginal_keys(request): "threshold": "threshold", "min_iterations": "min_iterations", "max_iterations": "max_iterations", - "initializer_kwargs": "kwargs_init", - "warm_start": "_warm_start", - "initializer": "quad_initializer", + "warm_start": "warm_start", + "initializer": "initializer", } gw_lr_solver_args = { @@ -194,7 +193,6 @@ def marginal_keys(request): "threshold": "threshold", "min_iterations": "min_iterations", "max_iterations": "max_iterations", - "initializer_kwargs": "kwargs_init", "initializer": "initializer", } @@ -246,7 +244,7 @@ def marginal_keys(request): "min_iterations": "min_iterations", "max_iterations": "max_iterations", "initializer": "initializer", - "initializer_kwargs": "kwargs_init", + "initializer_kwargs": "initializer_kwargs", } lr_sinkhorn_solver_args = sinkhorn_solver_args.copy() diff --git a/tests/problems/cross_modality/test_translation_problem.py b/tests/problems/cross_modality/test_translation_problem.py index b048fa9cf..4af87e3b1 100644 --- a/tests/problems/cross_modality/test_translation_problem.py +++ b/tests/problems/cross_modality/test_translation_problem.py @@ -1,10 +1,9 @@ from contextlib import nullcontext -from typing import Any, Literal, Mapping, Optional, Tuple +from typing import Any, Callable, Literal, Mapping, Optional, Tuple import pytest import numpy as np -from ott.geometry import epsilon_scheduler from anndata import AnnData @@ -144,12 +143,12 @@ def test_pass_arguments(self, adata_translation_split: Tuple[AnnData, AnnData], tp = tp.solve(**args_to_check) solver = tp[key].solver.solver - args = gw_solver_args if args_to_check["rank"] == -1 else gw_lr_solver_args for arg, val in args.items(): - assert getattr(solver, val) == args_to_check[arg], arg + if arg == "initializer": + assert isinstance(getattr(solver, val), Callable) - sinkhorn_solver = solver.linear_ot_solver if args_to_check["rank"] == -1 else solver + sinkhorn_solver = solver.linear_solver if args_to_check["rank"] == -1 else solver lin_solver_args = gw_linear_solver_args if args_to_check["rank"] == -1 else gw_lr_linear_solver_args tmp_dict = args_to_check["linear_solver_kwargs"] if args_to_check["rank"] == -1 else args_to_check for arg, val in lin_solver_args.items(): @@ -171,8 +170,7 @@ def test_pass_arguments(self, adata_translation_split: Tuple[AnnData, AnnData], el = getattr(geom, val)[0] if isinstance(getattr(geom, val), tuple) else getattr(geom, val) if arg == "epsilon": eps_processed = getattr(geom, val) - assert isinstance(eps_processed, epsilon_scheduler.Epsilon) - assert eps_processed.target == args_to_check[arg], arg + assert eps_processed == args_to_check[arg], arg else: assert getattr(geom, val) == args_to_check[arg], arg assert el == args_to_check[arg] diff --git a/tests/problems/generic/test_fgw_problem.py b/tests/problems/generic/test_fgw_problem.py index 3ee7b4f30..08d0a75bf 100644 --- a/tests/problems/generic/test_fgw_problem.py +++ b/tests/problems/generic/test_fgw_problem.py @@ -1,10 +1,9 @@ -from typing import Any, Literal, Mapping +from typing import Any, Callable, Literal, Mapping import pytest import numpy as np import pandas as pd -from ott.geometry import epsilon_scheduler from ott.geometry.costs import Cosine, Euclidean, PNormP, SqEuclidean, SqPNorm from ott.solvers.linear import acceleration @@ -112,9 +111,10 @@ def test_pass_arguments(self, adata_space_rotate: AnnData, args_to_check: Mappin solver = problem[key].solver.solver args = gw_solver_args if args_to_check["rank"] == -1 else gw_lr_solver_args for arg, val in args.items(): - assert getattr(solver, val, object()) == args_to_check[arg], arg + if arg == "initializer": + assert isinstance(getattr(solver, val), Callable) - sinkhorn_solver = solver.linear_ot_solver if args_to_check["rank"] == -1 else solver + sinkhorn_solver = solver.linear_solver if args_to_check["rank"] == -1 else solver lin_solver_args = gw_linear_solver_args if args_to_check["rank"] == -1 else gw_lr_linear_solver_args tmp_dict = args_to_check["linear_solver_kwargs"] if args_to_check["rank"] == -1 else args_to_check for arg, val in lin_solver_args.items(): @@ -136,8 +136,7 @@ def test_pass_arguments(self, adata_space_rotate: AnnData, args_to_check: Mappin el = getattr(geom, val)[0] if isinstance(getattr(geom, val), tuple) else getattr(geom, val) if arg == "epsilon": eps_processed = getattr(geom, val) - assert isinstance(eps_processed, epsilon_scheduler.Epsilon) - assert eps_processed.target == args_to_check[arg], arg + assert eps_processed == args_to_check[arg], arg else: assert getattr(geom, val) == args_to_check[arg], arg assert el == args_to_check[arg] @@ -342,7 +341,7 @@ def test_passing_ott_kwargs_linear(self, adata_space_rotate: AnnData, memory: in }, ) - sinkhorn_solver = problem[("0", "1")].solver.solver.linear_ot_solver + sinkhorn_solver = problem[("0", "1")].solver.solver.linear_solver anderson = sinkhorn_solver.anderson assert isinstance(anderson, acceleration.AndersonAcceleration) diff --git a/tests/problems/generic/test_gw_problem.py b/tests/problems/generic/test_gw_problem.py index 7ae0cb7e5..5fa815990 100644 --- a/tests/problems/generic/test_gw_problem.py +++ b/tests/problems/generic/test_gw_problem.py @@ -1,10 +1,9 @@ -from typing import Any, Literal, Mapping +from typing import Any, Callable, Literal, Mapping import pytest import numpy as np import pandas as pd -from ott.geometry import epsilon_scheduler from ott.geometry.costs import Cosine, Euclidean, PNormP, SqEuclidean, SqPNorm from ott.solvers.linear import acceleration @@ -117,9 +116,10 @@ def test_pass_arguments(self, adata_space_rotate: AnnData, args_to_check: Mappin args = gw_solver_args if args_to_check["rank"] == -1 else gw_lr_solver_args for arg, val in args.items(): assert hasattr(solver, val) - assert getattr(solver, val) == args_to_check[arg] + if arg == "initializer": + assert isinstance(getattr(solver, val), Callable) - sinkhorn_solver = solver.linear_ot_solver if args_to_check["rank"] == -1 else solver + sinkhorn_solver = solver.linear_solver if args_to_check["rank"] == -1 else solver lin_solver_args = gw_linear_solver_args if args_to_check["rank"] == -1 else gw_lr_linear_solver_args tmp_dict = args_to_check["linear_solver_kwargs"] if args_to_check["rank"] == -1 else args_to_check for arg, val in lin_solver_args.items(): @@ -141,8 +141,7 @@ def test_pass_arguments(self, adata_space_rotate: AnnData, args_to_check: Mappin el = getattr(geom, val)[0] if isinstance(getattr(geom, val), tuple) else getattr(geom, val) if arg == "epsilon": eps_processed = getattr(geom, val) - assert isinstance(eps_processed, epsilon_scheduler.Epsilon) - assert eps_processed.target == args_to_check[arg], arg + assert eps_processed == args_to_check[arg], arg else: assert getattr(geom, val) == args_to_check[arg], arg assert el == args_to_check[arg] @@ -307,7 +306,7 @@ def test_passing_ott_kwargs_linear(self, adata_space_rotate: AnnData, memory: in }, ) - sinkhorn_solver = problem[("0", "1")].solver.solver.linear_ot_solver + sinkhorn_solver = problem[("0", "1")].solver.solver.linear_solver anderson = sinkhorn_solver.anderson assert isinstance(anderson, acceleration.AndersonAcceleration) diff --git a/tests/problems/generic/test_sinkhorn_problem.py b/tests/problems/generic/test_sinkhorn_problem.py index 1badbf49b..755dc1541 100644 --- a/tests/problems/generic/test_sinkhorn_problem.py +++ b/tests/problems/generic/test_sinkhorn_problem.py @@ -1,10 +1,9 @@ -from typing import Any, Literal, Mapping +from typing import Any, Callable, Literal, Mapping import pytest import numpy as np import pandas as pd -from ott.geometry import epsilon_scheduler from ott.geometry.costs import Cosine, Euclidean, PNormP, SqEuclidean, SqPNorm from ott.solvers.linear import acceleration @@ -161,9 +160,13 @@ def test_pass_arguments(self, adata_time: AnnData, args_to_check: Mapping[str, A solver = problem[(0, 1)].solver.solver args = sinkhorn_solver_args if args_to_check["rank"] == -1 else lr_sinkhorn_solver_args for arg, val in args.items(): - assert hasattr(solver, val), val - el = getattr(solver, val)[0] if isinstance(getattr(solver, val), tuple) else getattr(solver, val) - assert el == args_to_check[arg], arg + if arg != "initializer_kwargs": + assert hasattr(solver, val), val + el = getattr(solver, val)[0] if isinstance(getattr(solver, val), tuple) else getattr(solver, val) + if arg == "initializer": + assert isinstance(el, Callable) + else: + assert el == args_to_check[arg], arg lin_prob = problem[(0, 1)]._solver._problem for arg, val in lin_prob_args.items(): @@ -177,8 +180,7 @@ def test_pass_arguments(self, adata_time: AnnData, args_to_check: Mapping[str, A el = getattr(geom, val)[0] if isinstance(getattr(geom, val), tuple) else getattr(geom, val) if arg == "epsilon": eps_processed = getattr(geom, val) - assert isinstance(eps_processed, epsilon_scheduler.Epsilon) - assert eps_processed.target == args_to_check[arg], arg + assert eps_processed == args_to_check[arg], arg else: assert getattr(geom, val) == args_to_check[arg], arg assert el == args_to_check[arg] diff --git a/tests/problems/space/test_alignment_problem.py b/tests/problems/space/test_alignment_problem.py index 0d1b21a16..166cafad3 100644 --- a/tests/problems/space/test_alignment_problem.py +++ b/tests/problems/space/test_alignment_problem.py @@ -1,12 +1,11 @@ from pathlib import Path -from typing import Any, Literal, Mapping, Optional +from typing import Any, Callable, Literal, Mapping, Optional import pytest import numpy as np import pandas as pd import scipy.sparse as sp -from ott.geometry import epsilon_scheduler import scanpy as sc from anndata import AnnData @@ -75,10 +74,16 @@ def test_prepare_star(self, adata_space_rotate: AnnData, reference: str): assert ref == reference assert isinstance(ap[prob_key], ap._base_problem_type) - @pytest.mark.skip(reason="See https://github.com/theislab/moscot/issues/678") @pytest.mark.parametrize( - ("epsilon", "alpha", "rank", "initializer"), - [(1, 0.9, -1, None), (1, 0.5, 10, "random"), (1, 0.5, 10, "rank2"), (0.1, 0.1, -1, None)], + ("epsilon", "alpha", "rank", "initializer", "should_raise"), + [ + (1, 0.9, -1, None, False), + (1, 0.5, 10, "random", False), + (1, 0.5, 10, "rank2", False), + (0.1, 0.1, -1, None, False), + (0.1, -0.1, -1, None, True), # Invalid alpha + (0.1, 1.1, -1, None, True), # Invalid alpha + ], ) def test_solve_balanced( self, @@ -87,6 +92,7 @@ def test_solve_balanced( alpha: float, rank: int, initializer: Optional[Literal["random", "rank2"]], + should_raise: bool, ): kwargs = {} if rank > -1: @@ -95,22 +101,23 @@ def test_solve_balanced( # kwargs["kwargs_init"] = {"key": 0} # kwargs["key"] = 0 return # TODO(@MUCDK) fix after refactoring - ap = ( - AlignmentProblem(adata=adata_space_rotate) - .prepare(batch_key="batch") - .solve(epsilon=epsilon, alpha=alpha, rank=rank, **kwargs) - ) - for prob_key in ap: - assert ap[prob_key].solution.rank == rank - if initializer != "random": # TODO: is this valid? - assert ap[prob_key].solution.converged - - # TODO(michalk8): use np.testing - assert np.allclose(*(sol.cost for sol in ap.solutions.values())) - assert np.all([sol.converged for sol in ap.solutions.values()]) - np.testing.assert_array_equal( - [np.all(np.isfinite(sol.transport_matrix)) for sol in ap.solutions.values()], True - ) + ap = AlignmentProblem(adata=adata_space_rotate).prepare(batch_key="batch") + if should_raise: + with pytest.raises(ValueError, match=r"Expected `alpha`"): + ap.solve(epsilon=epsilon, alpha=alpha, rank=rank, **kwargs) + else: + ap = ap.solve(epsilon=epsilon, alpha=alpha, rank=rank, **kwargs) + for prob_key in ap: + assert ap[prob_key].solution.rank == rank + if initializer != "random": # TODO: is this valid? + assert ap[prob_key].solution.converged + + # TODO(michalk8): use np.testing + assert np.allclose(*(sol.cost for sol in ap.solutions.values())) + assert np.all([sol.converged for sol in ap.solutions.values()]) + np.testing.assert_array_equal( + [np.all(np.isfinite(sol.transport_matrix)) for sol in ap.solutions.values()], True + ) def test_solve_unbalanced(self, adata_space_rotate: AnnData): tau_a, tau_b = [0.8, 1] @@ -162,7 +169,7 @@ def test_geodesic_cost_xy(self, adata_space_rotate: AnnData, key: str, dense_inp ap[("0", "1")].set_graph_xy(dfs[0], cost="geodesic") ap[("1", "2")].set_graph_xy(dfs[1], cost="geodesic") - ap = ap.solve(max_iterations=2, lse_mode=False) + ap = ap.solve(max_iterations=2) ta = ap[("0", "1")].xy assert isinstance(ta, TaggedArray) @@ -190,9 +197,12 @@ def test_pass_arguments(self, adata_space_rotate: AnnData, args_to_check: Mappin args = gw_solver_args if args_to_check["rank"] == -1 else gw_lr_solver_args for arg, val in args.items(): assert hasattr(solver, val) - assert getattr(solver, val) == args_to_check[arg] + if arg == "initializer": + assert isinstance(getattr(solver, val), Callable) + else: + assert getattr(solver, val) == args_to_check[arg] - sinkhorn_solver = solver.linear_ot_solver if args_to_check["rank"] == -1 else solver + sinkhorn_solver = solver.linear_solver if args_to_check["rank"] == -1 else solver lin_solver_args = gw_linear_solver_args if args_to_check["rank"] == -1 else gw_lr_linear_solver_args tmp_dict = args_to_check["linear_solver_kwargs"] if args_to_check["rank"] == -1 else args_to_check for arg, val in lin_solver_args.items(): @@ -216,8 +226,7 @@ def test_pass_arguments(self, adata_space_rotate: AnnData, args_to_check: Mappin el = getattr(geom, val)[0] if isinstance(getattr(geom, val), tuple) else getattr(geom, val) if arg == "epsilon": eps_processed = getattr(geom, val) - assert isinstance(eps_processed, epsilon_scheduler.Epsilon) - assert eps_processed.target == args_to_check[arg], arg + assert eps_processed == args_to_check[arg], arg else: assert getattr(geom, val) == args_to_check[arg], arg assert el == args_to_check[arg] diff --git a/tests/problems/space/test_mapping_problem.py b/tests/problems/space/test_mapping_problem.py index a51bf4ed9..b0645e28d 100644 --- a/tests/problems/space/test_mapping_problem.py +++ b/tests/problems/space/test_mapping_problem.py @@ -1,12 +1,12 @@ +import re from pathlib import Path -from typing import Any, List, Literal, Mapping, Optional, Union +from typing import Any, Callable, List, Literal, Mapping, Optional, Union import pytest import numpy as np import pandas as pd import scipy.sparse as sp -from ott.geometry import epsilon_scheduler from ott.solvers.linear.sinkhorn import SinkhornOutput from ott.solvers.quadratic.gromov_wasserstein import GWOutput @@ -96,7 +96,6 @@ def test_prepare_varnames(self, adata_mapping: AnnData, var_names: Optional[List assert prob.x.data_src.shape == (n_obs, x_n_var) assert prob.y.data_src.shape == (n_obs, y_n_var) - @pytest.mark.skip(reason="See https://github.com/theislab/moscot/issues/678") @pytest.mark.parametrize( ("epsilon", "alpha", "rank", "initializer"), [(1e-2, 0.9, -1, None), (2, 0.5, 10, "random"), (2, 0.5, 10, "rank2"), (2, 0.1, -1, None)], @@ -189,7 +188,7 @@ def test_geodesic_cost_xy(self, adata_mapping: AnnData, key: str, geodesic_y: bo if geodesic_y: mp[("1", "ref")].set_graph_y(df_y, cost="geodesic") mp[("2", "ref")].set_graph_y(df_y, cost="geodesic") - mp = mp.solve(max_iterations=2, lse_mode=False) + mp = mp.solve(max_iterations=2) ta = mp[("1", "ref")].xy assert isinstance(ta, TaggedArray) @@ -232,9 +231,10 @@ def test_pass_arguments(self, adata_mapping: AnnData, args_to_check: Mapping[str args = gw_solver_args if args_to_check["rank"] == -1 else gw_lr_solver_args for arg, val in args.items(): assert hasattr(solver, val) - assert getattr(solver, val) == args_to_check[arg] + if arg == "initializer": + assert isinstance(getattr(solver, val), Callable) - sinkhorn_solver = solver.linear_ot_solver if args_to_check["rank"] == -1 else solver + sinkhorn_solver = solver.linear_solver if args_to_check["rank"] == -1 else solver lin_solver_args = gw_linear_solver_args if args_to_check["rank"] == -1 else gw_lr_linear_solver_args tmp_dict = args_to_check["linear_solver_kwargs"] if args_to_check["rank"] == -1 else args_to_check for arg, val in lin_solver_args.items(): @@ -258,8 +258,7 @@ def test_pass_arguments(self, adata_mapping: AnnData, args_to_check: Mapping[str el = getattr(geom, val)[0] if isinstance(getattr(geom, val), tuple) else getattr(geom, val) if arg == "epsilon": eps_processed = getattr(geom, val) - assert isinstance(eps_processed, epsilon_scheduler.Epsilon) - assert eps_processed.target == args_to_check[arg], arg + assert eps_processed == args_to_check[arg], arg else: assert getattr(geom, val) == args_to_check[arg], arg assert el == args_to_check[arg] @@ -301,14 +300,14 @@ def test_problem_type( assert isinstance(sol._output, solution_kind) @pytest.mark.parametrize( - ("sc_attr", "alpha"), + ("sc_attr", "alpha", "raise_msg"), [ - (None, 0.5), - ({"attr": "X"}, 0), + (None, 0.5, re.escape("Expected `alpha` to be 0 for a `linear problem`.")), + ({"attr": "X"}, 0, re.escape("Expected `alpha` to be in interval `(0, 1]`, found `0`.")), ], ) def test_problem_type_corner_cases( - self, adata_mapping: AnnData, sc_attr: Optional[Mapping[str, str]], alpha: Optional[float] + self, adata_mapping: AnnData, sc_attr: Optional[Mapping[str, str]], alpha: Optional[float], raise_msg: str ): # initialize and prepare the MappingProblem adataref, adatasp = _adata_spatial_split(adata_mapping) @@ -316,5 +315,5 @@ def test_problem_type_corner_cases( mp = mp.prepare(batch_key="batch", sc_attr=sc_attr) # we test two incompatible combinations of `sc_attr` and `alpha` - with pytest.raises(ValueError, match=r"^Expected `alpha`"): + with pytest.raises(ValueError, match=raise_msg): mp.solve(alpha=alpha) diff --git a/tests/problems/spatio_temporal/test_spatio_temporal_problem.py b/tests/problems/spatio_temporal/test_spatio_temporal_problem.py index 3b68561ba..49773d3d9 100644 --- a/tests/problems/spatio_temporal/test_spatio_temporal_problem.py +++ b/tests/problems/spatio_temporal/test_spatio_temporal_problem.py @@ -4,7 +4,6 @@ import numpy as np import pandas as pd -from ott.geometry import epsilon_scheduler from anndata import AnnData @@ -60,7 +59,7 @@ def test_solve_balanced(self, adata_spatio_temporal: AnnData): assert isinstance(subsol, BaseDiscreteSolverOutput) assert key in expected_keys - @pytest.mark.skip(reason="unbalanced does not work yet") + @pytest.mark.skip(reason="unbalanced does not work yet: https://github.com/ott-jax/ott/issues/519") def test_solve_unbalanced(self, adata_spatio_temporal: AnnData): taus = [9e-1, 1e-2] problem1 = SpatioTemporalProblem(adata=adata_spatio_temporal) @@ -198,9 +197,10 @@ def test_pass_arguments(self, adata_spatio_temporal: AnnData, args_to_check: Map args = gw_solver_args if args_to_check["rank"] == -1 else gw_lr_solver_args for arg, val in args.items(): assert hasattr(solver, val) - assert getattr(solver, val) == args_to_check[arg], arg + if arg != "initializer": + assert getattr(solver, val) == args_to_check[arg], arg - sinkhorn_solver = solver.linear_ot_solver if args_to_check["rank"] == -1 else solver + sinkhorn_solver = solver.linear_solver if args_to_check["rank"] == -1 else solver lin_solver_args = gw_linear_solver_args if args_to_check["rank"] == -1 else gw_lr_linear_solver_args tmp_dict = args_to_check["linear_solver_kwargs"] if args_to_check["rank"] == -1 else args_to_check for arg, val in lin_solver_args.items(): @@ -224,8 +224,7 @@ def test_pass_arguments(self, adata_spatio_temporal: AnnData, args_to_check: Map el = getattr(geom, val)[0] if isinstance(getattr(geom, val), tuple) else getattr(geom, val) if arg == "epsilon": eps_processed = getattr(geom, val) - assert isinstance(eps_processed, epsilon_scheduler.Epsilon) - assert eps_processed.target == args_to_check[arg], arg + assert eps_processed == args_to_check[arg], arg else: assert getattr(geom, val) == args_to_check[arg], arg assert el == args_to_check[arg] diff --git a/tests/problems/time/test_lineage_problem.py b/tests/problems/time/test_lineage_problem.py index 5375c621e..21b0a7eca 100644 --- a/tests/problems/time/test_lineage_problem.py +++ b/tests/problems/time/test_lineage_problem.py @@ -1,9 +1,8 @@ -from typing import Any, List, Mapping +from typing import Any, Callable, List, Mapping import pytest import numpy as np -from ott.geometry import epsilon_scheduler from anndata import AnnData @@ -233,9 +232,10 @@ def test_pass_arguments(self, adata_time_barcodes: AnnData, args_to_check: Mappi args = gw_solver_args if args_to_check["rank"] == -1 else gw_lr_solver_args for arg, val in args.items(): assert hasattr(solver, val) - assert getattr(solver, val) == args_to_check[arg] + if arg == "initializer": + assert isinstance(getattr(solver, val), Callable) - sinkhorn_solver = solver.linear_ot_solver if args_to_check["rank"] == -1 else solver + sinkhorn_solver = solver.linear_solver if args_to_check["rank"] == -1 else solver lin_solver_args = gw_linear_solver_args if args_to_check["rank"] == -1 else gw_lr_linear_solver_args tmp_dict = args_to_check["linear_solver_kwargs"] if args_to_check["rank"] == -1 else args_to_check for arg, val in lin_solver_args.items(): @@ -259,8 +259,7 @@ def test_pass_arguments(self, adata_time_barcodes: AnnData, args_to_check: Mappi el = getattr(geom, val)[0] if isinstance(getattr(geom, val), tuple) else getattr(geom, val) if arg == "epsilon": eps_processed = getattr(geom, val) - assert isinstance(eps_processed, epsilon_scheduler.Epsilon) - assert eps_processed.target == args_to_check[arg], arg + assert eps_processed == args_to_check[arg], arg else: assert getattr(geom, val) == args_to_check[arg], arg assert el == args_to_check[arg] diff --git a/tests/problems/time/test_mixins.py b/tests/problems/time/test_mixins.py index bc116a704..f062adefa 100644 --- a/tests/problems/time/test_mixins.py +++ b/tests/problems/time/test_mixins.py @@ -1,4 +1,3 @@ -import sys from typing import Tuple import pytest @@ -236,7 +235,6 @@ def test_compute_interpolated_distance_pipeline(self, gt_temporal_adata: AnnData assert isinstance(interpolation_result, float) assert interpolation_result > 0 - @pytest.mark.skipif(sys.version_info < (3, 9), reason="requires python3.9 or higher") def test_compute_interpolated_distance_regression(self, gt_temporal_adata: AnnData): config = gt_temporal_adata.uns key = config["key"] @@ -264,7 +262,6 @@ def test_compute_interpolated_distance_regression(self, gt_temporal_adata: AnnDa interpolation_result, gt_temporal_adata.uns["interpolated_distance_10_105_11"], rtol=1e-6, atol=1e-6 ) - @pytest.mark.skipif(sys.version_info < (3, 9), reason="requires python3.9 or higher") def test_compute_time_point_distances_regression(self, gt_temporal_adata: AnnData): config = gt_temporal_adata.uns key = config["key"] @@ -316,7 +313,6 @@ def test_compute_batch_distances_regression(self, gt_temporal_adata: AnnData): assert isinstance(result, float) np.testing.assert_allclose(result, gt_temporal_adata.uns["batch_distances_10"], rtol=1e-5) - @pytest.mark.skipif(sys.version_info < (3, 9), reason="requires python3.9 or higher") def test_compute_random_distance_regression(self, gt_temporal_adata: AnnData): config = gt_temporal_adata.uns key = config["key"] diff --git a/tests/problems/time/test_temporal_base_problem.py b/tests/problems/time/test_temporal_base_problem.py index abdfe0f90..a52669c54 100644 --- a/tests/problems/time/test_temporal_base_problem.py +++ b/tests/problems/time/test_temporal_base_problem.py @@ -111,7 +111,7 @@ def test_posterior_growth_rates(self, adata_time_marginal_estimations: AnnData): b=True, marginal_kwargs={"proliferation_key": "proliferation"}, ) - prob = prob.solve(max_iterations=10) + prob = prob.solve(max_iterations=10, alpha=1.0) assert prob.delta == (t2 - t1) gr = prob.posterior_growth_rates diff --git a/tests/problems/time/test_temporal_problem.py b/tests/problems/time/test_temporal_problem.py index f53a745c9..6acd35e4f 100644 --- a/tests/problems/time/test_temporal_problem.py +++ b/tests/problems/time/test_temporal_problem.py @@ -1,4 +1,4 @@ -from typing import Any, List, Mapping, Optional +from typing import Any, Callable, List, Mapping, Optional import pytest @@ -6,7 +6,7 @@ import numpy as np import pandas as pd import scipy.sparse as sp -from ott.geometry import costs, epsilon_scheduler +from ott.geometry import costs from scipy.sparse import csr_matrix import scanpy as sc @@ -440,9 +440,13 @@ def test_pass_arguments(self, adata_time: AnnData, args_to_check: Mapping[str, A solver = problem[key].solver.solver args = sinkhorn_solver_args if args_to_check["rank"] == -1 else lr_sinkhorn_solver_args for arg, val in args.items(): - assert hasattr(solver, val) - el = getattr(solver, val)[0] if isinstance(getattr(solver, val), tuple) else getattr(solver, val) - assert el == args_to_check[arg] + if arg != "initializer_kwargs": + assert hasattr(solver, val) + el = getattr(solver, val)[0] if isinstance(getattr(solver, val), tuple) else getattr(solver, val) + if arg == "initializer": + assert isinstance(el, Callable) + else: + assert el == args_to_check[arg] lin_prob = problem[key]._solver._problem for arg, val in lin_prob_args.items(): @@ -456,8 +460,7 @@ def test_pass_arguments(self, adata_time: AnnData, args_to_check: Mapping[str, A el = getattr(geom, val)[0] if isinstance(getattr(geom, val), tuple) else getattr(geom, val) if arg == "epsilon": eps_processed = getattr(geom, val) - assert isinstance(eps_processed, epsilon_scheduler.Epsilon) - assert eps_processed.target == args_to_check[arg], arg + assert eps_processed == args_to_check[arg], arg else: assert getattr(geom, val) == args_to_check[arg], arg assert el == args_to_check[arg]