diff --git a/docs/core.rst b/docs/core.rst index 066c4b901..2acf61183 100644 --- a/docs/core.rst +++ b/docs/core.rst @@ -34,6 +34,15 @@ Sinkhorn sinkhorn.Sinkhorn sinkhorn.SinkhornOutput +Sinkhorn Dual Initializers +-------------------------- +.. autosummary:: + :toctree: _autosummary + + initializers.DefaultInitializer + initializers.GaussianInitializer + initializers.SortingInitializer + Low-Rank Sinkhorn ----------------- .. autosummary:: diff --git a/docs/references.bib b/docs/references.bib index f5a81da4c..8ee9bbc05 100644 --- a/docs/references.bib +++ b/docs/references.bib @@ -504,3 +504,10 @@ @inproceedings{arthur:07 location = {New Orleans, Louisiana}, series = {SODA '07} } + +@article{thornton2022rethinking:22, + title={Rethinking Initialization of the Sinkhorn Algorithm}, + author={Thornton, James and Cuturi, Marco}, + journal={arXiv preprint arXiv:2206.07630}, + year={2022} +} diff --git a/ott/core/__init__.py b/ott/core/__init__.py index cce7e8ba0..11a9cfe08 100644 --- a/ott/core/__init__.py +++ b/ott/core/__init__.py @@ -23,6 +23,7 @@ gromov_wasserstein, gw_barycenter, implicit_differentiation, + initializers, linear_problems, momentum, quad_problems, diff --git a/ott/core/initializers.py b/ott/core/initializers.py new file mode 100644 index 000000000..05e700a3e --- /dev/null +++ b/ott/core/initializers.py @@ -0,0 +1,276 @@ +# Copyright 2022 The OTT Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Sinkhorn initializers.""" +from abc import ABC, abstractmethod +from typing import Any, Dict, Optional, Sequence, Tuple + +import jax +import jax.numpy as jnp + +from ott.core import linear_problems +from ott.geometry import pointcloud + + +@jax.tree_util.register_pytree_node_class +class SinkhornInitializer(ABC): + """Base class for Sinkhorn initializers.""" + + def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]: + return [], {} + + @classmethod + def tree_unflatten( + cls, aux_data: Dict[str, Any], children: Sequence[Any] + ) -> "SinkhornInitializer": + return cls(*children, **aux_data) + + @abstractmethod + def init_dual_a( + self, ot_problem: linear_problems.LinearProblem, lse_mode: bool + ) -> jnp.ndarray: + """Initialization for Sinkhorn potential/scaling f_u.""" + + @abstractmethod + def init_dual_b( + self, ot_problem: linear_problems.LinearProblem, lse_mode: bool + ) -> jnp.ndarray: + """Initialization for Sinkhorn potential/scaling g_v.""" + + +@jax.tree_util.register_pytree_node_class +class DefaultInitializer(SinkhornInitializer): + """Default initialization of Sinkhorn dual potentials/primal scalings.""" + + def init_dual_a( + self, ot_problem: linear_problems.LinearProblem, lse_mode: bool + ) -> jnp.ndarray: + """Initialization for Sinkhorn potential/scaling f_u. + + Args: + ot_problem: OT problem between discrete distributions of size n and m. + lse_mode: Return potential if true, scaling if false. + + Returns: + potential/scaling, array of size n + """ + a = ot_problem.a + init_dual_a = jnp.zeros_like(a) if lse_mode else jnp.ones_like(a) + return init_dual_a + + def init_dual_b( + self, ot_problem: linear_problems.LinearProblem, lse_mode: bool + ) -> jnp.ndarray: + """Initialization for Sinkhorn potential/scaling g_v. + + Args: + ot_problem: OT problem between discrete distributions of size n and m. + lse_mode: Return potential if true, scaling if false. + + Returns: + potential/scaling, array of size m + """ + b = ot_problem.b + init_dual_b = jnp.zeros_like(b) if lse_mode else jnp.ones_like(b) + return init_dual_b + + +@jax.tree_util.register_pytree_node_class +class GaussianInitializer(DefaultInitializer): + """Gaussian initializer. + + From :cite:`thornton2022rethinking:22`. + Compute Gaussian approximations of each point cloud, then compute closed from + Kantorovich potential between Gaussian approximations using Brenier's theorem + (adapt convex/Brenier potential to Kantorovich). Use this Gaussian potential + to initialize Sinkhorn potentials/scalings. + """ + + def init_dual_a( + self, + ot_problem: linear_problems.LinearProblem, + lse_mode: bool, + ) -> jnp.ndarray: + """Gaussian init function. + + Args: + ot_problem: OT problem description with geometry and weights. + lse_mode: Return potential if true, scaling if false. + + Returns: + potential/scaling f_u, array of size n. + """ + # import Gaussian here due to circular imports + from ott.tools.gaussian_mixture import gaussian + + assert isinstance( + ot_problem.geom, pointcloud.PointCloud + ), "Gaussian initializer valid only for point clouds." + + x, y = ot_problem.geom.x, ot_problem.geom.y + a, b = ot_problem.a, ot_problem.b + + gaussian_a = gaussian.Gaussian.from_samples(x, weights=a) + gaussian_b = gaussian.Gaussian.from_samples(y, weights=b) + # Brenier potential for cost ||x-y||^2/2, multiply by two for ||x-y||^2 + f_potential = 2 * gaussian_a.f_potential(dest=gaussian_b, points=x) + f_potential = f_potential - jnp.mean(f_potential) + f_u = f_potential if lse_mode else ot_problem.geom.scaling_from_potential( + f_potential + ) + return f_u + + +@jax.tree_util.register_pytree_node_class +class SortingInitializer(DefaultInitializer): + """Sorting initializer. + + DualSort algorithm from :cite:`thornton2022rethinking:22`, solve + non-regularized OT problem via sorting, then compute potential through + iterated minimum on C-transform and use this potential to initialize + regularized potential. + + Args: + vectorized_update: Use vectorized inner loop if true. + tolerance: DualSort convergence threshold. + max_iter: Max DualSort steps. + """ + + def __init__( + self, + vectorized_update: bool = True, + tolerance: float = 1e-2, + max_iter: int = 100 + ): + super().__init__() + self.tolerance = tolerance + self.max_iter = max_iter + self.vectorized_update = vectorized_update + self.update_fn = _vectorized_update if self.vectorized_update else _coordinate_update + + def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]: + return ([], { + 'tolerance': self.tolerance, + 'max_iter': self.max_iter, + 'vectorized_update': self.vectorized_update + }) + + def _init_sorting_dual( + self, modified_cost: jnp.ndarray, init_f: jnp.ndarray + ) -> jnp.ndarray: + """Run DualSort algorithm. + + Args: + modified_cost: cost matrix minus diagonal column-wise. + init_f: potential f, array of size n. This is the starting potential, + which is then updated to make the init potential, so an init of an init. + + Returns: + potential f, array of size n. + """ + + def body_fn(state): + prev_f, _, it = state + new_f = self.update_fn(prev_f, modified_cost) + diff = jnp.sum((new_f - prev_f) ** 2) + it += 1 + return new_f, diff, it + + def cond_fn(state): + _, diff, it = state + return jnp.logical_and(diff > self.tolerance, it < self.max_iter) + + it = 0 + diff = self.tolerance + 1.0 + state = (init_f, diff, it) + + f_potential, _, it = jax.lax.while_loop( + cond_fun=cond_fn, body_fun=body_fn, init_val=state + ) + + return f_potential + + def init_dual_a( + self, + ot_problem: linear_problems.LinearProblem, + lse_mode: bool, + init_f: Optional[jnp.ndarray] = None, + ) -> jnp.ndarray: + """Apply DualSort algorithm. + + Args: + ot_problem: OT problem. + lse_mode: Return potential if true, scaling if false. + init_f: potential f, array of size n. This is the starting potential, + which is then updated to make the init potential, so an init of an init. + + Returns: + potential/scaling f_u, array of size n. + """ + assert not ot_problem.geom.is_online, "Sorting initializer does not work for online geometry." + # check for sorted x, y requires point cloud and could slow initializer + cost_matrix = ot_problem.geom.cost_matrix + + assert cost_matrix.shape[0] == cost_matrix.shape[ + 1], "Requires square cost matrix." + + modified_cost = cost_matrix - jnp.diag(cost_matrix)[None, :] + + n = cost_matrix.shape[0] + init_f = jnp.zeros(n) if init_f is None else init_f + + f_potential = self._init_sorting_dual(modified_cost, init_f) + f_potential = f_potential - jnp.mean(f_potential) + + f_u = f_potential if lse_mode else ot_problem.geom.scaling_from_potential( + f_potential + ) + + return f_u + + +def _vectorized_update( + f: jnp.ndarray, modified_cost: jnp.ndarray +) -> jnp.ndarray: + """Inner loop DualSort Update. + + Args: + f : potential f, array of size n. + modified_cost: cost matrix minus diagonal column-wise. + + Returns: + updated potential vector, f. + """ + f = jnp.min(modified_cost + f[None, :], axis=1) + return f + + +def _coordinate_update( + f: jnp.ndarray, modified_cost: jnp.ndarray +) -> jnp.ndarray: + """Coordinate-wise updates within inner loop. + + Args: + f: potential f, array of size n. + modified_cost: cost matrix minus diagonal column-wise. + + Returns: + updated potential vector, f. + """ + + def body_fn(i, f): + new_f = jnp.min(modified_cost[i, :] + f) + f = f.at[i].set(new_f) + return f + + return jax.lax.fori_loop(0, len(f), body_fn, f) diff --git a/ott/core/sinkhorn.py b/ott/core/sinkhorn.py index 18a6395dd..944231560 100644 --- a/ott/core/sinkhorn.py +++ b/ott/core/sinkhorn.py @@ -23,6 +23,7 @@ from ott.core import anderson as anderson_lib from ott.core import fixed_point_loop from ott.core import implicit_differentiation as implicit_lib +from ott.core import initializers as init_lib from ott.core import linear_problems from ott.core import momentum as momentum_lib from ott.core import unbalanced_functions @@ -349,6 +350,8 @@ def __init__( use_danskin: Optional[bool] = None, implicit_diff: Optional[implicit_lib.ImplicitDiff ] = implicit_lib.ImplicitDiff(), # noqa: E124 + potential_initializer: init_lib.SinkhornInitializer = init_lib + .DefaultInitializer(), jit: bool = True ): self.lse_mode = lse_mode @@ -368,6 +371,7 @@ def __init__( self.anderson = anderson self.implicit_diff = implicit_diff self.parallel_dual_updates = parallel_dual_updates + self.potential_initializer = potential_initializer self.jit = jit # Force implicit_differentiation to True when using Anderson acceleration, @@ -400,18 +404,25 @@ def __call__( init: Optional[Tuple[Optional[jnp.ndarray], Optional[jnp.ndarray]]] = None ) -> SinkhornOutput: """Main interface to run sinkhorn.""" # noqa: D401 + # initialization init_dual_a, init_dual_b = (init if init is not None else (None, None)) - a, b = ot_prob.a, ot_prob.b + if init_dual_a is None: - init_dual_a = jnp.zeros_like(a) if self.lse_mode else jnp.ones_like(a) + init_dual_a = self.potential_initializer.init_dual_a( + ot_problem=ot_prob, lse_mode=self.lse_mode + ) + if init_dual_b is None: - init_dual_b = jnp.zeros_like(b) if self.lse_mode else jnp.ones_like(b) + init_dual_b = self.potential_initializer.init_dual_b( + ot_problem=ot_prob, lse_mode=self.lse_mode + ) + # Cancel dual variables for zero weights. init_dual_a = jnp.where( - a > 0, init_dual_a, -jnp.inf if self.lse_mode else 0.0 + ot_prob.a > 0, init_dual_a, -jnp.inf if self.lse_mode else 0.0 ) init_dual_b = jnp.where( - b > 0, init_dual_b, -jnp.inf if self.lse_mode else 0.0 + ot_prob.b > 0, init_dual_b, -jnp.inf if self.lse_mode else 0.0 ) run_fn = jax.jit(run) if self.jit else run @@ -691,6 +702,8 @@ def make( precondition_fun: Optional[Callable[[float], float]] = None, parallel_dual_updates: bool = False, use_danskin: bool = None, + potential_initializer: init_lib.SinkhornInitializer = init_lib + .DefaultInitializer(), jit: bool = False ) -> Sinkhorn: """For backward compatibility.""" @@ -725,6 +738,7 @@ def make( implicit_diff=implicit_diff, parallel_dual_updates=parallel_dual_updates, use_danskin=use_danskin, + potential_initializer=potential_initializer, jit=jit ) diff --git a/ott/tools/gaussian_mixture/gaussian.py b/ott/tools/gaussian_mixture/gaussian.py index adeb8d9bc..00ea3e3a5 100644 --- a/ott/tools/gaussian_mixture/gaussian.py +++ b/ott/tools/gaussian_mixture/gaussian.py @@ -32,6 +32,32 @@ def __init__(self, loc: jnp.ndarray, scale: scale_tril.ScaleTriL): self._loc = loc self._scale = scale + @classmethod + def from_samples( + cls, points: jnp.ndarray, weights: jnp.ndarray = None + ) -> 'Gaussian': + """Construct a Gaussian from weighted samples. + + Unbiased, weighted covariance formula from https://en.wikipedia.org/wiki/Sample_mean_and_covariance#Weighted_samples + and https://www.gnu.org/software/gsl/doc/html/statistics.html?highlight=weighted#weighted-samples + + Args: + points: [n x d] array of samples + weights: [n] array of weights + + Returns: + Gaussian. + """ + n = points.shape[0] + if weights is None: + weights = jnp.ones(n) / n + + mean = weights.dot(points) + centered_x = (points - mean) + scaled_centered_x = centered_x * weights.reshape(-1, 1) + cov = scaled_centered_x.T.dot(centered_x) / (1 - weights.dot(weights)) + return cls.from_mean_and_cov(mean=mean, cov=cov) + @classmethod def from_random( cls, @@ -129,7 +155,40 @@ def w2_dist(self, other: 'Gaussian') -> jnp.ndarray: delta_sigma = self.scale.w2_dist(other.scale) return delta_mean + delta_sigma + def f_potential(self, dest: 'Gaussian', points: jnp.ndarray) -> jnp.ndarray: + """Optimal potential for W2 distance between Gaussians. Evaluated on points. + + Args: + dest: Gaussian object + points: samples + + Returns: + Dual potential, f + """ + scale_matrix = self.scale.gaussian_map(dest_scale=dest.scale) + centered_x = points - self.loc + scaled_x = (scale_matrix @ centered_x.T) + + @jax.vmap + def batch_inner_product(x, y): + return x.dot(y) + + return ( + 0.5 * batch_inner_product(points, points) - + 0.5 * batch_inner_product(centered_x, scaled_x.T) - + points.dot(dest.loc) + ) + def transport(self, dest: 'Gaussian', points: jnp.ndarray) -> jnp.ndarray: + """Transport points according to map between two Gaussian measures. + + Args: + dest: Gaussian object + points: samples + + Returns: + Transported samples + """ return self.scale.transport( dest_scale=dest.scale, points=points - self.loc[None] ) + dest.loc[None] diff --git a/ott/tools/gaussian_mixture/scale_tril.py b/ott/tools/gaussian_mixture/scale_tril.py index 3fd5ab632..c4f7ea077 100644 --- a/ott/tools/gaussian_mixture/scale_tril.py +++ b/ott/tools/gaussian_mixture/scale_tril.py @@ -69,7 +69,7 @@ def from_random( ) # random positive definite matrix - sigma = jnp.matmul(jnp.expand_dims(eigs, -2) * q, jnp.transpose(q)) + sigma = q * jnp.expand_dims(eigs, -2) @ q.T # cholesky factorization chol = jnp.linalg.cholesky(sigma) @@ -117,7 +117,7 @@ def cholesky(self) -> jnp.ndarray: def covariance(self) -> jnp.ndarray: """Get the covariance matrix.""" cholesky = self.cholesky() - return jnp.matmul(cholesky, jnp.transpose(cholesky)) + return cholesky @ cholesky.T def covariance_sqrt(self) -> jnp.ndarray: """Get the square root of the covariance matrix.""" @@ -134,7 +134,7 @@ def centered_to_z(self, x_centered: jnp.ndarray) -> jnp.ndarray: def z_to_centered(self, z: jnp.ndarray) -> jnp.ndarray: """Scale standardized points to points with the specified covariance.""" - return jnp.transpose(jnp.matmul(self.cholesky(), jnp.transpose(z))) + return (self.cholesky() @ z.T).T def w2_dist(self, other: 'ScaleTriL') -> jnp.ndarray: r"""Wasserstein distance W_2^2 to another Gaussian with same mean. @@ -157,17 +157,18 @@ def _flatten_cov(cov: jnp.ndarray) -> jnp.ndarray: return (cost_fn.norm(x0) + cost_fn.norm(x1) + cost_fn.pairwise(x0, x1))[...,] - def transport( - self, dest_scale: 'ScaleTriL', points: jnp.ndarray - ) -> jnp.ndarray: - """Transport between 0-mean normal w/ current scale to one w/ dest_scale. + def gaussian_map(self, dest_scale: 'ScaleTriL') -> jnp.ndarray: + """Scaling matrix used in transport between 0-mean Gaussians. + + Sigma_mu^{-1/2} @ + [Sigma_mu ^{1/2} Sigma_nu Sigma_mu ^{1/2}]^{1/2} + @ Sigma_mu ^{-1/2} Args: dest_scale: destination Scale - points: points to transport Returns: - Points transported to a Gaussian with the new scale. + Gaussian scaling matrix, same dimension as self.covaraince """ sqrt0, sqrt0_inv = linalg.matrix_powers(self.covariance(), (0.5, -0.5)) sigma1 = dest_scale.covariance() @@ -175,7 +176,22 @@ def transport( jnp.matmul(sqrt0, jnp.matmul(sigma1, sqrt0)) ) m = jnp.matmul(sqrt0_inv, jnp.matmul(m, sqrt0_inv)) - return jnp.transpose(jnp.matmul(m, jnp.transpose(points))) + return m + + def transport( + self, dest_scale: 'ScaleTriL', points: jnp.ndarray + ) -> jnp.ndarray: + """Apply Monge map, computed between two 0-mean Gaussians, to points. + + Args: + dest_scale: destination Scale + points: points to transport + + Returns: + Points transported to a Gaussian with the new scale. + """ + m = self.gaussian_map(dest_scale) + return (m @ points.T).T def tree_flatten(self): children = (self.params,) diff --git a/ott/tools/transport.py b/ott/tools/transport.py index da66df895..73674ed78 100644 --- a/ott/tools/transport.py +++ b/ott/tools/transport.py @@ -78,6 +78,8 @@ def solve( *args: Any, a: Optional[jnp.ndarray] = None, b: Optional[jnp.ndarray] = None, + init_dual_a: Optional[jnp.ndarray] = None, + init_dual_b: Optional[jnp.ndarray] = None, objective: Optional[Literal['linear', 'quadratic', 'fused']] = None, **kwargs: Any ) -> Transport: @@ -121,9 +123,10 @@ def solve( linear = isinstance(pb, linear_problems.LinearProblem) solver_fn = sinkhorn.make if linear else gromov_wasserstein.make geom_keys = ['cost_fn', 'power', 'online'] + remove_keys = geom_keys + eps_keys if linear else geom_keys for key in remove_keys: kwargs.pop(key, None) solver = solver_fn(**kwargs) - output = solver(pb) + output = solver(pb, (init_dual_a, init_dual_b)) return Transport(pb, output) diff --git a/tests/core/initializers_test.py b/tests/core/initializers_test.py new file mode 100644 index 000000000..340edc7b8 --- /dev/null +++ b/tests/core/initializers_test.py @@ -0,0 +1,266 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Lint as: python3 +"""Tests for Sinkhorn initializers.""" + +from functools import partial + +import jax +import jax.numpy as jnp +import numpy as np +import pytest + +from ott.core import initializers as init_lib +from ott.core import linear_problems +from ott.core.sinkhorn import sinkhorn +from ott.geometry import geometry, pointcloud + + +def create_sorting_problem(rng, n, epsilon=0.01, online=False): + # define ot problem + x_init = jnp.array([-1., 0, .22]) + y_init = jnp.array([0., 0, 1.1]) + x_rng, y_rng = jax.random.split(rng) + + x = jnp.concatenate([x_init, 10 + jnp.abs(jax.random.normal(x_rng, (n,)))]) + y = jnp.concatenate([y_init, 10 + jnp.abs(jax.random.normal(y_rng, (n,)))]) + + x = jnp.sort(x) + y = jnp.sort(y) + + n = len(x) + m = len(y) + a = jnp.ones(n) / n + b = jnp.ones(m) / m + + batch_size = 3 if online else None + geom = pointcloud.PointCloud( + x.reshape(-1, 1), + y.reshape(-1, 1), + epsilon=epsilon, + batch_size=batch_size + ) + ot_problem = linear_problems.LinearProblem(geom=geom, a=a, b=b) + + return ot_problem + + +def create_ot_problem(rng, n, m, d, epsilon=0.01, online=False): + # define ot problem + x_rng, y_rng = jax.random.split(rng) + + mu_a = jnp.array([-1, 1]) * 5 + mu_b = jnp.array([0, 0]) + + x = jax.random.normal(x_rng, (n, d)) + mu_a + y = jax.random.normal(y_rng, (m, d)) + mu_b + + a = jnp.ones(n) / n + b = jnp.ones(m) / m + + batch_size = 3 if online else None + geom = pointcloud.PointCloud(x, y, epsilon=epsilon, batch_size=batch_size) + + ot_problem = linear_problems.LinearProblem(geom=geom, a=a, b=b) + return ot_problem + + +# define sinkhorn functions +@partial(jax.jit, static_argnames=['lse_mode', 'vector_min']) +def run_sinkhorn_sort_init( + x, y, a=None, b=None, epsilon=0.01, vector_min=True, lse_mode=True +): + geom = pointcloud.PointCloud(x, y, epsilon=epsilon) + sort_init = init_lib.SortingInitializer(vectorized_update=vector_min) + out = sinkhorn( + geom, + a=a, + b=b, + jit=True, + potential_initializer=sort_init, + lse_mode=lse_mode + ) + return out + + +@partial(jax.jit, static_argnames=['lse_mode']) +def run_sinkhorn(x, y, a=None, b=None, epsilon=0.01, lse_mode=True): + geom = pointcloud.PointCloud(x, y, epsilon=epsilon) + out = sinkhorn(geom, a=a, b=b, jit=True, lse_mode=lse_mode) + return out + + +@partial(jax.jit, static_argnames=['lse_mode']) +def run_sinkhorn_gaus_init(x, y, a=None, b=None, epsilon=0.01, lse_mode=True): + geom = pointcloud.PointCloud(x, y, epsilon=epsilon) + out = sinkhorn( + geom, + a=a, + b=b, + jit=True, + potential_initializer=init_lib.GaussianInitializer(), + lse_mode=lse_mode + ) + return out + + +@pytest.mark.fast +class TestInitializers: + + def test_init_pytree(self): + + @jax.jit + def init_sort(): + init = init_lib.SortingInitializer() + return init + + @jax.jit + def init_gaus(): + init = init_lib.GaussianInitializer() + return init + + init_gaus() + init_sort() + + @pytest.mark.parametrize( + "vector_min, lse_mode", [(True, True), (True, False), (False, True)] + ) + def test_sorting_init(self, vector_min: bool, lse_mode: bool): + """Tests sorting dual initializer.""" + rng = jax.random.PRNGKey(42) + n = 500 + epsilon = 0.01 + + ot_problem = create_sorting_problem( + rng=rng, n=n, epsilon=epsilon, online=False + ) + # run sinkhorn + sink_out_base = run_sinkhorn( + x=ot_problem.geom.x, + y=ot_problem.geom.y, + a=ot_problem.a, + b=ot_problem.b, + epsilon=epsilon + ) + base_num_iter = jnp.sum(sink_out_base.errors > -1) + + sink_out_init = run_sinkhorn_sort_init( + x=ot_problem.geom.x, + y=ot_problem.geom.y, + a=ot_problem.a, + b=ot_problem.b, + epsilon=epsilon, + vector_min=vector_min, + lse_mode=lse_mode + ) + sort_num_iter = jnp.sum(sink_out_init.errors > -1) + + # check initializer is better or equal + if lse_mode: + assert base_num_iter >= sort_num_iter + + def test_sorting_init_online(self, rng: jnp.ndarray): + n = 100 + epsilon = 0.01 + + ot_problem = create_sorting_problem( + rng=rng, n=n, epsilon=epsilon, online=True + ) + sort_init = init_lib.SortingInitializer(vectorized_update=True) + with pytest.raises(AssertionError, match=r"online"): + sort_init.init_dual_a(ot_problem=ot_problem, lse_mode=True) + + def test_sorting_init_square_cost(self, rng: jnp.ndarray): + n = 100 + m = 150 + d = 1 + epsilon = 0.01 + + ot_problem = create_ot_problem(rng, n, m, d, epsilon=epsilon, online=False) + sort_init = init_lib.SortingInitializer(vectorized_update=True) + with pytest.raises(AssertionError, match=r"square"): + sort_init.init_dual_a(ot_problem=ot_problem, lse_mode=True) + + def test_default_initializer(self, rng: jnp.ndarray): + """Tests default initializer""" + n = 200 + m = 200 + d = 2 + epsilon = 0.01 + + ot_problem = create_ot_problem(rng, n, m, d, epsilon=epsilon, online=False) + + default_potential_a = init_lib.DefaultInitializer().init_dual_a( + ot_problem=ot_problem, lse_mode=True + ) + default_potential_b = init_lib.DefaultInitializer().init_dual_b( + ot_problem=ot_problem, lse_mode=True + ) + + # check default is 0 + np.testing.assert_array_equal(0., default_potential_a) + np.testing.assert_array_equal(0., default_potential_b) + + def test_gauss_pointcloud_geom(self, rng: jnp.ndarray): + n = 200 + m = 200 + d = 2 + epsilon = 0.01 + + ot_problem = create_ot_problem(rng, n, m, d, epsilon=epsilon, online=False) + + gaus_init = init_lib.GaussianInitializer() + new_geom = geometry.Geometry( + cost_matrix=ot_problem.geom.cost_matrix, epsilon=epsilon + ) + ot_problem = linear_problems.LinearProblem( + geom=new_geom, a=ot_problem.a, b=ot_problem.b + ) + + with pytest.raises(AssertionError, match=r"point cloud"): + gaus_init.init_dual_a(ot_problem=ot_problem, lse_mode=True) + + @pytest.mark.parametrize('lse_mode', [True, False]) + def test_gauss_initializer(self, lse_mode, rng: jnp.ndarray): + """Tests Gaussian initializer""" + # definte ot problem + n = 200 + m = 200 + d = 2 + epsilon = 0.01 + + ot_problem = create_ot_problem(rng, n, m, d, epsilon=epsilon, online=False) + + # run sinkhorn + sink_out = run_sinkhorn( + x=ot_problem.geom.x, + y=ot_problem.geom.y, + a=ot_problem.a, + b=ot_problem.b, + epsilon=epsilon, + lse_mode=lse_mode + ) + base_num_iter = jnp.sum(sink_out.errors > -1) + sink_out = run_sinkhorn_gaus_init( + x=ot_problem.geom.x, + y=ot_problem.geom.y, + a=ot_problem.a, + b=ot_problem.b, + epsilon=epsilon, + lse_mode=lse_mode + ) + gaus_num_iter = jnp.sum(sink_out.errors > -1) + + # check initializer is better + if lse_mode: + assert base_num_iter >= gaus_num_iter diff --git a/tests/core/sinkhorn_test.py b/tests/core/sinkhorn_test.py index fd80babd1..8f3805e5d 100644 --- a/tests/core/sinkhorn_test.py +++ b/tests/core/sinkhorn_test.py @@ -406,6 +406,20 @@ def test_restart(self, lse_mode: bool): geom.scaling_from_potential(out.f), geom.scaling_from_potential(out.g) ) + + if lse_mode: + default_a = jnp.zeros_like(init_dual_a) + default_b = jnp.zeros_like(init_dual_b) + else: + default_a = jnp.ones_like(init_dual_a) + default_b = jnp.ones_like(init_dual_b) + + with pytest.raises(AssertionError): + np.testing.assert_allclose(default_a, init_dual_a) + + with pytest.raises(AssertionError): + np.testing.assert_allclose(default_b, init_dual_b) + out_restarted = sinkhorn.sinkhorn( geom, a=self.a, @@ -416,6 +430,7 @@ def test_restart(self, lse_mode: bool): init_dual_b=init_dual_b, inner_iterations=1 ) + errors_restarted = out_restarted.errors err_restarted = errors_restarted[errors_restarted > -1][-1] assert threshold > err_restarted